aboutsummaryrefslogtreecommitdiff
path: root/prototypes/tls-autodetect.go
diff options
context:
space:
mode:
Diffstat (limited to 'prototypes/tls-autodetect.go')
-rw-r--r--prototypes/tls-autodetect.go451
1 files changed, 451 insertions, 0 deletions
diff --git a/prototypes/tls-autodetect.go b/prototypes/tls-autodetect.go
new file mode 100644
index 0000000..0427465
--- /dev/null
+++ b/prototypes/tls-autodetect.go
@@ -0,0 +1,451 @@
+//
+// Copyright (c) 2018, Přemysl Janouch <p@janouch.name>
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+//
+
+//
+// This is an example TLS-autodetecting chat server.
+//
+// These clients are unable to properly shutdown the connection on their exit:
+// telnet localhost 1234
+// openssl s_client -connect localhost:1234
+//
+// While this one doesn't react to an EOF from the server:
+// ncat -C localhost 1234
+// ncat -C --ssl localhost 1234
+//
+package main
+
+import (
+ "bufio"
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+)
+
+// --- Utilities ---------------------------------------------------------------
+
+//
+// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom
+// must be at least three octets long for this to work reliably, but that should
+// not pose a problem in practice. We might try waiting for them.
+//
+// SSL2: 1xxx xxxx | xxxx xxxx | <1>
+// (message length) (client hello)
+// SSL3/TLS: <22> | <3> | xxxx xxxx
+// (handshake)| (protocol version)
+//
+func detectTLS(sysconn syscall.RawConn) (isTLS bool) {
+ sysconn.Read(func(fd uintptr) (done bool) {
+ var buf [3]byte
+ n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK)
+ switch {
+ case n == 3:
+ isTLS = buf[0]&0x80 != 0 && buf[2] == 1
+ fallthrough
+ case n == 2:
+ isTLS = isTLS || buf[0] == 22 && buf[1] == 3
+ case n == 1:
+ isTLS = buf[0] == 22
+ case err == syscall.EAGAIN:
+ return false
+ }
+ return true
+ })
+ return isTLS
+}
+
+// --- Declarations ------------------------------------------------------------
+
+type connCloseWriter interface {
+ net.Conn
+ CloseWrite() error
+}
+
+type client struct {
+ transport net.Conn // underlying connection
+ tls *tls.Conn // TLS, if detected
+ conn connCloseWriter // high-level connection
+ inQ []byte // unprocessed input
+ outQ []byte // unprocessed output
+ reading bool // whether a reading goroutine is running
+ writing bool // whether a writing goroutine is running
+ closing bool // whether we're closing the connection
+ killTimer *time.Timer // timeout
+}
+
+type preparedEvent struct {
+ client *client
+ host string // client's hostname or literal IP address
+ isTLS bool // the client seems to use TLS
+}
+
+type readEvent struct {
+ client *client
+ data []byte // new data from the client
+ err error // read error
+}
+
+type writeEvent struct {
+ client *client
+ written int // amount of bytes written
+ err error // write error
+}
+
+var (
+ sigs = make(chan os.Signal, 1)
+ conns = make(chan net.Conn)
+ prepared = make(chan preparedEvent)
+ reads = make(chan readEvent)
+ writes = make(chan writeEvent)
+ timeouts = make(chan *client)
+
+ tlsConf *tls.Config
+ clients = make(map[*client]bool)
+ listener net.Listener
+ inShutdown bool
+ shutdownTimer <-chan time.Time
+)
+
+// --- Server ------------------------------------------------------------------
+
+// Broadcast to all /other/ clients (telnet-friendly, also in accordance to
+// the plan of extending this to an IRCd).
+func broadcast(line string, except *client) {
+ for c := range clients {
+ if c != except {
+ c.send(line)
+ }
+ }
+}
+
+// Initiate a clean shutdown of the whole daemon.
+func initiateShutdown() {
+ log.Println("shutting down")
+ if err := listener.Close(); err != nil {
+ log.Println(err)
+ }
+ for c := range clients {
+ c.closeLink()
+ }
+
+ shutdownTimer = time.After(3 * time.Second)
+ inShutdown = true
+}
+
+// Forcefully tear down all connections.
+func forceShutdown(reason string) {
+ if !inShutdown {
+ log.Fatalln("forceShutdown called without initiateShutdown")
+ }
+
+ log.Printf("forced shutdown (%s)\n", reason)
+ for c := range clients {
+ c.destroy()
+ }
+}
+
+// --- Client ------------------------------------------------------------------
+
+func (c *client) send(line string) {
+ if c.conn != nil && !c.closing {
+ c.outQ = append(c.outQ, (line + "\r\n")...)
+ c.flushOutQ()
+ }
+}
+
+// Tear down the client connection, trying to do so in a graceful manner.
+func (c *client) closeLink() {
+ if c.closing {
+ return
+ }
+ if c.conn == nil {
+ c.destroy()
+ return
+ }
+
+ // Since we send this goodbye, we don't need to call CloseWrite here.
+ c.send("Goodbye")
+ c.killTimer = time.AfterFunc(3*time.Second, func() {
+ timeouts <- c
+ })
+
+ c.closing = true
+}
+
+// Close the connection and forget about the client.
+func (c *client) destroy() {
+ // Try to send a "close notify" alert if the TLS object is ready,
+ // otherwise just tear down the transport.
+ if c.conn != nil {
+ _ = c.conn.Close()
+ } else {
+ _ = c.transport.Close()
+ }
+
+ // Clean up the goroutine, although a spurious event may still be sent.
+ if c.killTimer != nil {
+ c.killTimer.Stop()
+ }
+
+ log.Println("client destroyed")
+ delete(clients, c)
+}
+
+// Handle the results from initializing the client's connection.
+func (c *client) onPrepared(isTLS bool) {
+ if isTLS {
+ c.tls = tls.Server(c.transport, tlsConf)
+ c.conn = c.tls
+ } else {
+ c.conn = c.transport.(connCloseWriter)
+ }
+
+ // TODO: If we've tried to send any data before now, we need to flushOutQ.
+ go read(c)
+ c.reading = true
+}
+
+// Handle the results from trying to read from the client connection.
+func (c *client) onRead(data []byte, readErr error) {
+ if !c.reading {
+ // Abusing the flag to emulate CloseRead and skip over data, see below.
+ return
+ }
+
+ c.inQ = append(c.inQ, data...)
+ for {
+ advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */)
+ if advance == 0 {
+ break
+ }
+
+ c.inQ = c.inQ[advance:]
+ line := string(token)
+ fmt.Println(line)
+ broadcast(line, c)
+ }
+
+ if readErr != nil {
+ c.reading = false
+
+ if readErr != io.EOF {
+ log.Println(readErr)
+ c.destroy()
+ } else if c.closing {
+ // Disregarding whether a clean shutdown has happened or not.
+ log.Println("client finished shutdown")
+ c.destroy()
+ } else {
+ log.Println("client EOF")
+ c.closeLink()
+ }
+ } else if len(c.inQ) > 8192 {
+ log.Println("client inQ overrun")
+ // TODO: Inform the client about inQ overrun in the farewell message.
+ c.closeLink()
+
+ // tls.Conn doesn't have the CloseRead method (and it needs to be able
+ // to read from the TCP connection even for writes, so there isn't much
+ // sense in expecting the implementation to do anything useful),
+ // otherwise we'd use it to block incoming packet data.
+ c.reading = false
+ }
+}
+
+// Spawn a goroutine to flush the outQ if possible and necessary.
+func (c *client) flushOutQ() {
+ if !c.writing && c.conn != nil {
+ go write(c, c.outQ)
+ c.writing = true
+ }
+}
+
+// Handle the results from trying to write to the client connection.
+func (c *client) onWrite(written int, writeErr error) {
+ c.outQ = c.outQ[written:]
+ c.writing = false
+
+ if writeErr != nil {
+ log.Println(writeErr)
+ c.destroy()
+ } else if len(c.outQ) > 0 {
+ c.flushOutQ()
+ } else if c.closing {
+ if c.reading {
+ c.conn.CloseWrite()
+ } else {
+ c.destroy()
+ }
+ }
+}
+
+// --- Worker goroutines -------------------------------------------------------
+
+func accept(ln net.Listener) {
+ for {
+ if conn, err := ln.Accept(); err != nil {
+ // TODO: Consider specific cases in error handling, some errors
+ // are transitional while others are fatal.
+ log.Println(err)
+ break
+ } else {
+ conns <- conn
+ }
+ }
+}
+
+func prepare(client *client) {
+ conn := client.transport
+ host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
+ if err != nil {
+ // In effect, we require TCP/UDP, as they have port numbers.
+ log.Fatalln(err)
+ }
+
+ // The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not
+ // bothering with pointless contexts.
+ ch := make(chan string, 1)
+ go func() {
+ defer close(ch)
+ if names, err := net.LookupAddr(host); err != nil {
+ log.Println(err)
+ } else {
+ ch <- names[0]
+ }
+ }()
+
+ // While we can't cancel it, we still want to set a timeout on it.
+ select {
+ case <-time.After(5 * time.Second):
+ case resolved, ok := <-ch:
+ if ok {
+ host = resolved
+ }
+ }
+
+ // Note that in this demo application the autodetection prevents non-TLS
+ // clients from receiving any messages until they send something.
+ isTLS := false
+ if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil {
+ // This is just for the TLS detection and doesn't need to be fatal.
+ log.Println(err)
+ } else {
+ isTLS = detectTLS(sysconn)
+ }
+
+ // FIXME: When the client sends no data, we still initialize its conn.
+ prepared <- preparedEvent{client, host, isTLS}
+}
+
+func read(client *client) {
+ // A new buffer is allocated each time we receive some bytes, because of
+ // thread-safety. Therefore the buffer shouldn't be too large, or we'd
+ // need to copy it each time into a precisely sized new buffer.
+ var err error
+ for err == nil {
+ var (
+ buf [512]byte
+ n int
+ )
+ n, err = client.conn.Read(buf[:])
+ reads <- readEvent{client, buf[:n], err}
+ }
+}
+
+// Flush outQ, which is passed by parameter so that there are no data races.
+func write(client *client, data []byte) {
+ // We just write as much as we can, the main goroutine does the looping.
+ n, err := client.conn.Write(data)
+ writes <- writeEvent{client, n, err}
+}
+
+// --- Main --------------------------------------------------------------------
+
+func processOneEvent() {
+ select {
+ case <-sigs:
+ if inShutdown {
+ forceShutdown("requested by user")
+ } else {
+ initiateShutdown()
+ }
+
+ case <-shutdownTimer:
+ forceShutdown("timeout")
+
+ case conn := <-conns:
+ log.Println("accepted client connection")
+ c := &client{transport: conn}
+ clients[c] = true
+ go prepare(c)
+
+ case ev := <-prepared:
+ log.Println("client is ready, resolved to", ev.host)
+ if _, ok := clients[ev.client]; ok {
+ ev.client.onPrepared(ev.isTLS)
+ }
+
+ case ev := <-reads:
+ log.Println("received data from client")
+ if _, ok := clients[ev.client]; ok {
+ ev.client.onRead(ev.data, ev.err)
+ }
+
+ case ev := <-writes:
+ log.Println("sent data to client")
+ if _, ok := clients[ev.client]; ok {
+ ev.client.onWrite(ev.written, ev.err)
+ }
+
+ case c := <-timeouts:
+ if _, ok := clients[c]; ok {
+ log.Println("client timeouted")
+ c.destroy()
+ }
+ }
+}
+
+func main() {
+ // Just deal with unexpected flags, we don't use any ourselves.
+ flag.Parse()
+
+ if len(flag.Args()) != 3 {
+ log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0])
+ }
+
+ cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0))
+ if err != nil {
+ log.Fatalln(err)
+ }
+
+ tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}}
+ listener, err = net.Listen("tcp", flag.Arg(2))
+ if err != nil {
+ log.Fatalln(err)
+ }
+
+ go accept(listener)
+ signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
+
+ for !inShutdown || len(clients) > 0 {
+ processOneEvent()
+ }
+}