Use MbedTLS instead of Nettle

This commit is contained in:
Benau 2021-02-21 01:06:22 +08:00
parent 39c42c3f0c
commit d753393f4d
4 changed files with 88 additions and 118 deletions

View File

@ -33,7 +33,7 @@ option(USE_IPV6 "Allow create or connect to game server with IPv6 address, syste
option(USE_SYSTEM_WIIUSE "Use system WiiUse instead of the built-in version, when available." OFF)
option(USE_SQLITE3 "Use sqlite to manage server stats and ban list." ON)
option(USE_CRYPTO_OPENSSL "Use OpenSSL instead of Nettle for cryptography in STK." ON)
option(USE_CRYPTO_OPENSSL "Use OpenSSL instead of MbedTLS for cryptography in STK." ON)
CMAKE_DEPENDENT_OPTION(BUILD_RECORDER "Build opengl recorder" ON
"NOT SERVER_ONLY;NOT APPLE" OFF)
CMAKE_DEPENDENT_OPTION(USE_SYSTEM_SQUISH "Use system Squish library instead of the built-in version, when available." ON
@ -551,7 +551,7 @@ if (NOT NO_LIBATOMIC_NEEDED)
target_link_libraries(supertuxkart atomic)
endif()
# CURL and OpenSSL or Nettle
# CURL and OpenSSL or MbedTLS
# 1.0.1d for compatible AES GCM handling
SET(OPENSSL_MINIMUM_VERSION "1.0.1d")
if (MSVC OR LLVM_MINGW)
@ -566,10 +566,10 @@ else()
find_package(CURL REQUIRED)
include_directories(${CURL_INCLUDE_DIRS})
find_path(NETTLE_INCLUDE_DIRS nettle/version.h)
find_library(NETTLE_LIBRARY NAMES nettle libnettle)
find_path(MBEDTLS_INCLUDE_DIRS mbedtls/version.h)
find_library(MBEDCRYPTO_LIBRARY NAMES mbedcrypto libmbedcrypto)
if (NOT NETTLE_INCLUDE_DIRS OR NOT NETTLE_LIBRARY OR USE_CRYPTO_OPENSSL)
if (NOT MBEDCRYPTO_LIBRARY OR NOT MBEDTLS_INCLUDE_DIRS OR USE_CRYPTO_OPENSSL)
set(USE_CRYPTO_OPENSSL ON)
find_package(OpenSSL REQUIRED)
@ -580,7 +580,7 @@ else()
include_directories(${OpenSSL_INCLUDE_DIRS})
else()
set(USE_CRYPTO_OPENSSL OFF)
include_directories(${NETTLE_INCLUDE_DIRS})
include_directories(${MBEDTLS_INCLUDE_DIRS})
endif()
endif()
@ -588,8 +588,8 @@ if (USE_CRYPTO_OPENSSL)
message(STATUS "OpenSSL will be used for cryptography in STK.")
add_definitions(-DENABLE_CRYPTO_OPENSSL)
else()
message(STATUS "Nettle will be used for cryptography in STK.")
add_definitions(-DENABLE_CRYPTO_NETTLE)
message(STATUS "MbedTLS will be used for cryptography in STK.")
add_definitions(-DENABLE_CRYPTO_MBEDTLS)
endif()
# Common library dependencies
@ -612,7 +612,7 @@ endif()
if (USE_CRYPTO_OPENSSL)
target_link_libraries(supertuxkart ${OPENSSL_CRYPTO_LIBRARY})
else()
target_link_libraries(supertuxkart ${NETTLE_LIBRARY})
target_link_libraries(supertuxkart ${MBEDCRYPTO_LIBRARY})
endif()
if(NOT SERVER_ONLY)

View File

@ -22,7 +22,7 @@
#ifdef ENABLE_CRYPTO_OPENSSL
#include "network/crypto_openssl.hpp"
#else
#include "network/crypto_nettle.hpp"
#include "network/crypto_mbedtls.hpp"
#endif
#endif // HEADER_CRYPTO_HPP

View File

@ -16,66 +16,48 @@
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#ifdef ENABLE_CRYPTO_NETTLE
#ifdef ENABLE_CRYPTO_MBEDTLS
#include "network/crypto_nettle.hpp"
#include "network/crypto_mbedtls.hpp"
#include "network/network_config.hpp"
#include "network/network_string.hpp"
#include <nettle/base64.h>
#include <nettle/sha.h>
#include <nettle/version.h>
#if NETTLE_VERSION_MAJOR > 3 || \
(NETTLE_VERSION_MAJOR == 3 && NETTLE_VERSION_MINOR > 3)
typedef const char* NETTLE_CONST_CHAR;
typedef char* NETTLE_CHAR;
#else
typedef const uint8_t* NETTLE_CONST_CHAR;
typedef uint8_t* NETTLE_CHAR;
#endif
#include <mbedtls/base64.h>
#include <mbedtls/sha256.h>
#include <cstring>
// ============================================================================
std::string Crypto::base64(const std::vector<uint8_t>& input)
{
size_t required_size = 0;
mbedtls_base64_encode(NULL, 0, &required_size, &input[0], input.size());
std::string result;
const size_t char_size = ((input.size() + 3 - 1) / 3) * 4;
result.resize(char_size, (char)0);
base64_encode_raw((NETTLE_CHAR)&result[0], input.size(), input.data());
result.resize(required_size, (char)0);
mbedtls_base64_encode((unsigned char*)&result[0], required_size,
&required_size, &input[0], input.size());
// mbedtls_base64_encode includes the null terminator for required size
result.resize(strlen(result.c_str()));
return result;
} // base64
// ============================================================================
std::vector<uint8_t> Crypto::decode64(std::string input)
{
size_t decode_len = calcDecodeLength(input);
std::vector<uint8_t> result(decode_len, 0);
struct base64_decode_ctx ctx;
base64_decode_init(&ctx);
size_t decode_len_by_nettle;
#ifdef DEBUG
int ret = base64_decode_update(&ctx, &decode_len_by_nettle, result.data(),
input.size(), (NETTLE_CONST_CHAR)input.c_str());
assert(ret == 1);
ret = base64_decode_final(&ctx);
assert(ret == 1);
assert(decode_len_by_nettle == decode_len);
#else
base64_decode_update(&ctx, &decode_len_by_nettle, result.data(),
input.size(), (NETTLE_CONST_CHAR)input.c_str());
base64_decode_final(&ctx);
#endif
size_t required_size = 0;
mbedtls_base64_decode(NULL, 0, &required_size, (unsigned char*)&input[0],
input.size());
std::vector<uint8_t> result(required_size, 0);
mbedtls_base64_decode(result.data(), required_size,
&required_size, (unsigned char*)&input[0], input.size());
return result;
} // decode64
// ============================================================================
std::array<uint8_t, 32> Crypto::sha256(const std::string& input)
{
std::array<uint8_t, SHA256_DIGEST_SIZE> result;
struct sha256_ctx hash;
sha256_init(&hash);
sha256_update(&hash, input.size(), (const uint8_t*)input.c_str());
sha256_digest(&hash, SHA256_DIGEST_SIZE, result.data());
std::array<uint8_t, 32> result;
mbedtls_sha256_ret((unsigned char*)&input[0], input.size(),
result.data(), 0/*not 224*/);
return result;
} // sha256
@ -86,9 +68,12 @@ std::string Crypto::m_client_iv;
bool Crypto::encryptConnectionRequest(BareNetworkString& ns)
{
std::vector<uint8_t> cipher(ns.m_buffer.size() + 4, 0);
gcm_aes128_encrypt(&m_aes_encrypt_context, ns.m_buffer.size(),
cipher.data() + 4, ns.m_buffer.data());
gcm_aes128_digest(&m_aes_encrypt_context, 4, cipher.data());
if (mbedtls_gcm_crypt_and_tag(&m_aes_encrypt_context, MBEDTLS_GCM_ENCRYPT,
ns.m_buffer.size(), m_iv.data(), m_iv.size(), NULL, 0,
ns.m_buffer.data(), cipher.data() + 4, 4, cipher.data()) != 0)
{
return false;
}
std::swap(ns.m_buffer, cipher);
return true;
} // encryptConnectionRequest
@ -98,13 +83,12 @@ bool Crypto::decryptConnectionRequest(BareNetworkString& ns)
{
std::vector<uint8_t> pt(ns.m_buffer.size() - 4, 0);
uint8_t* tag = ns.m_buffer.data();
std::array<uint8_t, 4> tag_after = {};
gcm_aes128_decrypt(&m_aes_decrypt_context, ns.m_buffer.size() - 4,
pt.data(), ns.m_buffer.data() + 4);
gcm_aes128_digest(&m_aes_decrypt_context, 4, tag_after.data());
handleAuthentication(tag, tag_after);
if (mbedtls_gcm_auth_decrypt(&m_aes_decrypt_context, pt.size(),
m_iv.data(), m_iv.size(), NULL, 0, tag, 4, ns.m_buffer.data() + 4,
pt.data()) != 0)
{
throw std::runtime_error("Failed authentication.");
}
std::swap(ns.m_buffer, pt);
return true;
} // decryptConnectionRequest
@ -140,11 +124,13 @@ ENetPacket* Crypto::encryptSend(BareNetworkString& ns, bool reliable)
}
uint8_t* packet_start = p->data + 8;
gcm_aes128_set_iv(&m_aes_encrypt_context, 12, iv.data());
gcm_aes128_encrypt(&m_aes_encrypt_context, ns.m_buffer.size(),
packet_start, ns.m_buffer.data());
gcm_aes128_digest(&m_aes_encrypt_context, 4, p->data + 4);
if (mbedtls_gcm_crypt_and_tag(&m_aes_encrypt_context, MBEDTLS_GCM_ENCRYPT,
ns.m_buffer.size(), iv.data(), iv.size(), NULL, 0, ns.m_buffer.data(),
packet_start, 4, p->data + 4) != 0)
{
enet_packet_destroy(p);
return NULL;
}
ul.unlock();
p->data[0] = (val >> 24) & 0xff;
@ -168,13 +154,11 @@ NetworkString* Crypto::decryptRecieve(ENetPacket* p)
uint8_t* packet_start = p->data + 8;
uint8_t* tag = p->data + 4;
std::array<uint8_t, 4> tag_after = {};
gcm_aes128_set_iv(&m_aes_decrypt_context, 12, iv.data());
gcm_aes128_decrypt(&m_aes_decrypt_context, clen, ns->m_buffer.data(),
packet_start);
gcm_aes128_digest(&m_aes_decrypt_context, 4, tag_after.data());
handleAuthentication(tag, tag_after);
if (mbedtls_gcm_auth_decrypt(&m_aes_decrypt_context, clen, iv.data(),
iv.size(), NULL, 0, tag, 4, packet_start, ns->m_buffer.data()) != 0)
{
throw std::runtime_error("Failed authentication.");
}
NetworkString* result = ns.get();
ns.release();

View File

@ -16,16 +16,17 @@
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#ifdef ENABLE_CRYPTO_NETTLE
#ifdef ENABLE_CRYPTO_MBEDTLS
#ifndef HEADER_CRYPTO_NETTLE_HPP
#define HEADER_CRYPTO_NETTLE_HPP
#ifndef HEADER_CRYPTO_MBEDTLS_HPP
#define HEADER_CRYPTO_MBEDTLS_HPP
#include "utils/log.hpp"
#include <enet/enet.h>
#include <nettle/gcm.h>
#include <nettle/yarrow.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h>
#include <mbedtls/gcm.h>
#include <algorithm>
#include <array>
@ -50,39 +51,10 @@ private:
uint32_t m_packet_counter;
struct gcm_aes128_ctx m_aes_encrypt_context, m_aes_decrypt_context;
mbedtls_gcm_context m_aes_encrypt_context, m_aes_decrypt_context;
std::mutex m_crypto_mutex;
// ------------------------------------------------------------------------
static size_t calcDecodeLength(const std::string& input)
{
// Calculates the length of a decoded string
size_t padding = 0;
const size_t len = input.size();
if (input[len - 1] == '=' && input[len - 2] == '=')
{
// last two chars are =
padding = 2;
}
else if (input[len - 1] == '=')
{
// last char is =
padding = 1;
}
return (len * 3) / 4 - padding;
} // calcDecodeLength
// ------------------------------------------------------------------------
void handleAuthentication(const uint8_t* tag,
const std::array<uint8_t, 4>& tag_after)
{
for (unsigned i = 0; i < tag_after.size(); i++)
{
if (tag[i] != tag_after[i])
throw std::runtime_error("Failed authentication.");
}
}
public:
// ------------------------------------------------------------------------
static std::string base64(const std::vector<uint8_t>& input);
@ -103,16 +75,24 @@ public:
// ------------------------------------------------------------------------
static void initClientAES()
{
struct yarrow256_ctx ctx;
yarrow256_init(&ctx, 0, NULL);
mbedtls_entropy_context entropy;
mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ctr_drbg_init(&ctr_drbg);
std::random_device rd;
std::mt19937 g(rd());
std::array<uint8_t, YARROW256_SEED_FILE_SIZE> seed;
for (unsigned i = 0; i < YARROW256_SEED_FILE_SIZE; i++)
std::array<uint8_t, 28> seed, key_iv;
for (unsigned i = 0; i < 28; i++)
seed[i] = (uint8_t)(g() % 255);
yarrow256_seed(&ctx, YARROW256_SEED_FILE_SIZE, seed.data());
std::array<uint8_t, 28> key_iv;
yarrow256_random(&ctx, key_iv.size(), key_iv.data());
key_iv = seed;
if (mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func,
&entropy, seed.data(), seed.size()) == 0)
{
// If failed use the seed in the beginning
if (mbedtls_ctr_drbg_random((void*)&ctr_drbg, key_iv.data(),
key_iv.size()) != 0)
key_iv = seed;
}
m_client_key = base64({ key_iv.begin(), key_iv.begin() + 16 });
m_client_iv = base64({ key_iv.begin() + 16, key_iv.end() });
}
@ -134,10 +114,16 @@ public:
assert(iv.size() == 12);
std::copy_n(iv.begin(), 12, m_iv.begin());
m_packet_counter = 0;
gcm_aes128_set_key(&m_aes_encrypt_context, key.data());
gcm_aes128_set_iv(&m_aes_encrypt_context, 12, iv.data());
gcm_aes128_set_key(&m_aes_decrypt_context, key.data());
gcm_aes128_set_iv(&m_aes_decrypt_context, 12, iv.data());
mbedtls_gcm_setkey(&m_aes_encrypt_context, MBEDTLS_CIPHER_ID_AES,
key.data(), key.size() * 8);
mbedtls_gcm_setkey(&m_aes_decrypt_context, MBEDTLS_CIPHER_ID_AES,
key.data(), key.size() * 8);
}
// ------------------------------------------------------------------------
~Crypto()
{
mbedtls_gcm_free(&m_aes_encrypt_context);
mbedtls_gcm_free(&m_aes_decrypt_context);
}
// ------------------------------------------------------------------------
bool encryptConnectionRequest(BareNetworkString& ns);
@ -150,6 +136,6 @@ public:
};
#endif // HEADER_CRYPTO_NETTLE_HPP
#endif // HEADER_CRYPTO_MBEDTLS_HPP
#endif