182 lines
3.2 KiB
C++
182 lines
3.2 KiB
C++
|
|
||
|
// SslContext.cpp
|
||
|
|
||
|
// Implements the cSslContext class that holds everything a single SSL context needs to function
|
||
|
|
||
|
#include "Globals.h"
|
||
|
#include "SslContext.h"
|
||
|
#include "EntropyContext.h"
|
||
|
#include "CtrDrbgContext.h"
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
cSslContext::cSslContext(void) :
|
||
|
m_IsValid(false),
|
||
|
m_HasHandshaken(false)
|
||
|
{
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
cSslContext::~cSslContext()
|
||
|
{
|
||
|
if (m_IsValid)
|
||
|
{
|
||
|
ssl_free(&m_Ssl);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
int cSslContext::Initialize(bool a_IsClient, const SharedPtr<cCtrDrbgContext> & a_CtrDrbg)
|
||
|
{
|
||
|
// Check double-initialization:
|
||
|
if (m_IsValid)
|
||
|
{
|
||
|
LOGWARNING("SSL: Double initialization is not supported.");
|
||
|
return POLARSSL_ERR_SSL_MALLOC_FAILED; // There is no return value well-suited for this, reuse this one.
|
||
|
}
|
||
|
|
||
|
// Set the CtrDrbg context, create a new one if needed:
|
||
|
m_CtrDrbg = a_CtrDrbg;
|
||
|
if (m_CtrDrbg.get() == NULL)
|
||
|
{
|
||
|
m_CtrDrbg.reset(new cCtrDrbgContext);
|
||
|
m_CtrDrbg->Initialize("MCServer", 8);
|
||
|
}
|
||
|
|
||
|
// Initialize PolarSSL's structures:
|
||
|
memset(&m_Ssl, 0, sizeof(m_Ssl));
|
||
|
int res = ssl_init(&m_Ssl);
|
||
|
if (res != 0)
|
||
|
{
|
||
|
return res;
|
||
|
}
|
||
|
ssl_set_endpoint(&m_Ssl, a_IsClient ? SSL_IS_CLIENT : SSL_IS_SERVER);
|
||
|
ssl_set_authmode(&m_Ssl, SSL_VERIFY_OPTIONAL);
|
||
|
ssl_set_rng(&m_Ssl, ctr_drbg_random, &m_CtrDrbg->m_CtrDrbg);
|
||
|
ssl_set_bio(&m_Ssl, ReceiveEncrypted, this, SendEncrypted, this);
|
||
|
|
||
|
#ifdef _DEBUG
|
||
|
ssl_set_dbg(&m_Ssl, &SSLDebugMessage, this);
|
||
|
#endif
|
||
|
|
||
|
m_IsValid = true;
|
||
|
return 0;
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
void cSslContext::SetCACerts(const cX509CertPtr & a_CACert, const AString & a_ExpectedPeerName)
|
||
|
{
|
||
|
// Store the data in our internal buffers, to avoid losing the pointers later on
|
||
|
// PolarSSL will need these after this call returns, and the caller may move / delete the data before that:
|
||
|
m_ExpectedPeerName = a_ExpectedPeerName;
|
||
|
m_CACerts = a_CACert;
|
||
|
|
||
|
// Set the trusted CA root cert store:
|
||
|
ssl_set_authmode(&m_Ssl, SSL_VERIFY_REQUIRED);
|
||
|
ssl_set_ca_chain(&m_Ssl, m_CACerts->GetInternal(), NULL, m_ExpectedPeerName.empty() ? NULL : m_ExpectedPeerName.c_str());
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
int cSslContext::WritePlain(const void * a_Data, size_t a_NumBytes)
|
||
|
{
|
||
|
ASSERT(m_IsValid); // Need to call Initialize() first
|
||
|
if (!m_HasHandshaken)
|
||
|
{
|
||
|
int res = Handshake();
|
||
|
if (res != 0)
|
||
|
{
|
||
|
return res;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ssl_write(&m_Ssl, (const unsigned char *)a_Data, a_NumBytes);
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
int cSslContext::ReadPlain(void * a_Data, size_t a_MaxBytes)
|
||
|
{
|
||
|
ASSERT(m_IsValid); // Need to call Initialize() first
|
||
|
if (!m_HasHandshaken)
|
||
|
{
|
||
|
int res = Handshake();
|
||
|
if (res != 0)
|
||
|
{
|
||
|
return res;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ssl_read(&m_Ssl, (unsigned char *)a_Data, a_MaxBytes);
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
int cSslContext::Handshake(void)
|
||
|
{
|
||
|
ASSERT(m_IsValid); // Need to call Initialize() first
|
||
|
ASSERT(!m_HasHandshaken); // Must not call twice
|
||
|
|
||
|
int res = ssl_handshake(&m_Ssl);
|
||
|
if (res == 0)
|
||
|
{
|
||
|
m_HasHandshaken = true;
|
||
|
}
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
int cSslContext::NotifyClose(void)
|
||
|
{
|
||
|
return ssl_close_notify(&m_Ssl);
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
#ifdef _DEBUG
|
||
|
void cSslContext::SSLDebugMessage(void * a_UserParam, int a_Level, const char * a_Text)
|
||
|
{
|
||
|
if (a_Level > 3)
|
||
|
{
|
||
|
// Don't want the trace messages
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// Remove the terminating LF:
|
||
|
size_t len = strlen(a_Text) - 1;
|
||
|
while ((len > 0) && (a_Text[len] <= 32))
|
||
|
{
|
||
|
len--;
|
||
|
}
|
||
|
AString Text(a_Text, len + 1);
|
||
|
|
||
|
LOGD("SSL (%d): %s", a_Level, Text.c_str());
|
||
|
}
|
||
|
#endif // _DEBUG
|
||
|
|
||
|
|
||
|
|
||
|
|