/*
 * common.c: common functionality
 *
 * Copyright (c) 2014, Přemysl Janouch <p.janouch@gmail.com>
 * All rights reserved.
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

#define LIBERTY_WANT_SSL
#define LIBERTY_WANT_POLLER
#define LIBERTY_WANT_PROTO_IRC

#ifdef WANT_SYSLOG_LOGGING
	#define print_fatal_data    ((void *) LOG_ERR)
	#define print_error_data    ((void *) LOG_ERR)
	#define print_warning_data  ((void *) LOG_WARNING)
	#define print_status_data   ((void *) LOG_INFO)
	#define print_debug_data    ((void *) LOG_DEBUG)
#endif // WANT_SYSLOG_LOGGING

#include "liberty/liberty.c"
#include <arpa/inet.h>

// --- Logging -----------------------------------------------------------------

static void
log_message_syslog (void *user_data, const char *quote, const char *fmt,
	va_list ap)
{
	int prio = (int) (intptr_t) user_data;

	va_list va;
	va_copy (va, ap);
	int size = vsnprintf (NULL, 0, fmt, va);
	va_end (va);
	if (size < 0)
		return;

	char buf[size + 1];
	if (vsnprintf (buf, sizeof buf, fmt, ap) >= 0)
		syslog (prio, "%s%s", quote, buf);
}

// --- SOCKS 5/4a (blocking implementation) ------------------------------------

// These are awkward protocols.  Note that the `username' is used differently
// in SOCKS 4a and 5.  In the former version, it is the username that you can
// get ident'ed against.  In the latter version, it forms a pair with the
// password field and doesn't need to be an actual user on your machine.

// TODO: make a non-blocking poller-based version of this;
//   either use c-ares or (even better) start another thread to do resolution

struct socks_addr
{
	enum socks_addr_type
	{
		SOCKS_IPV4 = 1,                 ///< IPv4 address
		SOCKS_DOMAIN = 3,               ///< Domain name to be resolved
		SOCKS_IPV6 = 4                  ///< IPv6 address
	}
	type;                               ///< The type of this address
	union
	{
		uint8_t ipv4[4];                ///< IPv4 address, network octet order
		const char *domain;             ///< Domain name
		uint8_t ipv6[16];               ///< IPv6 address, network octet order
	}
	data;                               ///< The address itself
};

struct socks_data
{
	struct socks_addr address;          ///< Target address
	uint16_t port;                      ///< Target port
	const char *username;               ///< Authentication username
	const char *password;               ///< Authentication password

	struct socks_addr bound_address;    ///< Bound address at the server
	uint16_t bound_port;                ///< Bound port at the server
};

static bool
socks_get_socket (struct addrinfo *addresses, int *fd, struct error **e)
{
	int sockfd;
	for (; addresses; addresses = addresses->ai_next)
	{
		sockfd = socket (addresses->ai_family,
			addresses->ai_socktype, addresses->ai_protocol);
		if (sockfd == -1)
			continue;
		set_cloexec (sockfd);

		int yes = 1;
		soft_assert (setsockopt (sockfd, SOL_SOCKET, SO_KEEPALIVE,
			&yes, sizeof yes) != -1);

		if (!connect (sockfd, addresses->ai_addr, addresses->ai_addrlen))
			break;
		xclose (sockfd);
	}
	if (!addresses)
	{
		error_set (e, "couldn't connect to the SOCKS server");
		return false;
	}
	*fd = sockfd;
	return true;
}

#define SOCKS_FAIL(...)                                                        \
	BLOCK_START                                                                \
		error_set (e, __VA_ARGS__);                                            \
		goto fail;                                                             \
	BLOCK_END
#define SOCKS_RECV(buf, len)                                                   \
	BLOCK_START                                                                \
		if ((n = recv (sockfd, (buf), (len), 0)) == -1)                        \
			SOCKS_FAIL ("%s: %s", "recv", strerror (errno));                   \
		if (n != (len))                                                        \
			SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF");         \
	BLOCK_END

static bool
socks_4a_connect (struct addrinfo *addresses, struct socks_data *data,
	int *fd, struct error **e)
{
	int sockfd;
	if (!socks_get_socket (addresses, &sockfd, e))
		return false;

	const void *dest_ipv4 = "\x00\x00\x00\x01";
	const char *dest_domain = NULL;

	char buf[INET6_ADDRSTRLEN];
	switch (data->address.type)
	{
	case SOCKS_IPV4:
		dest_ipv4 = data->address.data.ipv4;
		break;
	case SOCKS_IPV6:
		// About the best thing we can do, not sure if it works anywhere at all
		if (!inet_ntop (AF_INET6, &data->address.data.ipv6, buf, sizeof buf))
			SOCKS_FAIL ("%s: %s", "inet_ntop", strerror (errno));
		dest_domain = buf;
		break;
	case SOCKS_DOMAIN:
		dest_domain = data->address.data.domain;
	}

	struct str req;
	str_init (&req);
	str_append_c (&req, 4);                // version
	str_append_c (&req, 1);                // connect

	str_append_c (&req, data->port >> 8);  // higher bits of port
	str_append_c (&req, data->port);       // lower bits of port
	str_append_data (&req, dest_ipv4, 4);  // destination address

	if (data->username)
		str_append (&req, data->username);
	str_append_c (&req, '\0');

	if (dest_domain)
	{
		str_append (&req, dest_domain);
		str_append_c (&req, '\0');
	}

	ssize_t n = send (sockfd, req.str, req.len, 0);
	str_free (&req);
	if (n == -1)
		SOCKS_FAIL ("%s: %s", "send", strerror (errno));

	uint8_t resp[8];
	SOCKS_RECV (resp, sizeof resp);
	if (resp[0] != 0)
		SOCKS_FAIL ("protocol error");

	switch (resp[1])
	{
	case 90:
		break;
	case 91:
		SOCKS_FAIL ("request rejected or failed");
	case 92:
		SOCKS_FAIL ("%s: %s", "request rejected",
			"SOCKS server cannot connect to identd on the client");
	case 93:
		SOCKS_FAIL ("%s: %s", "request rejected",
			"identd reports different user-id");
	default:
		SOCKS_FAIL ("protocol error");
	}

	*fd = sockfd;
	return true;

fail:
	xclose (sockfd);
	return false;
}

#undef SOCKS_FAIL
#define SOCKS_FAIL(...)                                                        \
	BLOCK_START                                                                \
		error_set (e, __VA_ARGS__);                                            \
		return false;                                                          \
	BLOCK_END

static bool
socks_5_userpass_auth (int sockfd, struct socks_data *data, struct error **e)
{
	size_t ulen = strlen (data->username);
	if (ulen > 255)
		ulen = 255;

	size_t plen = strlen (data->password);
	if (plen > 255)
		plen = 255;

	uint8_t req[3 + ulen + plen], *p = req;
	*p++ = 0x01;  // version
	*p++ = ulen;  // username length
	memcpy (p, data->username, ulen);
	p += ulen;
	*p++ = plen;  // password length
	memcpy (p, data->password, plen);
	p += plen;

	ssize_t n = send (sockfd, req, p - req, 0);
	if (n == -1)
		SOCKS_FAIL ("%s: %s", "send", strerror (errno));

	uint8_t resp[2];
	SOCKS_RECV (resp, sizeof resp);
	if (resp[0] != 0x01)
		SOCKS_FAIL ("protocol error");
	if (resp[1] != 0x00)
		SOCKS_FAIL ("authentication failure");
	return true;
}

static bool
socks_5_auth (int sockfd, struct socks_data *data, struct error **e)
{
	bool can_auth = data->username && data->password;

	uint8_t hello[4];
	hello[0] = 0x05;          // version
	hello[1] = 1 + can_auth;  // number of authentication methods
	hello[2] = 0x00;          // no authentication required
	hello[3] = 0x02;          // username/password

	ssize_t n = send (sockfd, hello, 3 + can_auth, 0);
	if (n == -1)
		SOCKS_FAIL ("%s: %s", "send", strerror (errno));

	uint8_t resp[2];
	SOCKS_RECV (resp, sizeof resp);
	if (resp[0] != 0x05)
		SOCKS_FAIL ("protocol error");

	switch (resp[1])
	{
	case 0x02:
		if (!can_auth)
			SOCKS_FAIL ("protocol error");
		if (!socks_5_userpass_auth (sockfd, data, e))
			return false;
	case 0x00:
		break;
	case 0xFF:
		SOCKS_FAIL ("no acceptable authentication methods");
	default:
		SOCKS_FAIL ("protocol error");
	}
	return true;
}

static bool
socks_5_send_req (int sockfd, struct socks_data *data, struct error **e)
{
	uint8_t req[4 + 256 + 2], *p = req;
	*p++ = 0x05;  // version
	*p++ = 0x01;  // connect
	*p++ = 0x00;  // reserved
	*p++ = data->address.type;

	switch (data->address.type)
	{
	case SOCKS_IPV4:
		memcpy (p, data->address.data.ipv4, sizeof data->address.data.ipv4);
		p += sizeof data->address.data.ipv4;
		break;
	case SOCKS_DOMAIN:
	{
		size_t dlen = strlen (data->address.data.domain);
		if (dlen > 255)
			dlen = 255;

		*p++ = dlen;
		memcpy (p, data->address.data.domain, dlen);
		p += dlen;
		break;
	}
	case SOCKS_IPV6:
		memcpy (p, data->address.data.ipv6, sizeof data->address.data.ipv6);
		p += sizeof data->address.data.ipv6;
		break;
	}
	*p++ = data->port >> 8;
	*p++ = data->port;

	if (send (sockfd, req, p - req, 0) == -1)
		SOCKS_FAIL ("%s: %s", "send", strerror (errno));
	return true;
}

static bool
socks_5_process_resp (int sockfd, struct socks_data *data, struct error **e)
{
	uint8_t resp_header[4];
	ssize_t n;
	SOCKS_RECV (resp_header, sizeof resp_header);
	if (resp_header[0] != 0x05)
		SOCKS_FAIL ("protocol error");

	switch (resp_header[1])
	{
	case 0x00:
		break;
	case 0x01:  SOCKS_FAIL ("general SOCKS server failure");
	case 0x02:  SOCKS_FAIL ("connection not allowed by ruleset");
	case 0x03:  SOCKS_FAIL ("network unreachable");
	case 0x04:  SOCKS_FAIL ("host unreachable");
	case 0x05:  SOCKS_FAIL ("connection refused");
	case 0x06:  SOCKS_FAIL ("TTL expired");
	case 0x07:  SOCKS_FAIL ("command not supported");
	case 0x08:  SOCKS_FAIL ("address type not supported");
	default:    SOCKS_FAIL ("protocol error");
	}

	switch ((data->bound_address.type = resp_header[3]))
	{
	case SOCKS_IPV4:
		SOCKS_RECV (data->bound_address.data.ipv4,
			sizeof data->bound_address.data.ipv4);
		break;
	case SOCKS_IPV6:
		SOCKS_RECV (data->bound_address.data.ipv6,
			sizeof data->bound_address.data.ipv6);
		break;
	case SOCKS_DOMAIN:
	{
		uint8_t len;
		SOCKS_RECV (&len, sizeof len);

		char domain[len + 1];
		SOCKS_RECV (domain, len);
		domain[len] = '\0';

		data->bound_address.data.domain = xstrdup (domain);
		break;
	}
	default:
		SOCKS_FAIL ("protocol error");
	}

	uint16_t port;
	SOCKS_RECV (&port, sizeof port);
	data->bound_port = ntohs (port);
	return true;
}

#undef SOCKS_FAIL
#undef SOCKS_RECV

static bool
socks_5_connect (struct addrinfo *addresses, struct socks_data *data,
	int *fd, struct error **e)
{
	int sockfd;
	if (!socks_get_socket (addresses, &sockfd, e))
		return false;

	if (!socks_5_auth (sockfd, data, e)
	 || !socks_5_send_req (sockfd, data, e)
	 || !socks_5_process_resp (sockfd, data, e))
	{
		xclose (sockfd);
		return false;
	}

	*fd = sockfd;
	return true;
}

static int
socks_connect (const char *socks_host, const char *socks_port,
	const char *host, const char *port,
	const char *username, const char *password, struct error **e)
{
	int result = -1;
	struct addrinfo gai_hints, *gai_result;
	memset (&gai_hints, 0, sizeof gai_hints);
	gai_hints.ai_socktype = SOCK_STREAM;

	unsigned long port_no;
	const struct servent *serv;
	if ((serv = getservbyname (port, "tcp")))
		port_no = (uint16_t) ntohs (serv->s_port);
	else if (!xstrtoul (&port_no, port, 10) || !port_no || port_no > UINT16_MAX)
	{
		error_set (e, "invalid port number");
		goto fail;
	}

	int err = getaddrinfo (socks_host, socks_port, &gai_hints, &gai_result);
	if (err)
	{
		error_set (e, "%s: %s", "getaddrinfo", gai_strerror (err));
		goto fail;
	}

	struct socks_data data =
		{ .username = username, .password = password, .port = port_no };

	if      (inet_pton (AF_INET,  host, &data.address.data.ipv4) == 1)
		data.address.type = SOCKS_IPV4;
	else if (inet_pton (AF_INET6, host, &data.address.data.ipv6) == 1)
		data.address.type = SOCKS_IPV6;
	else
	{
		data.address.type = SOCKS_DOMAIN;
		data.address.data.domain = host;
	}

	if (!socks_5_connect (gai_result, &data, &result, NULL))
		socks_4a_connect (gai_result, &data, &result, e);

	if (data.bound_address.type == SOCKS_DOMAIN)
		free ((char *) data.bound_address.data.domain);
	freeaddrinfo (gai_result);
fail:
	return result;
}