diff options
| author | Přemysl Janouch <p.janouch@gmail.com> | 2015-07-03 20:32:31 +0200 | 
|---|---|---|
| committer | Přemysl Janouch <p.janouch@gmail.com> | 2015-07-03 22:19:12 +0200 | 
| commit | 2357f1382ad5eaf100d9a982e48052a790435665 (patch) | |
| tree | 3512373c229418e831093aebad89d8fb917579e7 | |
| parent | 15882dcdf9d430cbefcc949bd3972656c58819e5 (diff) | |
| download | xK-2357f1382ad5eaf100d9a982e48052a790435665.tar.gz xK-2357f1382ad5eaf100d9a982e48052a790435665.tar.xz xK-2357f1382ad5eaf100d9a982e48052a790435665.zip | |
degesch: rewrite to use asynchronous I/O
| -rw-r--r-- | degesch.c | 690 | 
1 files changed, 447 insertions, 243 deletions
| @@ -1049,12 +1049,42 @@ buffer_destroy (struct buffer *self)  // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +enum transport_io_result +{ +	TRANSPORT_IO_OK = 0,                ///< Completed successfully +	TRANSPORT_IO_EOF,                   ///< Connection shut down by peer +	TRANSPORT_IO_ERROR                  ///< Connection error +}; + +// The only real purpose of this is to abstract away TLS/SSL +struct transport +{ +	/// Initialize the transport +	bool (*init) (struct server *s, struct error **e); +	/// Destroy the user data pointer +	void (*cleanup) (struct server *s); + +	/// The underlying socket may have become readable, update `read_buffer' +	enum transport_io_result (*on_readable) (struct server *s); +	/// The underlying socket may have become writeable, flush `write_buffer' +	enum transport_io_result (*on_writeable) (struct server *s); +	/// Return event mask to use in the poller +	int (*get_poll_events) (struct server *s); + +	/// Called just before closing the connection from our side +	void (*in_before_shutdown) (struct server *s); +}; + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +  enum server_state  {  	IRC_DISCONNECTED,                   ///< Not connected  	IRC_CONNECTING,                     ///< Connecting to the server  	IRC_CONNECTED,                      ///< Trying to register -	IRC_REGISTERED                      ///< We can chat now +	IRC_REGISTERED,                     ///< We can chat now +	IRC_CLOSING,                        ///< Flushing output before shutdown +	IRC_HALF_CLOSED                     ///< Connection shutdown from our side  };  /// Convert an IRC identifier character to lower-case @@ -1079,10 +1109,11 @@ struct server  	int socket;                         ///< Socket FD of the server  	struct str read_buffer;             ///< Input yet to be processed -	struct poller_fd read_event;        ///< We can read from the socket +	struct str write_buffer;            ///< Outut yet to be be sent out +	struct poller_fd socket_event;      ///< We can read from the socket -	SSL_CTX *ssl_ctx;                   ///< SSL context -	SSL *ssl;                           ///< SSL connection +	struct transport *transport;        ///< Transport method +	void *transport_data;               ///< Transport data  	// Events: @@ -1177,6 +1208,7 @@ server_init (struct server *self, struct poller *poller)  	self->socket = -1;  	str_init (&self->read_buffer); +	str_init (&self->write_buffer);  	self->state = IRC_DISCONNECTED;  	poller_timer_init (&self->timeout_tmr, poller); @@ -1214,17 +1246,19 @@ server_free (struct server *self)  		connector_free (self->connector);  		free (self->connector);  	} + +	if (self->transport +	 && self->transport->cleanup) +		self->transport->cleanup (self); +  	if (self->socket != -1)  	{  		xclose (self->socket); -		poller_fd_reset (&self->read_event); +		self->socket_event.closed = true; +		poller_fd_reset (&self->socket_event);  	}  	str_free (&self->read_buffer); - -	if (self->ssl) -		SSL_free (self->ssl); -	if (self->ssl_ctx) -		SSL_CTX_free (self->ssl_ctx); +	str_free (&self->write_buffer);  	str_map_free (&self->irc_users);  	str_map_free (&self->irc_channels); @@ -3080,10 +3114,6 @@ irc_set_casemapping (struct server *s,  // --- Core functionality ------------------------------------------------------ -// Most of the core IRC code comes from ZyklonB which is mostly blocking. -// While it's fairly easy to follow, it also stinks.  It needs to be rewritten -// to be as asynchronous as possible.  See kike.c for reference. -  static bool  irc_is_connected (struct server *s)  { @@ -3091,6 +3121,16 @@ irc_is_connected (struct server *s)  }  static void +irc_update_poller (struct server *s, const struct pollfd *pfd) +{ +	int new_events = s->transport->get_poll_events (s); +	hard_assert (new_events != 0); + +	if (!pfd || pfd->events != new_events) +		poller_fd_set (&s->socket_event, new_events); +} + +static void  irc_cancel_timers (struct server *s)  {  	poller_timer_reset (&s->timeout_tmr); @@ -3125,115 +3165,6 @@ irc_queue_reconnect (struct server *s)  // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -static bool -irc_initialize_ssl_ctx (struct server *s, struct error **e) -{ -	// XXX: maybe we should call SSL_CTX_set_options() for some workarounds - -	bool verify = get_config_boolean (s->config, "ssl_verify"); -	if (!verify) -		SSL_CTX_set_verify (s->ssl_ctx, SSL_VERIFY_NONE, NULL); - -	const char *ca_file = get_config_string (s->config, "ssl_ca_file"); -	const char *ca_path = get_config_string (s->config, "ssl_ca_path"); - -	struct error *error = NULL; -	if (ca_file || ca_path) -	{ -		if (SSL_CTX_load_verify_locations (s->ssl_ctx, ca_file, ca_path)) -			return true; - -		error_set (&error, "%s: %s", -			"Failed to set locations for the CA certificate bundle", -			ERR_reason_error_string (ERR_get_error ())); -		goto ca_error; -	} - -	if (!SSL_CTX_set_default_verify_paths (s->ssl_ctx)) -	{ -		error_set (&error, "%s: %s", -			"Couldn't load the default CA certificate bundle", -			ERR_reason_error_string (ERR_get_error ())); -		goto ca_error; -	} -	return true; - -ca_error: -	if (verify) -	{ -		error_propagate (e, error); -		return false; -	} - -	// Only inform the user if we're not actually verifying -	log_server_error (s, s->buffer, "#s", error->message); -	error_free (error); -	return true; -} - -static bool -irc_initialize_ssl (struct server *s, struct error **e) -{ -	const char *error_info = NULL; -	s->ssl_ctx = SSL_CTX_new (SSLv23_client_method ()); -	if (!s->ssl_ctx) -		goto error_ssl_1; -	if (!irc_initialize_ssl_ctx (s, e)) -		goto error_ssl_2; - -	s->ssl = SSL_new (s->ssl_ctx); -	if (!s->ssl) -		goto error_ssl_2; - -	const char *ssl_cert = get_config_string (s->config, "ssl_cert"); -	if (ssl_cert) -	{ -		char *path = resolve_config_filename (ssl_cert); -		if (!path) -			log_server_error (s, s->buffer, -				"#s: #s", "Cannot open file", ssl_cert); -		// XXX: perhaps we should read the file ourselves for better messages -		else if (!SSL_use_certificate_file (s->ssl, path, SSL_FILETYPE_PEM) -			|| !SSL_use_PrivateKey_file (s->ssl, path, SSL_FILETYPE_PEM)) -			log_server_error (s, s->buffer, -				"#s: #s", "Setting the SSL client certificate failed", -				ERR_error_string (ERR_get_error (), NULL)); -		free (path); -	} - -	SSL_set_connect_state (s->ssl); -	if (!SSL_set_fd (s->ssl, s->socket)) -		goto error_ssl_3; -	// Avoid SSL_write() returning SSL_ERROR_WANT_READ -	SSL_set_mode (s->ssl, SSL_MODE_AUTO_RETRY); - -	switch (xssl_get_error (s->ssl, SSL_connect (s->ssl), &error_info)) -	{ -	case SSL_ERROR_NONE: -		return true; -	case SSL_ERROR_ZERO_RETURN: -		error_info = "server closed the connection"; -	default: -		break; -	} - -error_ssl_3: -	SSL_free (s->ssl); -	s->ssl = NULL; -error_ssl_2: -	SSL_CTX_free (s->ssl_ctx); -	s->ssl_ctx = NULL; -error_ssl_1: -	// XXX: these error strings are really nasty; also there could be -	//   multiple errors on the OpenSSL stack. -	if (!error_info) -		error_info = ERR_error_string (ERR_get_error (), NULL); -	error_set (e, "%s: %s", "could not initialize SSL", error_info); -	return false; -} - -// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -  // As of 2015, everything should be in UTF-8.  And if it's not, we'll decode it  // as ISO Latin 1.  This function should not be called on the whole message.  static char * @@ -3272,6 +3203,8 @@ irc_send (struct server *s, const char *format, ...)  		print_debug ("tried sending a message to a dead server connection");  		return;  	} +	if (s->state == IRC_CLOSING) +		return;  	va_list ap;  	va_start (ap, format); @@ -3293,33 +3226,45 @@ irc_send (struct server *s, const char *format, ...)  		input_show (&s->ctx->input);  	} -	str_append (&str, "\r\n"); - -	if (s->ssl) -	{ -		// TODO: call SSL_get_error() to detect if a clean shutdown has occured -		if (SSL_write (s->ssl, str.str, str.len) != (int) str.len) -			LOG_FUNC_FAILURE ("SSL_write", -				ERR_error_string (ERR_get_error (), NULL)); -	} -	else if (write (s->socket, str.str, str.len) != (ssize_t) str.len) -		LOG_LIBC_FAILURE ("write"); +	str_append_str (&s->write_buffer, &str);  	str_free (&str); +	str_append (&s->write_buffer, "\r\n"); +	irc_update_poller (s, NULL);  }  // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -  static void +irc_real_shutdown (struct server *s) +{ +	hard_assert (irc_is_connected (s) && s->state != IRC_HALF_CLOSED); + +	if (s->transport +	 && s->transport->in_before_shutdown) +		s->transport->in_before_shutdown (s); + +	while (shutdown (s->socket, SHUT_WR) == -1) +		if (!soft_assert (errno == EINTR)) +			break; + +	s->state = IRC_HALF_CLOSED; +} + +static void  irc_shutdown (struct server *s)  { -	// Generally non-critical -	if (s->ssl) -		soft_assert (SSL_shutdown (s->ssl) != -1); -	else -		soft_assert (shutdown (s->socket, SHUT_WR) == 0); +	if (s->state == IRC_CLOSING +	 || s->state == IRC_HALF_CLOSED) +		return; -	// TODO: set a timer after which we cut the connection +	// TODO: set a timer to cut the connection if we don't receive an EOF +	s->state = IRC_CLOSING; + +	// Either there's still some data in the write buffer and we wait +	// until they're sent, or we send an EOF to the server right away +	if (!s->write_buffer.len) +		irc_real_shutdown (s);  }  static void @@ -3372,7 +3317,6 @@ initiate_quit (struct app_context *ctx)  		if (irc_is_connected (s))  		{ -			// XXX: when we go async, we'll have to flush output buffers first  			irc_shutdown (s);  			s->manual_disconnect = true;  		} @@ -3390,20 +3334,16 @@ on_irc_disconnected (struct server *s)  	hard_assert (irc_is_connected (s));  	// Get rid of the dead socket -	if (s->ssl) -	{ -		SSL_free (s->ssl); -		s->ssl = NULL; -		SSL_CTX_free (s->ssl_ctx); -		s->ssl_ctx = NULL; -	} +	if (s->transport +	 && s->transport->cleanup) +		s->transport->cleanup (s);  	xclose (s->socket);  	s->socket = -1;  	s->state = IRC_DISCONNECTED; -	s->read_event.closed = true; -	poller_fd_reset (&s->read_event); +	s->socket_event.closed = true; +	poller_fd_reset (&s->socket_event);  	// All of our timers have lost their meaning now  	irc_cancel_timers (s); @@ -3474,131 +3414,392 @@ on_irc_timeout (void *user_data)  	irc_send (s, "PING :%" PRIi64, (int64_t) time (NULL));  } -// --- Processing server output ------------------------------------------------ +// --- Server I/O --------------------------------------------------------------  static void irc_process_message  	(const struct irc_message *msg, const char *raw, void *user_data); -enum irc_read_result +static void +on_irc_ready (const struct pollfd *pfd, struct server *s)  { -	IRC_READ_OK,                        ///< Some data were read successfully -	IRC_READ_EOF,                       ///< The server has closed connection -	IRC_READ_AGAIN,                     ///< No more data at the moment -	IRC_READ_ERROR                      ///< General connection failure +	struct transport *transport = s->transport; +	enum transport_io_result result; + +	if ((result = transport->on_readable (s)) == TRANSPORT_IO_ERROR) +		goto error; +	bool read_eof = result == TRANSPORT_IO_EOF; + +	if (s->read_buffer.len >= (1 << 20)) +	{ +		// XXX: this is stupid; if anything, count it in dependence of time +		log_server_error (s, s->buffer, +			"The IRC server seems to spew out data frantically"); +		goto disconnect; +	} +	if (s->read_buffer.len) +		irc_process_buffer (&s->read_buffer, irc_process_message, s); + +	if ((result = transport->on_writeable (s)) == TRANSPORT_IO_ERROR) +		goto error; +	bool write_eof = result == TRANSPORT_IO_EOF; + +	// FIXME: this may probably fire multiple times if we're flushing after it, +	//   we should probably store this information next to the state +	if (read_eof || write_eof) +		log_server_error (s, s->buffer, "The IRC server closed the connection"); + +	// It makes no sense to flush anything if the write needs to read +	// and we receive an EOF -> disconnect right away +	if (write_eof) +		goto disconnect; + +	// If we've been asked to flush the write buffer and our job is complete, +	// we send an EOF to the server, changing the state to IRC_HALF_CLOSED +	if (s->state == IRC_CLOSING && !s->write_buffer.len) +		irc_real_shutdown (s); + +	if (read_eof) +	{ +		// Both ends closed, we're done +		if (s->state == IRC_HALF_CLOSED) +			goto disconnect; + +		// Otherwise we want to flush the write buffer +		irc_shutdown (s); + +		// If that went well, we can disconnect now +		if (s->state == IRC_HALF_CLOSED) +			goto disconnect; +	} + +	// XXX: shouldn't we rather wait for PONG messages? +	irc_reset_connection_timeouts (s); +	irc_update_poller (s, pfd); +	return; + +error: +	log_server_error (s, s->buffer, "Reading from the IRC server failed"); +disconnect: +	on_irc_disconnected (s); +} + +// --- Plain transport --------------------------------------------------------- + +static enum transport_io_result +transport_plain_on_readable (struct server *s) +{ +	struct str *buf = &s->read_buffer; +	ssize_t n_read; + +	while (true) +	{ +		str_ensure_space (buf, 512); +		n_read = recv (s->socket, buf->str + buf->len, +			buf->alloc - buf->len - 1 /* null byte */, 0); + +		if (n_read > 0) +		{ +			buf->str[buf->len += n_read] = '\0'; +			continue; +		} +		if (n_read == 0) +			return TRANSPORT_IO_EOF; + +		if (errno == EAGAIN) +			return TRANSPORT_IO_OK; +		if (errno == EINTR) +			continue; + +		LOG_LIBC_FAILURE ("recv"); +		return TRANSPORT_IO_ERROR; +	} +} + +static enum transport_io_result +transport_plain_on_writeable (struct server *s) +{ +	struct str *buf = &s->write_buffer; +	ssize_t n_written; + +	while (buf->len) +	{ +		n_written = send (s->socket, buf->str, buf->len, 0); +		if (n_written >= 0) +		{ +			str_remove_slice (buf, 0, n_written); +			continue; +		} + +		if (errno == EAGAIN) +			return TRANSPORT_IO_OK; +		if (errno == EINTR) +			continue; + +		LOG_LIBC_FAILURE ("send"); +		return TRANSPORT_IO_ERROR; +	} +	return TRANSPORT_IO_OK; +} + +static int +transport_plain_get_poll_events (struct server *s) +{ +	int events = POLLIN; +	if (s->write_buffer.len) +		events |= POLLOUT; +	return events; +} + +static struct transport g_transport_plain = +{ +	.on_readable      = transport_plain_on_readable, +	.on_writeable     = transport_plain_on_writeable, +	.get_poll_events  = transport_plain_get_poll_events,  }; -static enum irc_read_result -irc_fill_read_buffer_ssl (struct server *s, struct str *buf) +// --- SSL/TLS transport ------------------------------------------------------- + +struct transport_tls_data  { -	int n_read; -start: -	n_read = SSL_read (s->ssl, buf->str + buf->len, -		buf->alloc - buf->len - 1 /* null byte */); +	SSL_CTX *ssl_ctx;                   ///< SSL context +	SSL *ssl;                           ///< SSL/TLS connection +	bool ssl_rx_want_tx;                ///< SSL_read() wants to write +	bool ssl_tx_want_rx;                ///< SSL_write() wants to read +}; -	const char *error_info = NULL; -	switch (xssl_get_error (s->ssl, n_read, &error_info)) +static bool +transport_tls_init_ctx (struct server *s, SSL_CTX *ssl_ctx, struct error **e) +{ +	bool verify = get_config_boolean (s->config, "ssl_verify"); +	if (!verify) +		SSL_CTX_set_verify (ssl_ctx, SSL_VERIFY_NONE, NULL); + +	const char *ca_file = get_config_string (s->config, "ssl_ca_file"); +	const char *ca_path = get_config_string (s->config, "ssl_ca_path"); + +	struct error *error = NULL; +	if (ca_file || ca_path)  	{ -	case SSL_ERROR_NONE: -		buf->str[buf->len += n_read] = '\0'; -		return IRC_READ_OK; -	case SSL_ERROR_ZERO_RETURN: -		return IRC_READ_EOF; -	case SSL_ERROR_WANT_READ: -		return IRC_READ_AGAIN; -	case SSL_ERROR_WANT_WRITE: +		if (SSL_CTX_load_verify_locations (ssl_ctx, ca_file, ca_path)) +			return true; + +		error_set (&error, "%s: %s", +			"Failed to set locations for the CA certificate bundle", +			ERR_reason_error_string (ERR_get_error ())); +		goto ca_error; +	} + +	if (!SSL_CTX_set_default_verify_paths (ssl_ctx))  	{ -		// Let it finish the handshake as we don't poll for writability; -		// any errors are to be collected by SSL_read() in the next iteration -		struct pollfd pfd = { .fd = s->socket, .events = POLLOUT }; -		soft_assert (poll (&pfd, 1, 0) > 0); -		goto start; +		error_set (&error, "%s: %s", +			"Couldn't load the default CA certificate bundle", +			ERR_reason_error_string (ERR_get_error ())); +		goto ca_error;  	} -	case XSSL_ERROR_TRY_AGAIN: -		goto start; -	default: -		LOG_FUNC_FAILURE ("SSL_read", error_info); -		return IRC_READ_ERROR; + +	// XXX: maybe we should call SSL_CTX_set_options() for some workarounds +	SSL_CTX_set_mode (ssl_ctx, +		SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); +	return true; + +ca_error: +	if (verify) +	{ +		error_propagate (e, error); +		return false;  	} + +	// Only inform the user if we're not actually verifying +	log_server_error (s, s->buffer, "#s", error->message); +	error_free (error); +	return true;  } -static enum irc_read_result -irc_fill_read_buffer (struct server *s, struct str *buf) +static bool +transport_tls_init_cert (struct server *s, SSL *ssl, struct error **e)  { -	ssize_t n_read; -start: -	n_read = recv (s->socket, buf->str + buf->len, -		buf->alloc - buf->len - 1 /* null byte */, 0); +	const char *ssl_cert = get_config_string (s->config, "ssl_cert"); +	if (!ssl_cert) +		return true; + +	bool result = false; +	char *path = resolve_config_filename (ssl_cert); +	if (!path) +		error_set (e, "%s: %s", "Cannot open file", ssl_cert); +	// XXX: perhaps we should read the file ourselves for better messages +	else if (!SSL_use_certificate_file (ssl, path, SSL_FILETYPE_PEM) +		|| !SSL_use_PrivateKey_file (ssl, path, SSL_FILETYPE_PEM)) +		error_set (e, "%s: %s", "Setting the SSL client certificate failed", +			ERR_error_string (ERR_get_error (), NULL)); +	else +		result = true; +	free (path); +	return result; +} -	if (n_read > 0) +static bool +transport_tls_init (struct server *s, struct error **e) +{ +	const char *error_info = NULL; +	SSL_CTX *ssl_ctx = SSL_CTX_new (SSLv23_client_method ()); +	if (!ssl_ctx) +		goto error_ssl_1; +	if (!transport_tls_init_ctx (s, ssl_ctx, e)) +		goto error_ssl_2; + +	SSL *ssl = SSL_new (ssl_ctx); +	if (!ssl) +		goto error_ssl_2; + +	struct error *error = NULL; +	if (!transport_tls_init_cert (s, ssl, &error))  	{ -		buf->str[buf->len += n_read] = '\0'; -		return IRC_READ_OK; +		// XXX: is this a reason to abort the connection? +		log_server_error (s, s->buffer, "#s", error->message); +		error_free (error);  	} -	if (n_read == 0) -		return IRC_READ_EOF; -	if (errno == EAGAIN) -		return IRC_READ_AGAIN; -	if (errno == EINTR) -		goto start; +	SSL_set_connect_state (ssl); +	if (!SSL_set_fd (ssl, s->socket)) +		goto error_ssl_3; + +	// XXX: maybe set `ssl_rx_want_tx' to force a handshake? +	struct transport_tls_data *data = xcalloc (1, sizeof *data); +	data->ssl_ctx = ssl_ctx; +	data->ssl = ssl; -	LOG_LIBC_FAILURE ("recv"); -	return IRC_READ_ERROR; +	s->transport_data = data; +	return true; + +error_ssl_3: +	SSL_free (ssl); +error_ssl_2: +	SSL_CTX_free (ssl_ctx); +error_ssl_1: +	// XXX: these error strings are really nasty; also there could be +	//   multiple errors on the OpenSSL stack. +	if (!error_info) +		error_info = ERR_error_string (ERR_get_error (), NULL); +	error_set (e, "%s: %s", "could not initialize SSL/TLS", error_info); +	return false;  }  static void -on_irc_readable (const struct pollfd *fd, struct server *s) +transport_tls_cleanup (struct server *s)  { -	if (fd->revents & ~(POLLIN | POLLHUP | POLLERR)) -		print_debug ("fd %d: unexpected revents: %d", fd->fd, fd->revents); +	struct transport_tls_data *data = s->transport_data; +	if (data->ssl) +		SSL_free (data->ssl); +	if (data->ssl_ctx) +		SSL_CTX_free (data->ssl_ctx); +	free (data); +} -	(void) set_blocking (s->socket, false); +static enum transport_io_result +transport_tls_on_readable (struct server *s) +{ +	struct transport_tls_data *data = s->transport_data; +	if (data->ssl_tx_want_rx) +		return TRANSPORT_IO_OK;  	struct str *buf = &s->read_buffer; -	enum irc_read_result (*fill_buffer)(struct server *, struct str *) -		= s->ssl -		? irc_fill_read_buffer_ssl -		: irc_fill_read_buffer; -	bool disconnected = false; +	data->ssl_rx_want_tx = false;  	while (true)  	{  		str_ensure_space (buf, 512); -		switch (fill_buffer (s, buf)) +		int n_read = SSL_read (data->ssl, buf->str + buf->len, +			buf->alloc - buf->len - 1 /* null byte */); + +		const char *error_info = NULL; +		switch (xssl_get_error (data->ssl, n_read, &error_info))  		{ -		case IRC_READ_AGAIN: -			goto end; -		case IRC_READ_ERROR: -			log_server_error (s, s->buffer, -				"Reading from the IRC server failed"); -			disconnected = true; -			goto end; -		case IRC_READ_EOF: -			log_server_error (s, s->buffer, -				"The IRC server closed the connection"); -			disconnected = true; -			goto end; -		case IRC_READ_OK: -			break; +		case SSL_ERROR_NONE: +			buf->str[buf->len += n_read] = '\0'; +			continue; +		case SSL_ERROR_ZERO_RETURN: +			return TRANSPORT_IO_EOF; +		case SSL_ERROR_WANT_READ: +			return TRANSPORT_IO_OK; +		case SSL_ERROR_WANT_WRITE: +			data->ssl_rx_want_tx = true; +			return TRANSPORT_IO_OK; +		case XSSL_ERROR_TRY_AGAIN: +			continue; +		default: +			LOG_FUNC_FAILURE ("SSL_read", error_info); +			return TRANSPORT_IO_ERROR;  		} +	} +} -		if (buf->len >= (1 << 20)) +static enum transport_io_result +transport_tls_on_writeable (struct server *s) +{ +	struct transport_tls_data *data = s->transport_data; +	if (data->ssl_rx_want_tx) +		return TRANSPORT_IO_OK; + +	struct str *buf = &s->write_buffer; +	data->ssl_tx_want_rx = false; +	while (buf->len) +	{ +		int n_written = SSL_write (data->ssl, buf->str, buf->len); + +		const char *error_info = NULL; +		switch (xssl_get_error (data->ssl, n_written, &error_info))  		{ -			log_server_error (s, s->buffer, -				"The IRC server seems to spew out data frantically"); -			irc_shutdown (s); -			goto end; +		case SSL_ERROR_NONE: +			str_remove_slice (buf, 0, n_written); +			continue; +		case SSL_ERROR_ZERO_RETURN: +			return TRANSPORT_IO_EOF; +		case SSL_ERROR_WANT_WRITE: +			return TRANSPORT_IO_OK; +		case SSL_ERROR_WANT_READ: +			data->ssl_tx_want_rx = true; +			return TRANSPORT_IO_OK; +		case XSSL_ERROR_TRY_AGAIN: +			continue; +		default: +			LOG_FUNC_FAILURE ("SSL_write", error_info); +			return TRANSPORT_IO_ERROR;  		}  	} -end: -	(void) set_blocking (s->socket, true); -	irc_process_buffer (buf, irc_process_message, s); +	return TRANSPORT_IO_OK; +} -	if (disconnected) -		on_irc_disconnected (s); -	else -		irc_reset_connection_timeouts (s); +static int +transport_tls_get_poll_events (struct server *s) +{ +	struct transport_tls_data *data = s->transport_data; + +	int events = POLLIN; +	if (s->write_buffer.len || data->ssl_rx_want_tx) +		events |= POLLOUT; + +	// While we're waiting for an opposite event, we ignore the original +	if (data->ssl_rx_want_tx)  events &= ~POLLIN; +	if (data->ssl_tx_want_rx)  events &= ~POLLOUT; +	return events;  } +static void +transport_tls_in_before_shutdown (struct server *s) +{ +	struct transport_tls_data *data = s->transport_data; +	(void) SSL_shutdown (data->ssl); +} + +static struct transport g_transport_tls = +{ +	.init               = transport_tls_init, +	.cleanup            = transport_tls_cleanup, +	.on_readable        = transport_tls_on_readable, +	.on_writeable       = transport_tls_on_writeable, +	.get_poll_events    = transport_tls_get_poll_events, +	.in_before_shutdown = transport_tls_in_before_shutdown, +}; +  // --- Connection establishment ------------------------------------------------  static bool @@ -3667,11 +3868,14 @@ irc_finish_connection (struct server *s, int socket)  {  	struct app_context *ctx = s->ctx; +	set_blocking (socket, false);  	s->socket = socket; +	s->transport = get_config_boolean (s->config, "ssl") +		? &g_transport_tls +		: &g_transport_plain;  	struct error *e = NULL; -	bool use_ssl = get_config_boolean (s->config, "ssl"); -	if (use_ssl && !irc_initialize_ssl (s, &e)) +	if (s->transport->init && !s->transport->init (s, &e))  	{  		log_server_error (s, s->buffer, "Connection failed: #s", e->message);  		error_free (e); @@ -3679,21 +3883,21 @@ irc_finish_connection (struct server *s, int socket)  		xclose (s->socket);  		s->socket = -1; -		irc_queue_reconnect (s); +		s->transport = NULL;  		return;  	}  	log_server_status (s, s->buffer, "Connection established");  	s->state = IRC_CONNECTED; -	poller_fd_init (&s->read_event, &ctx->poller, s->socket); -	s->read_event.dispatcher = (poller_fd_fn) on_irc_readable; -	s->read_event.user_data = s; +	poller_fd_init (&s->socket_event, &ctx->poller, s->socket); +	s->socket_event.dispatcher = (poller_fd_fn) on_irc_ready; +	s->socket_event.user_data = s; -	poller_fd_set (&s->read_event, POLLIN); +	irc_update_poller (s, NULL);  	irc_reset_connection_timeouts (s); -  	irc_register (s); +  	refresh_prompt (s->ctx);  } | 
