aboutsummaryrefslogtreecommitdiff
path: root/common.c
diff options
context:
space:
mode:
Diffstat (limited to 'common.c')
-rw-r--r--common.c1096
1 files changed, 1096 insertions, 0 deletions
diff --git a/common.c b/common.c
new file mode 100644
index 0000000..ac83776
--- /dev/null
+++ b/common.c
@@ -0,0 +1,1096 @@
+/*
+ * common.c: common functionality
+ *
+ * Copyright (c) 2014 - 2022, Přemysl Eric Janouch <p@janouch.name>
+ *
+ * Permission to use, copy, modify, and/or distribute this software for any
+ * purpose with or without fee is hereby granted.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+ * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+ * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+ * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ *
+ */
+
+#define LIBERTY_WANT_SSL
+#define LIBERTY_WANT_ASYNC
+#define LIBERTY_WANT_POLLER
+#define LIBERTY_WANT_PROTO_IRC
+
+#ifdef WANT_SYSLOG_LOGGING
+#define print_fatal_data ((void *) LOG_ERR)
+#define print_error_data ((void *) LOG_ERR)
+#define print_warning_data ((void *) LOG_WARNING)
+#define print_status_data ((void *) LOG_INFO)
+#define print_debug_data ((void *) LOG_DEBUG)
+#endif // WANT_SYSLOG_LOGGING
+
+#include "liberty/liberty.c"
+#include <arpa/inet.h>
+#include <netinet/tcp.h>
+
+static void
+init_openssl (void)
+{
+#if OPENSSL_VERSION_NUMBER < 0x10100000L || LIBRESSL_VERSION_NUMBER
+ SSL_library_init ();
+ // XXX: this list is probably not complete
+ atexit (EVP_cleanup);
+ SSL_load_error_strings ();
+ atexit (ERR_free_strings);
+#else
+ // Cleanup is done automatically via atexit()
+ OPENSSL_init_ssl (0, NULL);
+#endif
+}
+
+static char *
+gai_reconstruct_address (struct addrinfo *ai)
+{
+ char host[NI_MAXHOST] = {}, port[NI_MAXSERV] = {};
+ int err = getnameinfo (ai->ai_addr, ai->ai_addrlen,
+ host, sizeof host, port, sizeof port,
+ NI_NUMERICHOST | NI_NUMERICSERV);
+ if (err)
+ {
+ print_debug ("%s: %s", "getnameinfo", gai_strerror (err));
+ return xstrdup ("?");
+ }
+ return format_host_port_pair (host, port);
+}
+
+static bool
+accept_error_is_transient (int err)
+{
+ // OS kernels may return a wide range of unforeseeable errors.
+ // Assuming that they're either transient or caused by
+ // a connection that we've just extracted from the queue.
+ switch (err)
+ {
+ case EBADF:
+ case EINVAL:
+ case ENOTSOCK:
+ case EOPNOTSUPP:
+ return false;
+ default:
+ return true;
+ }
+}
+
+/// Destructively tokenize an address into a host part, and a port part.
+/// The port is only overwritten if that part is found, allowing for defaults.
+static const char *
+tokenize_host_port (char *address, const char **port)
+{
+ // Unwrap IPv6 addresses in format_host_port_pair() format.
+ char *rbracket = strchr (address, ']');
+ if (*address == '[' && rbracket)
+ {
+ if (rbracket[1] == ':')
+ {
+ *port = rbracket + 2;
+ return *rbracket = 0, address + 1;
+ }
+ if (!rbracket[1])
+ return *rbracket = 0, address + 1;
+ }
+
+ char *colon = strchr (address, ':');
+ if (colon)
+ {
+ *port = colon + 1;
+ return *colon = 0, address;
+ }
+ return address;
+}
+
+// --- To be moved to liberty --------------------------------------------------
+
+// FIXME: in xssl_get_error() we rely on error reasons never being NULL (i.e.,
+// all loaded), which isn't very robust.
+// TODO: check all places where this is used and see if we couldn't gain better
+// information by piecing together some other subset of data from the error
+// stack. Most often, this is used in an error_set() context, which would
+// allow us to allocate memory instead of returning static strings.
+static const char *
+xerr_describe_error (void)
+{
+ unsigned long err = ERR_get_error ();
+ if (!err)
+ return "undefined error";
+
+ const char *reason = ERR_reason_error_string (err);
+ do
+ // Not thread-safe, not a concern right now--need a buffer
+ print_debug ("%s", ERR_error_string (err, NULL));
+ while ((err = ERR_get_error ()));
+
+ if (!reason)
+ return "cannot retrieve error description";
+ return reason;
+}
+
+static struct str
+str_from_cstr (const char *cstr)
+{
+ struct str self;
+ self.alloc = (self.len = strlen (cstr)) + 1;
+ self.str = memcpy (xmalloc (self.alloc), cstr, self.alloc);
+ return self;
+}
+
+static ssize_t
+strv_find (const struct strv *v, const char *s)
+{
+ for (size_t i = 0; i < v->len; i++)
+ if (!strcmp (v->vector[i], s))
+ return i;
+ return -1;
+}
+
+static time_t
+unixtime_msec (long *msec)
+{
+#ifdef _POSIX_TIMERS
+ struct timespec tp;
+ hard_assert (clock_gettime (CLOCK_REALTIME, &tp) != -1);
+ *msec = tp.tv_nsec / 1000000;
+#else // ! _POSIX_TIMERS
+ struct timeval tp;
+ hard_assert (gettimeofday (&tp, NULL) != -1);
+ *msec = tp.tv_usec / 1000;
+#endif // ! _POSIX_TIMERS
+ return tp.tv_sec;
+}
+
+// --- Logging -----------------------------------------------------------------
+
+static void
+log_message_syslog (void *user_data, const char *quote, const char *fmt,
+ va_list ap)
+{
+ int prio = (int) (intptr_t) user_data;
+
+ va_list va;
+ va_copy (va, ap);
+ int size = vsnprintf (NULL, 0, fmt, va);
+ va_end (va);
+ if (size < 0)
+ return;
+
+ char buf[size + 1];
+ if (vsnprintf (buf, sizeof buf, fmt, ap) >= 0)
+ syslog (prio, "%s%s", quote, buf);
+}
+
+// --- SOCKS 5/4a --------------------------------------------------------------
+
+// Asynchronous SOCKS connector. Adds more stuff on top of the regular one.
+
+// Note that the `username' is used differently in SOCKS 4a and 5. In the
+// former version, it is the username that you can get ident'ed against.
+// In the latter version, it forms a pair with the password field and doesn't
+// need to be an actual user on your machine.
+
+struct socks_addr
+{
+ enum socks_addr_type
+ {
+ SOCKS_IPV4 = 1, ///< IPv4 address
+ SOCKS_DOMAIN = 3, ///< Domain name to be resolved
+ SOCKS_IPV6 = 4 ///< IPv6 address
+ }
+ type; ///< The type of this address
+ union
+ {
+ uint8_t ipv4[4]; ///< IPv4 address, network octet order
+ char *domain; ///< Domain name
+ uint8_t ipv6[16]; ///< IPv6 address, network octet order
+ }
+ data; ///< The address itself
+};
+
+static void
+socks_addr_free (struct socks_addr *self)
+{
+ if (self->type == SOCKS_DOMAIN)
+ free (self->data.domain);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+struct socks_target
+{
+ LIST_HEADER (struct socks_target)
+
+ char *address_str; ///< Target address as a string
+ struct socks_addr address; ///< Target address
+ uint16_t port; ///< Target service port
+};
+
+enum socks_protocol
+{
+ SOCKS_5, ///< SOCKS5
+ SOCKS_4A, ///< SOCKS4A
+ SOCKS_MAX ///< End of protocol
+};
+
+static inline const char *
+socks_protocol_to_string (enum socks_protocol self)
+{
+ switch (self)
+ {
+ case SOCKS_5: return "SOCKS5";
+ case SOCKS_4A: return "SOCKS4A";
+ default: return NULL;
+ }
+}
+
+struct socks_connector
+{
+ struct connector *connector; ///< Proxy server iterator (effectively)
+ enum socks_protocol protocol_iter; ///< Protocol iterator
+ struct socks_target *targets_iter; ///< Targets iterator
+
+ // Negotiation:
+
+ struct poller_timer timeout; ///< Timeout timer
+
+ int socket_fd; ///< Current socket file descriptor
+ struct poller_fd socket_event; ///< Socket can be read from/written to
+ struct str read_buffer; ///< Read buffer
+ struct str write_buffer; ///< Write buffer
+
+ bool done; ///< Tunnel succesfully established
+ uint8_t bound_address_len; ///< Length of domain name
+ size_t data_needed; ///< How much data "on_data" needs
+
+ /// Process incoming data if there's enough of it available
+ bool (*on_data) (struct socks_connector *, struct msg_unpacker *);
+
+ // Configuration:
+
+ char *hostname; ///< SOCKS server hostname
+ char *service; ///< SOCKS server service name or port
+
+ char *username; ///< Username for authentication
+ char *password; ///< Password for authentication
+
+ struct socks_target *targets; ///< Targets
+ struct socks_target *targets_tail; ///< Tail of targets
+
+ void *user_data; ///< User data for callbacks
+
+ // Additional results:
+
+ struct socks_addr bound_address; ///< Bound address at the server
+ uint16_t bound_port; ///< Bound port at the server
+
+ // You may destroy the connector object in these two main callbacks:
+
+ /// Connection has been successfully established
+ void (*on_connected) (void *user_data, int socket, const char *hostname);
+ /// Failed to establish a connection to either target
+ void (*on_failure) (void *user_data);
+
+ // Optional:
+
+ /// Connecting to a new address
+ void (*on_connecting) (void *user_data,
+ const char *address, const char *via, const char *version);
+ /// Connecting to the last address has failed
+ void (*on_error) (void *user_data, const char *error);
+};
+
+// I've tried to make the actual protocol handlers as simple as possible
+
+#define SOCKS_FAIL(...) \
+ BLOCK_START \
+ char *error = xstrdup_printf (__VA_ARGS__); \
+ if (self->on_error) \
+ self->on_error (self->user_data, error); \
+ free (error); \
+ return false; \
+ BLOCK_END
+
+#define SOCKS_DATA_CB(name) static bool name \
+ (struct socks_connector *self, struct msg_unpacker *unpacker)
+
+#define SOCKS_GO(name, data_needed_) \
+ self->on_data = name; \
+ self->data_needed = data_needed_; \
+ return true
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+SOCKS_DATA_CB (socks_4a_finish)
+{
+ uint8_t null = 0, status = 0;
+ hard_assert (msg_unpacker_u8 (unpacker, &null));
+ hard_assert (msg_unpacker_u8 (unpacker, &status));
+
+ if (null != 0)
+ SOCKS_FAIL ("protocol error");
+
+ switch (status)
+ {
+ case 90:
+ self->done = true;
+ return false;
+ case 91:
+ SOCKS_FAIL ("request rejected or failed");
+ case 92:
+ SOCKS_FAIL ("%s: %s", "request rejected",
+ "SOCKS server cannot connect to identd on the client");
+ case 93:
+ SOCKS_FAIL ("%s: %s", "request rejected",
+ "identd reports different user-id");
+ default:
+ SOCKS_FAIL ("protocol error");
+ }
+}
+
+static bool
+socks_4a_start (struct socks_connector *self)
+{
+ struct socks_target *target = self->targets_iter;
+ const void *dest_ipv4 = "\x00\x00\x00\x01";
+ const char *dest_domain = NULL;
+
+ char buf[INET6_ADDRSTRLEN];
+ switch (target->address.type)
+ {
+ case SOCKS_IPV4:
+ dest_ipv4 = target->address.data.ipv4;
+ break;
+ case SOCKS_IPV6:
+ // About the best thing we can do, not sure if it works anywhere at all
+ if (!inet_ntop (AF_INET6, &target->address.data.ipv6, buf, sizeof buf))
+ SOCKS_FAIL ("%s: %s", "inet_ntop", strerror (errno));
+ dest_domain = buf;
+ break;
+ case SOCKS_DOMAIN:
+ dest_domain = target->address.data.domain;
+ }
+
+ struct str *wb = &self->write_buffer;
+ str_pack_u8 (wb, 4); // version
+ str_pack_u8 (wb, 1); // connect
+
+ str_pack_u16 (wb, target->port); // port
+ str_append_data (wb, dest_ipv4, 4); // destination address
+
+ if (self->username)
+ str_append (wb, self->username);
+ str_append_c (wb, '\0');
+
+ if (dest_domain)
+ {
+ str_append (wb, dest_domain);
+ str_append_c (wb, '\0');
+ }
+
+ SOCKS_GO (socks_4a_finish, 8);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+SOCKS_DATA_CB (socks_5_request_port)
+{
+ hard_assert (msg_unpacker_u16 (unpacker, &self->bound_port));
+ self->done = true;
+ return false;
+}
+
+SOCKS_DATA_CB (socks_5_request_ipv4)
+{
+ memcpy (self->bound_address.data.ipv4, unpacker->data, unpacker->len);
+ SOCKS_GO (socks_5_request_port, 2);
+}
+
+SOCKS_DATA_CB (socks_5_request_ipv6)
+{
+ memcpy (self->bound_address.data.ipv6, unpacker->data, unpacker->len);
+ SOCKS_GO (socks_5_request_port, 2);
+}
+
+SOCKS_DATA_CB (socks_5_request_domain_data)
+{
+ self->bound_address.data.domain = xstrndup (unpacker->data, unpacker->len);
+ SOCKS_GO (socks_5_request_port, 2);
+}
+
+SOCKS_DATA_CB (socks_5_request_domain)
+{
+ hard_assert (msg_unpacker_u8 (unpacker, &self->bound_address_len));
+ SOCKS_GO (socks_5_request_domain_data, self->bound_address_len);
+}
+
+SOCKS_DATA_CB (socks_5_request_finish)
+{
+ uint8_t version = 0, status = 0, reserved = 0, type = 0;
+ hard_assert (msg_unpacker_u8 (unpacker, &version));
+ hard_assert (msg_unpacker_u8 (unpacker, &status));
+ hard_assert (msg_unpacker_u8 (unpacker, &reserved));
+ hard_assert (msg_unpacker_u8 (unpacker, &type));
+
+ if (version != 0x05)
+ SOCKS_FAIL ("protocol error");
+
+ switch (status)
+ {
+ case 0x00:
+ break;
+ case 0x01: SOCKS_FAIL ("general SOCKS server failure");
+ case 0x02: SOCKS_FAIL ("connection not allowed by ruleset");
+ case 0x03: SOCKS_FAIL ("network unreachable");
+ case 0x04: SOCKS_FAIL ("host unreachable");
+ case 0x05: SOCKS_FAIL ("connection refused");
+ case 0x06: SOCKS_FAIL ("TTL expired");
+ case 0x07: SOCKS_FAIL ("command not supported");
+ case 0x08: SOCKS_FAIL ("address type not supported");
+ default: SOCKS_FAIL ("protocol error");
+ }
+
+ switch ((self->bound_address.type = type))
+ {
+ case SOCKS_IPV4:
+ SOCKS_GO (socks_5_request_ipv4, sizeof self->bound_address.data.ipv4);
+ case SOCKS_IPV6:
+ SOCKS_GO (socks_5_request_ipv6, sizeof self->bound_address.data.ipv6);
+ case SOCKS_DOMAIN:
+ SOCKS_GO (socks_5_request_domain, 1);
+ default:
+ SOCKS_FAIL ("protocol error");
+ }
+}
+
+static bool
+socks_5_request_start (struct socks_connector *self)
+{
+ struct socks_target *target = self->targets_iter;
+ struct str *wb = &self->write_buffer;
+ str_pack_u8 (wb, 0x05); // version
+ str_pack_u8 (wb, 0x01); // connect
+ str_pack_u8 (wb, 0x00); // reserved
+ str_pack_u8 (wb, target->address.type);
+
+ switch (target->address.type)
+ {
+ case SOCKS_IPV4:
+ str_append_data (wb,
+ target->address.data.ipv4, sizeof target->address.data.ipv4);
+ break;
+ case SOCKS_DOMAIN:
+ {
+ size_t dlen = strlen (target->address.data.domain);
+ if (dlen > 255)
+ dlen = 255;
+
+ str_pack_u8 (wb, dlen);
+ str_append_data (wb, target->address.data.domain, dlen);
+ break;
+ }
+ case SOCKS_IPV6:
+ str_append_data (wb,
+ target->address.data.ipv6, sizeof target->address.data.ipv6);
+ break;
+ }
+ str_pack_u16 (wb, target->port);
+
+ SOCKS_GO (socks_5_request_finish, 4);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+SOCKS_DATA_CB (socks_5_userpass_finish)
+{
+ uint8_t version = 0, status = 0;
+ hard_assert (msg_unpacker_u8 (unpacker, &version));
+ hard_assert (msg_unpacker_u8 (unpacker, &status));
+
+ if (version != 0x01)
+ SOCKS_FAIL ("protocol error");
+ if (status != 0x00)
+ SOCKS_FAIL ("authentication failure");
+
+ return socks_5_request_start (self);
+}
+
+static bool
+socks_5_userpass_start (struct socks_connector *self)
+{
+ size_t ulen = strlen (self->username);
+ if (ulen > 255)
+ ulen = 255;
+
+ size_t plen = strlen (self->password);
+ if (plen > 255)
+ plen = 255;
+
+ struct str *wb = &self->write_buffer;
+ str_pack_u8 (wb, 0x01); // version
+ str_pack_u8 (wb, ulen); // username length
+ str_append_data (wb, self->username, ulen);
+ str_pack_u8 (wb, plen); // password length
+ str_append_data (wb, self->password, plen);
+
+ SOCKS_GO (socks_5_userpass_finish, 2);
+}
+
+SOCKS_DATA_CB (socks_5_auth_finish)
+{
+ uint8_t version = 0, method = 0;
+ hard_assert (msg_unpacker_u8 (unpacker, &version));
+ hard_assert (msg_unpacker_u8 (unpacker, &method));
+
+ if (version != 0x05)
+ SOCKS_FAIL ("protocol error");
+
+ bool can_auth = self->username && self->password;
+
+ switch (method)
+ {
+ case 0x02:
+ if (!can_auth)
+ SOCKS_FAIL ("protocol error");
+
+ return socks_5_userpass_start (self);
+ case 0x00:
+ return socks_5_request_start (self);
+ case 0xFF:
+ SOCKS_FAIL ("no acceptable authentication methods");
+ default:
+ SOCKS_FAIL ("protocol error");
+ }
+}
+
+static bool
+socks_5_auth_start (struct socks_connector *self)
+{
+ bool can_auth = self->username && self->password;
+
+ struct str *wb = &self->write_buffer;
+ str_pack_u8 (wb, 0x05); // version
+ str_pack_u8 (wb, 1 + can_auth); // number of authentication methods
+ str_pack_u8 (wb, 0x00); // no authentication required
+ if (can_auth)
+ str_pack_u8 (wb, 0x02); // username/password
+
+ SOCKS_GO (socks_5_auth_finish, 2);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static void socks_connector_start (struct socks_connector *self);
+
+static void
+socks_connector_destroy_connector (struct socks_connector *self)
+{
+ if (self->connector)
+ {
+ connector_free (self->connector);
+ free (self->connector);
+ self->connector = NULL;
+ }
+}
+
+static void
+socks_connector_cancel_events (struct socks_connector *self)
+{
+ // Before calling the final callbacks, we should cancel events that
+ // could potentially fire; caller should destroy us immediately, though
+ poller_fd_reset (&self->socket_event);
+ poller_timer_reset (&self->timeout);
+}
+
+static void
+socks_connector_fail (struct socks_connector *self)
+{
+ socks_connector_cancel_events (self);
+ self->on_failure (self->user_data);
+}
+
+static bool
+socks_connector_step_iterators (struct socks_connector *self)
+{
+ // At the lowest level we iterate over all addresses for the SOCKS server
+ // and just try to connect; this is done automatically by the connector
+
+ // Then we iterate over available protocols
+ if (++self->protocol_iter != SOCKS_MAX)
+ return true;
+
+ // At the highest level we iterate over possible targets
+ self->protocol_iter = 0;
+ if (self->targets_iter && (self->targets_iter = self->targets_iter->next))
+ return true;
+
+ return false;
+}
+
+static void
+socks_connector_step (struct socks_connector *self)
+{
+ if (self->socket_fd != -1)
+ {
+ poller_fd_reset (&self->socket_event);
+ xclose (self->socket_fd);
+ self->socket_fd = -1;
+ }
+
+ socks_connector_destroy_connector (self);
+ if (socks_connector_step_iterators (self))
+ socks_connector_start (self);
+ else
+ socks_connector_fail (self);
+}
+
+static void
+socks_connector_on_timeout (struct socks_connector *self)
+{
+ if (self->on_error)
+ self->on_error (self->user_data, "timeout");
+
+ socks_connector_destroy_connector (self);
+ socks_connector_fail (self);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static void
+socks_connector_on_connected
+ (void *user_data, int socket_fd, const char *hostname)
+{
+ set_blocking (socket_fd, false);
+ (void) hostname;
+
+ struct socks_connector *self = user_data;
+ self->socket_fd = socket_fd;
+ self->socket_event.fd = socket_fd;
+ poller_fd_set (&self->socket_event, POLLIN | POLLOUT);
+ str_reset (&self->read_buffer);
+ str_reset (&self->write_buffer);
+
+ if (!(self->protocol_iter == SOCKS_5 && socks_5_auth_start (self))
+ && !(self->protocol_iter == SOCKS_4A && socks_4a_start (self)))
+ socks_connector_fail (self);
+}
+
+static void
+socks_connector_on_failure (void *user_data)
+{
+ struct socks_connector *self = user_data;
+ // TODO: skip SOCKS server on connection failure
+ socks_connector_step (self);
+}
+
+static void
+socks_connector_on_connecting (void *user_data, const char *via)
+{
+ struct socks_connector *self = user_data;
+ if (!self->on_connecting)
+ return;
+
+ struct socks_target *target = self->targets_iter;
+ char *port = xstrdup_printf ("%u", target->port);
+ char *address = format_host_port_pair (target->address_str, port);
+ free (port);
+ self->on_connecting (self->user_data, address, via,
+ socks_protocol_to_string (self->protocol_iter));
+ free (address);
+}
+
+static void
+socks_connector_on_error (void *user_data, const char *error)
+{
+ struct socks_connector *self = user_data;
+ // TODO: skip protocol on protocol failure
+ if (self->on_error)
+ self->on_error (self->user_data, error);
+}
+
+static void
+socks_connector_start (struct socks_connector *self)
+{
+ hard_assert (!self->connector);
+
+ struct connector *connector =
+ self->connector = xcalloc (1, sizeof *connector);
+ connector_init (connector, self->socket_event.poller);
+
+ connector->user_data = self;
+ connector->on_connected = socks_connector_on_connected;
+ connector->on_connecting = socks_connector_on_connecting;
+ connector->on_error = socks_connector_on_error;
+ connector->on_failure = socks_connector_on_failure;
+
+ connector_add_target (connector, self->hostname, self->service);
+ poller_timer_set (&self->timeout, 60 * 1000);
+ self->done = false;
+
+ self->bound_port = 0;
+ socks_addr_free (&self->bound_address);
+ memset (&self->bound_address, 0, sizeof self->bound_address);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static bool
+socks_try_fill_read_buffer (struct socks_connector *self, size_t n)
+{
+ ssize_t remains = (ssize_t) n - (ssize_t) self->read_buffer.len;
+ if (remains <= 0)
+ return true;
+
+ ssize_t received;
+ str_reserve (&self->read_buffer, remains);
+ do
+ received = recv (self->socket_fd,
+ self->read_buffer.str + self->read_buffer.len, remains, 0);
+ while ((received == -1) && errno == EINTR);
+
+ if (received == 0)
+ SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF");
+ if (received == -1 && errno != EAGAIN)
+ SOCKS_FAIL ("%s: %s", "recv", strerror (errno));
+ if (received > 0)
+ self->read_buffer.len += received;
+ return true;
+}
+
+static bool
+socks_call_on_data (struct socks_connector *self)
+{
+ size_t to_consume = self->data_needed;
+ if (!socks_try_fill_read_buffer (self, to_consume))
+ return false;
+ if (self->read_buffer.len < to_consume)
+ return true;
+
+ struct msg_unpacker unpacker =
+ msg_unpacker_make (self->read_buffer.str, self->read_buffer.len);
+ bool result = self->on_data (self, &unpacker);
+ str_remove_slice (&self->read_buffer, 0, to_consume);
+ return result;
+}
+
+static bool
+socks_try_flush_write_buffer (struct socks_connector *self)
+{
+ struct str *wb = &self->write_buffer;
+ ssize_t n_written;
+
+ while (wb->len)
+ {
+ n_written = send (self->socket_fd, wb->str, wb->len, 0);
+ if (n_written >= 0)
+ {
+ str_remove_slice (wb, 0, n_written);
+ continue;
+ }
+
+ if (errno == EAGAIN)
+ break;
+ if (errno == EINTR)
+ continue;
+
+ SOCKS_FAIL ("%s: %s", "send", strerror (errno));
+ }
+ return true;
+}
+
+static void
+socks_connector_on_ready
+ (const struct pollfd *pfd, struct socks_connector *self)
+{
+ (void) pfd;
+
+ if (socks_call_on_data (self) && socks_try_flush_write_buffer (self))
+ {
+ poller_fd_set (&self->socket_event,
+ self->write_buffer.len ? (POLLIN | POLLOUT) : POLLIN);
+ }
+ else if (self->done)
+ {
+ socks_connector_cancel_events (self);
+
+ int fd = self->socket_fd;
+ self->socket_fd = -1;
+
+ struct socks_target *target = self->targets_iter;
+ set_blocking (fd, true);
+ self->on_connected (self->user_data, fd, target->address_str);
+ }
+ else
+ // We've failed this target, let's try to move on
+ socks_connector_step (self);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static void
+socks_connector_init (struct socks_connector *self, struct poller *poller)
+{
+ memset (self, 0, sizeof *self);
+
+ self->socket_event = poller_fd_make (poller, (self->socket_fd = -1));
+ self->socket_event.dispatcher = (poller_fd_fn) socks_connector_on_ready;
+ self->socket_event.user_data = self;
+
+ self->timeout = poller_timer_make (poller);
+ self->timeout.dispatcher = (poller_timer_fn) socks_connector_on_timeout;
+ self->timeout.user_data = self;
+
+ self->read_buffer = str_make ();
+ self->write_buffer = str_make ();
+}
+
+static void
+socks_connector_free (struct socks_connector *self)
+{
+ socks_connector_destroy_connector (self);
+ socks_connector_cancel_events (self);
+
+ if (self->socket_fd != -1)
+ xclose (self->socket_fd);
+
+ str_free (&self->read_buffer);
+ str_free (&self->write_buffer);
+
+ free (self->hostname);
+ free (self->service);
+ free (self->username);
+ free (self->password);
+
+ LIST_FOR_EACH (struct socks_target, iter, self->targets)
+ {
+ socks_addr_free (&iter->address);
+ free (iter->address_str);
+ free (iter);
+ }
+
+ socks_addr_free (&self->bound_address);
+}
+
+static bool
+socks_connector_add_target (struct socks_connector *self,
+ const char *host, const char *service, struct error **e)
+{
+ unsigned long port;
+ const struct servent *serv;
+ if ((serv = getservbyname (service, "tcp")))
+ port = (uint16_t) ntohs (serv->s_port);
+ else if (!xstrtoul (&port, service, 10) || !port || port > UINT16_MAX)
+ {
+ error_set (e, "invalid port number");
+ return false;
+ }
+
+ struct socks_target *target = xcalloc (1, sizeof *target);
+ if (inet_pton (AF_INET, host, &target->address.data.ipv4) == 1)
+ target->address.type = SOCKS_IPV4;
+ else if (inet_pton (AF_INET6, host, &target->address.data.ipv6) == 1)
+ target->address.type = SOCKS_IPV6;
+ else
+ {
+ target->address.type = SOCKS_DOMAIN;
+ target->address.data.domain = xstrdup (host);
+ }
+
+ target->port = port;
+ target->address_str = xstrdup (host);
+ LIST_APPEND_WITH_TAIL (self->targets, self->targets_tail, target);
+ return true;
+}
+
+static void
+socks_connector_run (struct socks_connector *self,
+ const char *host, const char *service,
+ const char *username, const char *password)
+{
+ hard_assert (self->targets);
+ hard_assert (host && service);
+
+ self->hostname = xstrdup (host);
+ self->service = xstrdup (service);
+
+ if (username) self->username = xstrdup (username);
+ if (password) self->password = xstrdup (password);
+
+ self->targets_iter = self->targets;
+ self->protocol_iter = 0;
+ // XXX: this can fail immediately from an error creating the connector
+ socks_connector_start (self);
+}
+
+// --- CTCP decoding -----------------------------------------------------------
+
+#define CTCP_M_QUOTE '\020'
+#define CTCP_X_DELIM '\001'
+#define CTCP_X_QUOTE '\\'
+
+struct ctcp_chunk
+{
+ LIST_HEADER (struct ctcp_chunk)
+
+ bool is_extended; ///< Is this a tagged extended message?
+ bool is_partial; ///< Unterminated extended message
+ struct str tag; ///< The tag, if any
+ struct str text; ///< Message contents
+};
+
+static struct ctcp_chunk *
+ctcp_chunk_new (void)
+{
+ struct ctcp_chunk *self = xcalloc (1, sizeof *self);
+ self->tag = str_make ();
+ self->text = str_make ();
+ return self;
+}
+
+static void
+ctcp_chunk_destroy (struct ctcp_chunk *self)
+{
+ str_free (&self->tag);
+ str_free (&self->text);
+ free (self);
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static void
+ctcp_low_level_decode (const char *message, struct str *output)
+{
+ bool escape = false;
+ for (const char *p = message; *p; p++)
+ {
+ if (escape)
+ {
+ switch (*p)
+ {
+ case '0': str_append_c (output, '\0'); break;
+ case 'r': str_append_c (output, '\r'); break;
+ case 'n': str_append_c (output, '\n'); break;
+ default: str_append_c (output, *p);
+ }
+ escape = false;
+ }
+ else if (*p == CTCP_M_QUOTE)
+ escape = true;
+ else
+ str_append_c (output, *p);
+ }
+}
+
+static void
+ctcp_intra_decode (const char *chunk, size_t len, struct str *output)
+{
+ bool escape = false;
+ for (size_t i = 0; i < len; i++)
+ {
+ char c = chunk[i];
+ if (escape)
+ {
+ if (c == 'a')
+ str_append_c (output, CTCP_X_DELIM);
+ else
+ str_append_c (output, c);
+ escape = false;
+ }
+ else if (c == CTCP_X_QUOTE)
+ escape = true;
+ else
+ str_append_c (output, c);
+ }
+}
+
+// According to the original CTCP specification we should use
+// ctcp_intra_decode() on all parts, however no one seems to use that
+// and it breaks normal text with backslashes
+#ifndef SUPPORT_CTCP_X_QUOTES
+#define ctcp_intra_decode(s, len, output) str_append_data (output, s, len)
+#endif
+
+static void
+ctcp_parse_tagged (const char *chunk, size_t len, struct ctcp_chunk *output)
+{
+ // We may search for the space before doing the higher level decoding,
+ // as it doesn't concern space characters at all
+ size_t tag_end = len;
+ for (size_t i = 0; i < len; i++)
+ if (chunk[i] == ' ')
+ {
+ tag_end = i;
+ break;
+ }
+
+ output->is_extended = true;
+ ctcp_intra_decode (chunk, tag_end, &output->tag);
+ if (tag_end++ != len)
+ ctcp_intra_decode (chunk + tag_end, len - tag_end, &output->text);
+}
+
+static struct ctcp_chunk *
+ctcp_parse (const char *message)
+{
+ struct str m = str_make ();
+ ctcp_low_level_decode (message, &m);
+
+ struct ctcp_chunk *result = NULL, *result_tail = NULL;
+
+ size_t start = 0;
+ bool in_ctcp = false;
+ for (size_t i = 0; i < m.len; i++)
+ {
+ char c = m.str[i];
+ if (c != CTCP_X_DELIM)
+ continue;
+
+ // Remember the current state
+ size_t my_start = start;
+ bool my_is_ctcp = in_ctcp;
+
+ start = i + 1;
+ in_ctcp = !in_ctcp;
+
+ // Skip empty chunks
+ if (my_start == i)
+ continue;
+
+ struct ctcp_chunk *chunk = ctcp_chunk_new ();
+ if (my_is_ctcp)
+ ctcp_parse_tagged (m.str + my_start, i - my_start, chunk);
+ else
+ ctcp_intra_decode (m.str + my_start, i - my_start, &chunk->text);
+ LIST_APPEND_WITH_TAIL (result, result_tail, chunk);
+ }
+
+ // Finish the last part. Unended tagged chunks are marked as such.
+ if (start != m.len)
+ {
+ struct ctcp_chunk *chunk = ctcp_chunk_new ();
+ if (in_ctcp)
+ {
+ ctcp_parse_tagged (m.str + start, m.len - start, chunk);
+ chunk->is_partial = true;
+ }
+ else
+ ctcp_intra_decode (m.str + start, m.len - start, &chunk->text);
+ LIST_APPEND_WITH_TAIL (result, result_tail, chunk);
+ }
+
+ str_free (&m);
+ return result;
+}
+
+static void
+ctcp_destroy (struct ctcp_chunk *list)
+{
+ LIST_FOR_EACH (struct ctcp_chunk, iter, list)
+ ctcp_chunk_destroy (iter);
+}