diff --git a/src/Bindings/LuaState.cpp b/src/Bindings/LuaState.cpp index f11e74e75..ffe1fe4ac 100644 --- a/src/Bindings/LuaState.cpp +++ b/src/Bindings/LuaState.cpp @@ -20,6 +20,10 @@ extern "C" #include "../Entities/Entity.h" #include "../BlockEntities/BlockEntity.h" + + + + // fwd: "SQLite/lsqlite3.c" extern "C" { @@ -39,6 +43,10 @@ extern "C" const cLuaState::cRet cLuaState::Return = {}; +/** Each Lua state stores a pointer to its creating cLuaState in Lua globals, under this name. +This way any cLuaState can reference the main cLuaState's TrackedCallbacks, mutex etc. */ +static const char * g_CanonLuaStateGlobalName = "_CuberiteInternal_CanonLuaState"; + @@ -113,6 +121,72 @@ cLuaStateTracker & cLuaStateTracker::Get(void) +//////////////////////////////////////////////////////////////////////////////// +// cLuaState::cCallback: + +bool cLuaState::cCallback::RefStack(cLuaState & a_LuaState, int a_StackPos) +{ + // Check if the stack contains a function: + if (!lua_isfunction(a_LuaState, a_StackPos)) + { + return false; + } + + // Clear any previous callback: + Clear(); + + // Add self to LuaState's callback-tracking: + a_LuaState.TrackCallback(*this); + + // Store the new callback: + cCSLock Lock(m_CS); + m_Ref.RefStack(a_LuaState, a_StackPos); + return true; +} + + + + + +void cLuaState::cCallback::Clear(void) +{ + // Free the callback reference: + lua_State * luaState = nullptr; + { + cCSLock Lock(m_CS); + if (!m_Ref.IsValid()) + { + return; + } + luaState = m_Ref.GetLuaState(); + m_Ref.UnRef(); + } + + // Remove from LuaState's callback-tracking: + cLuaState(luaState).UntrackCallback(*this); +} + + + + + +void cLuaState::cCallback::Invalidate(void) +{ + cCSLock Lock(m_CS); + if (!m_Ref.IsValid()) + { + LOGD("%s: Invalidating an already invalid callback at %p, this should not happen", + __FUNCTION__, reinterpret_cast(this) + ); + return; + } + m_Ref.UnRef(); +} + + + + + //////////////////////////////////////////////////////////////////////////////// // cLuaState: @@ -170,6 +244,10 @@ void cLuaState::Create(void) luaL_openlibs(m_LuaState); m_IsOwned = true; cLuaStateTracker::Add(*this); + + // Add the CanonLuaState value into the Lua state, so that we can get it from anywhere: + lua_pushlightuserdata(m_LuaState, reinterpret_cast(this)); + lua_setglobal(m_LuaState, g_CanonLuaStateGlobalName); } @@ -206,6 +284,16 @@ void cLuaState::Close(void) Detach(); return; } + + // Invalidate all callbacks: + { + cCSLock Lock(m_CSTrackedCallbacks); + for (auto & c: m_TrackedCallbacks) + { + c->Invalidate(); + } + } + cLuaStateTracker::Del(*this); lua_close(m_LuaState); m_LuaState = nullptr; @@ -871,6 +959,15 @@ bool cLuaState::GetStackValue(int a_StackPos, cRef & a_Ref) +bool cLuaState::GetStackValue(int a_StackPos, cCallback & a_Callback) +{ + return a_Callback.RefStack(*this, a_StackPos); +} + + + + + bool cLuaState::GetStackValue(int a_StackPos, double & a_ReturnedVal) { if (lua_isnumber(m_LuaState, a_StackPos)) @@ -1701,6 +1798,52 @@ int cLuaState::BreakIntoDebugger(lua_State * a_LuaState) +void cLuaState::TrackCallback(cCallback & a_Callback) +{ + // Get the CanonLuaState global from Lua: + auto cb = WalkToNamedGlobal(g_CanonLuaStateGlobalName); + if (!cb.IsValid()) + { + LOGWARNING("%s: Lua state %p has invalid CanonLuaState!", __FUNCTION__, reinterpret_cast(m_LuaState)); + return; + } + auto & canonState = *reinterpret_cast(lua_touserdata(m_LuaState, -1)); + + // Add the callback: + cCSLock Lock(canonState.m_CSTrackedCallbacks); + canonState.m_TrackedCallbacks.push_back(&a_Callback); +} + + + + + +void cLuaState::UntrackCallback(cCallback & a_Callback) +{ + // Get the CanonLuaState global from Lua: + auto cb = WalkToNamedGlobal(g_CanonLuaStateGlobalName); + if (!cb.IsValid()) + { + LOGWARNING("%s: Lua state %p has invalid CanonLuaState!", __FUNCTION__, reinterpret_cast(m_LuaState)); + return; + } + auto & canonState = *reinterpret_cast(lua_touserdata(m_LuaState, -1)); + + // Remove the callback: + cCSLock Lock(canonState.m_CSTrackedCallbacks); + auto & trackedCallbacks = canonState.m_TrackedCallbacks; + trackedCallbacks.erase(std::remove_if(trackedCallbacks.begin(), trackedCallbacks.end(), + [&a_Callback](cCallback * a_StoredCallback) + { + return (a_StoredCallback == &a_Callback); + } + )); +} + + + + + //////////////////////////////////////////////////////////////////////////////// // cLuaState::cRef: @@ -1756,7 +1899,7 @@ void cLuaState::cRef::RefStack(cLuaState & a_LuaState, int a_StackPos) { UnRef(); } - m_LuaState = &a_LuaState; + m_LuaState = a_LuaState; lua_pushvalue(a_LuaState, a_StackPos); // Push a copy of the value at a_StackPos onto the stack m_Ref = luaL_ref(a_LuaState, LUA_REGISTRYINDEX); } @@ -1767,11 +1910,9 @@ void cLuaState::cRef::RefStack(cLuaState & a_LuaState, int a_StackPos) void cLuaState::cRef::UnRef(void) { - ASSERT(m_LuaState->IsValid()); // The reference should be destroyed before destroying the LuaState - if (IsValid()) { - luaL_unref(*m_LuaState, LUA_REGISTRYINDEX, m_Ref); + luaL_unref(m_LuaState, LUA_REGISTRYINDEX, m_Ref); } m_LuaState = nullptr; m_Ref = LUA_REFNIL; diff --git a/src/Bindings/LuaState.h b/src/Bindings/LuaState.h index b795a80d4..bce08f0fe 100644 --- a/src/Bindings/LuaState.h +++ b/src/Bindings/LuaState.h @@ -80,8 +80,11 @@ public: /** Allows to use this class wherever an int (i. e. ref) is to be used */ explicit operator int(void) const { return m_Ref; } + /** Returns the Lua state associated with the value. */ + lua_State * GetLuaState(void) { return m_LuaState; } + protected: - cLuaState * m_LuaState; + lua_State * m_LuaState; int m_Ref; // Remove the copy-constructor: @@ -112,6 +115,69 @@ public: } ; + /** Represents a callback to Lua that C++ code can call. + Is thread-safe and unload-safe. + When the Lua state is unloaded, the callback returns an error instead of calling into non-existent code. + To receive the callback instance from the Lua side, use RefStack() or (better) cLuaState::GetStackValue(). + Note that instances of this class are tracked in the canon LuaState instance, so that they can be invalidated + when the LuaState is unloaded; due to multithreading issues they can only be tracked by-ptr, which has + an unfortunate effect of disabling the copy and move constructors. */ + class cCallback + { + public: + /** Creates an unbound callback instance. */ + cCallback(void) = default; + + ~cCallback() + { + Clear(); + } + + /** Calls the Lua callback, if still available. + Returns true if callback has been called. + Returns false if the Lua state isn't valid anymore. */ + template + bool Call(Args &&... args) + { + cCSLock Lock(m_CS); + if (!m_Ref.IsValid()) + { + return false; + } + cLuaState(m_Ref.GetLuaState()).Call(m_Ref, std::forward(args)...); + return true; + } + + /** Set the contained callback to the function in the specified Lua state's stack position. + If a callback has been previously contained, it is freed first. */ + bool RefStack(cLuaState & a_LuaState, int a_StackPos); + + /** Frees the contained callback, if any. */ + void Clear(void); + + protected: + friend class cLuaState; + + /** The mutex protecting m_Ref against multithreaded access */ + cCriticalSection m_CS; + + /** Reference to the Lua callback */ + cRef m_Ref; + + + /** Invalidates the callback, without untracking it from the cLuaState. + Called only from cLuaState when closing the Lua state. */ + void Invalidate(void); + + /** This class cannot be copied, because it is tracked in the LuaState by-ptr. */ + cCallback(const cCallback &) = delete; + + /** This class cannot be moved, because it is tracked in the LuaState by-ptr. */ + cCallback(cCallback &&) = delete; + }; + typedef SharedPtr cCallbackPtr; + + /** A dummy class that's used only to delimit function args from return values for cLuaState::Call() */ class cRet { @@ -268,6 +334,7 @@ public: bool GetStackValue(int a_StackPos, bool & a_Value); bool GetStackValue(int a_StackPos, cPluginManager::CommandResult & a_Result); bool GetStackValue(int a_StackPos, cRef & a_Ref); + bool GetStackValue(int a_StackPos, cCallback & a_Ref); bool GetStackValue(int a_StackPos, double & a_Value); bool GetStackValue(int a_StackPos, eBlockFace & a_Value); bool GetStackValue(int a_StackPos, eWeather & a_Value); @@ -453,8 +520,7 @@ protected: bool m_IsOwned; /** The subsystem name is used for reporting errors to the console, it is either "plugin %s" or "LuaScript" - whatever is given to the constructor - */ + whatever is given to the constructor. */ AString m_SubsystemName; /** Name of the currently pushed function (for the Push / Call chain) */ @@ -463,6 +529,15 @@ protected: /** Number of arguments currently pushed (for the Push / Call chain) */ int m_NumCurrentFunctionArgs; + /** The tracked callbacks. + This object will invalidate all of these when it is about to be closed. + Protected against multithreaded access by m_CSTrackedCallbacks. */ + std::vector m_TrackedCallbacks; + + /** Protects m_TrackedTallbacks against multithreaded access. */ + cCriticalSection m_CSTrackedCallbacks; + + /** Variadic template terminator: If there's nothing more to push / pop, just call the function. Note that there are no return values either, because those are prefixed by a cRet value, so the arg list is never empty. */ bool PushCallPop(void) @@ -545,6 +620,14 @@ protected: /** Tries to break into the MobDebug debugger, if it is installed. */ static int BreakIntoDebugger(lua_State * a_LuaState); + + /** Adds the specified callback to tracking. + The callback will be invalidated when this Lua state is about to be closed. */ + void TrackCallback(cCallback & a_Callback); + + /** Removes the specified callback from tracking. + The callback will no longer be invalidated when this Lua state is about to be closed. */ + void UntrackCallback(cCallback & a_Callback); } ;