aboutsummaryrefslogtreecommitdiff
path: root/prototypes/tls-autodetect.go
diff options
context:
space:
mode:
authorPřemysl Janouch <p@janouch.name>2018-07-15 10:45:12 +0200
committerPřemysl Janouch <p@janouch.name>2018-07-15 12:51:06 +0200
commit728fa4e54800a505d7fac1b4a4b83862a30f94b2 (patch)
treea742357aa5c732bd3f2623c9c6f1345e3bdab664 /prototypes/tls-autodetect.go
parentb5b64db075b062c6b722d8ab1f8adf0a9dc63a41 (diff)
downloadhaven-728fa4e54800a505d7fac1b4a4b83862a30f94b2.tar.gz
haven-728fa4e54800a505d7fac1b4a4b83862a30f94b2.tar.xz
haven-728fa4e54800a505d7fac1b4a4b83862a30f94b2.zip
tls-autodetect: put most of the server code in place
So far we act up when it is the client who initializes the shutdown.
Diffstat (limited to 'prototypes/tls-autodetect.go')
-rw-r--r--prototypes/tls-autodetect.go188
1 files changed, 136 insertions, 52 deletions
diff --git a/prototypes/tls-autodetect.go b/prototypes/tls-autodetect.go
index c58ed28..ce88379 100644
--- a/prototypes/tls-autodetect.go
+++ b/prototypes/tls-autodetect.go
@@ -13,12 +13,17 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
//
+//
// This is an example TLS-autodetecting chat server.
//
-// You may connect to it either using:
+// You may connect to it using either of these:
+// ncat -C localhost 1234
+// ncat -C --ssl localhost 1234
+//
+// These clients are unable to properly shutdown the connection:
// telnet localhost 1234
-// or
// openssl s_client -connect localhost:1234
+//
package main
import (
@@ -37,6 +42,7 @@ import (
// --- Utilities ---------------------------------------------------------------
+//
// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom
// must be at least three octets long for this to work reliably, but that should
// not pose a problem in practice. We might try waiting for them.
@@ -46,8 +52,7 @@ import (
// SSL3/TLS: <22> | <3> | xxxx xxxx
// (handshake)| (protocol version)
//
-func detectTLS(sysconn syscall.RawConn) bool {
- isTLS := false
+func detectTLS(sysconn syscall.RawConn) (isTLS bool) {
sysconn.Read(func(fd uintptr) (done bool) {
var buf [3]byte
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK)
@@ -78,30 +83,38 @@ type client struct {
transport net.Conn // underlying connection
tls *tls.Conn // TLS, if detected
conn connCloseWrite // high-level connection
- connReady bool // conn is safe to read from the main goroutine
inQ []byte // unprocessed input
outQ []byte // unprocessed output
writing bool // whether a writing goroutine is running
inShutdown bool // whether we're closing connection
+ killTimer *time.Timer // timeout
+}
+
+type preparedEvent struct {
+ client *client
+ host string // client's hostname or literal IP address
+ isTLS bool // the client seems to use TLS
}
type readEvent struct {
- client *client // client
- data []byte // new data from the client
- err error // read error
+ client *client
+ data []byte // new data from the client
+ err error // read error
}
type writeEvent struct {
- client *client // client
- written int // amount of bytes written
- err error // write error
+ client *client
+ written int // amount of bytes written
+ err error // write error
}
var (
- sigs = make(chan os.Signal, 1)
- conns = make(chan net.Conn)
- reads = make(chan readEvent)
- writes = make(chan writeEvent)
+ sigs = make(chan os.Signal, 1)
+ conns = make(chan net.Conn)
+ prepared = make(chan preparedEvent)
+ reads = make(chan readEvent)
+ writes = make(chan writeEvent)
+ timeouts = make(chan *client)
tlsConf *tls.Config
clients = make(map[*client]bool)
@@ -122,6 +135,7 @@ func broadcast(line string, except *client) {
}
}
+// Initiate a clean shutdown of the whole daemon.
func initiateShutdown() {
log.Println("shutting down")
if err := listener.Close(); err != nil {
@@ -135,7 +149,12 @@ func initiateShutdown() {
inShutdown = true
}
+// Forcefully tear down all connections.
func forceShutdown(reason string) {
+ if !inShutdown {
+ log.Fatalln("forceShutdown called without initiateShutdown")
+ }
+
log.Printf("forced shutdown (%s)\n", reason)
for c := range clients {
c.destroy()
@@ -151,67 +170,84 @@ func (c *client) send(line string) {
}
}
-func (c *client) shutdown() {
+// Tear down the client connection, trying to do so in a graceful manner.
+func (c *client) kill() {
if c.inShutdown {
- log.Println("client double shutdown")
return
}
-
- // TODO: We must set a timer and destroy the client on timeout. Since we
- // have a central event loop, we probably need an event. Since we also
- // seem to need an event for TLS autodetection because of conn, we might
- // want to send an enumeration value.
- c.inShutdown = true
- c.conn.CloseWrite()
-}
-
-// Tear down the client connection, trying to do so in a graceful manner.
-func (c *client) kill() {
- if c.connReady {
- c.send("Goodbye")
- c.shutdown()
- } else {
+ if c.conn == nil {
c.destroy()
+ return
}
+
+ // Since we send this goodbye, we don't need to call CloseWrite.
+ c.send("Goodbye")
+ c.killTimer = time.AfterFunc(3*time.Second, func() {
+ timeouts <- c
+ })
+
+ c.inShutdown = true
}
// Close the connection and forget about the client.
func (c *client) destroy() {
// Try to send a "close notify" alert if the TLS object is ready,
// otherwise just tear down the transport.
- if c.connReady {
+ if c.conn != nil {
_ = c.conn.Close()
} else {
_ = c.transport.Close()
}
+ // Clean up the goroutine, although a spurious event may still be sent.
+ if c.killTimer != nil {
+ c.killTimer.Stop()
+ }
+
delete(clients, c)
}
+// Handle the results from initializing the client's connection.
+func (c *client) onPrepared(host string, isTLS bool) {
+ if isTLS {
+ c.tls = tls.Server(c.transport, tlsConf)
+ c.conn = c.tls
+ } else {
+ c.conn = c.transport.(connCloseWrite)
+ }
+
+ // TODO: Save the host in the client structure.
+ go read(c)
+}
+
// Handle the results from trying to read from the client connection.
func (c *client) onRead(data []byte, readErr error) {
c.inQ = append(c.inQ, data...)
for {
advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */)
- c.inQ = c.inQ[advance:]
if advance == 0 {
break
}
+ c.inQ = c.inQ[advance:]
line := string(token)
fmt.Println(line)
broadcast(line, c)
}
// TODO: Inform the client about the inQ overrun in the farewell message.
+ // TODO: We should stop receiving any more data from this client.
if len(c.inQ) > 8192 {
c.kill()
return
}
if readErr == io.EOF {
- // TODO: What if we're already in shutdown?
- c.shutdown()
+ if c.inShutdown {
+ c.destroy()
+ } else {
+ c.kill()
+ }
} else if readErr != nil {
log.Println(readErr)
c.destroy()
@@ -221,7 +257,7 @@ func (c *client) onRead(data []byte, readErr error) {
// Spawn a goroutine to flush the outQ if possible and necessary. If the
// connection is not ready yet, it needs to be retried as soon as it becomes.
func (c *client) flushOutQ() {
- if c.connReady && !c.writing {
+ if c.conn != nil && !c.writing {
go write(c, c.outQ)
c.writing = true
}
@@ -238,41 +274,77 @@ func (c *client) onWrite(written int, writeErr error) {
} else if len(c.outQ) > 0 {
c.flushOutQ()
} else if c.inShutdown {
- c.destroy()
+ if c.conn != nil {
+ // FIXME: This is only correct for when /we/ initiate the shutdown,
+ // otherwise we should perhaps just Close. Though even if we
+ // Close, there's a/ no writer to fail on it, and b/ the reader
+ // has already exited, too, which is why the client stays alive
+ // up until the timeout. It seems that in that case we need to
+ // call c.destroy().
+ c.conn.CloseWrite()
+ } else {
+ c.destroy()
+ }
}
}
// --- Worker goroutines -------------------------------------------------------
func accept(ln net.Listener) {
- // TODO: Consider specific cases in error handling, some errors
- // are transitional while others are fatal.
for {
if conn, err := ln.Accept(); err != nil {
+ // TODO: Consider specific cases in error handling, some errors
+ // are transitional while others are fatal.
log.Println(err)
+ break
} else {
conns <- conn
}
}
}
-func read(client *client) {
- // TODO: Either here or elsewhere we need to set a timeout.
+func prepare(client *client) {
+ conn := client.transport
+ host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
+ if err != nil {
+ // In effect, we require TCP/UDP, as they have port numbers.
+ log.Fatalln(err)
+ }
+
+ // The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not
+ // bothering with pointless contexts.
+ ch := make(chan string)
+ go func() {
+ defer close(ch)
+ if names, err := net.LookupAddr(host); err != nil {
+ log.Println(err)
+ } else {
+ ch <- names[0]
+ }
+ }()
+
+ // While we can't cancel it, we still want to set a timeout on it.
+ select {
+ case <-time.After(5 * time.Second):
+ case resolved, ok := <-ch:
+ if ok {
+ host = resolved
+ }
+ }
- client.conn = client.transport.(connCloseWrite)
- if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil {
+ isTLS := false
+ if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil {
// This is just for the TLS detection and doesn't need to be fatal.
log.Println(err)
- } else if detectTLS(sysconn) {
- client.tls = tls.Server(client.transport, tlsConf)
- client.conn = client.tls
+ } else {
+ isTLS = detectTLS(sysconn)
}
- // TODO: Signal the main goroutine that conn is ready. In fact, the upper
- // part could be mostly moved to the main goroutine and we'd only spawn
- // a thin wrapper around detectTLS, sending back {*client, bool}. Heck,
- // I could get rid of connReady.
+ prepared <- preparedEvent{client, host, isTLS}
+}
+
+func read(client *client) {
// A new buffer is allocated each time we receive some bytes, because of
// thread-safety. Therefore the buffer shouldn't be too large, or we'd
// need to copy it each time into a precisely sized new buffer.
@@ -312,7 +384,13 @@ func processOneEvent() {
log.Println("accepted client connection")
c := &client{transport: conn}
clients[c] = true
- go read(c)
+ go prepare(c)
+
+ case ev := <-prepared:
+ log.Println("client is ready:", ev.host)
+ if _, ok := clients[ev.client]; ok {
+ ev.client.onPrepared(ev.host, ev.isTLS)
+ }
case ev := <-reads:
log.Println("received data from client")
@@ -325,6 +403,12 @@ func processOneEvent() {
if _, ok := clients[ev.client]; ok {
ev.client.onWrite(ev.written, ev.err)
}
+
+ case c := <-timeouts:
+ if _, ok := clients[c]; ok {
+ log.Println("client timeouted")
+ c.destroy()
+ }
}
}