From 81f2a9e99d7716ac7ab8158cf131424c82866f74 Mon Sep 17 00:00:00 2001 From: Benau Date: Fri, 7 Sep 2018 23:55:01 +0800 Subject: [PATCH] Implement CIDR banning with online id split --- src/config/user_config.hpp | 19 +++-- src/main.cpp | 66 +++++++++++++++ src/network/crypto_nettle.cpp | 2 +- src/network/network_console.cpp | 24 ++++-- src/network/protocols/server_lobby.cpp | 109 ++++++++++++++++++++++--- src/network/protocols/server_lobby.hpp | 14 +++- 6 files changed, 206 insertions(+), 28 deletions(-) diff --git a/src/config/user_config.hpp b/src/config/user_config.hpp index a60836c47..b05a8bed2 100644 --- a/src/config/user_config.hpp +++ b/src/config/user_config.hpp @@ -152,6 +152,10 @@ public: { return m_elements[key]; } + U& at(const T key) + { + return m_elements.at(key); + } }; // MapUserConfigParam typedef MapUserConfigParam UIntToUIntUserConfigParam; typedef MapUserConfigParam StringToUIntUserConfigParam; @@ -765,12 +769,15 @@ namespace UserConfigParams &m_network_group, "Value used to calculate time limit in CTF, which " "is max(3.0, number of players * (time-limit-threshold-ctf + flag-return-timemout / 60.0)) * 60.0," " negative value to disable time limit.")); - PARAM_PREFIX StringToUIntUserConfigParam m_server_ban_list - PARAM_DEFAULT(StringToUIntUserConfigParam("server_ban_list", - "LHS: IP in x.x.x.x format, RHS: online id, if 0 than all players " - "from this IP will be banned.", - { { "0.0.0.0", 0u } } - )); + PARAM_PREFIX StringToUIntUserConfigParam m_server_ip_ban_list + PARAM_DEFAULT(StringToUIntUserConfigParam("server_ip_ban_list", + "LHS: IP in X.X.X.X/Y (CIDR) format, use Y of 32 for a specific ip, " + "RHS: time epoch to expire, if -1 (uint32_t max) than a permanent ban.", + { { "0.0.0.0/0", 0u } })); + PARAM_PREFIX UIntToUIntUserConfigParam m_server_online_id_ban_list + PARAM_DEFAULT(UIntToUIntUserConfigParam("server_online_id_ban_list", + "LHS: online id, RHS: time epoch to expire, if -1 (uint32_t max) than a permanent ban.", + { { 0u, 0u } })); PARAM_PREFIX IntUserConfigParam m_max_ping PARAM_DEFAULT(IntUserConfigParam(300, "max-ping", &m_network_group, "Maximum ping allowed for a player (in ms).")); diff --git a/src/main.cpp b/src/main.cpp index 9f9ffce1d..34c7ac673 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -166,6 +166,7 @@ #include #include #include +#include #include @@ -214,6 +215,7 @@ #include "modes/profile_world.hpp" #include "network/protocols/connect_to_server.hpp" #include "network/protocols/client_lobby.hpp" +#include "network/protocols/server_lobby.hpp" #include "network/game_setup.hpp" #include "network/network_config.hpp" #include "network/network_string.hpp" @@ -2370,6 +2372,70 @@ void runUnitTests() Log::info("UnitTest", "RewindQueue"); RewindQueue::unitTesting(); + Log::info("UnitTest", "IP ban"); + NetworkConfig::get()->unsetNetworking(); + ServerLobby sl; + + UserConfigParams::m_server_ip_ban_list = + { + { "1.2.3.4/32", std::numeric_limits::max() } + }; + sl.updateBanList(); + assert(sl.isBannedForIP(TransportAddress("1.2.3.4"))); + assert(!sl.isBannedForIP(TransportAddress("1.2.3.5"))); + assert(!sl.isBannedForIP(TransportAddress("1.2.3.3"))); + + UserConfigParams::m_server_ip_ban_list = + { + { "1.2.3.4/23", std::numeric_limits::max() } + }; + sl.updateBanList(); + assert(!sl.isBannedForIP(TransportAddress("1.2.1.255"))); + assert(sl.isBannedForIP(TransportAddress("1.2.2.0"))); + assert(sl.isBannedForIP(TransportAddress("1.2.2.3"))); + assert(sl.isBannedForIP(TransportAddress("1.2.2.4"))); + assert(sl.isBannedForIP(TransportAddress("1.2.2.5"))); + assert(sl.isBannedForIP(TransportAddress("1.2.3.3"))); + assert(sl.isBannedForIP(TransportAddress("1.2.3.4"))); + assert(sl.isBannedForIP(TransportAddress("1.2.3.5"))); + assert(sl.isBannedForIP(TransportAddress("1.2.3.255"))); + assert(!sl.isBannedForIP(TransportAddress("1.2.4.0"))); + + UserConfigParams::m_server_ip_ban_list = + { + { "11.12.13.14/22", std::numeric_limits::max() }, + { "12.13.14.15/24", std::numeric_limits::max() }, + { "123.234.56.78/26", std::numeric_limits::max() }, + { "234.123.56.78/25", std::numeric_limits::max() }, + // Test for overlap handling + { "12.13.14.23/32", std::numeric_limits::max() }, + { "12.13.14.255/32", std::numeric_limits::max() } + }; + sl.updateBanList(); + assert(!sl.isBannedForIP(TransportAddress("11.12.11.255"))); + assert(sl.isBannedForIP(TransportAddress("11.12.12.0"))); + assert(sl.isBannedForIP(TransportAddress("11.12.13.14"))); + assert(sl.isBannedForIP(TransportAddress("11.12.15.255"))); + assert(!sl.isBannedForIP(TransportAddress("11.12.16.0"))); + + assert(!sl.isBannedForIP(TransportAddress("12.13.13.255"))); + assert(sl.isBannedForIP(TransportAddress("12.13.14.0"))); + assert(sl.isBannedForIP(TransportAddress("12.13.14.15"))); + assert(sl.isBannedForIP(TransportAddress("12.13.14.255"))); + assert(!sl.isBannedForIP(TransportAddress("12.13.15.0"))); + + assert(!sl.isBannedForIP(TransportAddress("123.234.56.63"))); + assert(sl.isBannedForIP(TransportAddress("123.234.56.64"))); + assert(sl.isBannedForIP(TransportAddress("123.234.56.78"))); + assert(sl.isBannedForIP(TransportAddress("123.234.56.127"))); + assert(!sl.isBannedForIP(TransportAddress("123.234.56.128"))); + + assert(!sl.isBannedForIP(TransportAddress("234.123.55.255"))); + assert(sl.isBannedForIP(TransportAddress("234.123.56.0"))); + assert(sl.isBannedForIP(TransportAddress("234.123.56.78"))); + assert(sl.isBannedForIP(TransportAddress("234.123.56.127"))); + assert(!sl.isBannedForIP(TransportAddress("234.123.56.128"))); + Log::info("UnitTest", "====================="); Log::info("UnitTest", "Testing successful "); Log::info("UnitTest", "====================="); diff --git a/src/network/crypto_nettle.cpp b/src/network/crypto_nettle.cpp index e2a1d5356..d8327fd7a 100644 --- a/src/network/crypto_nettle.cpp +++ b/src/network/crypto_nettle.cpp @@ -25,7 +25,7 @@ #include #include -#if NETTLE_VERSION_MAJOR > 3 || +#if NETTLE_VERSION_MAJOR > 3 || \ (NETTLE_VERSION_MAJOR == 3 && NETTLE_VERSION_MINOR > 3) typedef const char* NETTLE_CONST_CHAR; typedef char* NETTLE_CHAR; diff --git a/src/network/network_console.cpp b/src/network/network_console.cpp index 4612d1823..6f8f427af 100644 --- a/src/network/network_console.cpp +++ b/src/network/network_console.cpp @@ -27,6 +27,7 @@ #include "main_loop.hpp" #include +#include namespace NetworkConsole { @@ -87,9 +88,13 @@ void mainLoop(STKHost* host) if (peer) { peer->kick(); - UserConfigParams::m_server_ban_list - [peer->getAddress().toString(false/*show_port*/)] = 0; - LobbyProtocol::get()->updateBanList(); + // ATM use permanently ban + auto sl = LobbyProtocol::get(); + auto lock = sl->acquireConnectionMutex(); + UserConfigParams::m_server_ip_ban_list + [peer->getAddress().toString(false/*show_port*/) + "/32"] + = std::numeric_limits::max(); + sl->updateBanList(); } else std::cout << "Unknown host id: " << number << std::endl; @@ -107,11 +112,18 @@ void mainLoop(STKHost* host) } else if (str == "listban") { - for (auto& ban : UserConfigParams::m_server_ban_list) + for (auto& ban : UserConfigParams::m_server_ip_ban_list) { - if (ban.first == "0.0.0.0") + if (ban.first == "0.0.0.0/0") continue; - std::cout << "IP: " << ban.first << " online id: " << + std::cout << "IP: " << ban.first << ", expire at: " << + ban.second << std::endl; + } + for (auto& ban : UserConfigParams::m_server_online_id_ban_list) + { + if (ban.first == 0) + continue; + std::cout << "Online id: " << ban.first << ", expire at: " << ban.second << std::endl; } } diff --git a/src/network/protocols/server_lobby.cpp b/src/network/protocols/server_lobby.cpp index 8d57dcfe1..5bd9294aa 100644 --- a/src/network/protocols/server_lobby.cpp +++ b/src/network/protocols/server_lobby.cpp @@ -114,7 +114,8 @@ ServerLobby::ServerLobby() : LobbyProtocol(NULL) */ ServerLobby::~ServerLobby() { - if (NetworkConfig::get()->isWAN()) + if (NetworkConfig::get()->isNetworking() && + NetworkConfig::get()->isWAN()) { unregisterServer(true/*now*/); } @@ -1406,12 +1407,13 @@ void ServerLobby::connectionRequested(Event* event) online_id = data.getUInt32(); encrypted_size = data.getUInt32(); - bool is_banned = false; - auto ret = m_ban_list.find(peer->getAddress().getIP()); - if (ret != m_ban_list.end()) + bool is_banned = isBannedForIP(peer->getAddress()); + if (online_id != 0 && !is_banned) { - // Ban all players if ban list is zero or compare it with online id - if (ret->second == 0 || (online_id != 0 && ret->second == online_id)) + if (m_online_id_ban_list.find(online_id) != + m_online_id_ban_list.end() && + (uint32_t)StkTime::getTimeSinceEpoch() < + m_online_id_ban_list.at(online_id)) { is_banned = true; } @@ -1447,7 +1449,7 @@ void ServerLobby::connectionRequested(Event* event) // Reject non-valiated player joinning if WAN server and not disabled // encforement of validation, unless it's player from localhost or lan // And no duplicated online id or split screen players in ranked server - if ((encrypted_size == 0 && + if (((encrypted_size == 0 || online_id == 0) && !(peer->getAddress().isPublicAddressLocalhost() || peer->getAddress().isLAN()) && NetworkConfig::get()->isWAN() && @@ -1992,14 +1994,68 @@ void ServerLobby::playerFinishedResult(Event *event) //----------------------------------------------------------------------------- void ServerLobby::updateBanList() { - std::lock_guard lock(m_connection_mutex); - m_ban_list.clear(); - for (auto& ban : UserConfigParams::m_server_ban_list) + m_ip_ban_list.clear(); + m_online_id_ban_list.clear(); + + for (auto& ban : UserConfigParams::m_server_ip_ban_list) { - if (ban.first == "0.0.0.0") + if (ban.first == "0.0.0.0/0" || + (uint32_t)StkTime::getTimeSinceEpoch() > ban.second) continue; - m_ban_list[TransportAddress(ban.first).getIP()] = ban.second; + uint32_t netbits = 0; + std::vector ip_and_netbits = + StringUtils::split(ban.first, '/'); + if (ip_and_netbits.size() != 2 || + !StringUtils::fromString(ip_and_netbits[1], netbits) || + netbits > 32) + { + Log::error("STKHost", "Wrong CIDR: %s", ban.first.c_str()); + continue; + } + TransportAddress addr(ip_and_netbits[0]); + if (addr.getIP() == 0) + { + Log::error("STKHost", "Wrong CIDR: %s", ban.first.c_str()); + continue; + } + uint32_t mask = ~((1 << (32 - netbits)) - 1); + uint32_t ip_start = addr.getIP() & mask; + uint32_t ip_end = (addr.getIP() & mask) | ~mask; + m_ip_ban_list[ip_start] = + std::make_tuple(ip_end, ban.first, ban.second); } + + std::map final_ip_ban_list; + for (auto it = m_ip_ban_list.begin(); + it != m_ip_ban_list.end();) + { + auto next_itr = std::next(it); + if (next_itr != m_ip_ban_list.end() && + next_itr->first <= std::get<0>(it->second)) + { + Log::warn("ServerLobby", "%s overlaps %s, removing the first one.", + std::get<1>(next_itr->second).c_str(), + std::get<1>(it->second).c_str()); + m_ip_ban_list.erase(next_itr); + continue; + } + final_ip_ban_list[std::get<1>(it->second)] = + UserConfigParams::m_server_ip_ban_list.at(std::get<1>(it->second)); + it++; + } + UserConfigParams::m_server_ip_ban_list = final_ip_ban_list; + + std::map final_online_id_ban_list; + for (auto& ban : UserConfigParams::m_server_online_id_ban_list) + { + if (ban.first == 0 || + (uint32_t)StkTime::getTimeSinceEpoch() > ban.second) + continue; + m_online_id_ban_list[ban.first] = ban.second; + final_online_id_ban_list[ban.first] = + UserConfigParams::m_server_online_id_ban_list.at(ban.first); + } + UserConfigParams::m_server_online_id_ban_list = final_online_id_ban_list; } // updateBanList //----------------------------------------------------------------------------- @@ -2296,3 +2352,32 @@ void ServerLobby::resetServer() delete server_info; setup(); } // resetServer + +//----------------------------------------------------------------------------- +bool ServerLobby::isBannedForIP(const TransportAddress& addr) const +{ + uint32_t ip_decimal = addr.getIP(); + auto lb = m_ip_ban_list.lower_bound(addr.getIP()); + bool is_banned = false; + if (lb != m_ip_ban_list.end() && ip_decimal >= lb->first/*ip_start*/) + { + if (ip_decimal <= std::get<0>(lb->second)/*ip_end*/ && + (uint32_t)StkTime::getTimeSinceEpoch() < std::get<2>(lb->second)) + is_banned = true; + } + else if (lb != m_ip_ban_list.begin()) + { + lb--; + if (ip_decimal>= lb->first/*ip_start*/ && + ip_decimal <= std::get<0>(lb->second)/*ip_end*/ && + (uint32_t)StkTime::getTimeSinceEpoch() < std::get<2>(lb->second)) + is_banned = true; + } + if (is_banned) + { + Log::info("ServerLobby", "%s is banned by CIDR %s", + addr.toString(false/*show_port*/).c_str(), + std::get<1>(lb->second).c_str()); + } + return is_banned; +} // isBannedForIP diff --git a/src/network/protocols/server_lobby.hpp b/src/network/protocols/server_lobby.hpp index afb531d19..47f026623 100644 --- a/src/network/protocols/server_lobby.hpp +++ b/src/network/protocols/server_lobby.hpp @@ -100,10 +100,15 @@ private: /** Lock this mutex whenever a client is connect / disconnect or * starting race. */ - std::mutex m_connection_mutex; + mutable std::mutex m_connection_mutex; - /** Ban list ip (in decimal) with online user id. */ - std::map m_ban_list; + /** Ban list of ip ranges. */ + std::map > + m_ip_ban_list; + + /** Ban list of online user id. */ + std::map m_online_id_ban_list; TransportAddress m_server_address; @@ -257,12 +262,15 @@ public: void finishedLoadingWorld() OVERRIDE; ServerState getCurrentState() const { return m_state.load(); } void updateBanList(); + std::unique_lock acquireConnectionMutex() const + { return std::unique_lock(m_connection_mutex); } bool waitingForPlayers() const; uint32_t getWaitingPlayersCount() const { return m_waiting_players_counts.load(); } virtual bool allPlayersReady() const OVERRIDE { return m_state.load() >= WAIT_FOR_RACE_STARTED; } virtual bool isRacing() const OVERRIDE { return m_state.load() == RACING; } + bool isBannedForIP(const TransportAddress& addr) const; bool allowJoinedPlayersWaiting() const; }; // class ServerLobby