1
0

Fix cUrlClient leak (#4125)

Fixes #4040
* The TCP connection is now shutdown after OnBodyFinished
* Any open connections are closed when cNetworkSingleton::Terminate() is called.
* Removed ownership cycles in cUrlClientRequest
* Added a check to the test to ensure there are no leaks.
This commit is contained in:
peterbell10 2018-02-20 17:08:46 +00:00 committed by GitHub
parent cf75d7b2c5
commit 1ea36298d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 41 deletions

View File

@ -54,12 +54,10 @@ public:
m_Callbacks->OnError(a_ErrorMessage); m_Callbacks->OnError(a_ErrorMessage);
// Terminate the request's TCP link: // Terminate the request's TCP link:
auto link = m_Link; if (auto link = m_Link.lock())
if (link != nullptr)
{ {
link->Close(); link->Close();
} }
m_Self.reset();
} }
@ -76,7 +74,7 @@ public:
{ {
return nullptr; return nullptr;
} }
cX509CertPtr cert; cX509CertPtr cert = std::make_shared<cX509Cert>();
if (!cert->Parse(itr->second.data(), itr->second.size())) if (!cert->Parse(itr->second.data(), itr->second.size()))
{ {
LOGD("OwnCert failed to parse"); LOGD("OwnCert failed to parse");
@ -92,7 +90,7 @@ public:
{ {
return nullptr; return nullptr;
} }
cCryptoKeyPtr key; cCryptoKeyPtr key = std::make_shared<cCryptoKey>();
auto passItr = m_Options.find("OwnPrivKeyPassword"); auto passItr = m_Options.find("OwnPrivKeyPassword");
auto pass = (passItr == m_Options.end()) ? AString() : passItr->second; auto pass = (passItr == m_Options.end()) ? AString() : passItr->second;
if (!key->ParsePrivate(itr->second.data(), itr->second.size(), pass)) if (!key->ParsePrivate(itr->second.data(), itr->second.size(), pass))
@ -126,15 +124,15 @@ protected:
/** Extra options to be used for the request. */ /** Extra options to be used for the request. */
AStringMap m_Options; AStringMap m_Options;
/** SharedPtr to self, so that this object can keep itself alive for as long as it needs, /** weak_ptr to self, so that this object can keep itself alive as needed by calling lock(),
and pass self as callbacks to cNetwork functions. */ and pass self as callbacks to cNetwork functions. */
std::shared_ptr<cUrlClientRequest> m_Self; std::weak_ptr<cUrlClientRequest> m_Self;
/** The handler that "talks" the protocol specified in m_UrlScheme, handles all the sending and receiving. */ /** The handler that "talks" the protocol specified in m_UrlScheme, handles all the sending and receiving. */
std::shared_ptr<cSchemeHandler> m_SchemeHandler; std::shared_ptr<cSchemeHandler> m_SchemeHandler;
/** The link handling the request. */ /** The link handling the request. */
cTCPLinkPtr m_Link; std::weak_ptr<cTCPLink> m_Link;
/** The number of redirect attempts that will still be followed. /** The number of redirect attempts that will still be followed.
If the response specifies a redirect and this is nonzero, the redirect is followed. If the response specifies a redirect and this is nonzero, the redirect is followed.
@ -171,7 +169,6 @@ protected:
virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
{ {
m_Callbacks->OnError(Printf("Network error %d (%s)", a_ErrorCode, a_ErrorMsg.c_str())); m_Callbacks->OnError(Printf("Network error %d (%s)", a_ErrorCode, a_ErrorMsg.c_str()));
m_Self.reset();
} }
@ -332,8 +329,8 @@ public:
// cHTTPResponseParser::cCallbacks overrides: // cHTTPResponseParser::cCallbacks overrides:
virtual void OnError(const AString & a_ErrorDescription) override virtual void OnError(const AString & a_ErrorDescription) override
{ {
m_ParentRequest.CallErrorCallback(a_ErrorDescription);
m_Link = nullptr; m_Link = nullptr;
m_ParentRequest.CallErrorCallback(a_ErrorDescription);
} }
@ -430,6 +427,8 @@ public:
else else
{ {
m_ParentRequest.GetCallbacks().OnBodyFinished(); m_ParentRequest.GetCallbacks().OnBodyFinished();
// Finished recieving data, shutdown the link
m_Link->Shutdown();
} }
} }
@ -493,15 +492,21 @@ void cUrlClientRequest::RedirectTo(const AString & a_RedirectUrl)
return; return;
} }
// Keep ourself alive while the link drops ownership
auto Self = m_Self.lock();
// Do the actual redirect: // Do the actual redirect:
m_Link->Close(); if (auto Link = m_Link.lock())
{
Link->Close();
}
m_Url = a_RedirectUrl; m_Url = a_RedirectUrl;
m_NumRemainingRedirects = m_NumRemainingRedirects - 1; m_NumRemainingRedirects = m_NumRemainingRedirects - 1;
auto res = DoRequest(m_Self); auto res = DoRequest(Self);
if (!res.first) if (!res.first)
{ {
m_Callbacks->OnError(Printf("Redirection failed: %s", res.second.c_str())); m_Callbacks->OnError(Printf("Redirection failed: %s", res.second.c_str()));
return;
} }
} }
@ -560,9 +565,6 @@ void cUrlClientRequest::OnRemoteClosed()
{ {
handler->OnRemoteClosed(); handler->OnRemoteClosed();
} }
// Let ourselves be deleted
m_Self.reset();
} }
@ -590,7 +592,8 @@ std::pair<bool, AString> cUrlClientRequest::DoRequest(std::shared_ptr<cUrlClient
return std::make_pair(false, Printf("Unknown Url scheme: %s", m_UrlScheme.c_str())); return std::make_pair(false, Printf("Unknown Url scheme: %s", m_UrlScheme.c_str()));
} }
if (!cNetwork::Connect(m_UrlHost, m_UrlPort, m_Self, m_Self)) // Connect and transfer ownership to the link
if (!cNetwork::Connect(m_UrlHost, m_UrlPort, a_Self, a_Self))
{ {
return std::make_pair(false, "Network connection failed"); return std::make_pair(false, "Network connection failed");
} }

View File

@ -6,6 +6,7 @@
#include "Globals.h" #include "Globals.h"
#include "NetworkSingleton.h" #include "NetworkSingleton.h"
#include "OSSupport/Network.h"
#include <event2/thread.h> #include <event2/thread.h>
#include <event2/bufferevent.h> #include <event2/bufferevent.h>
#include <event2/listener.h> #include <event2/listener.h>
@ -102,11 +103,25 @@ void cNetworkSingleton::Terminate(void)
event_base_loopbreak(m_EventBase); event_base_loopbreak(m_EventBase);
m_EventLoopThread.join(); m_EventLoopThread.join();
// Remove all objects: // Close all open connections:
{ {
cCSLock Lock(m_CS); cCSLock Lock(m_CS);
m_Connections.clear(); // Must take copies because Close will modify lists
m_Servers.clear(); auto Conns = m_Connections;
for (auto & Conn : Conns)
{
Conn->Close();
}
auto Servers = m_Servers;
for (auto & Server : Servers)
{
Server->Close();
}
// Closed handles should have removed themself
ASSERT(m_Connections.empty());
ASSERT(m_Servers.empty());
} }
// Free the underlying LibEvent objects: // Free the underlying LibEvent objects:
@ -167,7 +182,7 @@ void cNetworkSingleton::SignalizeStartup(evutil_socket_t a_Socket, short a_Event
void cNetworkSingleton::AddLink(cTCPLinkImplPtr a_Link) void cNetworkSingleton::AddLink(cTCPLinkPtr a_Link)
{ {
ASSERT(!m_HasTerminated); ASSERT(!m_HasTerminated);
cCSLock Lock(m_CS); cCSLock Lock(m_CS);
@ -178,7 +193,7 @@ void cNetworkSingleton::AddLink(cTCPLinkImplPtr a_Link)
void cNetworkSingleton::RemoveLink(const cTCPLinkImpl * a_Link) void cNetworkSingleton::RemoveLink(const cTCPLink * a_Link)
{ {
ASSERT(!m_HasTerminated); ASSERT(!m_HasTerminated);
cCSLock Lock(m_CS); cCSLock Lock(m_CS);
@ -196,7 +211,7 @@ void cNetworkSingleton::RemoveLink(const cTCPLinkImpl * a_Link)
void cNetworkSingleton::AddServer(cServerHandleImplPtr a_Server) void cNetworkSingleton::AddServer(cServerHandlePtr a_Server)
{ {
ASSERT(!m_HasTerminated); ASSERT(!m_HasTerminated);
cCSLock Lock(m_CS); cCSLock Lock(m_CS);
@ -207,7 +222,7 @@ void cNetworkSingleton::AddServer(cServerHandleImplPtr a_Server)
void cNetworkSingleton::RemoveServer(const cServerHandleImpl * a_Server) void cNetworkSingleton::RemoveServer(const cServerHandle * a_Server)
{ {
ASSERT(!m_HasTerminated); ASSERT(!m_HasTerminated);
cCSLock Lock(m_CS); cCSLock Lock(m_CS);

View File

@ -24,12 +24,12 @@
// fwd: // fwd:
struct event_base; struct event_base;
class cTCPLinkImpl; class cTCPLink;
typedef std::shared_ptr<cTCPLinkImpl> cTCPLinkImplPtr; typedef std::shared_ptr<cTCPLink> cTCPLinkPtr;
typedef std::vector<cTCPLinkImplPtr> cTCPLinkImplPtrs; typedef std::vector<cTCPLinkPtr> cTCPLinkPtrs;
class cServerHandleImpl; class cServerHandle;
typedef std::shared_ptr<cServerHandleImpl> cServerHandleImplPtr; typedef std::shared_ptr<cServerHandle> cServerHandlePtr;
typedef std::vector<cServerHandleImplPtr> cServerHandleImplPtrs; typedef std::vector<cServerHandlePtr> cServerHandlePtrs;
@ -61,20 +61,20 @@ public:
/** Adds the specified link to m_Connections. /** Adds the specified link to m_Connections.
Used by the underlying link implementation when a new link is created. */ Used by the underlying link implementation when a new link is created. */
void AddLink(cTCPLinkImplPtr a_Link); void AddLink(cTCPLinkPtr a_Link);
/** Removes the specified link from m_Connections. /** Removes the specified link from m_Connections.
Used by the underlying link implementation when the link is closed / errored. */ Used by the underlying link implementation when the link is closed / errored. */
void RemoveLink(const cTCPLinkImpl * a_Link); void RemoveLink(const cTCPLink * a_Link);
/** Adds the specified link to m_Servers. /** Adds the specified link to m_Servers.
Used by the underlying server handle implementation when a new listening server is created. Used by the underlying server handle implementation when a new listening server is created.
Only servers that succeed in listening are added. */ Only servers that succeed in listening are added. */
void AddServer(cServerHandleImplPtr a_Server); void AddServer(cServerHandlePtr a_Server);
/** Removes the specified server from m_Servers. /** Removes the specified server from m_Servers.
Used by the underlying server handle implementation when the server is closed. */ Used by the underlying server handle implementation when the server is closed. */
void RemoveServer(const cServerHandleImpl * a_Server); void RemoveServer(const cServerHandle * a_Server);
protected: protected:
@ -82,10 +82,10 @@ protected:
event_base * m_EventBase; event_base * m_EventBase;
/** Container for all client connections, including ones with pending-connect. */ /** Container for all client connections, including ones with pending-connect. */
cTCPLinkImplPtrs m_Connections; cTCPLinkPtrs m_Connections;
/** Container for all servers that are currently active. */ /** Container for all servers that are currently active. */
cServerHandleImplPtrs m_Servers; cServerHandlePtrs m_Servers;
/** Mutex protecting all containers against multithreaded access. */ /** Mutex protecting all containers against multithreaded access. */
cCriticalSection m_CS; cCriticalSection m_CS;

View File

@ -6,6 +6,11 @@
namespace
{
/** Track number of cCallbacks instances alive. */
std::atomic<int> g_ActiveCallbacks{ 0 };
/** Simple callbacks that dump the events to the console and signalize a cEvent when the request is finished. */ /** Simple callbacks that dump the events to the console and signalize a cEvent when the request is finished. */
class cCallbacks: class cCallbacks:
@ -15,12 +20,14 @@ public:
cCallbacks(cEvent & a_Event): cCallbacks(cEvent & a_Event):
m_Event(a_Event) m_Event(a_Event)
{ {
++g_ActiveCallbacks;
LOGD("Created a cCallbacks instance at %p", reinterpret_cast<void *>(this)); LOGD("Created a cCallbacks instance at %p", reinterpret_cast<void *>(this));
} }
virtual ~cCallbacks() override virtual ~cCallbacks() override
{ {
--g_ActiveCallbacks;
LOGD("Deleting the cCallbacks instance at %p", reinterpret_cast<void *>(this)); LOGD("Deleting the cCallbacks instance at %p", reinterpret_cast<void *>(this));
} }
@ -102,7 +109,7 @@ protected:
static int TestRequest1() int TestRequest1()
{ {
LOG("Running test 1"); LOG("Running test 1");
cEvent evtFinished; cEvent evtFinished;
@ -126,7 +133,7 @@ static int TestRequest1()
static int TestRequest2() int TestRequest2()
{ {
LOG("Running test 2"); LOG("Running test 2");
cEvent evtFinished; cEvent evtFinished;
@ -148,7 +155,7 @@ static int TestRequest2()
static int TestRequest3() int TestRequest3()
{ {
LOG("Running test 3"); LOG("Running test 3");
cEvent evtFinished; cEvent evtFinished;
@ -172,7 +179,7 @@ static int TestRequest3()
static int TestRequest4() int TestRequest4()
{ {
LOG("Running test 4"); LOG("Running test 4");
cEvent evtFinished; cEvent evtFinished;
@ -194,7 +201,7 @@ static int TestRequest4()
static int TestRequests() int TestRequests()
{ {
std::function<int(void)> tests[] = std::function<int(void)> tests[] =
{ {
@ -215,6 +222,8 @@ static int TestRequests()
return 0; return 0;
} }
} // namespace (anonymous)
@ -231,6 +240,11 @@ int main()
LOGD("Terminating cNetwork..."); LOGD("Terminating cNetwork...");
cNetworkSingleton::Get().Terminate(); cNetworkSingleton::Get().Terminate();
// No leaked callback instances
LOGD("cCallback instances still alive: %d", g_ActiveCallbacks.load());
assert_test(g_ActiveCallbacks == 0);
LOGD("cUrlClient test finished"); LOGD("cUrlClient test finished");
return res; return res;