diff options
Diffstat (limited to 'prototypes/tls-autodetect.go')
-rw-r--r-- | prototypes/tls-autodetect.go | 356 |
1 files changed, 356 insertions, 0 deletions
diff --git a/prototypes/tls-autodetect.go b/prototypes/tls-autodetect.go new file mode 100644 index 0000000..c58ed28 --- /dev/null +++ b/prototypes/tls-autodetect.go @@ -0,0 +1,356 @@ +// +// Copyright (c) 2018, Přemysl Janouch <p@janouch.name> +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +// + +// This is an example TLS-autodetecting chat server. +// +// You may connect to it either using: +// telnet localhost 1234 +// or +// openssl s_client -connect localhost:1234 +package main + +import ( + "bufio" + "crypto/tls" + "flag" + "fmt" + "io" + "log" + "net" + "os" + "os/signal" + "syscall" + "time" +) + +// --- Utilities --------------------------------------------------------------- + +// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom +// must be at least three octets long for this to work reliably, but that should +// not pose a problem in practice. We might try waiting for them. +// +// SSL2: 1xxx xxxx | xxxx xxxx | <1> +// (message length) (client hello) +// SSL3/TLS: <22> | <3> | xxxx xxxx +// (handshake)| (protocol version) +// +func detectTLS(sysconn syscall.RawConn) bool { + isTLS := false + sysconn.Read(func(fd uintptr) (done bool) { + var buf [3]byte + n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) + switch { + case n == 3: + isTLS = buf[0]&0x80 != 0 && buf[2] == 1 + fallthrough + case n == 2: + isTLS = buf[0] == 22 && buf[1] == 3 + case n == 1: + isTLS = buf[0] == 22 + case err == syscall.EAGAIN: + return false + } + return true + }) + return isTLS +} + +// --- Declarations ------------------------------------------------------------ + +type connCloseWrite interface { + net.Conn + CloseWrite() error +} + +type client struct { + transport net.Conn // underlying connection + tls *tls.Conn // TLS, if detected + conn connCloseWrite // high-level connection + connReady bool // conn is safe to read from the main goroutine + inQ []byte // unprocessed input + outQ []byte // unprocessed output + writing bool // whether a writing goroutine is running + inShutdown bool // whether we're closing connection +} + +type readEvent struct { + client *client // client + data []byte // new data from the client + err error // read error +} + +type writeEvent struct { + client *client // client + written int // amount of bytes written + err error // write error +} + +var ( + sigs = make(chan os.Signal, 1) + conns = make(chan net.Conn) + reads = make(chan readEvent) + writes = make(chan writeEvent) + + tlsConf *tls.Config + clients = make(map[*client]bool) + listener net.Listener + inShutdown bool + shutdownTimer <-chan time.Time +) + +// --- Server ------------------------------------------------------------------ + +// Broadcast to all /other/ clients (telnet-friendly, also in accordance to +// the plan of extending this to an IRCd). +func broadcast(line string, except *client) { + for c := range clients { + if c != except { + c.send(line) + } + } +} + +func initiateShutdown() { + log.Println("shutting down") + if err := listener.Close(); err != nil { + log.Println(err) + } + for c := range clients { + c.kill() + } + + shutdownTimer = time.After(3 * time.Second) + inShutdown = true +} + +func forceShutdown(reason string) { + log.Printf("forced shutdown (%s)\n", reason) + for c := range clients { + c.destroy() + } +} + +// --- Client ------------------------------------------------------------------ + +func (c *client) send(line string) { + if !c.inShutdown { + c.outQ = append(c.outQ, (line + "\r\n")...) + c.flushOutQ() + } +} + +func (c *client) shutdown() { + if c.inShutdown { + log.Println("client double shutdown") + return + } + + // TODO: We must set a timer and destroy the client on timeout. Since we + // have a central event loop, we probably need an event. Since we also + // seem to need an event for TLS autodetection because of conn, we might + // want to send an enumeration value. + c.inShutdown = true + c.conn.CloseWrite() +} + +// Tear down the client connection, trying to do so in a graceful manner. +func (c *client) kill() { + if c.connReady { + c.send("Goodbye") + c.shutdown() + } else { + c.destroy() + } +} + +// Close the connection and forget about the client. +func (c *client) destroy() { + // Try to send a "close notify" alert if the TLS object is ready, + // otherwise just tear down the transport. + if c.connReady { + _ = c.conn.Close() + } else { + _ = c.transport.Close() + } + + delete(clients, c) +} + +// Handle the results from trying to read from the client connection. +func (c *client) onRead(data []byte, readErr error) { + c.inQ = append(c.inQ, data...) + for { + advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) + c.inQ = c.inQ[advance:] + if advance == 0 { + break + } + + line := string(token) + fmt.Println(line) + broadcast(line, c) + } + + // TODO: Inform the client about the inQ overrun in the farewell message. + if len(c.inQ) > 8192 { + c.kill() + return + } + + if readErr == io.EOF { + // TODO: What if we're already in shutdown? + c.shutdown() + } else if readErr != nil { + log.Println(readErr) + c.destroy() + } +} + +// Spawn a goroutine to flush the outQ if possible and necessary. If the +// connection is not ready yet, it needs to be retried as soon as it becomes. +func (c *client) flushOutQ() { + if c.connReady && !c.writing { + go write(c, c.outQ) + c.writing = true + } +} + +// Handle the results from trying to write to the client connection. +func (c *client) onWrite(written int, writeErr error) { + c.outQ = c.outQ[written:] + c.writing = false + + if writeErr != nil { + log.Println(writeErr) + c.destroy() + } else if len(c.outQ) > 0 { + c.flushOutQ() + } else if c.inShutdown { + c.destroy() + } +} + +// --- Worker goroutines ------------------------------------------------------- + +func accept(ln net.Listener) { + // TODO: Consider specific cases in error handling, some errors + // are transitional while others are fatal. + for { + if conn, err := ln.Accept(); err != nil { + log.Println(err) + } else { + conns <- conn + } + } +} + +func read(client *client) { + // TODO: Either here or elsewhere we need to set a timeout. + + client.conn = client.transport.(connCloseWrite) + if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil { + // This is just for the TLS detection and doesn't need to be fatal. + log.Println(err) + } else if detectTLS(sysconn) { + client.tls = tls.Server(client.transport, tlsConf) + client.conn = client.tls + } + + // TODO: Signal the main goroutine that conn is ready. In fact, the upper + // part could be mostly moved to the main goroutine and we'd only spawn + // a thin wrapper around detectTLS, sending back {*client, bool}. Heck, + // I could get rid of connReady. + + // A new buffer is allocated each time we receive some bytes, because of + // thread-safety. Therefore the buffer shouldn't be too large, or we'd + // need to copy it each time into a precisely sized new buffer. + var err error + for err == nil { + var ( + buf [512]byte + n int + ) + n, err = client.conn.Read(buf[:]) + reads <- readEvent{client, buf[:n], err} + } +} + +// Flush outQ, which is passed by parameter so that there are no data races. +func write(client *client, data []byte) { + // We just write as much as we can, the main goroutine does the looping. + n, err := client.conn.Write(data) + writes <- writeEvent{client, n, err} +} + +// --- Main -------------------------------------------------------------------- + +func processOneEvent() { + select { + case <-sigs: + if inShutdown { + forceShutdown("requested by user") + } else { + initiateShutdown() + } + + case <-shutdownTimer: + forceShutdown("timeout") + + case conn := <-conns: + log.Println("accepted client connection") + c := &client{transport: conn} + clients[c] = true + go read(c) + + case ev := <-reads: + log.Println("received data from client") + if _, ok := clients[ev.client]; ok { + ev.client.onRead(ev.data, ev.err) + } + + case ev := <-writes: + log.Println("sent data to client") + if _, ok := clients[ev.client]; ok { + ev.client.onWrite(ev.written, ev.err) + } + } +} + +func main() { + // Just deal with unexpected flags, we don't use any ourselves. + flag.Parse() + + if len(flag.Args()) != 3 { + log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0]) + } + + cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0)) + if err != nil { + log.Fatalln(err) + } + + tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}} + listener, err = net.Listen("tcp", flag.Arg(2)) + if err != nil { + log.Fatalln(err) + } + + go accept(listener) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + + for !inShutdown || len(clients) > 0 { + processOneEvent() + } +} |