From 1fe107875b05cc07cf62c714c0136026eef7b93a Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Sun, 25 Oct 2020 14:50:07 -0400 Subject: [PATCH] Overhaul network I/O to be async for real Had to totally cut off OpenSSL from the network fd because obviously OpenSSL is just going to wreck our shit --- include/server.h | 13 +- include/tls.h | 2 +- src/serve.c | 1 + src/server.c | 314 +++++++++++++++++++++++++++++++++-------------- src/tls.c | 33 +++-- 5 files changed, 252 insertions(+), 111 deletions(-) diff --git a/include/server.h b/include/server.h index db68e33..5795426 100644 --- a/include/server.h +++ b/include/server.h @@ -10,9 +10,11 @@ struct gmnisrv_server; -enum response_state { - RESPOND_HEADER, - RESPOND_BODY, +enum client_state { + CLIENT_STATE_REQUEST, + CLIENT_STATE_SSL, + CLIENT_STATE_HEADER, + CLIENT_STATE_BODY, }; struct gmnisrv_client { @@ -24,13 +26,12 @@ struct gmnisrv_client { struct pollfd *pollfd; SSL *ssl; - BIO *bio, *sbio; + BIO *rbio, *wbio; char buf[4096]; - static_assert(GEMINI_MAX_URL + 3 < 4096, "GEMINI_MAX_URL is too high"); size_t bufix, bufln; - enum response_state state; + enum client_state state, next; enum gemini_status status; char *meta; FILE *body; diff --git a/include/tls.h b/include/tls.h index 81ff613..0882ff0 100644 --- a/include/tls.h +++ b/include/tls.h @@ -5,7 +5,7 @@ struct gmnisrv_config; int tls_init(struct gmnisrv_config *conf); void tls_finish(struct gmnisrv_config *conf); -SSL *tls_get_ssl(struct gmnisrv_config *conf, int fd); +SSL *tls_get_ssl(struct gmnisrv_config *conf); void tls_set_host(SSL *ssl, struct gmnisrv_host *host); #endif diff --git a/src/serve.c b/src/serve.c index df08d08..b798e7b 100644 --- a/src/serve.c +++ b/src/serve.c @@ -18,6 +18,7 @@ void client_submit_response(struct gmnisrv_client *client, enum gemini_status status, const char *meta, FILE *body) { + client->state = CLIENT_STATE_HEADER; client->status = status; client->meta = strdup(meta); client->body = body; diff --git a/src/server.c b/src/server.c index 6412f1b..65b8204 100644 --- a/src/server.c +++ b/src/server.c @@ -186,9 +186,6 @@ disconnect_client(struct gmnisrv_server *server, struct gmnisrv_client *client) client->path ? client->path : "(none)", ms, client->bbytes, (int)client->status, client->meta); } - if (client->bio) { - BIO_free_all(client->bio); - } if (client->ssl) { SSL_free(client->ssl); } @@ -211,7 +208,7 @@ disconnect_client(struct gmnisrv_server *server, struct gmnisrv_client *client) static int client_init_ssl(struct gmnisrv_server *server, struct gmnisrv_client *client) { - client->ssl = tls_get_ssl(server->conf, client->sockfd); + client->ssl = tls_get_ssl(server->conf); if (!client->ssl) { client_error(&client->addr, "unable to initialize SSL, disconnecting"); @@ -219,151 +216,280 @@ client_init_ssl(struct gmnisrv_server *server, struct gmnisrv_client *client) return 1; } - int r = SSL_accept(client->ssl); - if (r != 1) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_READ || r == SSL_ERROR_WANT_WRITE) { - return 1; - } - client_error(&client->addr, "SSL accept error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return 1; - } + client->rbio = BIO_new(BIO_s_mem()); + client->wbio = BIO_new(BIO_s_mem()); - client->sbio = BIO_new(BIO_f_ssl()); - BIO_set_ssl(client->sbio, client->ssl, 0); - client->bio = BIO_new(BIO_f_buffer()); - BIO_push(client->bio, client->sbio); + SSL_set_accept_state(client->ssl); + SSL_set_bio(client->ssl, client->rbio, client->wbio); return 0; } -enum client_state { - CLIENT_CONNECTED, - CLIENT_DISCONNECTED, +enum connection_state { + CONNECTED, + DISCONNECTED, }; -static enum client_state +static enum connection_state client_readable(struct gmnisrv_server *server, struct gmnisrv_client *client) { if (!client->ssl && client_init_ssl(server, client) != 0) { - return CLIENT_DISCONNECTED; + return DISCONNECTED; } + + char buf[BUFSIZ]; + ssize_t n = read(client->sockfd, buf, sizeof(buf)); + if (n <= 0) { + disconnect_client(server, client); + return DISCONNECTED; + } + + size_t w = 0; + while (w < (size_t)n) { + int r = BIO_write(client->rbio, &buf[w], n - w); + if (r <= 0) { + client_error(&client->addr, + "Error writing to client RBIO: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + w += r; + } + + if (!SSL_is_init_finished(client->ssl)) { + int r = SSL_accept(client->ssl); + switch ((r = SSL_get_error(client->ssl, r))) { + case SSL_ERROR_NONE: + break; + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + goto queue_ssl_write; + case SSL_ERROR_SSL: + client_error(&client->addr, + "SSL accept error: %s", + ERR_error_string(ERR_get_error(), NULL)); + disconnect_client(server, client); + return DISCONNECTED; + default: + client_error(&client->addr, + "SSL accept error: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + + if (!SSL_is_init_finished(client->ssl)) { + return CONNECTED; + } + } + if (!client->host) { + client_log(&client->addr, "missing client host"); const char *error = "This server requires clients to support the TLS SNI (server name identification) extension"; client_submit_response(client, GEMINI_STATUS_BAD_REQUEST, error, NULL); - return CLIENT_CONNECTED; + return CONNECTED; } - int r = BIO_gets(client->bio, client->buf, sizeof(client->buf)); - if (r <= 0) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_READ) { - return CLIENT_CONNECTED; + int r, e; + do { + if (client->bufln >= sizeof(client->buf)) { + client_log(&client->addr, "overlong"); + const char *error = "Protocol error: malformed request"; + client_submit_response(client, + GEMINI_STATUS_BAD_REQUEST, error, NULL); + return CONNECTED; } - client_error(&client->addr, "SSL read error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } - client->buf[r] = '\0'; + + r = SSL_read(client->ssl, + &client->buf[client->bufln], + sizeof(client->buf) - client->bufln); + + switch ((e = SSL_get_error(client->ssl, r))) { + case SSL_ERROR_NONE: + client->bufln += r; + break; + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + break; + case SSL_ERROR_SSL: + client_error(&client->addr, + "SSL read error: %s", + ERR_error_string(ERR_get_error(), NULL)); + disconnect_client(server, client); + return DISCONNECTED; + default: + client_error(&client->addr, + "SSL read error: %s", + ERR_error_string(e, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + } while (r > 0); + + client->buf[client->bufln] = '\0'; char *newline = strstr(client->buf, "\r\n"); if (!newline) { const char *error = "Protocol error: malformed request"; - client_submit_response(client, - GEMINI_STATUS_BAD_REQUEST, error, NULL); - return CLIENT_CONNECTED; + switch (e) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + goto queue_ssl_write; + default: + client_submit_response(client, + GEMINI_STATUS_BAD_REQUEST, error, NULL); + return CONNECTED; + } } *newline = 0; if (!request_validate(client, &client->path)) { - return CLIENT_CONNECTED; + return CONNECTED; } serve_request(client); - return CLIENT_CONNECTED; + return CONNECTED; + +queue_ssl_write: + client->bufln = 0; + client->state = CLIENT_STATE_SSL; + client->next = CLIENT_STATE_REQUEST; + do { + assert(client->bufln < sizeof(client->buf)); + r = BIO_read(client->wbio, + &client->buf[client->bufln], + sizeof(client->buf) - client->bufln); + if (r <= 0) { + if (BIO_should_retry(client->wbio)) { + continue; + } + client_error(&client->addr, + "BIO read error: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } else { + client->bufln += r; + } + } while (r > 0); + client->pollfd->events = POLLOUT; + return CONNECTED; } -static enum client_state +static enum connection_state client_writable(struct gmnisrv_server *server, struct gmnisrv_client *client) { int r; ssize_t n; + char buf[BUFSIZ]; switch (client->state) { - case RESPOND_HEADER: + case CLIENT_STATE_REQUEST: + assert(0); // Invariant + case CLIENT_STATE_SSL: + assert(client->bufln > 0); + n = write(client->sockfd, client->buf, client->bufln); + if (n <= 0) { + client_log(&client->addr, "write error: %s", + strerror(errno)); + disconnect_client(server, client); + return DISCONNECTED; + } + client->bufln -= n; + if (client->bufln == 0) { + client->state = client->next; + if (client->state == CLIENT_STATE_REQUEST) { + client->pollfd->events = POLLIN; + } + } + return CONNECTED; + case CLIENT_STATE_HEADER: if (client->bufix == 0) { assert(strlen(client->meta) <= 1024); - n = snprintf(client->buf, sizeof(client->buf), + int n = snprintf(client->buf, sizeof(client->buf), "%02d %s\r\n", (int)client->status, client->meta); assert(n > 0); client->bufln = n; } - r = BIO_write(client->sbio, &client->buf[client->bufix], - client->bufln - client->bufix); - if (r <= 0) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_WRITE) { - return CLIENT_CONNECTED; - } - client->status = GEMINI_STATUS_NONE; - client_error(&client->addr, - "header write error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } - client->bufix += r; - if (client->bufix >= client->bufln) { - if (!client->body) { - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } else { - client->state = RESPOND_BODY; - client->bufix = client->bufln = 0; - return CLIENT_CONNECTED; - } - } break; - case RESPOND_BODY: + case CLIENT_STATE_BODY: if (client->bufix >= client->bufln) { - n = fread(client->buf, 1, sizeof(client->buf), - client->body); + int n = fread(client->buf, 1, + sizeof(client->buf), client->body); if (n == -1) { client_error(&client->addr, "Error reading response body: %s", strerror(errno)); disconnect_client(server, client); - return CLIENT_DISCONNECTED; + return DISCONNECTED; } if (n == 0) { // EOF disconnect_client(server, client); - return CLIENT_DISCONNECTED; + return DISCONNECTED; } client->bbytes += n; client->bufln = n; client->bufix = 0; } - r = BIO_write(client->sbio, &client->buf[client->bufix], - client->bufln - client->bufix); - if (r <= 0) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_WRITE) { - return CLIENT_CONNECTED; - } - client->status = GEMINI_STATUS_NONE; - client_error(&client->addr, "body write error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } - client->bufix += r; break; } - return false; + + r = SSL_write(client->ssl, &client->buf[client->bufix], + client->bufln - client->bufix); + if (r <= 0) { + r = SSL_get_error(client->ssl, r); + assert(r == SSL_ERROR_WANT_WRITE); // Hmm? + client->status = GEMINI_STATUS_NONE; + client_error(&client->addr, + "header write error %s, disconnecting", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + client->bufix += r; + + while (r > 0) { + r = BIO_read(client->wbio, buf, sizeof(buf)); + if (r < 0 && !BIO_should_retry(client->wbio)) { + client_error(&client->addr, + "BIO read error: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + + for (int w = 0; w < r; ) { + int q = write(client->sockfd, &buf[w], r - w); + if (q < 0) { + assert(0); // TODO: handle write errors + } + w += q; + } + } + + switch (client->state) { + case CLIENT_STATE_REQUEST: + case CLIENT_STATE_SSL: + assert(0); // Invariant + case CLIENT_STATE_HEADER: + if (client->bufix >= client->bufln) { + if (!client->body) { + disconnect_client(server, client); + return DISCONNECTED; + } else { + client->state = CLIENT_STATE_BODY; + client->bufix = client->bufln = 0; + return CONNECTED; + } + } + break; + case CLIENT_STATE_BODY: + break; + } + + return CONNECTED; } static long @@ -450,18 +576,18 @@ server_run(struct gmnisrv_server *server) for (size_t i = 0; i < server->nclients; ++i) { int pi = i + server->nlisten; - enum client_state s = CLIENT_CONNECTED; + enum connection_state s = CONNECTED; if ((server->fds[pi].revents & (POLLHUP | POLLERR))) { disconnect_client(server, &server->clients[i]); - s = CLIENT_DISCONNECTED; + s = DISCONNECTED; } - if (s == CLIENT_CONNECTED && (server->fds[pi].revents & POLLIN)) { + if (s == CONNECTED && (server->fds[pi].revents & POLLIN)) { s = client_readable(server, &server->clients[i]); } - if (s == CLIENT_CONNECTED && (server->fds[pi].revents & POLLOUT)) { + if (s == CONNECTED && (server->fds[pi].revents & POLLOUT)) { s = client_writable(server, &server->clients[i]); } - if (s == CLIENT_DISCONNECTED) { + if (s == DISCONNECTED) { --i; } } diff --git a/src/tls.c b/src/tls.c index cde4b25..ca8a307 100644 --- a/src/tls.c +++ b/src/tls.c @@ -161,12 +161,31 @@ tls_init(struct gmnisrv_config *conf) SSL_load_error_strings(); ERR_load_crypto_strings(); - conf->tls.ssl_ctx = SSL_CTX_new(TLS_method()); + conf->tls.ssl_ctx = SSL_CTX_new(TLS_server_method()); assert(conf->tls.ssl_ctx); + int r = SSL_CTX_set_min_proto_version(conf->tls.ssl_ctx, TLS1_2_VERSION); + assert(r == 1); + + r = SSL_CTX_set_cipher_list(conf->tls.ssl_ctx, + "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" + "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" + "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" + "DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" + "TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:" + "TLS_CHACHA20_POLY1305_SHA256"); + assert(r == 1); + SSL_CTX_set_tlsext_servername_callback(conf->tls.ssl_ctx, NULL); - int r; + // TLS re-negotiation is a fucking STUPID idea + // I'm gating this behind an #ifdef based on an optimistic assumption + // that someday it will be removed from OpenSSL entirely because of how + // fucking stupid this fucking godawful idea is +#ifdef SSL_OP_NO_RENEGOTIATION + SSL_CTX_set_options(conf->tls.ssl_ctx, SSL_OP_NO_RENEGOTIATION); +#endif + for (struct gmnisrv_host *host = conf->hosts; host; host = host->next) { r = tls_host_init(&conf->tls, host); if (r != 0) { @@ -188,15 +207,9 @@ tls_finish(struct gmnisrv_config *conf) } SSL * -tls_get_ssl(struct gmnisrv_config *conf, int fd) +tls_get_ssl(struct gmnisrv_config *conf) { - SSL *ssl = SSL_new(conf->tls.ssl_ctx); - if (!ssl) { - return NULL; - } - int r = SSL_set_fd(ssl, fd); - assert(r == 1); - return ssl; + return SSL_new(conf->tls.ssl_ctx); } void