/*
 * common.c: common functionality
 *
 * Copyright (c) 2014 - 2015, Přemysl Janouch <p.janouch@gmail.com>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

#define LIBERTY_WANT_SSL
#define LIBERTY_WANT_ASYNC
#define LIBERTY_WANT_POLLER
#define LIBERTY_WANT_PROTO_IRC

#ifdef WANT_SYSLOG_LOGGING
	#define print_fatal_data    ((void *) LOG_ERR)
	#define print_error_data    ((void *) LOG_ERR)
	#define print_warning_data  ((void *) LOG_WARNING)
	#define print_status_data   ((void *) LOG_INFO)
	#define print_debug_data    ((void *) LOG_DEBUG)
#endif // WANT_SYSLOG_LOGGING

#include "liberty/liberty.c"
#include <arpa/inet.h>
#include <netinet/tcp.h>

/// Shorthand to set an error and return failure from the function
#define FAIL(...)                                                              \
	BLOCK_START                                                                \
		error_set (e, __VA_ARGS__);                                            \
		return 0;                                                              \
	BLOCK_END

// --- To be moved to liberty --------------------------------------------------

static ssize_t
str_vector_find (const struct str_vector *v, const char *s)
{
	for (size_t i = 0; i < v->len; i++)
		if (!strcmp (v->vector[i], s))
			return i;
	return -1;
}

static time_t
unixtime_msec (long *msec)
{
#ifdef _POSIX_TIMERS
        struct timespec tp;
        hard_assert (clock_gettime (CLOCK_REALTIME, &tp) != -1);
        *msec = tp.tv_nsec / 1000000;
#else // ! _POSIX_TIMERS
        struct timeval tp;
        hard_assert (gettimeofday (&tp, NULL) != -1);
        *msec = tp.tv_usec / 1000;
#endif // ! _POSIX_TIMERS
        return tp.tv_sec;
}

/// This differs from the non-unique version in that we expect the filename
/// to be something like a pattern for mkstemp(), so the resulting path can
/// reside in a system-wide directory with no risk of a conflict.
static char *
resolve_relative_runtime_unique_filename (const char *filename)
{
	struct str path;
	str_init (&path);

	const char *runtime_dir = getenv ("XDG_RUNTIME_DIR");
	if (runtime_dir && *runtime_dir == '/')
		str_append (&path, runtime_dir);
	else
		str_append (&path, "/tmp");
	str_append_printf (&path, "/%s/%s", PROGRAM_NAME, filename);

	// Try to create the file's ancestors;
	// typically the user will want to immediately create a file in there
	const char *last_slash = strrchr (path.str, '/');
	if (last_slash && last_slash != path.str)
	{
		char *copy = xstrndup (path.str, last_slash - path.str);
		(void) mkdir_with_parents (copy, NULL);
		free (copy);
	}
	return str_steal (&path);
}

static bool
xwrite (int fd, const char *data, size_t len, struct error **e)
{
	size_t written = 0;
	while (written < len)
	{
		ssize_t res = write (fd, data + written, len - written);
		if (res >= 0)
			written += res;
		else if (errno != EINTR)
			FAIL ("%s", strerror (errno));
	}
	return true;
}

// --- Simple network I/O ------------------------------------------------------

// TODO: move to liberty and remove from dwmstatus.c as well

#define SOCKET_IO_OVERFLOW (8 << 20)    ///< How large a read buffer can be

enum socket_io_result
{
	SOCKET_IO_OK,                       ///< Completed successfully
	SOCKET_IO_EOF,                      ///< Connection shut down by peer
	SOCKET_IO_ERROR                     ///< Connection error
};

static enum socket_io_result
socket_io_try_read (int socket_fd, struct str *rb, struct error **e)
{
	// We allow buffering of a fair amount of data, however within reason,
	// so that it's not so easy to flood us and cause an allocation failure
	ssize_t n_read;
	while (rb->len < SOCKET_IO_OVERFLOW)
	{
		str_ensure_space (rb, 4096);
		n_read = recv (socket_fd, rb->str + rb->len,
			rb->alloc - rb->len - 1 /* null byte */, 0);

		if (n_read > 0)
		{
			rb->str[rb->len += n_read] = '\0';
			continue;
		}
		if (n_read == 0)
			return SOCKET_IO_EOF;

		if (errno == EAGAIN)
			return SOCKET_IO_OK;
		if (errno == EINTR)
			continue;

		error_set (e, "%s", strerror (errno));
		return SOCKET_IO_ERROR;
	}
	return SOCKET_IO_OK;
}

