diff options
Diffstat (limited to 'prototypes/tls-autodetect.go')
-rw-r--r-- | prototypes/tls-autodetect.go | 451 |
1 files changed, 451 insertions, 0 deletions
diff --git a/prototypes/tls-autodetect.go b/prototypes/tls-autodetect.go new file mode 100644 index 0000000..0427465 --- /dev/null +++ b/prototypes/tls-autodetect.go @@ -0,0 +1,451 @@ +// +// Copyright (c) 2018, Přemysl Janouch <p@janouch.name> +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +// + +// +// This is an example TLS-autodetecting chat server. +// +// These clients are unable to properly shutdown the connection on their exit: +// telnet localhost 1234 +// openssl s_client -connect localhost:1234 +// +// While this one doesn't react to an EOF from the server: +// ncat -C localhost 1234 +// ncat -C --ssl localhost 1234 +// +package main + +import ( + "bufio" + "crypto/tls" + "flag" + "fmt" + "io" + "log" + "net" + "os" + "os/signal" + "syscall" + "time" +) + +// --- Utilities --------------------------------------------------------------- + +// +// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom +// must be at least three octets long for this to work reliably, but that should +// not pose a problem in practice. We might try waiting for them. +// +// SSL2: 1xxx xxxx | xxxx xxxx | <1> +// (message length) (client hello) +// SSL3/TLS: <22> | <3> | xxxx xxxx +// (handshake)| (protocol version) +// +func detectTLS(sysconn syscall.RawConn) (isTLS bool) { + sysconn.Read(func(fd uintptr) (done bool) { + var buf [3]byte + n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) + switch { + case n == 3: + isTLS = buf[0]&0x80 != 0 && buf[2] == 1 + fallthrough + case n == 2: + isTLS = isTLS || buf[0] == 22 && buf[1] == 3 + case n == 1: + isTLS = buf[0] == 22 + case err == syscall.EAGAIN: + return false + } + return true + }) + return isTLS +} + +// --- Declarations ------------------------------------------------------------ + +type connCloseWriter interface { + net.Conn + CloseWrite() error +} + +type client struct { + transport net.Conn // underlying connection + tls *tls.Conn // TLS, if detected + conn connCloseWriter // high-level connection + inQ []byte // unprocessed input + outQ []byte // unprocessed output + reading bool // whether a reading goroutine is running + writing bool // whether a writing goroutine is running + closing bool // whether we're closing the connection + killTimer *time.Timer // timeout +} + +type preparedEvent struct { + client *client + host string // client's hostname or literal IP address + isTLS bool // the client seems to use TLS +} + +type readEvent struct { + client *client + data []byte // new data from the client + err error // read error +} + +type writeEvent struct { + client *client + written int // amount of bytes written + err error // write error +} + +var ( + sigs = make(chan os.Signal, 1) + conns = make(chan net.Conn) + prepared = make(chan preparedEvent) + reads = make(chan readEvent) + writes = make(chan writeEvent) + timeouts = make(chan *client) + + tlsConf *tls.Config + clients = make(map[*client]bool) + listener net.Listener + inShutdown bool + shutdownTimer <-chan time.Time +) + +// --- Server ------------------------------------------------------------------ + +// Broadcast to all /other/ clients (telnet-friendly, also in accordance to +// the plan of extending this to an IRCd). +func broadcast(line string, except *client) { + for c := range clients { + if c != except { + c.send(line) + } + } +} + +// Initiate a clean shutdown of the whole daemon. +func initiateShutdown() { + log.Println("shutting down") + if err := listener.Close(); err != nil { + log.Println(err) + } + for c := range clients { + c.closeLink() + } + + shutdownTimer = time.After(3 * time.Second) + inShutdown = true +} + +// Forcefully tear down all connections. +func forceShutdown(reason string) { + if !inShutdown { + log.Fatalln("forceShutdown called without initiateShutdown") + } + + log.Printf("forced shutdown (%s)\n", reason) + for c := range clients { + c.destroy() + } +} + +// --- Client ------------------------------------------------------------------ + +func (c *client) send(line string) { + if c.conn != nil && !c.closing { + c.outQ = append(c.outQ, (line + "\r\n")...) + c.flushOutQ() + } +} + +// Tear down the client connection, trying to do so in a graceful manner. +func (c *client) closeLink() { + if c.closing { + return + } + if c.conn == nil { + c.destroy() + return + } + + // Since we send this goodbye, we don't need to call CloseWrite here. + c.send("Goodbye") + c.killTimer = time.AfterFunc(3*time.Second, func() { + timeouts <- c + }) + + c.closing = true +} + +// Close the connection and forget about the client. +func (c *client) destroy() { + // Try to send a "close notify" alert if the TLS object is ready, + // otherwise just tear down the transport. + if c.conn != nil { + _ = c.conn.Close() + } else { + _ = c.transport.Close() + } + + // Clean up the goroutine, although a spurious event may still be sent. + if c.killTimer != nil { + c.killTimer.Stop() + } + + log.Println("client destroyed") + delete(clients, c) +} + +// Handle the results from initializing the client's connection. +func (c *client) onPrepared(isTLS bool) { + if isTLS { + c.tls = tls.Server(c.transport, tlsConf) + c.conn = c.tls + } else { + c.conn = c.transport.(connCloseWriter) + } + + // TODO: If we've tried to send any data before now, we need to flushOutQ. + go read(c) + c.reading = true +} + +// Handle the results from trying to read from the client connection. +func (c *client) onRead(data []byte, readErr error) { + if !c.reading { + // Abusing the flag to emulate CloseRead and skip over data, see below. + return + } + + c.inQ = append(c.inQ, data...) + for { + advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) + if advance == 0 { + break + } + + c.inQ = c.inQ[advance:] + line := string(token) + fmt.Println(line) + broadcast(line, c) + } + + if readErr != nil { + c.reading = false + + if readErr != io.EOF { + log.Println(readErr) + c.destroy() + } else if c.closing { + // Disregarding whether a clean shutdown has happened or not. + log.Println("client finished shutdown") + c.destroy() + } else { + log.Println("client EOF") + c.closeLink() + } + } else if len(c.inQ) > 8192 { + log.Println("client inQ overrun") + // TODO: Inform the client about inQ overrun in the farewell message. + c.closeLink() + + // tls.Conn doesn't have the CloseRead method (and it needs to be able + // to read from the TCP connection even for writes, so there isn't much + // sense in expecting the implementation to do anything useful), + // otherwise we'd use it to block incoming packet data. + c.reading = false + } +} + +// Spawn a goroutine to flush the outQ if possible and necessary. +func (c *client) flushOutQ() { + if !c.writing && c.conn != nil { + go write(c, c.outQ) + c.writing = true + } +} + +// Handle the results from trying to write to the client connection. +func (c *client) onWrite(written int, writeErr error) { + c.outQ = c.outQ[written:] + c.writing = false + + if writeErr != nil { + log.Println(writeErr) + c.destroy() + } else if len(c.outQ) > 0 { + c.flushOutQ() + } else if c.closing { + if c.reading { + c.conn.CloseWrite() + } else { + c.destroy() + } + } +} + +// --- Worker goroutines ------------------------------------------------------- + +func accept(ln net.Listener) { + for { + if conn, err := ln.Accept(); err != nil { + // TODO: Consider specific cases in error handling, some errors + // are transitional while others are fatal. + log.Println(err) + break + } else { + conns <- conn + } + } +} + +func prepare(client *client) { + conn := client.transport + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + // In effect, we require TCP/UDP, as they have port numbers. + log.Fatalln(err) + } + + // The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not + // bothering with pointless contexts. + ch := make(chan string, 1) + go func() { + defer close(ch) + if names, err := net.LookupAddr(host); err != nil { + log.Println(err) + } else { + ch <- names[0] + } + }() + + // While we can't cancel it, we still want to set a timeout on it. + select { + case <-time.After(5 * time.Second): + case resolved, ok := <-ch: + if ok { + host = resolved + } + } + + // Note that in this demo application the autodetection prevents non-TLS + // clients from receiving any messages until they send something. + isTLS := false + if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil { + // This is just for the TLS detection and doesn't need to be fatal. + log.Println(err) + } else { + isTLS = detectTLS(sysconn) + } + + // FIXME: When the client sends no data, we still initialize its conn. + prepared <- preparedEvent{client, host, isTLS} +} + +func read(client *client) { + // A new buffer is allocated each time we receive some bytes, because of + // thread-safety. Therefore the buffer shouldn't be too large, or we'd + // need to copy it each time into a precisely sized new buffer. + var err error + for err == nil { + var ( + buf [512]byte + n int + ) + n, err = client.conn.Read(buf[:]) + reads <- readEvent{client, buf[:n], err} + } +} + +// Flush outQ, which is passed by parameter so that there are no data races. +func write(client *client, data []byte) { + // We just write as much as we can, the main goroutine does the looping. + n, err := client.conn.Write(data) + writes <- writeEvent{client, n, err} +} + +// --- Main -------------------------------------------------------------------- + +func processOneEvent() { + select { + case <-sigs: + if inShutdown { + forceShutdown("requested by user") + } else { + initiateShutdown() + } + + case <-shutdownTimer: + forceShutdown("timeout") + + case conn := <-conns: + log.Println("accepted client connection") + c := &client{transport: conn} + clients[c] = true + go prepare(c) + + case ev := <-prepared: + log.Println("client is ready, resolved to", ev.host) + if _, ok := clients[ev.client]; ok { + ev.client.onPrepared(ev.isTLS) + } + + case ev := <-reads: + log.Println("received data from client") + if _, ok := clients[ev.client]; ok { + ev.client.onRead(ev.data, ev.err) + } + + case ev := <-writes: + log.Println("sent data to client") + if _, ok := clients[ev.client]; ok { + ev.client.onWrite(ev.written, ev.err) + } + + case c := <-timeouts: + if _, ok := clients[c]; ok { + log.Println("client timeouted") + c.destroy() + } + } +} + +func main() { + // Just deal with unexpected flags, we don't use any ourselves. + flag.Parse() + + if len(flag.Args()) != 3 { + log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0]) + } + + cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0)) + if err != nil { + log.Fatalln(err) + } + + tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}} + listener, err = net.Listen("tcp", flag.Arg(2)) + if err != nil { + log.Fatalln(err) + } + + go accept(listener) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + + for !inShutdown || len(clients) > 0 { + processOneEvent() + } +} |