diff options
Diffstat (limited to 'hnc')
-rw-r--r-- | hnc/main.go | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/hnc/main.go b/hnc/main.go new file mode 100644 index 0000000..45254f6 --- /dev/null +++ b/hnc/main.go @@ -0,0 +1,150 @@ +// +// Copyright (c) 2018, Přemysl Janouch <p@janouch.name> +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +// + +// hnc is a netcat-alike that shuts down properly. +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "io" + "net" + "os" +) + +// #include <unistd.h> +import "C" + +func isatty(fd uintptr) bool { return C.isatty(C.int(fd)) != 0 } + +func log(format string, args ...interface{}) { + msg := fmt.Sprintf(format+"\n", args...) + if isatty(os.Stderr.Fd()) { + msg = "\x1b[0;1;31m" + msg + "\x1b[m" + } + os.Stderr.WriteString(msg) +} + +var ( + flagTLS = flag.Bool("tls", false, "connect using TLS") + flagCRLF = flag.Bool("crlf", false, "translate LF into CRLF") +) + +// Network connection that can shut down the write end. +type connCloseWriter interface { + net.Conn + CloseWrite() error +} + +func dial(address string) (connCloseWriter, error) { + if *flagTLS { + return tls.Dial("tcp", address, &tls.Config{ + InsecureSkipVerify: true, + }) + } + transport, err := net.Dial("tcp", address) + if err != nil { + return nil, err + } + return transport.(connCloseWriter), nil +} + +func expand(raw []byte) []byte { + if !*flagCRLF { + return raw + } + var res []byte + for _, b := range raw { + if b == '\n' { + res = append(res, '\r') + } + res = append(res, b) + } + return res +} + +// Asynchronously delivered result of io.Reader. +type readResult struct { + b []byte + err error +} + +func read(r io.Reader, ch chan<- readResult) { + defer close(ch) + for { + var buf [8192]byte + n, err := r.Read(buf[:]) + ch <- readResult{buf[:n], err} + if err != nil { + break + } + } +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), + "Usage: %s [OPTION]... HOST PORT\n"+ + "Connect to a remote host over TCP/IP.\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.Parse() + if flag.NArg() != 2 { + flag.Usage() + os.Exit(2) + } + + conn, err := dial(net.JoinHostPort(flag.Arg(0), flag.Arg(1))) + if err != nil { + log("dial: %s", err) + os.Exit(1) + } + + fromUser := make(chan readResult) + go read(os.Stdin, fromUser) + + fromConn := make(chan readResult) + go read(conn, fromConn) + + for fromUser != nil || fromConn != nil { + select { + case result := <-fromUser: + if len(result.b) > 0 { + if _, err := conn.Write(expand(result.b)); err != nil { + log("remote: %s", err) + } + } + if result.err != nil { + log("%s: %s", "stdin", result.err) + fromUser = nil + if err := conn.CloseWrite(); err != nil { + log("remote: %s", err) + } + } + case result := <-fromConn: + if len(result.b) > 0 { + if _, err := os.Stdout.Write(result.b); err != nil { + log("stdout: %s", err) + } + } + if result.err != nil { + log("remote: %s", result.err) + fromConn = nil + } + } + } +} |