static enum socket_io_result
socket_io_try_write (int socket_fd, struct str *wb, struct error **e)
{
	ssize_t n_written;
	while (wb->len)
	{
		n_written = send (socket_fd, wb->str, wb->len, 0);
		if (n_written >= 0)
		{
			str_remove_slice (wb, 0, n_written);
			continue;
		}

		if (errno == EAGAIN)
			return SOCKET_IO_OK;
		if (errno == EINTR)
			continue;

		error_set (e, "%s", strerror (errno));
		return SOCKET_IO_ERROR;
	}
	return SOCKET_IO_OK;
}

// --- Logging -----------------------------------------------------------------

static void
log_message_syslog (void *user_data, const char *quote, const char *fmt,
	va_list ap)
{
	int prio = (int) (intptr_t) user_data;

	va_list va;
	va_copy (va, ap);
	int size = vsnprintf (NULL, 0, fmt, va);
	va_end (va);
	if (size < 0)
		return;

	char buf[size + 1];
	if (vsnprintf (buf, sizeof buf, fmt, ap) >= 0)
		syslog (prio, "%s%s", quote, buf);
}

// --- SOCKS 5/4a --------------------------------------------------------------

// Asynchronous SOCKS connector.  Adds more stuff on top of the regular one.

// Note that the `username' is used differently in SOCKS 4a and 5.  In the
// former version, it is the username that you can get ident'ed against.
// In the latter version, it forms a pair with the password field and doesn't
// need to be an actual user on your machine.

struct socks_addr
{
	enum socks_addr_type
	{
		SOCKS_IPV4 = 1,                 ///< IPv4 address
		SOCKS_DOMAIN = 3,               ///< Domain name to be resolved
		SOCKS_IPV6 = 4                  ///< IPv6 address
	}
	type;                               ///< The type of this address
	union
	{
		uint8_t ipv4[4];                ///< IPv4 address, network octet order
		char *domain;                   ///< Domain name
		uint8_t ipv6[16];               ///< IPv6 address, network octet order
	}
	data;                               ///< The address itself
};

