diff options
author | Přemysl Janouch <p@janouch.name> | 2018-07-15 10:45:12 +0200 |
---|---|---|
committer | Přemysl Janouch <p@janouch.name> | 2018-07-15 12:51:06 +0200 |
commit | 728fa4e54800a505d7fac1b4a4b83862a30f94b2 (patch) | |
tree | a742357aa5c732bd3f2623c9c6f1345e3bdab664 /prototypes/tls-autodetect.go | |
parent | b5b64db075b062c6b722d8ab1f8adf0a9dc63a41 (diff) | |
download | haven-728fa4e54800a505d7fac1b4a4b83862a30f94b2.tar.gz haven-728fa4e54800a505d7fac1b4a4b83862a30f94b2.tar.xz haven-728fa4e54800a505d7fac1b4a4b83862a30f94b2.zip |
tls-autodetect: put most of the server code in place
So far we act up when it is the client who initializes the shutdown.
Diffstat (limited to 'prototypes/tls-autodetect.go')
-rw-r--r-- | prototypes/tls-autodetect.go | 188 |
1 files changed, 136 insertions, 52 deletions
diff --git a/prototypes/tls-autodetect.go b/prototypes/tls-autodetect.go index c58ed28..ce88379 100644 --- a/prototypes/tls-autodetect.go +++ b/prototypes/tls-autodetect.go @@ -13,12 +13,17 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // +// // This is an example TLS-autodetecting chat server. // -// You may connect to it either using: +// You may connect to it using either of these: +// ncat -C localhost 1234 +// ncat -C --ssl localhost 1234 +// +// These clients are unable to properly shutdown the connection: // telnet localhost 1234 -// or // openssl s_client -connect localhost:1234 +// package main import ( @@ -37,6 +42,7 @@ import ( // --- Utilities --------------------------------------------------------------- +// // Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom // must be at least three octets long for this to work reliably, but that should // not pose a problem in practice. We might try waiting for them. @@ -46,8 +52,7 @@ import ( // SSL3/TLS: <22> | <3> | xxxx xxxx // (handshake)| (protocol version) // -func detectTLS(sysconn syscall.RawConn) bool { - isTLS := false +func detectTLS(sysconn syscall.RawConn) (isTLS bool) { sysconn.Read(func(fd uintptr) (done bool) { var buf [3]byte n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) @@ -78,30 +83,38 @@ type client struct { transport net.Conn // underlying connection tls *tls.Conn // TLS, if detected conn connCloseWrite // high-level connection - connReady bool // conn is safe to read from the main goroutine inQ []byte // unprocessed input outQ []byte // unprocessed output writing bool // whether a writing goroutine is running inShutdown bool // whether we're closing connection + killTimer *time.Timer // timeout +} + +type preparedEvent struct { + client *client + host string // client's hostname or literal IP address + isTLS bool // the client seems to use TLS } type readEvent struct { - client *client // client - data []byte // new data from the client - err error // read error + client *client + data []byte // new data from the client + err error // read error } type writeEvent struct { - client *client // client - written int // amount of bytes written - err error // write error + client *client + written int // amount of bytes written + err error // write error } var ( - sigs = make(chan os.Signal, 1) - conns = make(chan net.Conn) - reads = make(chan readEvent) - writes = make(chan writeEvent) + sigs = make(chan os.Signal, 1) + conns = make(chan net.Conn) + prepared = make(chan preparedEvent) + reads = make(chan readEvent) + writes = make(chan writeEvent) + timeouts = make(chan *client) tlsConf *tls.Config clients = make(map[*client]bool) @@ -122,6 +135,7 @@ func broadcast(line string, except *client) { } } +// Initiate a clean shutdown of the whole daemon. func initiateShutdown() { log.Println("shutting down") if err := listener.Close(); err != nil { @@ -135,7 +149,12 @@ func initiateShutdown() { inShutdown = true } +// Forcefully tear down all connections. func forceShutdown(reason string) { + if !inShutdown { + log.Fatalln("forceShutdown called without initiateShutdown") + } + log.Printf("forced shutdown (%s)\n", reason) for c := range clients { c.destroy() @@ -151,67 +170,84 @@ func (c *client) send(line string) { } } -func (c *client) shutdown() { +// Tear down the client connection, trying to do so in a graceful manner. +func (c *client) kill() { if c.inShutdown { - log.Println("client double shutdown") return } - - // TODO: We must set a timer and destroy the client on timeout. Since we - // have a central event loop, we probably need an event. Since we also - // seem to need an event for TLS autodetection because of conn, we might - // want to send an enumeration value. - c.inShutdown = true - c.conn.CloseWrite() -} - -// Tear down the client connection, trying to do so in a graceful manner. -func (c *client) kill() { - if c.connReady { - c.send("Goodbye") - c.shutdown() - } else { + if c.conn == nil { c.destroy() + return } + + // Since we send this goodbye, we don't need to call CloseWrite. + c.send("Goodbye") + c.killTimer = time.AfterFunc(3*time.Second, func() { + timeouts <- c + }) + + c.inShutdown = true } // Close the connection and forget about the client. func (c *client) destroy() { // Try to send a "close notify" alert if the TLS object is ready, // otherwise just tear down the transport. - if c.connReady { + if c.conn != nil { _ = c.conn.Close() } else { _ = c.transport.Close() } + // Clean up the goroutine, although a spurious event may still be sent. + if c.killTimer != nil { + c.killTimer.Stop() + } + delete(clients, c) } +// Handle the results from initializing the client's connection. +func (c *client) onPrepared(host string, isTLS bool) { + if isTLS { + c.tls = tls.Server(c.transport, tlsConf) + c.conn = c.tls + } else { + c.conn = c.transport.(connCloseWrite) + } + + // TODO: Save the host in the client structure. + go read(c) +} + // Handle the results from trying to read from the client connection. func (c *client) onRead(data []byte, readErr error) { c.inQ = append(c.inQ, data...) for { advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) - c.inQ = c.inQ[advance:] if advance == 0 { break } + c.inQ = c.inQ[advance:] line := string(token) fmt.Println(line) broadcast(line, c) } // TODO: Inform the client about the inQ overrun in the farewell message. + // TODO: We should stop receiving any more data from this client. if len(c.inQ) > 8192 { c.kill() return } if readErr == io.EOF { - // TODO: What if we're already in shutdown? - c.shutdown() + if c.inShutdown { + c.destroy() + } else { + c.kill() + } } else if readErr != nil { log.Println(readErr) c.destroy() @@ -221,7 +257,7 @@ func (c *client) onRead(data []byte, readErr error) { // Spawn a goroutine to flush the outQ if possible and necessary. If the // connection is not ready yet, it needs to be retried as soon as it becomes. func (c *client) flushOutQ() { - if c.connReady && !c.writing { + if c.conn != nil && !c.writing { go write(c, c.outQ) c.writing = true } @@ -238,41 +274,77 @@ func (c *client) onWrite(written int, writeErr error) { } else if len(c.outQ) > 0 { c.flushOutQ() } else if c.inShutdown { - c.destroy() + if c.conn != nil { + // FIXME: This is only correct for when /we/ initiate the shutdown, + // otherwise we should perhaps just Close. Though even if we + // Close, there's a/ no writer to fail on it, and b/ the reader + // has already exited, too, which is why the client stays alive + // up until the timeout. It seems that in that case we need to + // call c.destroy(). + c.conn.CloseWrite() + } else { + c.destroy() + } } } // --- Worker goroutines ------------------------------------------------------- func accept(ln net.Listener) { - // TODO: Consider specific cases in error handling, some errors - // are transitional while others are fatal. for { if conn, err := ln.Accept(); err != nil { + // TODO: Consider specific cases in error handling, some errors + // are transitional while others are fatal. log.Println(err) + break } else { conns <- conn } } } -func read(client *client) { - // TODO: Either here or elsewhere we need to set a timeout. +func prepare(client *client) { + conn := client.transport + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + // In effect, we require TCP/UDP, as they have port numbers. + log.Fatalln(err) + } + + // The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not + // bothering with pointless contexts. + ch := make(chan string) + go func() { + defer close(ch) + if names, err := net.LookupAddr(host); err != nil { + log.Println(err) + } else { + ch <- names[0] + } + }() + + // While we can't cancel it, we still want to set a timeout on it. + select { + case <-time.After(5 * time.Second): + case resolved, ok := <-ch: + if ok { + host = resolved + } + } - client.conn = client.transport.(connCloseWrite) - if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil { + isTLS := false + if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil { // This is just for the TLS detection and doesn't need to be fatal. log.Println(err) - } else if detectTLS(sysconn) { - client.tls = tls.Server(client.transport, tlsConf) - client.conn = client.tls + } else { + isTLS = detectTLS(sysconn) } - // TODO: Signal the main goroutine that conn is ready. In fact, the upper - // part could be mostly moved to the main goroutine and we'd only spawn - // a thin wrapper around detectTLS, sending back {*client, bool}. Heck, - // I could get rid of connReady. + prepared <- preparedEvent{client, host, isTLS} +} + +func read(client *client) { // A new buffer is allocated each time we receive some bytes, because of // thread-safety. Therefore the buffer shouldn't be too large, or we'd // need to copy it each time into a precisely sized new buffer. @@ -312,7 +384,13 @@ func processOneEvent() { log.Println("accepted client connection") c := &client{transport: conn} clients[c] = true - go read(c) + go prepare(c) + + case ev := <-prepared: + log.Println("client is ready:", ev.host) + if _, ok := clients[ev.client]; ok { + ev.client.onPrepared(ev.host, ev.isTLS) + } case ev := <-reads: log.Println("received data from client") @@ -325,6 +403,12 @@ func processOneEvent() { if _, ok := clients[ev.client]; ok { ev.client.onWrite(ev.written, ev.err) } + + case c := <-timeouts: + if _, ok := clients[c]; ok { + log.Println("client timeouted") + c.destroy() + } } } |