BlockingSslClientSocket: Migrated to cNetwork API.
This commit is contained in:
parent
7dfeb67f01
commit
86f2f82d2a
@ -10,6 +10,80 @@
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// cBlockingSslClientSocketConnectCallbacks:
|
||||
|
||||
class cBlockingSslClientSocketConnectCallbacks:
|
||||
public cNetwork::cConnectCallbacks
|
||||
{
|
||||
/** The socket object that is using this instance of the callbacks. */
|
||||
cBlockingSslClientSocket & m_Socket;
|
||||
|
||||
virtual void OnConnected(cTCPLink & a_Link) override
|
||||
{
|
||||
m_Socket.OnConnected();
|
||||
}
|
||||
|
||||
virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
|
||||
{
|
||||
m_Socket.OnConnectError(a_ErrorMsg);
|
||||
}
|
||||
|
||||
public:
|
||||
cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket & a_Socket):
|
||||
m_Socket(a_Socket)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// cBlockingSslClientSocketLinkCallbacks:
|
||||
|
||||
class cBlockingSslClientSocketLinkCallbacks:
|
||||
public cTCPLink::cCallbacks
|
||||
{
|
||||
cBlockingSslClientSocket & m_Socket;
|
||||
|
||||
virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
|
||||
{
|
||||
m_Socket.SetLink(a_Link);
|
||||
}
|
||||
|
||||
|
||||
virtual void OnReceivedData(const char * a_Data, size_t a_Length)
|
||||
{
|
||||
m_Socket.OnReceivedData(a_Data, a_Length);
|
||||
}
|
||||
|
||||
|
||||
virtual void OnRemoteClosed(void)
|
||||
{
|
||||
m_Socket.OnDisconnected();
|
||||
}
|
||||
|
||||
|
||||
virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg)
|
||||
{
|
||||
m_Socket.OnDisconnected();
|
||||
}
|
||||
public:
|
||||
cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket & a_Socket):
|
||||
m_Socket(a_Socket)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// cBlockingSslClientSocket:
|
||||
|
||||
cBlockingSslClientSocket::cBlockingSslClientSocket(void) :
|
||||
m_Ssl(*this),
|
||||
m_IsConnected(false)
|
||||
@ -32,10 +106,19 @@ bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Po
|
||||
}
|
||||
|
||||
// Connect the underlying socket:
|
||||
m_Socket.CreateSocket(cSocket::IPv4);
|
||||
if (!m_Socket.ConnectIPv4(a_ServerName.c_str(), a_Port))
|
||||
m_ServerName = a_ServerName;
|
||||
if (!cNetwork::Connect(a_ServerName, a_Port,
|
||||
std::make_shared<cBlockingSslClientSocketConnectCallbacks>(*this),
|
||||
std::make_shared<cBlockingSslClientSocketLinkCallbacks>(*this))
|
||||
)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wait for the connection to succeed or fail:
|
||||
m_Event.Wait();
|
||||
if (!m_IsConnected)
|
||||
{
|
||||
Printf(m_LastErrorText, "Socket connect failed: %s", m_Socket.GetLastErrorString().c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -102,7 +185,7 @@ bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes)
|
||||
ASSERT(m_IsConnected);
|
||||
|
||||
// Keep sending the data until all of it is sent:
|
||||
const char * Data = (const char *)a_Data;
|
||||
const char * Data = reinterpret_cast<const char *>(a_Data);
|
||||
size_t NumBytes = a_NumBytes;
|
||||
for (;;)
|
||||
{
|
||||
@ -156,7 +239,8 @@ void cBlockingSslClientSocket::Disconnect(void)
|
||||
}
|
||||
|
||||
m_Ssl.NotifyClose();
|
||||
m_Socket.CloseSocket();
|
||||
m_Socket->Close();
|
||||
m_Socket.reset();
|
||||
m_IsConnected = false;
|
||||
}
|
||||
|
||||
@ -166,13 +250,25 @@ void cBlockingSslClientSocket::Disconnect(void)
|
||||
|
||||
int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
|
||||
{
|
||||
int res = m_Socket.Receive((char *)a_Buffer, a_NumBytes, 0);
|
||||
if (res < 0)
|
||||
// Wait for any incoming data, if there is none:
|
||||
cCSLock Lock(m_CSIncomingData);
|
||||
while (m_IsConnected && m_IncomingData.empty())
|
||||
{
|
||||
cCSUnlock Unlock(Lock);
|
||||
m_Event.Wait();
|
||||
}
|
||||
|
||||
// If we got disconnected, report an error after processing all data:
|
||||
if (!m_IsConnected && m_IncomingData.empty())
|
||||
{
|
||||
// PolarSSL's net routines distinguish between connection reset and general failure, we don't need to
|
||||
return POLARSSL_ERR_NET_RECV_FAILED;
|
||||
}
|
||||
return res;
|
||||
|
||||
// Copy the data from the incoming buffer into the specified space:
|
||||
size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size());
|
||||
memcpy(a_Buffer, m_IncomingData.data(), NumToCopy);
|
||||
m_IncomingData.erase(0, NumToCopy);
|
||||
return static_cast<int>(NumToCopy);
|
||||
}
|
||||
|
||||
|
||||
@ -181,13 +277,69 @@ int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t
|
||||
|
||||
int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
|
||||
{
|
||||
int res = m_Socket.Send((const char *)a_Buffer, a_NumBytes);
|
||||
if (res < 0)
|
||||
cTCPLinkPtr Socket(m_Socket); // Make a copy so that multiple threads don't race on deleting the socket.
|
||||
if (Socket == nullptr)
|
||||
{
|
||||
return POLARSSL_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
if (!Socket->Send(a_Buffer, a_NumBytes))
|
||||
{
|
||||
// PolarSSL's net routines distinguish between connection reset and general failure, we don't need to
|
||||
return POLARSSL_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
return res;
|
||||
return static_cast<int>(a_NumBytes);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void cBlockingSslClientSocket::OnConnected(void)
|
||||
{
|
||||
m_IsConnected = true;
|
||||
m_Event.Set();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void cBlockingSslClientSocket::OnConnectError(const AString & a_ErrorMsg)
|
||||
{
|
||||
LOG("Cannot connect to %s: %s", m_ServerName.c_str(), a_ErrorMsg.c_str());
|
||||
m_Event.Set();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size)
|
||||
{
|
||||
{
|
||||
cCSLock Lock(m_CSIncomingData);
|
||||
m_IncomingData.append(a_Data, a_Size);
|
||||
}
|
||||
m_Event.Set();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void cBlockingSslClientSocket::SetLink(cTCPLinkPtr a_Link)
|
||||
{
|
||||
m_Socket = a_Link;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void cBlockingSslClientSocket::OnDisconnected(void)
|
||||
{
|
||||
m_Socket.reset();
|
||||
m_IsConnected = false;
|
||||
m_Event.Set();
|
||||
}
|
||||
|
||||
|
||||
|
@ -9,8 +9,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "OSSupport/Network.h"
|
||||
#include "CallbackSslContext.h"
|
||||
#include "../OSSupport/Socket.h"
|
||||
|
||||
|
||||
|
||||
@ -51,25 +51,56 @@ public:
|
||||
const AString & GetLastErrorText(void) const { return m_LastErrorText; }
|
||||
|
||||
protected:
|
||||
friend class cBlockingSslClientSocketConnectCallbacks;
|
||||
friend class cBlockingSslClientSocketLinkCallbacks;
|
||||
|
||||
/** The SSL context used for the socket */
|
||||
cCallbackSslContext m_Ssl;
|
||||
|
||||
/** The underlying socket to the SSL server */
|
||||
cSocket m_Socket;
|
||||
cTCPLinkPtr m_Socket;
|
||||
|
||||
/** The object used to signal state changes in the socket (the cause of the blocking). */
|
||||
cEvent m_Event;
|
||||
|
||||
/** The trusted CA root cert store, if we are to verify the cert strictly. Set by SetTrustedRootCertsFromString(). */
|
||||
cX509CertPtr m_CACerts;
|
||||
|
||||
/** The expected SSL peer's name, if we are to verify the cert strictly. Set by SetTrustedRootCertsFromString(). */
|
||||
AString m_ExpectedPeerName;
|
||||
|
||||
/** The hostname to which the socket is connecting (stored for error reporting). */
|
||||
AString m_ServerName;
|
||||
|
||||
/** Text of the last error that has occurred. */
|
||||
AString m_LastErrorText;
|
||||
|
||||
/** Set to true if the connection established successfully. */
|
||||
bool m_IsConnected;
|
||||
|
||||
/** Protects m_IncomingData against multithreaded access. */
|
||||
cCriticalSection m_CSIncomingData;
|
||||
|
||||
/** Buffer for the data incoming on the network socket.
|
||||
Protected by m_CSIncomingData. */
|
||||
AString m_IncomingData;
|
||||
|
||||
|
||||
/** Called when the connection is established successfully. */
|
||||
void OnConnected(void);
|
||||
|
||||
/** Called when an error occurs while connecting the socket. */
|
||||
void OnConnectError(const AString & a_ErrorMsg);
|
||||
|
||||
/** Called when there's incoming data from the socket. */
|
||||
void OnReceivedData(const char * a_Data, size_t a_Size);
|
||||
|
||||
/** Called when the link for the connection is created. */
|
||||
void SetLink(cTCPLinkPtr a_Link);
|
||||
|
||||
/** Called when the link is disconnected, either gracefully or by an error. */
|
||||
void OnDisconnected(void);
|
||||
|
||||
// cCallbackSslContext::cDataCallbacks overrides:
|
||||
virtual int ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes) override;
|
||||
virtual int SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes) override;
|
||||
|
@ -70,7 +70,7 @@ int cSslContext::Initialize(bool a_IsClient, const SharedPtr<cCtrDrbgContext> &
|
||||
// so they're disabled until someone needs them
|
||||
ssl_set_dbg(&m_Ssl, &SSLDebugMessage, this);
|
||||
ssl_set_verify(&m_Ssl, &SSLVerifyCert, this);
|
||||
*/
|
||||
//*/
|
||||
|
||||
/*
|
||||
// Set ciphersuite to the easiest one to decode, so that the connection can be wireshark-decoded:
|
||||
|
Loading…
x
Reference in New Issue
Block a user