// Copyright (c) 2022, Přemysl Eric Janouch <p@janouch.name>
// SPDX-License-Identifier: 0BSD

package main

import (
	"bufio"
	"context"
	"encoding/binary"
	"encoding/json"
	"flag"
	"fmt"
	"html/template"
	"io"
	"log"
	"net"
	"net/http"
	"os"
	"time"

	"nhooyr.io/websocket"
)

var (
	debug = flag.Bool("debug", false, "enable debug output")

	addressBind    string
	addressConnect string
	addressWS      string
)

// -----------------------------------------------------------------------------

func relayReadFrame(r io.Reader) []byte {
	var length uint32
	if err := binary.Read(r, binary.BigEndian, &length); err != nil {
		log.Println("Event receive failed: " + err.Error())
		return nil
	}
	b := make([]byte, length)
	if _, err := io.ReadFull(r, b); err != nil {
		log.Println("Event receive failed: " + err.Error())
		return nil
	}

	if *debug {
		log.Printf("<? %v\n", b)

		var m RelayEventMessage
		if after, ok := m.ConsumeFrom(b); !ok {
			log.Println("Event deserialization failed")
			return nil
		} else if len(after) != 0 {
			log.Println("Event deserialization failed: trailing data")
			return nil
		}

		j, err := m.MarshalJSON()
		if err != nil {
			log.Println("Event marshalling failed: " + err.Error())
			return nil
		}

		log.Printf("<- %s\n", j)
	}
	return b
}

func relayMakeReceiver(ctx context.Context, conn net.Conn) <-chan []byte {
	// The usual event message rarely gets above 1 kilobyte,
	// thus this is set to buffer up at most 1 megabyte or so.
	p := make(chan []byte, 1000)
	r := bufio.NewReaderSize(conn, 65536)
	go func() {
		defer close(p)
		for {
			j := relayReadFrame(r)
			if j == nil {
				return
			}
			select {
			case p <- j:
			case <-ctx.Done():
				return
			}
		}
	}()
	return p
}

func relayWriteJSON(conn net.Conn, j []byte) bool {
	var m RelayCommandMessage
	if err := json.Unmarshal(j, &m); err != nil {
		log.Println("Command unmarshalling failed: " + err.Error())
		return false
	}

	b, ok := m.AppendTo(make([]byte, 4))
	if !ok {
		log.Println("Command serialization failed")
		return false
	}
	binary.BigEndian.PutUint32(b[:4], uint32(len(b)-4))
	if _, err := conn.Write(b); err != nil {
		log.Println("Command send failed: " + err.Error())
		return false
	}

	if *debug {
		log.Printf("-> %v\n", b)
	}
	return true
}

// -----------------------------------------------------------------------------

func clientReadJSON(ctx context.Context, ws *websocket.Conn) []byte {
	t, j, err := ws.Read(ctx)
	if err != nil {
		log.Println("Command receive failed: " + err.Error())
		return nil
	}
	if t != websocket.MessageText {
		log.Println(
			"Command receive failed: " + "binary messages are not supported")
		return nil
	}

	if *debug {
		log.Printf("?> %s\n", j)
	}
	return j
}

func clientWriteBinary(ctx context.Context, ws *websocket.Conn, b []byte) bool {
	if err := ws.Write(ctx, websocket.MessageBinary, b); err != nil {
		log.Println("Event send failed: " + err.Error())
		return false
	}
	return true
}

func clientWriteError(ctx context.Context, ws *websocket.Conn, err error) bool {
	b, ok := (&RelayEventMessage{
		EventSeq: 0,
		Data: RelayEventData{
			Interface: RelayEventDataError{
				Event:      RelayEventError,
				CommandSeq: 0,
				Error:      err.Error(),
			},
		},
	}).AppendTo(nil)
	if ok {
		log.Println("Event serialization failed")
		return false
	}
	return clientWriteBinary(ctx, ws, b)
}

func handleWS(w http.ResponseWriter, r *http.Request) {
	ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
		InsecureSkipVerify: true,
		// Note that Safari can be broken with compression.
		CompressionMode: websocket.CompressionContextTakeover,
		// This is for the payload; set higher to avoid overhead.
		CompressionThreshold: 64 << 10,
	})
	if err != nil {
		log.Println("Client rejected: " + err.Error())
		return
	}
	defer ws.Close(websocket.StatusGoingAway, "Goodbye")

	ctx, cancel := context.WithCancel(r.Context())
	defer cancel()

	conn, err := net.Dial("tcp", addressConnect)
	if err != nil {
		log.Println("Connection failed: " + err.Error())
		clientWriteError(ctx, ws, err)
		return
	}
	defer conn.Close()

	// To decrease latencies, events are received and decoded in parallel
	// to their sending, and we try to batch them together.
	relayFrames := relayMakeReceiver(ctx, conn)
	batchFrames := func() []byte {
		batch, ok := <-relayFrames
		if !ok {
			return nil
		}
	Batch:
		for {
			select {
			case b, ok := <-relayFrames:
				if !ok {
					break Batch
				}
				batch = append(batch, b...)
			default:
				break Batch
			}
		}
		return batch
	}

	// We don't need to intervene, so it's just two separate pipes so far.
	go func() {
		defer cancel()
		for {
			j := clientReadJSON(ctx, ws)
			if j == nil {
				return
			}
			relayWriteJSON(conn, j)
		}
	}()
	go func() {
		defer cancel()
		for {
			b := batchFrames()
			if b == nil {
				return
			}
			clientWriteBinary(ctx, ws, b)
		}
	}()
	<-ctx.Done()
}

// -----------------------------------------------------------------------------

var staticHandler = http.FileServer(http.Dir("."))

var page = template.Must(template.New("/").Parse(`<!DOCTYPE html>
<html>
<head>
	<title>xP</title>
	<meta charset="utf-8" />
	<meta name="viewport" content="width=device-width, initial-scale=1">
	<link rel="stylesheet" href="xP.css" />
</head>
<body>
	<script src="mithril.js">
	</script>
	<script>
	let proxy = '{{ . }}'
	</script>
	<script type="module" src="xP.js">
	</script>
</body>
</html>`))

func handleDefault(w http.ResponseWriter, r *http.Request) {
	if r.URL.Path != "/" {
		staticHandler.ServeHTTP(w, r)
		return
	}

	wsURI := addressWS
	if wsURI == "" {
		wsURI = fmt.Sprintf("ws://%s/ws", r.Host)
	}
	if err := page.Execute(w, wsURI); err != nil {
		log.Println("Template execution failed: " + err.Error())
	}
}

func main() {
	flag.Usage = func() {
		fmt.Fprintf(flag.CommandLine.Output(),
			"Usage: %s [OPTION...] BIND CONNECT [WSURI]\n\n", os.Args[0])
		flag.PrintDefaults()
	}

	flag.Parse()
	if flag.NArg() < 2 || flag.NArg() > 3 {
		flag.Usage()
		os.Exit(1)
	}

	addressBind, addressConnect = flag.Arg(0), flag.Arg(1)
	if flag.NArg() > 2 {
		addressWS = flag.Arg(2)
	}

	http.Handle("/ws", http.HandlerFunc(handleWS))
	http.Handle("/", http.HandlerFunc(handleDefault))

	s := &http.Server{
		Addr:           addressBind,
		ReadTimeout:    60 * time.Second,
		WriteTimeout:   60 * time.Second,
		MaxHeaderBytes: 32 << 10,
	}
	log.Fatal(s.ListenAndServe())
}