diff options
Diffstat (limited to 'src/zyklonb.c')
-rw-r--r-- | src/zyklonb.c | 3320 |
1 files changed, 3320 insertions, 0 deletions
diff --git a/src/zyklonb.c b/src/zyklonb.c new file mode 100644 index 0000000..a182d17 --- /dev/null +++ b/src/zyklonb.c @@ -0,0 +1,3320 @@ +/* + * zyklonb.c: the experimental IRC bot + * + * Copyright (c) 2014, Přemysl Janouch <p.janouch@gmail.com> + * All rights reserved. + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION + * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#define _POSIX_C_SOURCE 199309L +#define _XOPEN_SOURCE 500 + +#include <stdio.h> +#include <stdlib.h> +#include <errno.h> +#include <string.h> +#include <stdarg.h> +#include <stdint.h> +#include <stdbool.h> +#include <ctype.h> + +#include <unistd.h> +#include <sys/wait.h> +#include <sys/stat.h> +#include <fcntl.h> +#include <poll.h> +#include <signal.h> +#include <strings.h> +#include <regex.h> +#include <libgen.h> + +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> + +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif // ! NI_MAXHOST + +#include <getopt.h> +#include <openssl/ssl.h> +#include <openssl/err.h> +#include "siphash.h" + +#define PROGRAM_NAME "ZyklonB" +#define PROGRAM_VERSION "alpha" + +extern char **environ; + +#if defined __GNUC__ +#define ATTRIBUTE_PRINTF(x, y) __attribute__ ((format (printf, x, y))) +#else // ! __GNUC__ +#define ATTRIBUTE_PRINTF(x, y) +#endif // ! __GNUC__ + +#if defined __GNUC__ && __GNUC__ >= 4 +#define ATTRIBUTE_SENTINEL __attribute__ ((sentinel)) +#else // ! __GNUC__ || __GNUC__ < 4 +#define ATTRIBUTE_SENTINEL +#endif // ! __GNUC__ || __GNUC__ < 4 + +#define N_ELEMENTS(a) (sizeof (a) / sizeof ((a)[0])) + +#define BLOCK_START do { +#define BLOCK_END } while (0) + +// --- Utilities --------------------------------------------------------------- + +static void +print_message (FILE *stream, const char *type, const char *fmt, ...) + ATTRIBUTE_PRINTF (3, 4); + +static void +print_message (FILE *stream, const char *type, const char *fmt, ...) +{ + va_list ap; + + va_start (ap, fmt); + fprintf (stream, "%s ", type); + vfprintf (stream, fmt, ap); + fputs ("\n", stream); + va_end (ap); +} + +#define print_fatal(...) print_message (stderr, "fatal:", __VA_ARGS__) +#define print_error(...) print_message (stderr, "error:", __VA_ARGS__) +#define print_warning(...) print_message (stderr, "warning:", __VA_ARGS__) +#define print_status(...) print_message (stdout, "--", __VA_ARGS__) + +// --- Debugging and assertions ------------------------------------------------ + +// We should check everything that may possibly fail with at least a soft +// assertion, so that any causes for problems don't slip us by silently. +// +// `g_soft_asserts_are_deadly' may be useful while running inside a debugger. + +static bool g_debug_mode; ///< Debug messages are printed +static bool g_soft_asserts_are_deadly; ///< soft_assert() aborts as well + +#define print_debug(...) \ + BLOCK_START \ + if (g_debug_mode) \ + print_message (stderr, "debug:", __VA_ARGS__); \ + BLOCK_END + +static void +assertion_failure_handler (bool is_fatal, const char *file, int line, + const char *function, const char *condition) +{ + if (is_fatal) + { + print_fatal ("assertion failed [%s:%d in function %s]: %s", + file, line, function, condition); + abort (); + } + else + print_debug ("assertion failed [%s:%d in function %s]: %s", + file, line, function, condition); +} + +#define soft_assert(condition) \ + ((condition) ? true : \ + (assertion_failure_handler (g_soft_asserts_are_deadly, \ + __FILE__, __LINE__, __func__, #condition), false)) + +#define hard_assert(condition) \ + ((condition) ? (void) 0 : \ + assertion_failure_handler (true, \ + __FILE__, __LINE__, __func__, #condition)) + +// --- Safe memory management -------------------------------------------------- + +// When a memory allocation fails and we need the memory, we're usually pretty +// much fucked. Use the non-prefixed versions when there's a legitimate +// worry that an unrealistic amount of memory may be requested for allocation. + +// XXX: it's not a good idea to use print_message() as it may want to allocate +// further memory for printf() and the output streams. That may fail. + +static void * +xmalloc (size_t n) +{ + void *p = malloc (n); + if (!p) + { + print_fatal ("malloc: %s", strerror (errno)); + exit (EXIT_FAILURE); + } + return p; +} + +static void * +xcalloc (size_t n, size_t m) +{ + void *p = calloc (n, m); + if (!p && n && m) + { + print_fatal ("calloc: %s", strerror (errno)); + exit (EXIT_FAILURE); + } + return p; +} + +static void * +xrealloc (void *o, size_t n) +{ + void *p = realloc (o, n); + if (!p && n) + { + print_fatal ("realloc: %s", strerror (errno)); + exit (EXIT_FAILURE); + } + return p; +} + +static void * +xreallocarray (void *o, size_t n, size_t m) +{ + if (m && n > SIZE_MAX / m) + { + errno = ENOMEM; + print_fatal ("reallocarray: %s", strerror (errno)); + exit (EXIT_FAILURE); + } + return xrealloc (o, n * m); +} + +static char * +xstrdup (const char *s) +{ + return strcpy (xmalloc (strlen (s) + 1), s); +} + +static char * +xstrndup (const char *s, size_t n) +{ + size_t size = strlen (s); + if (n > size) + n = size; + + char *copy = xmalloc (n + 1); + memcpy (copy, s, n); + copy[n] = '\0'; + return copy; +} + +// --- Double-linked list helpers ---------------------------------------------- + +// The links of the list need to have the members `prev' and `next'. + +#define LIST_PREPEND(head, link) \ + BLOCK_START \ + (link)->prev = NULL; \ + (link)->next = (head); \ + if ((link)->next) \ + (link)->next->prev = (link); \ + (head) = (link); \ + BLOCK_END + +#define LIST_UNLINK(head, link) \ + BLOCK_START \ + if ((link)->prev) \ + (link)->prev->next = (link)->next; \ + else \ + (head) = (link)->next; \ + if ((link)->next) \ + (link)->next->prev = (link)->prev; \ + BLOCK_END + +// --- Dynamically allocated string array -------------------------------------- + +struct str_vector +{ + char **vector; + size_t len; + size_t alloc; +}; + +static void +str_vector_init (struct str_vector *self) +{ + self->alloc = 4; + self->len = 0; + self->vector = xcalloc (sizeof *self->vector, self->alloc); +} + +static void +str_vector_free (struct str_vector *self) +{ + unsigned i; + for (i = 0; i < self->len; i++) + free (self->vector[i]); + + free (self->vector); + self->vector = NULL; +} + +static void +str_vector_add_owned (struct str_vector *self, char *s) +{ + self->vector[self->len] = s; + if (++self->len >= self->alloc) + self->vector = xreallocarray (self->vector, + sizeof *self->vector, (self->alloc <<= 1)); + self->vector[self->len] = NULL; +} + +static void +str_vector_add (struct str_vector *self, const char *s) +{ + str_vector_add_owned (self, xstrdup (s)); +} + +static void +str_vector_add_args (struct str_vector *self, const char *s, ...) + ATTRIBUTE_SENTINEL; + +static void +str_vector_add_args (struct str_vector *self, const char *s, ...) +{ + va_list ap; + + va_start (ap, s); + while (s) + { + str_vector_add (self, s); + s = va_arg (ap, const char *); + } + va_end (ap); +} + +static void +str_vector_add_vector (struct str_vector *self, char **vector) +{ + while (*vector) + str_vector_add (self, *vector++); +} + +static void +str_vector_remove (struct str_vector *self, size_t i) +{ + hard_assert (i < self->len); + free (self->vector[i]); + memmove (self->vector + i, self->vector + i + 1, + (self->len-- - i) * sizeof *self->vector); +} + +// --- Dynamically allocated strings ------------------------------------------- + +// Basically a string builder to abstract away manual memory management. + +struct str +{ + char *str; ///< String data, null terminated + size_t alloc; ///< How many bytes are allocated + size_t len; ///< How long the string actually is +}; + +/// We don't care about allocations that are way too large for the content, as +/// long as the allocation is below the given threshold. (Trivial heuristics.) +#define STR_SHRINK_THRESHOLD (1 << 20) + +static void +str_init (struct str *self) +{ + self->alloc = 16; + self->len = 0; + self->str = strcpy (xmalloc (self->alloc), ""); +} + +static void +str_free (struct str *self) +{ + free (self->str); + self->str = NULL; + self->alloc = 0; + self->len = 0; +} + +static void +str_reset (struct str *self) +{ + str_free (self); + str_init (self); +} + +static char * +str_steal (struct str *self) +{ + char *str = self->str; + self->str = NULL; + str_free (self); + return str; +} + +static void +str_ensure_space (struct str *self, size_t n) +{ + // We allocate at least one more byte for the terminating null character + size_t new_alloc = self->alloc; + while (new_alloc <= self->len + n) + new_alloc <<= 1; + if (new_alloc != self->alloc) + self->str = xrealloc (self->str, (self->alloc = new_alloc)); +} + +static void +str_append_data (struct str *self, const char *data, size_t n) +{ + str_ensure_space (self, n); + memcpy (self->str + self->len, data, n); + self->len += n; + self->str[self->len] = '\0'; +} + +static void +str_append_c (struct str *self, char c) +{ + str_append_data (self, &c, 1); +} + +static void +str_append (struct str *self, const char *s) +{ + str_append_data (self, s, strlen (s)); +} + +static void +str_append_str (struct str *self, const struct str *another) +{ + str_append_data (self, another->str, another->len); +} + +static int +str_append_vprintf (struct str *self, const char *fmt, va_list va) +{ + va_list ap; + int size; + + va_copy (ap, va); + size = vsnprintf (NULL, 0, fmt, ap); + va_end (ap); + + if (size < 0) + return -1; + + va_copy (ap, va); + str_ensure_space (self, size); + size = vsnprintf (self->str + self->len, self->alloc - self->len, fmt, ap); + va_end (ap); + + if (size > 0) + self->len += size; + + return size; +} + +static int +str_append_printf (struct str *self, const char *fmt, ...) + ATTRIBUTE_PRINTF (2, 3); + +static int +str_append_printf (struct str *self, const char *fmt, ...) +{ + va_list ap; + + va_start (ap, fmt); + int size = str_append_vprintf (self, fmt, ap); + va_end (ap); + return size; +} + +static void +str_remove_slice (struct str *self, size_t start, size_t length) +{ + size_t end = start + length; + hard_assert (end <= self->len); + memmove (self->str + start, self->str + end, self->len - end); + self->str[self->len -= length] = '\0'; + + // Shrink the string if the allocation becomes way too large + if (self->alloc >= STR_SHRINK_THRESHOLD && self->len < (self->alloc >> 2)) + self->str = xrealloc (self->str, self->alloc >>= 2); +} + +// --- Errors ------------------------------------------------------------------ + +// Error reporting utilities. Inspired by GError, only much simpler. + +struct error +{ + size_t domain; ///< The domain of the error + int id; ///< The concrete error ID + char *message; ///< Textual description of the event +}; + +static size_t +error_resolve_domain (size_t *tag) +{ + // This method is fairly sensitive to the order in which resolution + // requests come in, does not provide a good way of decoding the number + // back to a meaningful identifier, and may not play all too well with + // dynamic libraries when a module is e.g. statically linked into multiple + // libraries, but it's fast, simple, and more than enough for our purposes. + static size_t domain_counter; + + if (!*tag) + *tag = ++domain_counter; + return *tag; +} + +static void +error_set (struct error **e, size_t domain, int id, + const char *message, ...) ATTRIBUTE_PRINTF (4, 5); + +static void +error_set (struct error **e, size_t domain, int id, + const char *message, ...) +{ + if (!e) + return; + + va_list ap; + va_start (ap, message); + int size = snprintf (NULL, 0, message, ap); + va_end (ap); + + hard_assert (size >= 0); + + struct error *tmp = xmalloc (sizeof *tmp); + tmp->domain = domain; + tmp->id = id; + tmp->message = xmalloc (size + 1); + + va_start (ap, message); + size = snprintf (tmp->message, size + 1, message, ap); + va_end (ap); + + hard_assert (size >= 0); + + soft_assert (*e == NULL); + *e = tmp; +} + +static void +error_free (struct error *e) +{ + free (e->message); + free (e); +} + +static void +error_propagate (struct error **destination, struct error *source) +{ + if (!destination) + { + error_free (source); + return; + } + + soft_assert (*destination == NULL); + *destination = source; +} + +// --- String hash map --------------------------------------------------------- + +// The most basic <string, managed pointer> map (or associative array). + +struct str_map_link +{ + struct str_map_link *next; ///< The next link in a chain + struct str_map_link *prev; ///< The previous link in a chain + + void *data; ///< Payload + size_t key_length; ///< Length of the key without '\0' + char key[]; ///< The key for this link +}; + +struct str_map +{ + struct str_map_link **map; ///< The hash table data itself + size_t alloc; ///< Number of allocated entries + size_t len; ///< Number of entries in the table + void (*free) (void *); ///< Callback to destruct the payload +}; + +#define STR_MAP_MIN_ALLOC 16 + +typedef void (*str_map_free_func) (void *); + +static void +str_map_init (struct str_map *self) +{ + self->alloc = STR_MAP_MIN_ALLOC; + self->len = 0; + self->free = NULL; + self->map = xcalloc (self->alloc, sizeof *self->map); +} + +static void +str_map_free (struct str_map *self) +{ + struct str_map_link **iter, **end = self->map + self->alloc; + struct str_map_link *link, *tmp; + + for (iter = self->map; iter < end; iter++) + for (link = *iter; link; link = tmp) + { + tmp = link->next; + if (self->free) + self->free (link->data); + free (link); + } + + free (self->map); + self->map = NULL; +} + +static uint64_t +str_map_hash (const char *s, size_t len) +{ + static unsigned char key[16] = "SipHash 2-4 key!"; + return siphash (key, (const void *) s, len); +} + +static uint64_t +str_map_pos (struct str_map *self, const char *s) +{ + size_t mask = self->alloc - 1; + return str_map_hash (s, strlen (s)) & mask; +} + +static uint64_t +str_map_link_hash (struct str_map_link *self) +{ + return str_map_hash (self->key, self->key_length); +} + +static void +str_map_resize (struct str_map *self, size_t new_size) +{ + struct str_map_link **old_map = self->map; + size_t i, old_size = self->alloc; + + // Only powers of two, so that we don't need to compute the modulo + hard_assert ((new_size & (new_size - 1)) == 0); + size_t mask = new_size - 1; + + self->alloc = new_size; + self->map = xcalloc (self->alloc, sizeof *self->map); + for (i = 0; i < old_size; i++) + { + struct str_map_link *iter = old_map[i], *next_iter; + while (iter) + { + next_iter = iter->next; + uint64_t pos = str_map_link_hash (iter) & mask; + LIST_PREPEND (self->map[pos], iter); + iter = next_iter; + } + } + + free (old_map); +} + +static void +str_map_set (struct str_map *self, const char *key, void *value) +{ + uint64_t pos = str_map_pos (self, key); + struct str_map_link *iter = self->map[pos]; + for (; iter; iter = iter->next) + { + if (strcmp (key, iter->key)) + continue; + + // Storing the same data doesn't destroy it + if (self->free && value != iter->data) + self->free (iter->data); + + if (value) + { + iter->data = value; + return; + } + + LIST_UNLINK (self->map[pos], iter); + free (iter); + self->len--; + + // The array should be at least 1/4 full + if (self->alloc >= (STR_MAP_MIN_ALLOC << 2) + && self->len < (self->alloc >> 2)) + str_map_resize (self, self->alloc >> 2); + return; + } + + if (!value) + return; + + if (self->len >= self->alloc) + { + str_map_resize (self, self->alloc << 1); + pos = str_map_pos (self, key); + } + + // Link in a new element for the given <key, value> pair + size_t key_length = strlen (key); + struct str_map_link *link = xmalloc (sizeof *link + key_length + 1); + link->data = value; + link->key_length = key_length; + memcpy (link->key, key, key_length + 1); + + LIST_PREPEND (self->map[pos], link); + self->len++; +} + +static void * +str_map_find (struct str_map *self, const char *key) +{ + struct str_map_link *iter = self->map[str_map_pos (self, key)]; + for (; iter; iter = iter->next) + if (!strcmp (key, (char *) iter + sizeof *iter)) + return iter->data; + return NULL; +} + +// --- File descriptor utilities ----------------------------------------------- + +static void +set_cloexec (int fd) +{ + soft_assert (fcntl (fd, F_SETFD, fcntl (fd, F_GETFD) | FD_CLOEXEC) != -1); +} + +static bool +set_blocking (int fd, bool blocking) +{ + int flags = fcntl (fd, F_GETFL); + hard_assert (flags != -1); + + bool prev = !(flags & O_NONBLOCK); + if (blocking) + flags &= ~O_NONBLOCK; + else + flags |= O_NONBLOCK; + + hard_assert (fcntl (fd, F_SETFL, flags) != -1); + return prev; +} + +static void +xclose (int fd) +{ + while (close (fd) == -1) + if (!soft_assert (errno == EINTR)) + break; +} + +// --- Polling ----------------------------------------------------------------- + +// Basically the poor man's GMainLoop/libev/libuv. It might make some sense +// to instead use those tested and proven libraries but we don't need much +// and it's interesting to implement. + +// At the moment the FD's are stored in an unsorted array. This is not ideal +// complexity-wise but I don't think I have much of a choice with poll(), +// and neither with epoll for that matter. +// +// unsorted array sorted array +// search O(n) O(log n) [O(log log n)] +// insert by fd O(n) O(n) +// delete by fd O(n) O(n) +// +// Insertion in the unsorted array can be reduced to O(1) if I maintain a +// bitmap of present FD's but that's still not a huge win. +// +// I don't expect this to be much of an issue, as there are typically not going +// to be that many FD's to watch, and the linear approach is cache-friendly. + +typedef void (*poller_dispatcher_func) (const struct pollfd *, void *); + +#define POLLER_MIN_ALLOC 16 + +#ifdef __linux__ + +// I don't really need this, I've basically implemented this just because I can. + +#include <sys/epoll.h> + +struct poller_info +{ + int fd; ///< Our file descriptor + uint32_t events; ///< The events we registered + poller_dispatcher_func dispatcher; ///< Event dispatcher + void *user_data; ///< User data +}; + +struct poller +{ + int epoll_fd; ///< The epoll FD + struct poller_info **info; ///< Information associated with each FD + struct epoll_event *revents; ///< Output array for epoll_wait() + size_t len; ///< Number of polled descriptors + size_t alloc; ///< Number of entries allocated + + /// Index of the element in `revents' that's currently being dispatched, + /// or -1 if we're not dispatching at the moment. + int dispatch_iterator; + + /// The total number of entries stored in `revents' by epoll_wait(). + int dispatch_total; +}; + +static void +poller_init (struct poller *self) +{ + self->epoll_fd = epoll_create (POLLER_MIN_ALLOC); + hard_assert (self->epoll_fd != -1); + set_cloexec (self->epoll_fd); + + self->len = 0; + self->alloc = POLLER_MIN_ALLOC; + self->info = xcalloc (self->alloc, sizeof *self->info); + self->revents = xcalloc (self->alloc, sizeof *self->revents); + + self->dispatch_iterator = -1; + self->dispatch_total = 0; +} + +static void +poller_free (struct poller *self) +{ + for (size_t i = 0; i < self->len; i++) + { + struct poller_info *info = self->info[i]; + hard_assert (epoll_ctl (self->epoll_fd, + EPOLL_CTL_DEL, info->fd, (void *) "") != -1); + free (info); + } + + xclose (self->epoll_fd); + free (self->info); + free (self->revents); +} + +static ssize_t +poller_find_by_fd (struct poller *self, int fd) +{ + for (size_t i = 0; i < self->len; i++) + if (self->info[i]->fd == fd) + return i; + return -1; +} + +static void +poller_ensure_space (struct poller *self) +{ + if (self->len < self->alloc) + return; + + self->alloc <<= 1; + self->revents = xreallocarray + (self->revents, sizeof *self->revents, self->alloc); + self->info = xreallocarray + (self->info, sizeof *self->info, self->alloc); +} + +static int +poller_epoll_to_poll_events (int events) +{ + int result = 0; + if (events & EPOLLIN) result |= POLLIN; + if (events & EPOLLOUT) result |= POLLOUT; + if (events & EPOLLERR) result |= POLLERR; + if (events & EPOLLHUP) result |= POLLHUP; + if (events & EPOLLPRI) result |= POLLPRI; + return result; +} + +static uint32_t +poller_poll_to_epoll_events (uint32_t events) +{ + uint32_t result = 0; + if (events & POLLIN) result |= EPOLLIN; + if (events & POLLOUT) result |= EPOLLOUT; + if (events & POLLERR) result |= EPOLLERR; + if (events & POLLHUP) result |= EPOLLHUP; + if (events & POLLPRI) result |= EPOLLPRI; + return result; +} + +static void +poller_set (struct poller *self, int fd, short int events, + poller_dispatcher_func dispatcher, void *data) +{ + ssize_t index = poller_find_by_fd (self, fd); + bool modifying = true; + if (index == -1) + { + poller_ensure_space (self); + self->info[index = self->len++] = xcalloc (1, sizeof **self->info); + modifying = false; + } + + struct poller_info *info = self->info[index]; + info->fd = fd; + info->dispatcher = dispatcher; + info->user_data = data; + + struct epoll_event event; + event.events = poller_poll_to_epoll_events (events); + event.data.ptr = info; + hard_assert (epoll_ctl (self->epoll_fd, + modifying ? EPOLL_CTL_MOD : EPOLL_CTL_ADD, fd, &event) != -1); +} + +static void +poller_remove_from_dispatch (struct poller *self, + const struct poller_info *info) +{ + if (self->dispatch_iterator == -1) + return; + + int i; + for (i = self->dispatch_iterator; i < self->dispatch_total; i++) + if (self->revents[i].data.ptr == info) + break; + if (i == self->dispatch_total) + return; + + if (i != --self->dispatch_total) + self->revents[i] = self->revents[self->dispatch_total]; + + // We've removed the element we're currently processing; go back one entry + // so that we don't skip the one we might have replaced it with. + if (i == self->dispatch_iterator) + self->dispatch_iterator--; +} + +static void +poller_remove_at_index (struct poller *self, size_t index) +{ + hard_assert (index < self->len); + struct poller_info *info = self->info[index]; + + poller_remove_from_dispatch (self, info); + hard_assert (epoll_ctl (self->epoll_fd, + EPOLL_CTL_DEL, info->fd, (void *) "") != -1); + + free (info); + if (index != --self->len) + self->info[index] = self->info[self->len]; +} + +static void +poller_run (struct poller *self) +{ + // Not reentrant + hard_assert (self->dispatch_iterator == -1); + + int n_fds; + do + n_fds = epoll_wait (self->epoll_fd, self->revents, self->len, -1); + while (n_fds == -1 && errno == EINTR); + + if (n_fds == -1) + { + print_fatal ("%s: %s", "epoll", strerror (errno)); + exit (EXIT_FAILURE); + } + + for (int i = 0; i < n_fds; i++) + { + struct epoll_event *revents = self->revents + i; + struct poller_info *info = revents->data.ptr; + + struct pollfd pfd; + pfd.fd = info->fd; + pfd.revents = poller_epoll_to_poll_events (revents->events); + pfd.events = poller_epoll_to_poll_events (info->events); + + self->dispatch_iterator = i; + self->dispatch_total = n_fds; + + info->dispatcher (&pfd, info->user_data); + + i = self->dispatch_iterator; + n_fds = self->dispatch_total; + } + + self->dispatch_iterator = -1; + self->dispatch_total = 0; +} + +#else // !__linux__ + +struct poller_info +{ + poller_dispatcher_func dispatcher; ///< Event dispatcher + void *user_data; ///< User data +}; + +struct poller +{ + struct pollfd *fds; ///< Polled descriptors + struct poller_info *fds_info; ///< Additional information for each FD + size_t len; ///< Number of polled descriptors + size_t alloc; ///< Number of entries allocated + + int dispatch_index; ///< The currently dispatched FD or -1 +}; + +static void +poller_init (struct poller *self) +{ + self->alloc = POLLER_MIN_ALLOC; + self->len = 0; + self->fds = xcalloc (self->alloc, sizeof *self->fds); + self->fds_info = xcalloc (self->alloc, sizeof *self->fds_info); + self->dispatch_index = -1; +} + +static void +poller_free (struct poller *self) +{ + free (self->fds); + free (self->fds_info); +} + +static ssize_t +poller_find_by_fd (struct poller *self, int fd) +{ + for (size_t i = 0; i < self->len; i++) + if (self->fds[i].fd == fd) + return i; + return -1; +} + +static void +poller_ensure_space (struct poller *self) +{ + if (self->len < self->alloc) + return; + + self->alloc <<= 1; + self->fds = xreallocarray (self->fds, sizeof *self->fds, self->alloc); + self->fds_info = xreallocarray + (self->fds_info, sizeof *self->fds_info, self->alloc); +} + +static void +poller_set (struct poller *self, int fd, short int events, + poller_dispatcher_func dispatcher, void *data) +{ + ssize_t index = poller_find_by_fd (self, fd); + if (index == -1) + { + poller_ensure_space (self); + index = self->len++; + } + + struct pollfd *new_entry = self->fds + index; + memset (new_entry, 0, sizeof *new_entry); + new_entry->fd = fd; + new_entry->events = events; + + self->fds_info[self->len] = (struct poller_info) { dispatcher, data }; +} + +static void +poller_remove_at_index (struct poller *self, size_t index) +{ + hard_assert (index < self->len); + if (index == --self->len) + return; + + // Make sure that we don't disrupt the dispatch loop; kind of crude + if ((int) index < self->dispatch_index) + { + memmove (self->fds + index, self->fds + index + 1, + (self->len - index) * sizeof *self->fds); + memmove (self->fds_info + index, self->fds_info + index + 1, + (self->len - index) * sizeof *self->fds_info); + } + else + { + self->fds[index] = self->fds[self->len]; + self->fds_info[index] = self->fds_info[self->len]; + } + + if ((int) index <= self->dispatch_index) + self->dispatch_index--; +} + +static void +poller_run (struct poller *self) +{ + // Not reentrant + hard_assert (self->dispatch_index == -1); + + int result; + do + result = poll (self->fds, self->len, -1); + while (result == -1 && errno == EINTR); + + if (result == -1) + { + print_fatal ("%s: %s", "poll", strerror (errno)); + exit (EXIT_FAILURE); + } + + for (int i = 0; i < (int) self->len; i++) + { + struct pollfd pfd = self->fds[i]; + if (!pfd.revents) + continue; + + struct poller_info *info = self->fds_info + i; + self->dispatch_index = i; + info->dispatcher (&pfd, info->user_data); + i = self->dispatch_index; + } + + self->dispatch_index = -1; +} + +#endif // !__linux__ + +// --- Utilities --------------------------------------------------------------- + +static void +split_str_ignore_empty (const char *s, char delimiter, struct str_vector *out) +{ + const char *begin = s, *end; + + while ((end = strchr (begin, delimiter))) + { + if (begin != end) + str_vector_add_owned (out, xstrndup (begin, end - begin)); + begin = ++end; + } + + if (*begin) + str_vector_add (out, begin); +} + +static char * +strip_str_in_place (char *s, const char *stripped_chars) +{ + char *end = s + strlen (s); + while (end > s && strchr (stripped_chars, end[-1])) + *--end = '\0'; + + char *start = s + strspn (s, stripped_chars); + if (start > s) + memmove (s, start, end - start + 1); + return s; +} + +static bool +str_append_env_path (struct str *output, const char *var, bool only_absolute) +{ + const char *value = getenv (var); + + if (!value || (only_absolute && *value != '/')) + return false; + + str_append (output, value); + return true; +} + +static void +get_xdg_home_dir (struct str *output, const char *var, const char *def) +{ + str_reset (output); + if (!str_append_env_path (output, var, true)) + { + str_append_env_path (output, "HOME", false); + str_append_c (output, '/'); + str_append (output, def); + } +} + +static size_t io_error_domain_tag; +#define IO_ERROR (error_resolve_domain (&io_error_domain_tag)) + +enum +{ + IO_ERROR_FAILED +}; + +static bool +ensure_directory_existence (const char *path, struct error **e) +{ + struct stat st; + + if (stat (path, &st)) + { + if (mkdir (path, S_IRWXU | S_IRWXG | S_IRWXO)) + { + error_set (e, IO_ERROR, IO_ERROR_FAILED, + "cannot create directory `%s': %s", + path, strerror (errno)); + return false; + } + } + else if (!S_ISDIR (st.st_mode)) + { + error_set (e, IO_ERROR, IO_ERROR_FAILED, + "cannot create directory `%s': %s", + path, "file exists but is not a directory"); + return false; + } + return true; +} + +static bool +mkdir_with_parents (char *path, struct error **e) +{ + char *p = path; + + // XXX: This is prone to the TOCTTOU problem. The solution would be to + // rewrite the function using the {mkdir,fstat}at() functions from + // POSIX.1-2008, ideally returning a file descriptor to the open + // directory, with the current code as a fallback. Or to use chdir(). + while ((p = strchr (p + 1, '/'))) + { + *p = '\0'; + bool success = ensure_directory_existence (path, e); + *p = '/'; + + if (!success) + return false; + } + + return ensure_directory_existence (path, e); +} + +static bool +set_boolean_if_valid (bool *out, const char *s) +{ + if (!strcasecmp (s, "yes")) *out = true; + else if (!strcasecmp (s, "no")) *out = false; + else if (!strcasecmp (s, "on")) *out = true; + else if (!strcasecmp (s, "off")) *out = false; + else if (!strcasecmp (s, "true")) *out = true; + else if (!strcasecmp (s, "false")) *out = false; + else return false; + + return true; +} + +static void +regerror_to_str (int code, const regex_t *preg, struct str *out) +{ + size_t required = regerror (code, preg, NULL, 0); + str_ensure_space (out, required); + out->len += regerror (code, preg, + out->str + out->len, out->alloc - out->len) - 1; +} + +static size_t regex_error_domain_tag; +#define REGEX_ERROR (error_resolve_domain (®ex_error_domain_tag)) + +enum +{ + REGEX_ERROR_COMPILATION_FAILED +}; + +static bool +regex_match (const char *regex, const char *s, struct error **e) +{ + regex_t re; + int err = regcomp (&re, regex, REG_EXTENDED | REG_NOSUB); + if (err) + { + struct str desc; + + str_init (&desc); + regerror_to_str (err, &re, &desc); + error_set (e, REGEX_ERROR, REGEX_ERROR_COMPILATION_FAILED, + "failed to compile regular expression: %s", desc.str); + str_free (&desc); + return false; + } + + bool result = regexec (&re, s, 0, NULL, 0) != REG_NOMATCH; + regfree (&re); + return result; +} + +static bool +read_line (FILE *fp, struct str *s) +{ + int c; + bool at_end = true; + + str_reset (s); + while ((c = fgetc (fp)) != EOF) + { + at_end = false; + if (c == '\r') + continue; + if (c == '\n') + break; + str_append_c (s, c); + } + + return !at_end; +} + +// --- IRC utilities ----------------------------------------------------------- + +struct irc_message +{ + char *prefix; + char *command; + struct str_vector params; +}; + +static void +irc_parse_message (struct irc_message *msg, const char *line) +{ + msg->prefix = NULL; + msg->command = NULL; + str_vector_init (&msg->params); + + // Prefix + if (*line == ':') + { + size_t prefix_len = strcspn (++line, " "); + msg->prefix = xstrndup (line, prefix_len); + line += prefix_len; + } + + // Command name + { + while (*line == ' ') + line++; + + size_t cmd_len = strcspn (line, " "); + msg->command = xstrndup (line, cmd_len); + line += cmd_len; + } + + // Arguments + while (true) + { + while (*line == ' ') + line++; + + if (*line == ':') + { + str_vector_add (&msg->params, ++line); + break; + } + + size_t param_len = strcspn (line, " "); + if (!param_len) + break; + + str_vector_add_owned (&msg->params, xstrndup (line, param_len)); + line += param_len; + } +} + +static void +irc_free_message (struct irc_message *msg) +{ + free (msg->prefix); + free (msg->command); + str_vector_free (&msg->params); +} + +static void +irc_process_buffer (struct str *buf, + void (*callback)(const struct irc_message *, const char *, void *), + void *user_data) +{ + char *start = buf->str; + char *end = start + buf->len; + + for (char *p = start; p + 1 < end; p++) + { + // Split the input on newlines + if (p[0] != '\r' || p[1] != '\n') + continue; + + *p = 0; + + struct irc_message msg; + irc_parse_message (&msg, start); + callback (&msg, start, user_data); + irc_free_message (&msg); + + start = p + 2; + } + + str_remove_slice (buf, 0, start - buf->str); +} + +// --- Configuration ----------------------------------------------------------- + +// The keys are stripped of surrounding whitespace, the values are not. + +static size_t config_error_domain_tag; +#define CONFIG_ERROR (error_resolve_domain (&config_error_domain_tag)) + +enum +{ + CONFIG_ERROR_MALFORMED +}; + +struct config_item +{ + const char *key; + const char *default_value; + const char *description; +}; + +static FILE * +get_config_file (void) +{ + struct str_vector paths; + struct str config_home, file; + const char *xdg_config_dirs; + unsigned i; + FILE *fp = NULL; + + str_vector_init (&paths); + + str_init (&config_home); + get_xdg_home_dir (&config_home, "XDG_CONFIG_HOME", ".config"); + str_vector_add (&paths, config_home.str); + str_free (&config_home); + + if ((xdg_config_dirs = getenv ("XDG_CONFIG_DIRS"))) + split_str_ignore_empty (xdg_config_dirs, ':', &paths); + + str_init (&file); + for (i = 0; i < paths.len; i++) + { + // As per spec, relative paths are ignored + if (*paths.vector[i] != '/') + continue; + + str_reset (&file); + str_append (&file, paths.vector[i]); + str_append (&file, "/" PROGRAM_NAME "/" PROGRAM_NAME ".conf"); + + if ((fp = fopen (file.str, "r"))) + break; + } + + str_free (&file); + str_vector_free (&paths); + return fp; +} + +static bool +read_config_file (struct str_map *config, struct error **e) +{ + struct str line; + FILE *fp = get_config_file (); + unsigned line_no = 0; + bool errors = false; + + if (!fp) + return true; + + str_init (&line); + for (line_no = 1; read_line (fp, &line); line_no++) + { + char *start = line.str; + if (*start == '#') + continue; + + while (isspace (*start)) + start++; + + char *end = strchr (start, '='); + if (!end) + { + if (*start) + { + error_set (e, CONFIG_ERROR, CONFIG_ERROR_MALFORMED, + "line %u in config: %s", line_no, "malformed input"); + errors = true; + break; + } + } + else + { + char *value = end + 1; + do + *end = '\0'; + while (isspace (*--end)); + + str_map_set (config, start, xstrdup (value)); + } + } + + str_free (&line); + fclose (fp); + + return !errors; +} + +// --- Configuration (application-specific) ------------------------------------ + +static struct config_item g_config_table[] = +{ + { "nickname", "ZyklonB", "IRC nickname" }, + { "username", "bot", "IRC user name" }, + { "fullname", "ZyklonB IRC bot", "IRC full name/e-mail" }, + + { "irc_host", NULL, "Address of the IRC server" }, + { "irc_port", "6667", "Port of the IRC server" }, + { "ssl_use", "off", "Whether to use SSL" }, + { "ssl_cert", NULL, "Client SSL certificate (PEM)" }, + { "autojoin", NULL, "Channels to join on start" }, + { "reconnect", "on", "Whether to reconnect on error" }, + { "reconnect_delay", "5", "Time between reconnecting" }, + + { "prefix", ":", "The prefix for bot commands" }, + { "admin", NULL, "Host mask for administrators" }, + { "plugins", NULL, "The plugins to load on startup" }, + { "plugin_dir", NULL, "Where to search for plugins" }, + { "recover", "on", "Whether to re-launch on crash" }, +}; + +static void +load_config_defaults (struct str_map *config) +{ + for (size_t i = 0; i < N_ELEMENTS (g_config_table); i++) + { + const struct config_item *item = g_config_table + i; + if (item->default_value) + str_map_set (config, item->key, xstrdup (item->default_value)); + } +} + +// --- Application data -------------------------------------------------------- + +struct plugin_data +{ + struct plugin_data *next; ///< The next link in a chain + struct plugin_data *prev; ///< The previous link in a chain + + struct bot_context *ctx; ///< Parent context + + pid_t pid; ///< PID of the plugin process + char *name; ///< Plugin identifier + + bool is_zombie; ///< Whether the child is a zombie + bool initialized; ///< Ready to exchange IRC messages + struct str queued_output; ///< Output queued up until initialized + + // Since we're doing non-blocking I/O, we need to queue up data so that + // we don't stall on plugins unnecessarily. + + int read_fd; ///< The read end of the comm. pipe + struct str read_buffer; ///< Unprocessed input + + int write_fd; ///< The write end of the comm. pipe + struct str write_buffer; ///< Output yet to be sent out +}; + +static void +plugin_data_init (struct plugin_data *self) +{ + memset (self, 0, sizeof *self); + + self->pid = -1; + str_init (&self->queued_output); + + self->read_fd = -1; + str_init (&self->read_buffer); + self->write_fd = -1; + str_init (&self->write_buffer); +} + +static void +plugin_data_free (struct plugin_data *self) +{ + soft_assert (self->pid == -1); + free (self->name); + + str_free (&self->read_buffer); + if (!soft_assert (self->read_fd == -1)) + xclose (self->read_fd); + + str_free (&self->write_buffer); + if (!soft_assert (self->write_fd == -1)) + xclose (self->write_fd); + + if (!self->initialized) + str_free (&self->queued_output); +} + +static size_t connect_error_domain_tag; +#define CONNECT_ERROR (error_resolve_domain (&connect_error_domain_tag)) + +enum +{ + CONNECT_ERROR_INVALID_CONFIGURATION, + CONNECT_ERROR_FAILED +}; + +struct bot_context +{ + struct str_map config; ///< User configuration + + int irc_fd; ///< Socket FD of the server + struct str read_buffer; ///< Input yet to be processed + bool irc_ready; ///< Whether we may send messages now + + SSL_CTX *ssl_ctx; ///< SSL context + SSL *ssl; ///< SSL connection + + struct plugin_data *plugins; ///< Linked list of plugins + struct str_map plugins_by_name; ///< Indexes @em plugins by their name + + struct poller poller; ///< Manages polled descriptors + bool quitting; ///< User requested quitting + bool polling; ///< The event loop is running +}; + +static void +bot_context_init (struct bot_context *ctx) +{ + str_map_init (&ctx->config); + ctx->config.free = free; + load_config_defaults (&ctx->config); + + ctx->irc_fd = -1; + str_init (&ctx->read_buffer); + ctx->irc_ready = false; + + ctx->ssl = NULL; + ctx->ssl_ctx = NULL; + + ctx->plugins = NULL; + str_map_init (&ctx->plugins_by_name); + + poller_init (&ctx->poller); + ctx->quitting = false; + ctx->polling = false; +} + +static void +bot_context_free (struct bot_context *ctx) +{ + str_map_free (&ctx->config); + str_free (&ctx->read_buffer); + + // TODO: terminate the plugins properly before this is called + struct plugin_data *link, *tmp; + for (link = ctx->plugins; link; link = tmp) + { + tmp = link->next; + plugin_data_free (link); + free (link); + } + + if (ctx->irc_fd != -1) + xclose (ctx->irc_fd); + if (ctx->ssl) + SSL_free (ctx->ssl); + if (ctx->ssl_ctx) + SSL_CTX_free (ctx->ssl_ctx); + + str_map_free (&ctx->plugins_by_name); + poller_free (&ctx->poller); +} + +static void +irc_shutdown (struct bot_context *ctx) +{ + // Generally non-critical + if (ctx->ssl) + soft_assert (SSL_shutdown (ctx->ssl) != -1); + else + soft_assert (shutdown (ctx->irc_fd, SHUT_WR) == 0); +} + +static void +initiate_quit (struct bot_context *ctx) +{ + irc_shutdown (ctx); + ctx->quitting = true; +} + +static void +try_finish_quit (struct bot_context *ctx) +{ + if (!ctx->quitting) + return; + if (ctx->irc_fd == -1 && !ctx->plugins) + ctx->polling = false; +} + +static bool irc_send (struct bot_context *ctx, + const char *format, ...) ATTRIBUTE_PRINTF (2, 3); + +// XXX: is it okay to just ignore the return value and wait until we receive +// it in on_irc_readable()? +static bool +irc_send (struct bot_context *ctx, const char *format, ...) +{ + va_list ap; + + if (g_debug_mode) + { + fputs ("[IRC] <== \"", stderr); + va_start (ap, format); + vfprintf (stderr, format, ap); + va_end (ap); + fputs ("\"\n", stderr); + } + + soft_assert (ctx->irc_fd != -1); + + va_start (ap, format); + struct str str; + str_init (&str); + str_append_vprintf (&str, format, ap); + str_append (&str, "\r\n"); + va_end (ap); + + bool result = true; + if (ctx->ssl) + { + // TODO: call SSL_get_error() to detect if a clean shutdown has occured + if (SSL_write (ctx->ssl, str.str, str.len) != (int) str.len) + { + print_debug ("%s: %s: %s", __func__, "SSL_write", + ERR_error_string (ERR_get_error (), NULL)); + result = false; + } + } + else if (write (ctx->irc_fd, str.str, str.len) != (ssize_t) str.len) + { + print_debug ("%s: %s: %s", __func__, "write", strerror (errno)); + result = false; + } + + str_free (&str); + return result; +} + +static bool +irc_initialize_ssl (struct bot_context *ctx, struct error **e) +{ + ctx->ssl_ctx = SSL_CTX_new (SSLv23_client_method ()); + if (!ctx->ssl_ctx) + goto error_ssl_1; + // We don't care; some encryption is always better than no encryption + SSL_CTX_set_verify (ctx->ssl_ctx, SSL_VERIFY_NONE, NULL); + // XXX: maybe we should call SSL_CTX_set_options() for some workarounds + + ctx->ssl = SSL_new (ctx->ssl_ctx); + if (!ctx->ssl) + goto error_ssl_2; + + const char *ssl_cert = str_map_find (&ctx->config, "ssl_cert"); + if (ssl_cert + && !SSL_use_certificate_file (ctx->ssl, ssl_cert, SSL_FILETYPE_PEM)) + { + // XXX: perhaps we should read the file ourselves for better messages + print_error ("%s: %s", "setting the SSL client certificate failed", + ERR_error_string (ERR_get_error (), NULL)); + } + + SSL_set_connect_state (ctx->ssl); + if (!SSL_set_fd (ctx->ssl, ctx->irc_fd)) + goto error_ssl_3; + // Avoid SSL_write() returning SSL_ERROR_WANT_READ + SSL_set_mode (ctx->ssl, SSL_MODE_AUTO_RETRY); + if (SSL_connect (ctx->ssl) > 0) + return true; + +error_ssl_3: + SSL_free (ctx->ssl); + ctx->ssl = NULL; +error_ssl_2: + SSL_CTX_free (ctx->ssl_ctx); + ctx->ssl_ctx = NULL; +error_ssl_1: + // XXX: these error strings are really nasty; also there could be + // multiple errors on the OpenSSL stack. + error_set (e, CONNECT_ERROR, CONNECT_ERROR_FAILED, + "%s: %s", "could not initialize SSL", + ERR_error_string (ERR_get_error (), NULL)); + return false; +} + +static bool +irc_establish_connection (struct bot_context *ctx, + const char *host, const char *port, bool use_ssl, struct error **e) +{ + struct addrinfo gai_hints, *gai_result, *gai_iter; + memset (&gai_hints, 0, sizeof gai_hints); + + // We definitely want TCP. + gai_hints.ai_socktype = SOCK_STREAM; + + int err = getaddrinfo (host, port, &gai_hints, &gai_result); + if (err) + { + error_set (e, CONNECT_ERROR, CONNECT_ERROR_FAILED, "%s: %s: %s", + "connection failed", "getaddrinfo", gai_strerror (err)); + return false; + } + + int sockfd; + for (gai_iter = gai_result; gai_iter; gai_iter = gai_iter->ai_next) + { + sockfd = socket (gai_iter->ai_family, + gai_iter->ai_socktype, gai_iter->ai_protocol); + if (sockfd == -1) + continue; + set_cloexec (sockfd); + + int yes = 1; + soft_assert (setsockopt (sockfd, SOL_SOCKET, SO_KEEPALIVE, + &yes, sizeof yes) != -1); + + const char *real_host = host; + + // Let's try to resolve the address back into a real hostname; + // we don't really need this, so we can let it quietly fail + char buf[NI_MAXHOST]; + err = getnameinfo (gai_iter->ai_addr, gai_iter->ai_addrlen, + buf, sizeof buf, NULL, 0, 0); + if (err) + print_debug ("%s: %s", "getnameinfo", gai_strerror (err)); + else + real_host = buf; + + // XXX: we shouldn't mix these statuses with `struct error'; choose 1! + print_status ("connecting to `%s:%s'...", real_host, port); + if (!connect (sockfd, gai_iter->ai_addr, gai_iter->ai_addrlen)) + break; + + xclose (sockfd); + } + + freeaddrinfo (gai_result); + + if (!gai_iter) + { + error_set (e, CONNECT_ERROR, CONNECT_ERROR_FAILED, "connection failed"); + return false; + } + + ctx->irc_fd = sockfd; + if (use_ssl && !irc_initialize_ssl (ctx, e)) + { + xclose (ctx->irc_fd); + ctx->irc_fd = -1; + return false; + } + + print_status ("connection established"); + return true; +} + +// --- Signals ----------------------------------------------------------------- + +static int g_signal_pipe[2]; ///< A pipe used to signal... signals + +static struct str_vector + g_original_argv, ///< Original program arguments + g_recovery_env; ///< Environment for re-exec recovery + +/// Program termination has been requested by a signal +static volatile sig_atomic_t g_termination_requested; + +/// Points to startup reason location within `g_recovery_environment' +static char **g_startup_reason_location; +/// The environment variable used to pass the startup reason when re-executing +static const char g_startup_reason_str[] = "STARTUP_REASON"; + +static void +sigchld_handler (int signum) +{ + (void) signum; + + int original_errno = errno; + // Just so that the read end of the pipe wakes up the poller. + // NOTE: Linux has signalfd() and eventfd(), and the BSD's have kqueue. + // All of them are better than this approach, although platform-specific. + if (write (g_signal_pipe[1], "c", 1) == -1) + soft_assert (errno == EAGAIN); + errno = original_errno; +} + +static void +sigterm_handler (int signum) +{ + (void) signum; + + g_termination_requested = true; + + int original_errno = errno; + if (write (g_signal_pipe[1], "t", 1) == -1) + soft_assert (errno == EAGAIN); + errno = original_errno; +} + +static void +setup_signal_handlers (void) +{ + if (pipe (g_signal_pipe) == -1) + { + print_fatal ("pipe: %s", strerror (errno)); + exit (EXIT_FAILURE); + } + + set_cloexec (g_signal_pipe[0]); + set_cloexec (g_signal_pipe[1]); + + // So that the pipe cannot overflow; it would make write() block within + // the signal handler, which is something we really don't want to happen. + // The same holds true for read(). + set_blocking (g_signal_pipe[0], false); + set_blocking (g_signal_pipe[1], false); + + struct sigaction sa; + sa.sa_flags = SA_RESTART; + sa.sa_handler = sigchld_handler; + sigemptyset (&sa.sa_mask); + + if (sigaction (SIGCHLD, &sa, NULL) == -1) + { + print_fatal ("sigaction: %s", strerror (errno)); + exit (EXIT_FAILURE); + } + + signal (SIGPIPE, SIG_IGN); + + sa.sa_handler = sigterm_handler; + if (sigaction (SIGINT, &sa, NULL) == -1 + || sigaction (SIGTERM, &sa, NULL) == -1) + print_error ("sigaction: %s", strerror (errno)); +} + +static void +translate_signal_info (int no, const char **name, int code, const char **reason) +{ + if (code == SI_USER) *reason = "signal sent by kill()"; + if (code == SI_QUEUE) *reason = "signal sent by sigqueue()"; + + switch (no) + { + case SIGILL: + *name = "SIGILL"; + if (code == ILL_ILLOPC) *reason = "illegal opcode"; + if (code == ILL_ILLOPN) *reason = "illegal operand"; + if (code == ILL_ILLADR) *reason = "illegal addressing mode"; + if (code == ILL_ILLTRP) *reason = "illegal trap"; + if (code == ILL_PRVOPC) *reason = "privileged opcode"; + if (code == ILL_PRVREG) *reason = "privileged register"; + if (code == ILL_COPROC) *reason = "coprocessor error"; + if (code == ILL_BADSTK) *reason = "internal stack error"; + break; + case SIGFPE: + *name = "SIGFPE"; + if (code == FPE_INTDIV) *reason = "integer divide by zero"; + if (code == FPE_INTOVF) *reason = "integer overflow"; + if (code == FPE_FLTDIV) *reason = "floating-point divide by zero"; + if (code == FPE_FLTOVF) *reason = "floating-point overflow"; + if (code == FPE_FLTUND) *reason = "floating-point underflow"; + if (code == FPE_FLTRES) *reason = "floating-point inexact result"; + if (code == FPE_FLTINV) *reason = "invalid floating-point operation"; + if (code == FPE_FLTSUB) *reason = "subscript out of range"; + break; + case SIGSEGV: + *name = "SIGSEGV"; + if (code == SEGV_MAPERR) + *reason = "address not mapped to object"; + if (code == SEGV_ACCERR) + *reason = "invalid permissions for mapped object"; + break; + case SIGBUS: + *name = "SIGBUS"; + if (code == BUS_ADRALN) *reason = "invalid address alignment"; + if (code == BUS_ADRERR) *reason = "nonexistent physical address"; + if (code == BUS_OBJERR) *reason = "object-specific hardware error"; + break; + default: + *name = NULL; + } +} + +static void +recovery_handler (int signum, siginfo_t *info, void *context) +{ + (void) context; + + // TODO: maybe try to force a core dump like this: if (fork() == 0) return; + // TODO: maybe we could even send "\r\nQUIT :reason\r\n" to the server. >_> + // As long as we're not connected via TLS, that is. + + const char *signal_name = NULL, *reason = NULL; + translate_signal_info (signum, &signal_name, info->si_code, &reason); + + char buf[128], numbuf[8]; + if (!signal_name) + { + snprintf (numbuf, sizeof numbuf, "%d", signum); + signal_name = numbuf; + } + + if (reason) + snprintf (buf, sizeof buf, "%s=%s: %s: %s", g_startup_reason_str, + "signal received", signal_name, reason); + else + snprintf (buf, sizeof buf, "%s=%s: %s", g_startup_reason_str, + "signal received", signal_name); + *g_startup_reason_location = buf; + + // TODO: maybe pregenerate the path, see the following for some other ways + // that would be illegal to do from within a signal handler: + // http://stackoverflow.com/a/1024937 + // http://stackoverflow.com/q/799679 + // Especially if we change the current working directory in the program. + // + // Note that I can just overwrite g_orig_argv[0]. + + // NOTE: our children will read EOF on the read ends of their pipes as a + // a result of O_CLOEXEC. That should be enough to make them terminate. + + char **argv = g_original_argv.vector, **argp = g_recovery_env.vector; + execve ("/proc/self/exe", argv, argp); // Linux + execve ("/proc/curproc/file", argv, argp); // BSD + execve ("/proc/curproc/exe", argv, argp); // BSD + execve ("/proc/self/path/a.out", argv, argp); // Solaris + execve (argv[0], argv, argp); // unreliable fallback + + // Let's just crash + perror ("execve"); + signal (signum, SIG_DFL); + raise (signum); +} + +static void +prepare_recovery_environment (void) +{ + str_vector_init (&g_recovery_env); + str_vector_add_vector (&g_recovery_env, environ); + + // Prepare a location within the environment where we will put the startup + // (or maybe rather restart) reason in case of an irrecoverable error. + char **iter; + for (iter = g_recovery_env.vector; *iter; iter++) + { + const size_t len = sizeof g_startup_reason_str - 1; + if (!strncmp (*iter, g_startup_reason_str, len) && (*iter)[len] == '=') + break; + } + + if (iter) + g_startup_reason_location = iter; + else + { + str_vector_add (&g_recovery_env, ""); + g_startup_reason_location = + g_recovery_env.vector + g_recovery_env.len - 1; + } +} + +static void +setup_recovery_handler (struct bot_context *ctx) +{ + const char *recover_str = str_map_find (&ctx->config, "recover"); + hard_assert (recover_str != NULL); // We have a default value for this + + bool recover; + if (!set_boolean_if_valid (&recover, recover_str)) + { + print_fatal ("invalid configuration value for `%s'", "recover"); + exit (EXIT_FAILURE); + } + if (!recover) + return; + + // Make sure these signals aren't blocked, otherwise we would be unable + // to handle them, making the critical conditions fatal. + sigset_t mask; + sigemptyset (&mask); + sigaddset (&mask, SIGSEGV); + sigaddset (&mask, SIGBUS); + sigaddset (&mask, SIGFPE); + sigaddset (&mask, SIGILL); + sigprocmask (SIG_UNBLOCK, &mask, NULL); + + struct sigaction sa; + sa.sa_flags = SA_SIGINFO; + sa.sa_sigaction = recovery_handler; + sigemptyset (&sa.sa_mask); + + prepare_recovery_environment (); + + // TODO: also handle SIGABRT... or avoid doing abort() in the first place? + if (sigaction (SIGSEGV, &sa, NULL) == -1 + || sigaction (SIGBUS, &sa, NULL) == -1 + || sigaction (SIGFPE, &sa, NULL) == -1 + || sigaction (SIGILL, &sa, NULL) == -1) + print_error ("sigaction: %s", strerror (errno)); +} + +// --- Plugins ----------------------------------------------------------------- + +/// The name of the special IRC command for interprocess communication +static const char *plugin_ipc_command = "ZYKLONB"; + +static size_t plugin_error_domain_tag; +#define PLUGIN_ERROR (error_resolve_domain (&plugin_error_domain_tag)) + +enum +{ + PLUGIN_ERROR_ALREADY_LOADED, + PLUGIN_ERROR_NOT_LOADED, + PLUGIN_ERROR_LOADING_FAILED +}; + +static struct plugin_data * +plugin_find_by_pid (struct bot_context *ctx, pid_t pid) +{ + struct plugin_data *iter; + for (iter = ctx->plugins; iter; iter = iter->next) + if (iter->pid == pid) + return iter; + return NULL; +} + +static bool +plugin_zombify (struct plugin_data *plugin) +{ + if (plugin->is_zombie) + return false; + + // FIXME: make sure that we don't remove entries from the poller while we + // still may have stuff to read; maybe just check that the read pipe is + // empty before closing it... and then on EOF check if `pid == -1' and + // only then dispose of it (it'd be best to simulate that both of these + // cases may happen). + ssize_t poller_idx = + poller_find_by_fd (&plugin->ctx->poller, plugin->write_fd); + if (poller_idx != -1) + poller_remove_at_index (&plugin->ctx->poller, poller_idx); + + // TODO: try to flush the write buffer (non-blocking)? + + // The plugin should terminate itself after it receives EOF. + xclose (plugin->write_fd); + plugin->write_fd = -1; + + // Make it a pseudo-anonymous zombie. In this state we process any + // remaining commands it attempts to send to us before it finally dies. + str_map_set (&plugin->ctx->plugins_by_name, plugin->name, NULL); + plugin->is_zombie = true; + return true; +} + +static void +on_plugin_writable (const struct pollfd *fd, struct plugin_data *plugin) +{ + struct bot_context *ctx = plugin->ctx; + struct str *buf = &plugin->write_buffer; + size_t written_total = 0; + + // TODO: see "Advanced Programming in the UNIX Environment" Figure C.19; + // check for any unexpected behaviour that might occur + if (fd->revents != POLLOUT) + print_debug ("poller fd %d: revents: %d", fd->fd, fd->revents); + + while (written_total != buf->len) + { + ssize_t n_written = write (fd->fd, buf->str + written_total, + buf->len - written_total); + + if (n_written < 0) + { + if (errno == EAGAIN) + break; + + if (!soft_assert (errno == EINTR) && !plugin->is_zombie) + { + print_debug ("%s: %s", "recv", strerror (errno)); + print_error ("failure on writing to plugin `%s'," + " therefore I'm unloading it", plugin->name); + plugin_zombify (plugin); + break; + } + } + + // This may be equivalent to EAGAIN on some implementations + if (n_written == 0) + break; + + written_total += n_written; + } + + if (written_total != 0) + str_remove_slice (buf, 0, written_total); + + if (buf->len == 0) + { + // Everything has been written, there's no need to end up in here again + ssize_t index = poller_find_by_fd (&ctx->poller, fd->fd); + if (index != -1) + poller_remove_at_index (&ctx->poller, index); + } +} + +static void +plugin_queue_write (struct plugin_data *plugin) +{ + if (plugin->is_zombie) + return; + + // Don't let the write buffer grow infinitely. If there's a ton of data + // waiting to be processed by the plugin, it usually means there's something + // wrong with it (such as someone stopping the process). + if (plugin->write_buffer.len >= (1 << 20)) + { + print_warning ("plugin `%s' does not seem to process messages fast" + " enough, I'm unloading it", plugin->name); + plugin_zombify (plugin); + return; + } + + poller_set (&plugin->ctx->poller, plugin->write_fd, POLLOUT, + (poller_dispatcher_func) on_plugin_writable, plugin); +} + +static void +plugin_send (struct plugin_data *plugin, const char *format, ...) + ATTRIBUTE_PRINTF (2, 3); + +static void +plugin_send (struct plugin_data *plugin, const char *format, ...) +{ + va_list ap; + + if (g_debug_mode) + { + fprintf (stderr, "[%s] <-- \"", plugin->name); + va_start (ap, format); + vfprintf (stderr, format, ap); + va_end (ap); + fputs ("\"\n", stderr); + } + + va_start (ap, format); + str_append_vprintf (&plugin->write_buffer, format, ap); + va_end (ap); + str_append (&plugin->write_buffer, "\r\n"); + + plugin_queue_write (plugin); +} + +static void +plugin_process_message (const struct irc_message *msg, + const char *raw, void *user_data) +{ + struct plugin_data *plugin = user_data; + struct bot_context *ctx = plugin->ctx; + + if (g_debug_mode) + fprintf (stderr, "[%s] --> \"%s\"\n", plugin->name, raw); + + if (!strcasecmp (msg->command, plugin_ipc_command)) + { + // Replies are sent in the order in which they came in, so there's + // no need to attach a special identifier to them. It might be + // desirable in some cases, though. + + if (msg->params.len < 1) + return; + + const char *command = msg->params.vector[0]; + if (!strcasecmp (command, "register")) + { + // Register for relaying of IRC traffic + plugin->initialized = true; + + // Flush any queued up traffic here. The point of queuing it in + // the first place is so that we don't have to wait for plugin + // initialization during startup. + // + // Note that if we start filtering data coming to the plugins e.g. + // based on what it tells us upon registration, we might need to + // filter `queued_output' as well. + str_append_str (&plugin->write_buffer, &plugin->queued_output); + str_free (&plugin->queued_output); + + // NOTE: this may trigger the buffer length check + plugin_queue_write (plugin); + } + else if (!strcasecmp (command, "get_config")) + { + if (msg->params.len < 2) + return; + + const char *value = + str_map_find (&ctx->config, msg->params.vector[1]); + // TODO: escape the value (although there's no need to ATM) + plugin_send (plugin, "%s :%s", + plugin_ipc_command, value ? value : ""); + } + else if (!strcasecmp (command, "print")) + { + if (msg->params.len < 2) + return; + + printf ("%s", msg->params.vector[1]); + } + } + else if (plugin->initialized) + { + // Pass everything else through to the IRC server + irc_send (ctx, "%s", raw); + } +} + +static void +on_plugin_readable (const struct pollfd *fd, struct plugin_data *plugin) +{ + // TODO: see "Advanced Programming in the UNIX Environment" Figure C.19; + // check for any unexpected behaviour that might occur + if (fd->revents != POLLIN) + print_debug ("poller fd %d: revents: %d", fd->fd, fd->revents); + + // TODO: see if I can reuse irc_fill_read_buffer() + struct str *buf = &plugin->read_buffer; + while (true) + { + str_ensure_space (buf, 512 + 1); + ssize_t n_read = read (fd->fd, buf->str + buf->len, + buf->alloc - buf->len - 1); + + if (n_read < 0) + { + if (errno == EAGAIN) + break; + if (soft_assert (errno == EINTR)) + continue; + + if (!plugin->is_zombie) + { + print_error ("failure on reading from plugin `%s'," + " therefore I'm unloading it", plugin->name); + plugin_zombify (plugin); + } + return; + } + + // EOF; hopefully it will die soon (maybe it already has) + if (n_read == 0) + break; + + buf->str[buf->len += n_read] = '\0'; + if (buf->len >= (1 << 20)) + { + // XXX: this isn't really the best flood prevention mechanism, + // but it wasn't even supposed to be one. + if (plugin->is_zombie) + { + print_error ("a zombie of plugin `%s' is trying to flood us," + " therefore I'm killing it", plugin->name); + kill (plugin->pid, SIGKILL); + } + else + { + print_error ("plugin `%s' seems to spew out data frantically," + " therefore I'm unloading it", plugin->name); + plugin_zombify (plugin); + } + return; + } + } + + // Hold it in the buffer while we're disconnected + struct bot_context *ctx = plugin->ctx; + if (ctx->irc_fd != -1 && ctx->irc_ready) + irc_process_buffer (buf, plugin_process_message, plugin); +} + +static bool +is_valid_plugin_name (const char *name) +{ + if (!*name) + return false; + for (const char *p = name; *p; p++) + if (!isgraph (*p) || *p == '/') + return false; + return true; +} + +static bool +plugin_load (struct bot_context *ctx, const char *name, struct error **e) +{ + const char *plugin_dir = str_map_find (&ctx->config, "plugin_dir"); + if (!plugin_dir) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_LOADING_FAILED, + "plugin directory not set"); + return false; + } + + if (!is_valid_plugin_name (name)) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_LOADING_FAILED, + "invalid plugin name"); + return false; + } + + if (str_map_find (&ctx->plugins_by_name, name)) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_ALREADY_LOADED, + "the plugin has already been loaded"); + return false; + } + + int stdin_pipe[2]; + if (pipe (stdin_pipe) == -1) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_LOADING_FAILED, "%s: %s: %s", + "failed to load the plugin", "pipe", strerror (errno)); + goto fail_1; + } + + int stdout_pipe[2]; + if (pipe (stdout_pipe) == -1) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_LOADING_FAILED, "%s: %s: %s", + "failed to load the plugin", "pipe", strerror (errno)); + goto fail_2; + } + + set_cloexec (stdin_pipe[1]); + set_cloexec (stdout_pipe[0]); + + pid_t pid = fork (); + if (pid == -1) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_LOADING_FAILED, "%s: %s: %s", + "failed to load the plugin", "fork", strerror (errno)); + goto fail_3; + } + + if (pid == 0) + { + // Redirect the child's stdin and stdout to the pipes + hard_assert (dup2 (stdin_pipe[0], STDIN_FILENO) != -1); + hard_assert (dup2 (stdout_pipe[1], STDOUT_FILENO) != -1); + + xclose (stdin_pipe[0]); + xclose (stdout_pipe[1]); + + struct str pathname; + str_init (&pathname); + str_append (&pathname, plugin_dir); + str_append_c (&pathname, '/'); + str_append (&pathname, name); + + // Restore some of the signal handling + signal (SIGPIPE, SIG_DFL); + + char *const argv[] = { pathname.str, NULL }; + execve (argv[0], argv, environ); + + // We will collect the failure later via SIGCHLD + print_fatal ("%s: %s: %s", + "failed to load the plugin", "exec", strerror (errno)); + _exit (EXIT_FAILURE); + } + + xclose (stdin_pipe[0]); + xclose (stdout_pipe[1]); + + set_blocking (stdout_pipe[0], false); + set_blocking (stdin_pipe[1], false); + + struct plugin_data *plugin = xmalloc (sizeof *plugin); + plugin_data_init (plugin); + plugin->ctx = ctx; + plugin->pid = pid; + plugin->name = xstrdup (name); + plugin->read_fd = stdout_pipe[0]; + plugin->write_fd = stdin_pipe[1]; + + LIST_PREPEND (ctx->plugins, plugin); + str_map_set (&ctx->plugins_by_name, name, plugin); + + poller_set (&ctx->poller, stdout_pipe[0], POLLIN, + (poller_dispatcher_func) on_plugin_readable, plugin); + return true; + +fail_3: + xclose (stdout_pipe[0]); + xclose (stdout_pipe[1]); +fail_2: + xclose (stdin_pipe[0]); + xclose (stdin_pipe[1]); +fail_1: + return false; +} + +static bool +plugin_unload (struct bot_context *ctx, const char *name, struct error **e) +{ + struct plugin_data *plugin = str_map_find (&ctx->plugins_by_name, name); + + if (!plugin) + { + error_set (e, PLUGIN_ERROR, PLUGIN_ERROR_NOT_LOADED, + "no such plugin is loaded"); + return false; + } + + plugin_zombify (plugin); + + // TODO: add a `kill zombies' command to forcefully get rid of processes + // that do not understand the request. + // TODO: set a timeout before we go for a kill automatically (and if this + // was a reload request, try to bring the plugin back up) + return true; +} + +static void +plugin_load_all_from_config (struct bot_context *ctx) +{ + const char *plugin_list = str_map_find (&ctx->config, "plugins"); + if (!plugin_list) + return; + + struct str_vector plugins; + str_vector_init (&plugins); + + split_str_ignore_empty (plugin_list, ',', &plugins); + for (size_t i = 0; i < plugins.len; i++) + { + char *name = strip_str_in_place (plugins.vector[i], " "); + + struct error *e = NULL; + if (!plugin_load (ctx, name, &e)) + { + print_error ("plugin `%s' failed to load: %s", name, e->message); + error_free (e); + } + } + + str_vector_free (&plugins); +} + +// --- Main program ------------------------------------------------------------ + +static bool +parse_bot_command (const char *s, const char *command, const char **following) +{ + size_t command_len = strlen (command); + if (strncasecmp (s, command, command_len)) + return false; + s += command_len; + + // Expect a word boundary, so that we don't respond to invalid things + if (isalnum (*s)) + return false; + + // Ignore any initial spaces; the rest is the command's argument + while (isblank (*s)) + s++; + *following = s; + return true; +} + +static void +split_bot_command_argument_list (const char *arguments, struct str_vector *out) +{ + split_str_ignore_empty (arguments, ',', out); + for (size_t i = 0; i < out->len; ) + { + if (!*strip_str_in_place (out->vector[i], " \t")) + str_vector_remove (out, i); + else + i++; + } +} + +static bool +is_private_message (const struct irc_message *msg) +{ + hard_assert (msg->params.len); + return !strchr ("#&+!", *msg->params.vector[0]); +} + +static bool +is_sent_by_admin (struct bot_context *ctx, const struct irc_message *msg) +{ + const char *admin = str_map_find (&ctx->config, "admin"); + + // No administrator set -> everyone is an administrator + if (!admin) + return true; + + // TODO: precompile the regex + struct error *e = NULL; + if (regex_match (admin, msg->prefix, NULL)) + return true; + + if (e) + { + print_error ("%s: %s", "invalid admin mask", e->message); + error_free (e); + return true; + } + + return false; +} + +static void respond_to_user (struct bot_context *ctx, const struct + irc_message *msg, const char *format, ...) ATTRIBUTE_PRINTF (3, 4); + +static void +respond_to_user (struct bot_context *ctx, const struct irc_message *msg, + const char *format, ...) +{ + if (!soft_assert (msg->prefix && msg->params.len)) + return; + + char nick[strcspn (msg->prefix, "!") + 1]; + strncpy (nick, msg->prefix, sizeof nick - 1); + nick[sizeof nick - 1] = '\0'; + + struct str text; + va_list ap; + + str_init (&text); + va_start (ap, format); + str_append_vprintf (&text, format, ap); + va_end (ap); + + if (is_private_message (msg)) + irc_send (ctx, "PRIVMSG %s :%s", nick, text.str); + else + irc_send (ctx, "PRIVMSG %s :%s: %s", + msg->params.vector[0], nick, text.str); + + str_free (&text); +} + +static void +process_plugin_load (struct bot_context *ctx, + const struct irc_message *msg, const char *name) +{ + struct error *e = NULL; + if (plugin_load (ctx, name, &e)) + respond_to_user (ctx, msg, "plugin `%s' queued for loading", name); + else + { + respond_to_user (ctx, msg, "plugin `%s' could not be loaded: %s", + name, e->message); + error_free (e); + } +} + +static void +process_plugin_unload (struct bot_context *ctx, + const struct irc_message *msg, const char *name) +{ + struct error *e = NULL; + if (plugin_unload (ctx, name, &e)) + respond_to_user (ctx, msg, "plugin `%s' unloaded", name); + else + { + respond_to_user (ctx, msg, "plugin `%s' could not be unloaded: %s", + name, e->message); + error_free (e); + } +} + +static void +process_plugin_reload (struct bot_context *ctx, + const struct irc_message *msg, const char *name) +{ + // So far the only error that can occur is that the plugin hasn't been + // loaded, which in this case doesn't really matter. + plugin_unload (ctx, name, NULL); + + process_plugin_load (ctx, msg, name); +} + +static void +process_privmsg (struct bot_context *ctx, const struct irc_message *msg) +{ + if (!is_sent_by_admin (ctx, msg)) + return; + if (msg->params.len < 2) + return; + + const char *prefix = str_map_find (&ctx->config, "prefix"); + hard_assert (prefix != NULL); // We have a default value for this + + // For us to recognize the command, it has to start with the prefix, + // with the exception of PM's sent directly to us. + const char *text = msg->params.vector[1]; + if (!strncmp (text, prefix, strlen (prefix))) + text += strlen (prefix); + else if (!is_private_message (msg)) + return; + + const char *following; + struct str_vector list; + str_vector_init (&list); + + if (parse_bot_command (text, "quote", &following)) + // This seems to replace tons of random stupid commands + irc_send (ctx, "%s", following); + else if (parse_bot_command (text, "quit", &following)) + { + // We actually need this command (instead of just `quote') because we + // could try to reconnect to the server automatically otherwise. + if (*following) + irc_send (ctx, "QUIT :%s", following); + else + irc_send (ctx, "QUIT"); + initiate_quit (ctx); + } + else if (parse_bot_command (text, "status", &following)) + { + struct str report; + str_init (&report); + + const char *reason = getenv (g_startup_reason_str); + if (!reason) + reason = "launched normally"; + str_append_printf (&report, "\x02startup reason:\x0f %s, ", reason); + + str_append (&report, "\x02plugins:\x0f "); + size_t zombies = 0; + for (struct plugin_data *plugin = ctx->plugins; + plugin; plugin = plugin->next) + { + if (plugin->is_zombie) + zombies++; + else + str_append_printf (&report, "%s, ", plugin->name); + } + if (!ctx->plugins) + str_append (&report, "\x02none\x0f, "); + str_append_printf (&report, "\x02zombies:\x0f %zu", zombies); + + respond_to_user (ctx, msg, "%s", report.str); + str_free (&report); + } + else if (parse_bot_command (text, "load", &following)) + { + split_bot_command_argument_list (following, &list); + for (size_t i = 0; i < list.len; i++) + process_plugin_load (ctx, msg, list.vector[i]); + } + else if (parse_bot_command (text, "reload", &following)) + { + split_bot_command_argument_list (following, &list); + for (size_t i = 0; i < list.len; i++) + process_plugin_reload (ctx, msg, list.vector[i]); + } + else if (parse_bot_command (text, "unload", &following)) + { + split_bot_command_argument_list (following, &list); + for (size_t i = 0; i < list.len; i++) + process_plugin_unload (ctx, msg, list.vector[i]); + } + + str_vector_free (&list); +} + +static void +irc_process_message (const struct irc_message *msg, + const char *raw, void *user_data) +{ + struct bot_context *ctx = user_data; + if (g_debug_mode) + fprintf (stderr, "[%s] ==> \"%s\"\n", "IRC", raw); + + // This should be as minimal as possible, I don't want to have the whole bot + // written in C, especially when I have this overengineered plugin system. + // Therefore the very basic functionality only. + // + // I should probably even rip out the autojoin... + + // First forward the message to all the plugins + for (struct plugin_data *plugin = ctx->plugins; + plugin; plugin = plugin->next) + { + if (plugin->is_zombie) + continue; + + if (plugin->initialized) + plugin_send (plugin, "%s", raw); + else + // TODO: make sure that this buffer doesn't get too large either + str_append_printf (&plugin->queued_output, "%s\r\n", raw); + } + + if (!strcasecmp (msg->command, "PING")) + { + if (msg->params.len) + irc_send (ctx, "PONG :%s", msg->params.vector[0]); + else + irc_send (ctx, "PONG"); + } + else if (!ctx->irc_ready && (!strcasecmp (msg->command, "MODE") + || !strcasecmp (msg->command, "376") // RPL_ENDOFMOTD + || !strcasecmp (msg->command, "422"))) // ERR_NOMOTD + { + print_status ("successfully connected"); + ctx->irc_ready = true; + + const char *autojoin = str_map_find (&ctx->config, "autojoin"); + if (autojoin) + irc_send (ctx, "JOIN :%s", autojoin); + } + else if (!strcasecmp (msg->command, "PRIVMSG")) + process_privmsg (ctx, msg); +} + +enum irc_read_result +{ + IRC_READ_OK, ///< Some data were read successfully + IRC_READ_EOF, ///< The server has closed connection + IRC_READ_AGAIN, ///< No more data at the moment + IRC_READ_ERROR ///< General connection failure +}; + +static enum irc_read_result +irc_fill_read_buffer_ssl (struct bot_context *ctx, struct str *buf) +{ + int n_read; +start: + n_read = SSL_read (ctx->ssl, buf->str + buf->len, + buf->alloc - buf->len - 1 /* null byte */); + + const char *error_info = NULL; + switch (SSL_get_error (ctx->ssl, n_read)) + { + case SSL_ERROR_NONE: + buf->str[buf->len += n_read] = '\0'; + return IRC_READ_AGAIN; + case SSL_ERROR_ZERO_RETURN: + return IRC_READ_EOF; + case SSL_ERROR_WANT_READ: + return IRC_READ_AGAIN; + case SSL_ERROR_WANT_WRITE: + { + // Let it finish the handshake as we don't poll for writability; + // any errors are to be collected by SSL_read() in the next iteration + struct pollfd pfd = { .fd = ctx->irc_fd, .events = POLLOUT }; + soft_assert (poll (&pfd, 1, 0) > 0); + goto start; + } + case SSL_ERROR_SYSCALL: + { + int err; + if ((err = ERR_get_error ())) + error_info = ERR_error_string (err, NULL); + else if (n_read == 0) + return IRC_READ_EOF; + else + { + if (errno == EINTR) + goto start; + error_info = strerror (errno); + } + break; + } + case SSL_ERROR_SSL: + default: + error_info = ERR_error_string (ERR_get_error (), NULL); + } + + print_debug ("%s: %s: %s", __func__, "SSL_read", error_info); + return IRC_READ_ERROR; +} + +static enum irc_read_result +irc_fill_read_buffer (struct bot_context *ctx, struct str *buf) +{ + ssize_t n_read; +start: + n_read = recv (ctx->irc_fd, buf->str + buf->len, + buf->alloc - buf->len - 1 /* null byte */, 0); + + if (n_read > 0) + { + buf->str[buf->len += n_read] = '\0'; + return IRC_READ_OK; + } + if (n_read == 0) + return IRC_READ_EOF; + + if (errno == EAGAIN) + return IRC_READ_AGAIN; + if (errno == EINTR) + goto start; + + print_debug ("%s: %s: %s", __func__, "recv", strerror (errno)); + return IRC_READ_ERROR; +} + +static bool irc_connect (struct bot_context *ctx, struct error **e); + +static void +irc_try_reconnect (struct bot_context *ctx) +{ + if (!soft_assert (ctx->irc_fd == -1)) + return; + + const char *reconnect_str = str_map_find (&ctx->config, "reconnect"); + hard_assert (reconnect_str != NULL); // We have a default value for this + + bool reconnect; + if (!set_boolean_if_valid (&reconnect, reconnect_str)) + { + print_fatal ("invalid configuration value for `%s'", "recover"); + try_finish_quit (ctx); + return; + } + if (!reconnect) + return; + + const char *delay_str = str_map_find (&ctx->config, "reconnect_delay"); + hard_assert (delay_str != NULL); // We have a default value for this + + char *end_ptr; + errno = 0; + long delay = strtol (delay_str, &end_ptr, 10); + if (errno != 0 || end_ptr == delay_str || *end_ptr) + { + print_error ("invalid configuration value for `%s'", + "reconnect_delay"); + delay = 0; + } + + while (true) + { + // TODO: this would be better suited by a timeout event; + // remember to update try_finish_quit() etc. to reflect this + print_status ("trying to reconnect in %ld seconds...", delay); + sleep (delay); + + struct error *e = NULL; + if (irc_connect (ctx, &e)) + break; + + print_error ("%s", e->message); + error_free (e); + } + + // TODO: inform plugins about the new connection +} + +static void +on_irc_disconnected (struct bot_context *ctx) +{ + // Get rid of the dead socket etc. + if (ctx->ssl) + { + SSL_free (ctx->ssl); + ctx->ssl = NULL; + SSL_CTX_free (ctx->ssl_ctx); + ctx->ssl_ctx = NULL; + } + + ssize_t i = poller_find_by_fd (&ctx->poller, ctx->irc_fd); + if (i != -1) + poller_remove_at_index (&ctx->poller, i); + + xclose (ctx->irc_fd); + ctx->irc_fd = -1; + + // TODO: inform plugins about the disconnect event + + if (ctx->quitting) + { + // Unload all plugins + // TODO: wait for a few seconds and then send SIGKILL to all plugins + for (struct plugin_data *plugin = ctx->plugins; + plugin; plugin = plugin->next) + plugin_zombify (plugin); + + try_finish_quit (ctx); + return; + } + + irc_try_reconnect (ctx); +} + +static void +on_irc_readable (const struct pollfd *fd, struct bot_context *ctx) +{ + if (fd->revents != POLLIN) + print_debug ("poller fd %d: revents: %d", fd->fd, fd->revents); + + (void) set_blocking (ctx->irc_fd, false); + + struct str *buf = &ctx->read_buffer; + enum irc_read_result (*fill_buffer)(struct bot_context *, struct str *) + = ctx->ssl + ? irc_fill_read_buffer_ssl + : irc_fill_read_buffer; + bool disconnected = false; + while (true) + { + str_ensure_space (buf, 512); + switch (fill_buffer (ctx, buf)) + { + case IRC_READ_AGAIN: + goto end; + case IRC_READ_ERROR: + print_error ("reading from the IRC server failed"); + disconnected = true; + goto end; + case IRC_READ_EOF: + print_status ("the IRC server closed the connection"); + disconnected = true; + goto end; + case IRC_READ_OK: + break; + } + + if (buf->len >= (1 << 20)) + { + print_fatal ("the IRC server seems to spew out data frantically"); + irc_shutdown (ctx); + goto end; + } + } +end: + (void) set_blocking (ctx->irc_fd, true); + irc_process_buffer (buf, irc_process_message, ctx); + + if (disconnected) + on_irc_disconnected (ctx); +} + +static bool +irc_connect (struct bot_context *ctx, struct error **e) +{ + const char *irc_host = str_map_find (&ctx->config, "irc_host"); + const char *irc_port = str_map_find (&ctx->config, "irc_port"); + const char *ssl_use_str = str_map_find (&ctx->config, "ssl_use"); + + const char *nickname = str_map_find (&ctx->config, "nickname"); + const char *username = str_map_find (&ctx->config, "username"); + const char *fullname = str_map_find (&ctx->config, "fullname"); + + // We have a default value for these + hard_assert (irc_port && ssl_use_str); + hard_assert (nickname && username && fullname); + + // TODO: again, get rid of `struct error' in here. The question is: how + // do we tell our caller that he should not try to reconnect? + if (!irc_host) + { + error_set (e, CONNECT_ERROR, CONNECT_ERROR_INVALID_CONFIGURATION, + "no hostname specified in configuration"); + return false; + } + + bool use_ssl; + if (!set_boolean_if_valid (&use_ssl, ssl_use_str)) + { + error_set (e, CONNECT_ERROR, CONNECT_ERROR_INVALID_CONFIGURATION, + "invalid configuration value for `%s'", "use_ssl"); + return false; + } + + if (!irc_establish_connection (ctx, irc_host, irc_port, use_ssl, e)) + return false; + + // TODO: set a timeout on the socket, something like 30 minutes, then we + // should ideally send a PING... or just forcefully reconnect. + // + // TODO: in exec try: 1/ set blocking, 2/ setsockopt() SO_LINGER, + // (struct linger) { .l_onoff = true; .l_linger = 1 /* 1s should do */; } + // 3/ /* O_CLOEXEC */ But only if the QUIT message proves unreliable. + poller_set (&ctx->poller, ctx->irc_fd, POLLIN, + (poller_dispatcher_func) on_irc_readable, ctx); + + // TODO: probably check for errors from these calls as well + irc_send (ctx, "NICK %s", nickname); + irc_send (ctx, "USER %s 8 * :%s", username, fullname); + return true; +} + +static void +on_signal_pipe_readable (const struct pollfd *fd, struct bot_context *ctx) +{ + char *dummy; + (void) read (fd->fd, &dummy, 1); + + // XXX: do we need to check if we have a connection? + if (g_termination_requested && !ctx->quitting) + { + irc_send (ctx, "QUIT :Terminated by signal"); + initiate_quit (ctx); + } + + // Reap all dead children (since the pipe may overflow, we ask waitpid() + // to return all the zombies it knows about). + while (true) + { + int status; + pid_t zombie = waitpid (-1, &status, WNOHANG); + + if (zombie == -1) + { + // No children to wait on + if (errno == ECHILD) + break; + + hard_assert (errno == EINTR); + continue; + } + + if (zombie == 0) + break; + + struct plugin_data *plugin = plugin_find_by_pid (ctx, zombie); + // Something has died but we don't recognize it (re-exec?) + if (!soft_assert (plugin != NULL)) + continue; + + // TODO: callbacks on children death, so that we may tell the user + // "plugin `name' died like a dirty jewish pig"; use `status' + if (!plugin->is_zombie && WIFSIGNALED (status)) + { + char *notes = ""; +#ifdef WCOREDUMP + if (WCOREDUMP (status)) + notes = " (core dumped)"; +#endif + print_warning ("Plugin `%s' died from signal %d%s", + plugin->name, WTERMSIG (status), notes); + } + + // Let's go through the zombie state to simplify things a bit + // TODO: might not be a completely bad idea to restart the plugin + plugin_zombify (plugin); + + plugin->pid = -1; + + ssize_t poller_idx = poller_find_by_fd (&ctx->poller, plugin->read_fd); + if (poller_idx != -1) + poller_remove_at_index (&ctx->poller, poller_idx); + + xclose (plugin->read_fd); + plugin->read_fd = -1; + + LIST_UNLINK (ctx->plugins, plugin); + plugin_data_free (plugin); + free (plugin); + + // Living child processes block us from quitting + try_finish_quit (ctx); + } +} + +static void +write_default_configuration (const char *filename) +{ + struct str path, base; + int status = EXIT_SUCCESS; + + str_init (&path); + str_init (&base); + + if (filename) + { + char *tmp = xstrdup (filename); + str_append (&path, dirname (tmp)); + strcpy (tmp, filename); + str_append (&base, basename (tmp)); + free (tmp); + } + else + { + get_xdg_home_dir (&path, "XDG_CONFIG_HOME", ".config"); + str_append (&path, "/" PROGRAM_NAME); + str_append (&base, PROGRAM_NAME ".conf"); + } + + struct error *e = NULL; + if (!mkdir_with_parents (path.str, &e)) + { + print_fatal ("%s", e->message); + status = EXIT_FAILURE; + goto out; + } + + str_append_c (&path, '/'); + str_append_str (&path, &base); + + FILE *fp = fopen (path.str, "w"); + if (!fp) + { + print_fatal ("could not open `%s' for writing: %s", + path.str, strerror (errno)); + status = EXIT_FAILURE; + goto out; + } + + errno = 0; + for (size_t i = 0; i < N_ELEMENTS (g_config_table); i++) + { + const struct config_item *item = g_config_table + i; + fprintf (fp, "# %s\n", item->description); + if (item->default_value) + fprintf (fp, "%s=%s\n", item->key, item->default_value); + else + fprintf (fp, "#%s=\n", item->key); + } + fclose (fp); + if (errno) + { + print_fatal ("writing to `%s' failed: %s", path.str, strerror (errno)); + status = EXIT_FAILURE; + goto out; + } + print_status ("configuration written to `%s'", path.str); + +out: + str_free (&path); + str_free (&base); + exit (status); +} + +static void +print_usage (const char *program_name) +{ + fprintf (stderr, + "Usage: %s [OPTION]...\n" + "Experimental IRC bot.\n" + "\n" + " -d, --debug run in debug mode\n" + " -h, --help display this help and exit\n" + " -V, --version output version information and exit\n" + " --write-default-cfg [filename]\n" + " write a default configuration file and exit\n", + program_name); +} + +int +main (int argc, char *argv[]) +{ + const char *invocation_name = argv[0]; + str_vector_init (&g_original_argv); + str_vector_add_vector (&g_original_argv, argv); + + static struct option opts[] = + { + { "debug", no_argument, NULL, 'd' }, + { "help", no_argument, NULL, 'h' }, + { "version", no_argument, NULL, 'V' }, + { "write-default-cfg", optional_argument, NULL, 'w' }, + { NULL, 0, NULL, 0 } + }; + + while (1) + { + int c, opt_index; + + c = getopt_long (argc, argv, "dhV", opts, &opt_index); + if (c == -1) + break; + + switch (c) + { + case 'd': + g_debug_mode = true; + break; + case 'h': + print_usage (invocation_name); + exit (EXIT_SUCCESS); + case 'V': + printf (PROGRAM_NAME " " PROGRAM_VERSION "\n"); + exit (EXIT_SUCCESS); + case 'w': + write_default_configuration (optarg); + abort (); + default: + print_fatal ("error in options"); + exit (EXIT_FAILURE); + } + } + + print_status (PROGRAM_NAME " " PROGRAM_VERSION " starting"); + setup_signal_handlers (); + + SSL_library_init (); + atexit (EVP_cleanup); + SSL_load_error_strings (); + atexit (ERR_free_strings); + + struct bot_context ctx; + bot_context_init (&ctx); + + struct error *e = NULL; + if (!read_config_file (&ctx.config, &e)) + { + print_fatal ("error loading configuration: %s", e->message); + error_free (e); + exit (EXIT_FAILURE); + } + + setup_recovery_handler (&ctx); + poller_set (&ctx.poller, g_signal_pipe[0], POLLIN, + (poller_dispatcher_func) on_signal_pipe_readable, &ctx); + + plugin_load_all_from_config (&ctx); + if (!irc_connect (&ctx, &e)) + { + print_error ("%s", e->message); + error_free (e); + exit (EXIT_FAILURE); + } + + // TODO: clean re-exec support; to save the state I can either use argv, + // argp, or I can create a temporary file, unlink it and use the FD + // (mkstemp() on a `struct str' constructed from XDG_RUNTIME_DIR, TMPDIR + // or /tmp as a last resort + PROGRAM_NAME + ".XXXXXX" -> unlink(); + // remember to use O_CREAT | O_EXCL). The state needs to be versioned. + // Unfortunately I cannot de/serialize SSL state. + + ctx.polling = true; + while (ctx.polling) + poller_run (&ctx.poller); + + bot_context_free (&ctx); + str_vector_free (&g_original_argv); + return EXIT_SUCCESS; +} + |