static void
socks_addr_free (struct socks_addr *self)
{
	if (self->type == SOCKS_DOMAIN)
		free (self->data.domain);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

struct socks_target
{
	LIST_HEADER (struct socks_target)

	char *address_str;                  ///< Target address as a string
	struct socks_addr address;          ///< Target address
	uint16_t port;                      ///< Target service port
};

enum socks_protocol
{
	SOCKS_5,                            ///< SOCKS5
	SOCKS_4A,                           ///< SOCKS4A
	SOCKS_MAX                           ///< End of protocol
};

static inline const char *
socks_protocol_to_string (enum socks_protocol self)
{
	switch (self)
	{
	case SOCKS_5:   return "SOCKS5";
	case SOCKS_4A:  return "SOCKS4A";
	default:        return NULL;
	}
}

struct socks_connector
{
	struct connector *connector;        ///< Proxy server iterator (effectively)
	enum socks_protocol protocol_iter;  ///< Protocol iterator
	struct socks_target *targets_iter;  ///< Targets iterator

	// Negotiation:

	struct poller_timer timeout;        ///< Timeout timer

	int socket_fd;                      ///< Current socket file descriptor
	struct poller_fd socket_event;      ///< Socket can be read from/written to
	struct str read_buffer;             ///< Read buffer
	struct str write_buffer;            ///< Write buffer

	bool done;                          ///< Tunnel succesfully established
	uint8_t bound_address_len;          ///< Length of domain name
	size_t data_needed;                 ///< How much data "on_data" needs

	/// Process incoming data if there's enough of it available
	bool (*on_data) (struct socks_connector *, struct msg_unpacker *);

	// Configuration:

	char *hostname;                     ///< SOCKS server hostname
	char *service;                      ///< SOCKS server service name or port

	char *username;                     ///< Username for authentication
	char *password;                     ///< Password for authentication

	struct socks_target *targets;       ///< Targets
	struct socks_target *targets_tail;  ///< Tail of targets

	void *user_data;                    ///< User data for callbacks

	// Additional results:

	struct socks_addr bound_address;    ///< Bound address at the server
	uint16_t bound_port;                ///< Bound port at the server

	// You may destroy the connector object in these two main callbacks:

	/// Connection has been successfully established
	void (*on_connected) (void *user_data, int socket, const char *hostname);
	/// Failed to establish a connection to either target
	void (*on_failure) (void *user_data);

	// Optional:

	/// Connecting to a new address
	void (*on_connecting) (void *user_data,
		const char *address, const char *via, const char *version);
	/// Connecting to the last address has failed
	void (*on_error) (void *user_data, const char *error);
};

// I've tried to make the actual protocol handlers as simple as possible

#define SOCKS_FAIL(...)                                                        \
	BLOCK_START                                                                \
		char *error = xstrdup_printf (__VA_ARGS__);                            \
		if (self->on_error)                                                    \
			self->on_error (self->user_data, error);                           \
		free (error);                                                          \
		return false;                                                          \
	BLOCK_END

#define SOCKS_DATA_CB(name) static bool name                                   \
	(struct socks_connector *self, struct msg_unpacker *unpacker)

#define SOCKS_GO(name, data_needed_)                                           \
	self->on_data = name;                                                      \
	self->data_needed = data_needed_;                                          \
	return true

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

SOCKS_DATA_CB (socks_4a_finish)
{
	uint8_t null, status;
	hard_assert (msg_unpacker_u8 (unpacker, &null));
	hard_assert (msg_unpacker_u8 (unpacker, &status));

	if (null != 0)
		SOCKS_FAIL ("protocol error");

	switch (status)
	{
	case 90:
		self->done = true;
		return false;
	case 91:
		SOCKS_FAIL ("request rejected or failed");
	case 92:
		SOCKS_FAIL ("%s: %s", "request rejected",
			"SOCKS server cannot connect to identd on the client");
	case 93:
		SOCKS_FAIL ("%s: %s", "request rejected",
			"identd reports different user-id");
	default:
		SOCKS_FAIL ("protocol error");
	}
}

static bool
socks_4a_start (struct socks_connector *self)
{
	struct socks_target *target = self->targets_iter;
	const void *dest_ipv4 = "\x00\x00\x00\x01";
	const char *dest_domain = NULL;

	char buf[INET6_ADDRSTRLEN];
	switch (target->address.type)
	{
	case SOCKS_IPV4:
		dest_ipv4 = target->address.data.ipv4;
		break;
	case SOCKS_IPV6:
		// About the best thing we can do, not sure if it works anywhere at all
		if (!inet_ntop (AF_INET6, &target->address.data.ipv6, buf, sizeof buf))
			SOCKS_FAIL ("%s: %s", "inet_ntop", strerror (errno));
		dest_domain = buf;
		break;
	case SOCKS_DOMAIN:
		dest_domain = target->address.data.domain;
	}

	struct str *wb = &self->write_buffer;
	str_pack_u8 (wb, 4);                  // version
	str_pack_u8 (wb, 1);                  // connect

	str_pack_u16 (wb, target->port);      // port
	str_append_data (wb, dest_ipv4, 4);   // destination address

	if (self->username)
		str_append (wb, self->username);
	str_append_c (wb, '\0');

	if (dest_domain)
	{
		str_append (wb, dest_domain);
		str_append_c (wb, '\0');
	}

	SOCKS_GO (socks_4a_finish, 8);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

SOCKS_DATA_CB (socks_5_request_port)
{
	hard_assert (msg_unpacker_u16 (unpacker, &self->bound_port));
	self->done = true;
	return false;
}

SOCKS_DATA_CB (socks_5_request_ipv4)
{
	memcpy (self->bound_address.data.ipv4, unpacker->data, unpacker->len);
	SOCKS_GO (socks_5_request_port, 2);
}

SOCKS_DATA_CB (socks_5_request_ipv6)
{
	memcpy (self->bound_address.data.ipv6, unpacker->data, unpacker->len);
	SOCKS_GO (socks_5_request_port, 2);
}

SOCKS_DATA_CB (socks_5_request_domain_data)
{
	self->bound_address.data.domain = xstrndup (unpacker->data, unpacker->len);
	SOCKS_GO (socks_5_request_port, 2);
}

SOCKS_DATA_CB (socks_5_request_domain)
{
	hard_assert (msg_unpacker_u8 (unpacker, &self->bound_address_len));
	SOCKS_GO (socks_5_request_domain_data, self->bound_address_len);
}

SOCKS_DATA_CB (socks_5_request_finish)
{
	uint8_t version, status, reserved, type;
	hard_assert (msg_unpacker_u8 (unpacker, &version));
	hard_assert (msg_unpacker_u8 (unpacker, &status));
	hard_assert (msg_unpacker_u8 (unpacker, &reserved));
	hard_assert (msg_unpacker_u8 (unpacker, &type));

	if (version != 0x05)
		SOCKS_FAIL ("protocol error");

	switch (status)
	{
	case 0x00:
		break;
	case 0x01:  SOCKS_FAIL ("general SOCKS server failure");
	case 0x02:  SOCKS_FAIL ("connection not allowed by ruleset");
	case 0x03:  SOCKS_FAIL ("network unreachable");
	case 0x04:  SOCKS_FAIL ("host unreachable");
	case 0x05:  SOCKS_FAIL ("connection refused");
	case 0x06:  SOCKS_FAIL ("TTL expired");
	case 0x07:  SOCKS_FAIL ("command not supported");
	case 0x08:  SOCKS_FAIL ("address type not supported");
	default:    SOCKS_FAIL ("protocol error");
	}

	switch ((self->bound_address.type = type))
	{
	case SOCKS_IPV4:
		SOCKS_GO (socks_5_request_ipv4, sizeof self->bound_address.data.ipv4);
	case SOCKS_IPV6:
		SOCKS_GO (socks_5_request_ipv6, sizeof self->bound_address.data.ipv6);
	case SOCKS_DOMAIN:
		SOCKS_GO (socks_5_request_domain, 1);
	default:
		SOCKS_FAIL ("protocol error");
	}
}

static bool
socks_5_request_start (struct socks_connector *self)
{
	struct socks_target *target = self->targets_iter;
	struct str *wb = &self->write_buffer;
	str_pack_u8 (wb, 0x05);  // version
	str_pack_u8 (wb, 0x01);  // connect
	str_pack_u8 (wb, 0x00);  // reserved
	str_pack_u8 (wb, target->address.type);

	switch (target->address.type)
	{
	case SOCKS_IPV4:
		str_append_data (wb,
			target->address.data.ipv4, sizeof target->address.data.ipv4);
		break;
	case SOCKS_DOMAIN:
	{
		size_t dlen = strlen (target->address.data.domain);
		if (dlen > 255)
			dlen = 255;

		str_pack_u8 (wb, dlen);
		str_append_data (wb, target->address.data.domain, dlen);
		break;
	}
	case SOCKS_IPV6:
		str_append_data (wb,
			target->address.data.ipv6, sizeof target->address.data.ipv6);
		break;
	}
	str_pack_u16 (wb, target->port);

	SOCKS_GO (socks_5_request_finish, 4);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

SOCKS_DATA_CB (socks_5_userpass_finish)
{
	uint8_t version, status;
	hard_assert (msg_unpacker_u8 (unpacker, &version));
	hard_assert (msg_unpacker_u8 (unpacker, &status));

	if (version != 0x01)
		SOCKS_FAIL ("protocol error");
	if (status != 0x00)
		SOCKS_FAIL ("authentication failure");

	return socks_5_request_start (self);
}

static bool
socks_5_userpass_start (struct socks_connector *self)
{
	size_t ulen = strlen (self->username);
	if (ulen > 255)
		ulen = 255;

	size_t plen = strlen (self->password);
	if (plen > 255)
		plen = 255;

	struct str *wb = &self->write_buffer;
	str_pack_u8 (wb, 0x01);  // version
	str_pack_u8 (wb, ulen);  // username length
	str_append_data (wb, self->username, ulen);
	str_pack_u8 (wb, plen);  // password length
	str_append_data (wb, self->password, plen);

	SOCKS_GO (socks_5_userpass_finish, 2);
}

SOCKS_DATA_CB (socks_5_auth_finish)
{
	uint8_t version, method;
	hard_assert (msg_unpacker_u8 (unpacker, &version));
	hard_assert (msg_unpacker_u8 (unpacker, &method));

	if (version != 0x05)
		SOCKS_FAIL ("protocol error");

	bool can_auth = self->username && self->password;

	switch (method)
	{
	case 0x02:
		if (!can_auth)
			SOCKS_FAIL ("protocol error");

		return socks_5_userpass_start (self);
	case 0x00:
		return socks_5_request_start (self);
	case 0xFF:
		SOCKS_FAIL ("no acceptable authentication methods");
	default:
		SOCKS_FAIL ("protocol error");
	}
}

static bool
socks_5_auth_start (struct socks_connector *self)
{
	bool can_auth = self->username && self->password;

	struct str *wb = &self->write_buffer;
	str_pack_u8 (wb, 0x05);          // version
	str_pack_u8 (wb, 1 + can_auth);  // number of authentication methods
	str_pack_u8 (wb, 0x00);          // no authentication required
	if (can_auth)
		str_pack_u8 (wb, 0x02);      // username/password

	SOCKS_GO (socks_5_auth_finish, 2);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

static void socks_connector_start (struct socks_connector *self);

static void
socks_connector_destroy_connector (struct socks_connector *self)
{
	if (self->connector)
	{
		connector_free (self->connector);
		free (self->connector);
		self->connector = NULL;
	}
}

static void
socks_connector_cancel_events (struct socks_connector *self)
{
	// Before calling the final callbacks, we should cancel events that
	// could potentially fire; caller should destroy us immediately, though
	poller_fd_reset (&self->socket_event);
	poller_timer_reset (&self->timeout);
}

static void
socks_connector_fail (struct socks_connector *self)
{
	socks_connector_cancel_events (self);
	self->on_failure (self->user_data);
}

static bool
socks_connector_step_iterators (struct socks_connector *self)
{
	// At the lowest level we iterate over all addresses for the SOCKS server
	// and just try to connect; this is done automatically by the connector

	// Then we iterate over available protocols
	if (++self->protocol_iter != SOCKS_MAX)
		return true;

	// At the highest level we iterate over possible targets
	self->protocol_iter = 0;
	if (self->targets_iter && (self->targets_iter = self->targets_iter->next))
		return true;

	return false;
}

static void
socks_connector_step (struct socks_connector *self)
{
	if (self->socket_fd != -1)
	{
		poller_fd_reset (&self->socket_event);
		xclose (self->socket_fd);
		self->socket_fd = -1;
	}

	socks_connector_destroy_connector (self);
	if (socks_connector_step_iterators (self))
		socks_connector_start (self);
	else
		socks_connector_fail (self);
}

static void
socks_connector_on_timeout (struct socks_connector *self)
{
	if (self->on_error)
		self->on_error (self->user_data, "timeout");

	socks_connector_destroy_connector (self);
	socks_connector_fail (self);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

static void
socks_connector_on_connected
	(void *user_data, int socket_fd, const char *hostname)
{
	set_blocking (socket_fd, false);
	(void) hostname;

	struct socks_connector *self = user_data;
	self->socket_fd = socket_fd;
	self->socket_event.fd = socket_fd;
	poller_fd_set (&self->socket_event, POLLIN | POLLOUT);
	str_reset (&self->read_buffer);
	str_reset (&self->write_buffer);

	if (!(self->protocol_iter == SOCKS_5  && socks_5_auth_start (self))
	 && !(self->protocol_iter == SOCKS_4A && socks_4a_start     (self)))
		socks_connector_fail (self);
}

static void
socks_connector_on_failure (void *user_data)
{
	struct socks_connector *self = user_data;
	// TODO: skip SOCKS server on connection failure
	socks_connector_step (self);
}

static void
socks_connector_on_connecting (void *user_data, const char *via)
{
	struct socks_connector *self = user_data;
	if (!self->on_connecting)
		return;

	struct socks_target *target = self->targets_iter;
	char *port = xstrdup_printf ("%u", target->port);
	char *address = format_host_port_pair (target->address_str, port);
	free (port);
	self->on_connecting (self->user_data, address, via,
		socks_protocol_to_string (self->protocol_iter));
	free (address);
}

static void
socks_connector_on_error (void *user_data, const char *error)
{
	struct socks_connector *self = user_data;
	// TODO: skip protocol on protocol failure
	if (self->on_error)
		self->on_error (self->user_data, error);
}

static void
socks_connector_start (struct socks_connector *self)
{
	hard_assert (!self->connector);

	struct connector *connector =
		self->connector = xcalloc (1, sizeof *connector);
	connector_init (connector, self->socket_event.poller);

	connector->user_data     = self;
	connector->on_connected  = socks_connector_on_connected;
	connector->on_connecting = socks_connector_on_connecting;
	connector->on_error      = socks_connector_on_error;
	connector->on_failure    = socks_connector_on_failure;

	connector_add_target (connector, self->hostname, self->service);
	poller_timer_set (&self->timeout, 60 * 1000);
	self->done = false;

	self->bound_port = 0;
	socks_addr_free (&self->bound_address);
	memset (&self->bound_address, 0, sizeof self->bound_address);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

static bool
socks_try_fill_read_buffer (struct socks_connector *self, size_t n)
{
	ssize_t remains = (ssize_t) n - (ssize_t) self->read_buffer.len;
	if (remains <= 0)
		return true;

	ssize_t received;
	str_ensure_space (&self->read_buffer, remains);
	do
		received = recv (self->socket_fd,
			self->read_buffer.str + self->read_buffer.len, remains, 0);
	while ((received == -1) && errno == EINTR);

	if (received == 0)
		SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF");
	if (received == -1 && errno != EAGAIN)
		SOCKS_FAIL ("%s: %s", "recv", strerror (errno));
	if (received > 0)
		self->read_buffer.len += received;
	return true;
}

static bool
socks_call_on_data (struct socks_connector *self)
{
	size_t to_consume = self->data_needed;
	if (!socks_try_fill_read_buffer (self, to_consume))
		return false;
	if (self->read_buffer.len < to_consume)
		return true;

	struct msg_unpacker unpacker;
	msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len);
	bool result = self->on_data (self, &unpacker);
	str_remove_slice (&self->read_buffer, 0, to_consume);
	return result;
}

static bool
socks_try_flush_write_buffer (struct socks_connector *self)
{
	struct str *wb = &self->write_buffer;
	ssize_t n_written;

	while (wb->len)
	{
		n_written = send (self->socket_fd, wb->str, wb->len, 0);
		if (n_written >= 0)
		{
			str_remove_slice (wb, 0, n_written);
			continue;
		}

		if (errno == EAGAIN)
			break;
		if (errno == EINTR)
			continue;

		SOCKS_FAIL ("%s: %s", "send", strerror (errno));
	}
	return true;
}

static void
socks_connector_on_ready
	(const struct pollfd *pfd, struct socks_connector *self)
{
	(void) pfd;

	if (socks_call_on_data (self) && socks_try_flush_write_buffer (self))
	{
		poller_fd_set (&self->socket_event,
			self->write_buffer.len ? (POLLIN | POLLOUT) : POLLIN);
	}
	else if (self->done)
	{
		socks_connector_cancel_events (self);

		int fd = self->socket_fd;
		self->socket_fd = -1;

		struct socks_target *target = self->targets_iter;
		set_blocking (fd, true);
		self->on_connected (self->user_data, fd, target->address_str);
	}
	else
		// We've failed this target, let's try to move on
		socks_connector_step (self);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

static void
socks_connector_init (struct socks_connector *self, struct poller *poller)
{
	memset (self, 0, sizeof *self);

	poller_fd_init (&self->socket_event, poller, (self->socket_fd = -1));
	self->socket_event.dispatcher = (poller_fd_fn) socks_connector_on_ready;
	self->socket_event.user_data = self;

	poller_timer_init (&self->timeout, poller);
	self->timeout.dispatcher = (poller_timer_fn) socks_connector_on_timeout;
	self->timeout.user_data = self;

	str_init (&self->read_buffer);
	str_init (&self->write_buffer);
}

static void
socks_connector_free (struct socks_connector *self)
{
	socks_connector_destroy_connector (self);
	socks_connector_cancel_events (self);

	if (self->socket_fd != -1)
		xclose (self->socket_fd);

	str_free (&self->read_buffer);
	str_free (&self->write_buffer);

	free (self->hostname);
	free (self->service);
	free (self->username);
	free (self->password);

	LIST_FOR_EACH (struct socks_target, iter, self->targets)
	{
		socks_addr_free (&iter->address);
		free (iter->address_str);
		free (iter);
	}

	socks_addr_free (&self->bound_address);
}

static bool
socks_connector_add_target (struct socks_connector *self,
	const char *host, const char *service, struct error **e)
{
	unsigned long port;
	const struct servent *serv;
	if ((serv = getservbyname (service, "tcp")))
		port = (uint16_t) ntohs (serv->s_port);
	else if (!xstrtoul (&port, service, 10) || !port || port > UINT16_MAX)
	{
		error_set (e, "invalid port number");
		return false;
	}

	struct socks_target *target = xcalloc (1, sizeof *target);
	if      (inet_pton (AF_INET,  host, &target->address.data.ipv4) == 1)
		target->address.type = SOCKS_IPV4;
	else if (inet_pton (AF_INET6, host, &target->address.data.ipv6) == 1)
		target->address.type = SOCKS_IPV6;
	else
	{
		target->address.type = SOCKS_DOMAIN;
		target->address.data.domain = xstrdup (host);
	}

	target->port = port;
	target->address_str = xstrdup (host);
	LIST_APPEND_WITH_TAIL (self->targets, self->targets_tail, target);
	return true;
}

static void
socks_connector_run (struct socks_connector *self,
	const char *host, const char *service,
	const char *username, const char *password)
{
	hard_assert (self->targets);
	hard_assert (host && service);

	self->hostname = xstrdup (host);
	self->service  = xstrdup (service);

	if (username)  self->username = xstrdup (username);
	if (password)  self->password = xstrdup (password);

	self->targets_iter = self->targets;
	self->protocol_iter = 0;
	// XXX: this can fail immediately from an error creating the connector
	socks_connector_start (self);
}

// --- CTCP decoding -----------------------------------------------------------

#define CTCP_M_QUOTE '\020'
#define CTCP_X_DELIM '\001'
#define CTCP_X_QUOTE '\\'

struct ctcp_chunk
{
	LIST_HEADER (struct ctcp_chunk)

	bool is_extended;                   ///< Is this a tagged extended message?
	bool is_partial;                    ///< Unterminated extended message
	struct str tag;                     ///< The tag, if any
	struct str text;                    ///< Message contents
};

static struct ctcp_chunk *
ctcp_chunk_new (void)
{
	struct ctcp_chunk *self = xcalloc (1, sizeof *self);
	str_init (&self->tag);
	str_init (&self->text);
	return self;
}

static void
ctcp_chunk_destroy (struct ctcp_chunk *self)
{
	str_free (&self->tag);
	str_free (&self->text);
	free (self);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

static void
ctcp_low_level_decode (const char *message, struct str *output)
{
	bool escape = false;
	for (const char *p = message; *p; p++)
	{
		if (escape)
		{
			switch (*p)
			{
			case '0': str_append_c (output, '\0'); break;
			case 'r': str_append_c (output, '\r'); break;
			case 'n': str_append_c (output, '\n'); break;
			default:  str_append_c (output, *p);
			}
			escape = false;
		}
		else if (*p == CTCP_M_QUOTE)
			escape = true;
		else
			str_append_c (output, *p);
	}
}

static void
ctcp_intra_decode (const char *chunk, size_t len, struct str *output)
{
	bool escape = false;
	for (size_t i = 0; i < len; i++)
	{
		char c = chunk[i];
		if (escape)
		{
			if (c == 'a')
				str_append_c (output, CTCP_X_DELIM);
			else
				str_append_c (output, c);
			escape = false;
		}
		else if (c == CTCP_X_QUOTE)
			escape = true;
		else
			str_append_c (output, c);
	}
}

static void
ctcp_parse_tagged (const char *chunk, size_t len, struct ctcp_chunk *output)
{
	// We may search for the space before doing the higher level decoding,
	// as it doesn't concern space characters at all
	size_t tag_end = len;
	for (size_t i = 0; i < len; i++)
		if (chunk[i] == ' ')
		{
			tag_end = i;
			break;
		}

	output->is_extended = true;
	ctcp_intra_decode (chunk, tag_end, &output->tag);
	if (tag_end++ != len)
		ctcp_intra_decode (chunk + tag_end, len - tag_end, &output->text);
}

static struct ctcp_chunk *
ctcp_parse (const char *message)
{
	struct str m;
	str_init (&m);
	ctcp_low_level_decode (message, &m);

	struct ctcp_chunk *result = NULL, *result_tail = NULL;

	// According to the original CTCP specification we should use
	// ctcp_intra_decode() on all parts, however no one seems to
	// use that and it breaks normal text with backslashes
	size_t start = 0;
	bool in_ctcp = false;
	for (size_t i = 0; i < m.len; i++)
	{
		char c = m.str[i];
		if (c != CTCP_X_DELIM)
			continue;

		// Remember the current state
		size_t my_start = start;
		bool my_is_ctcp = in_ctcp;

		start = i + 1;
		in_ctcp = !in_ctcp;

		// Skip empty chunks
		if (my_start == i)
			continue;

		struct ctcp_chunk *chunk = ctcp_chunk_new ();
		if (my_is_ctcp)
			ctcp_parse_tagged (m.str + my_start, i - my_start, chunk);
		else
			str_append_data (&chunk->text, m.str + my_start, i - my_start);
		LIST_APPEND_WITH_TAIL (result, result_tail, chunk);
	}

	// Finish the last part.  Unended tagged chunks are marked as such.
	if (start != m.len)
	{
		struct ctcp_chunk *chunk = ctcp_chunk_new ();
		if (in_ctcp)
		{
			ctcp_parse_tagged (m.str + start, m.len - start, chunk);
			chunk->is_partial = true;
		}
		else
			str_append_data (&chunk->text, m.str + start, m.len - start);
		LIST_APPEND_WITH_TAIL (result, result_tail, chunk);
	}

	str_free (&m);
	return result;
}

static void
ctcp_destroy (struct ctcp_chunk *list)
{
	LIST_FOR_EACH (struct ctcp_chunk, iter, list)
		ctcp_chunk_destroy (iter);
}