package main

import (
	"context"
	"encoding/binary"
	"encoding/json"
	"fmt"
	"html/template"
	"io"
	"log"
	"net"
	"net/http"
	"os"
	"time"

	"golang.org/x/net/websocket"
)

var (
	addressBind    string
	addressConnect string
)

func clientToRelay(
	ctx context.Context, ws *websocket.Conn, conn net.Conn) bool {
	var j string
	if err := websocket.Message.Receive(ws, &j); err != nil {
		log.Println("Command receive failed: " + err.Error())
		return false
	}

	log.Printf("?> %s\n", j)

	var m RelayCommandMessage
	if err := json.Unmarshal([]byte(j), &m); err != nil {
		log.Println("Command unmarshalling failed: " + err.Error())
		return false
	}

	b, ok := m.AppendTo(make([]byte, 4))
	if !ok {
		log.Println("Command serialization failed")
		return false
	}
	binary.BigEndian.PutUint32(b[:4], uint32(len(b)-4))
	if _, err := conn.Write(b); err != nil {
		log.Println("Command send failed: " + err.Error())
		return false
	}

	log.Printf("-> %v\n", b)
	return true
}

func relayToClient(
	ctx context.Context, ws *websocket.Conn, conn net.Conn) bool {
	var length uint32
	if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
		log.Println("Event receive failed: " + err.Error())
		return false
	}
	b := make([]byte, length)
	if _, err := io.ReadFull(conn, b); err != nil {
		log.Println("Event receive failed: " + err.Error())
		return false
	}

	log.Printf("<? %v\n", b)

	var m RelayEventMessage
	if after, ok := m.ConsumeFrom(b); !ok {
		log.Println("Event deserialization failed")
		return false
	} else if len(after) != 0 {
		log.Println("Event deserialization failed: trailing data")
		return false
	}

	j, err := json.Marshal(&m)
	if err != nil {
		log.Println("Event marshalling failed: " + err.Error())
		return false
	}
	if err := websocket.Message.Send(ws, string(j)); err != nil {
		log.Println("Event send failed: " + err.Error())
		return false
	}

	log.Printf("<- %s\n", j)
	return true
}

func errorToClient(ws *websocket.Conn, err error) bool {
	j, err := json.Marshal(&RelayEventMessage{
		EventSeq: 0,
		Data: RelayEventData{
			Interface: RelayEventDataError{
				Event:      RelayEventError,
				CommandSeq: 0,
				Error:      err.Error(),
			},
		},
	})
	if err != nil {
		log.Println("Event marshalling failed: " + err.Error())
		return false
	}
	if err := websocket.Message.Send(ws, string(j)); err != nil {
		log.Println("Event send failed: " + err.Error())
		return false
	}
	return true
}

func handleWebSocket(ws *websocket.Conn) {
	conn, err := net.Dial("tcp", addressConnect)
	if err != nil {
		errorToClient(ws, err)
		return
	}

	// We don't need to intervene, so it's just two separate pipes so far.
	ctx, cancel := context.WithCancel(ws.Request().Context())
	go func() {
		for clientToRelay(ctx, ws, conn) {
		}
		cancel()
	}()
	go func() {
		for relayToClient(ctx, ws, conn) {
		}
		cancel()
	}()
	<-ctx.Done()
}

var staticHandler = http.FileServer(http.Dir("."))

var page = template.Must(template.New("/").Parse(`<!DOCTYPE html>
<html>
<head>
	<title>xP</title>
	<meta charset="utf-8" />
	<link rel="stylesheet" href="xP.css" />
</head>
<body>
	<script src="mithril.js">
	</script>
	<script>
	let proxy = '{{ . }}'
	</script>
	<script src="xP.js">
	</script>
</body>
</html>`))

func handleDefault(w http.ResponseWriter, r *http.Request) {
	if r.URL.Path != "/" {
		staticHandler.ServeHTTP(w, r)
		return
	}

	wsURI := fmt.Sprintf("ws://%s/ws", r.Host)
	if err := page.Execute(w, wsURI); err != nil {
		log.Println("Template execution failed: " + err.Error())
	}
}

func main() {
	if len(os.Args) != 3 {
		log.Fatalf("usage: %s BIND CONNECT\n", os.Args[0])
	}

	addressBind, addressConnect = os.Args[1], os.Args[2]

	http.Handle("/ws", websocket.Handler(handleWebSocket))
	http.Handle("/", http.HandlerFunc(handleDefault))

	s := &http.Server{
		Addr:           addressBind,
		ReadTimeout:    60 * time.Second,
		WriteTimeout:   60 * time.Second,
		MaxHeaderBytes: 32 << 10,
	}
	log.Fatal(s.ListenAndServe())
}