diff --git a/src/network/database_connector.cpp b/src/network/database_connector.cpp index 93844c897..30155495b 100644 --- a/src/network/database_connector.cpp +++ b/src/network/database_connector.cpp @@ -1,6 +1,6 @@ // // SuperTuxKart - a fun racing game with go-kart -// Copyright (C) 2013-2015 SuperTuxKart-Team +// Copyright (C) 2024 SuperTuxKart-Team // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License @@ -29,6 +29,61 @@ #include "utils/log.hpp" //----------------------------------------------------------------------------- +/** Prints "?" to the output stream and saves the Binder object to the + * corresponding BinderCollection so that it can produce bind function later + * When we invoke StringUtils::insertValues with a Binder argument, the + * implementation of insertValues ensures that this function is invoked for + * all Binder arguments from left to right. + */ +std::ostream& operator << (std::ostream& os, const Binder& binder) +{ + os << "?"; + binder.m_collection.lock()->m_binders.emplace_back(std::make_shared(binder)); + return os; +} // operator << (Binder) + +//----------------------------------------------------------------------------- +/** Returns a bind function that should be used inside an easySQLQuery. As the + * Binder objects are already ordered in a correct way, the indices just go + * from 1 upwards. Depending on a particular Binder, we can also bind NULL + * instead of a string. + */ +std::function BinderCollection::getBindFunction() const +{ + auto binders = m_binders; + return [binders](sqlite3_stmt* stmt) + { + int idx = 1; + for (std::shared_ptr binder: binders) + { + if (binder) + { + // SQLITE_TRANSIENT to copy string + if (binder->m_use_null_if_empty && binder->m_value.empty()) + { + if (sqlite3_bind_null(stmt, idx) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind NULL for %s.", + binder->m_name.c_str()); + } + } + else + { + if (sqlite3_bind_text(stmt, idx, binder->m_value.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s as %s.", + binder->m_value.c_str(), binder->m_name.c_str()); + } + } + } + ++idx; + } + }; +} // BinderCollection::getBindFunction + +//----------------------------------------------------------------------------- +/** Opens the database, sets its busy handler and variables related to it. */ void DatabaseConnector::initDatabase() { m_last_poll_db_time = StkTime::getMonoTimeMs(); @@ -83,6 +138,7 @@ void DatabaseConnector::initDatabase() } // initDatabase //----------------------------------------------------------------------------- +/** Closes the database. */ void DatabaseConnector::destroyDatabase() { auto peers = STKHost::get()->getPeers(); @@ -93,12 +149,18 @@ void DatabaseConnector::destroyDatabase() } // destroyDatabase //----------------------------------------------------------------------------- -/** Run simple query with write lock waiting and optional function, this - * function has no callback for the return (if any) by the query. - * Return true if no error occurs +/** Runs simple query with optional bind function. If output vector pointer is + * not (default) nullptr, then the output is written there. + * \param query The SQL query with '?'-placeholders for values to bind. + * \param output The 2D vector for output rows. If nullptr, the query output + * is ignored. + * \param bind_function The function for binding missing values. + * \return True if no error occurs. */ -bool DatabaseConnector::easySQLQuery(const std::string& query, - std::function bind_function) const +bool DatabaseConnector::easySQLQuery( + const std::string& query, std::vector>* output, + std::function bind_function, + std::string null_value) const { if (!m_db) return false; @@ -109,6 +171,24 @@ bool DatabaseConnector::easySQLQuery(const std::string& query, if (bind_function) bind_function(stmt); ret = sqlite3_step(stmt); + if (output) + { + output->clear(); + while (ret == SQLITE_ROW) + { + output->emplace_back(); + int columns = sqlite3_column_count(stmt); + for (int i = 0; i < columns; ++i) + { + const char* value = (char*)sqlite3_column_text(stmt, i); + if (value == nullptr) + output->back().push_back(null_value); + else + output->back().push_back(std::string(value)); + } + ret = sqlite3_step(stmt); + } + } ret = sqlite3_finalize(stmt); if (ret != SQLITE_OK) { @@ -129,38 +209,30 @@ bool DatabaseConnector::easySQLQuery(const std::string& query, } // easySQLQuery //----------------------------------------------------------------------------- -/* Write true to result if table name exists in database. */ +/** Performs a query to determine if a certain table exists. + * \param table The searched name. + * \param result The output value. + */ void DatabaseConnector::checkTableExists(const std::string& table, bool& result) { if (!m_db) return; - sqlite3_stmt* stmt = NULL; + result = false; if (!table.empty()) { std::string query = StringUtils::insertValues( "SELECT count(type) FROM sqlite_master " "WHERE type='table' AND name='%s';", table.c_str()); - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) + int number; + if (StringUtils::fromString(output[0][0], number) && number == 1) { - int number = sqlite3_column_int(stmt, 0); - if (number == 1) - { - Log::info("DatabaseConnector", "Table named %s will be used.", - table.c_str()); - result = true; - } - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); + Log::info("DatabaseConnector", "Table named %s will be used.", + table.c_str()); + result = true; } } } @@ -172,6 +244,12 @@ void DatabaseConnector::checkTableExists(const std::string& table, bool& result) } // checkTableExists //----------------------------------------------------------------------------- +/** Queries the database's IP mapping to determine the country code for an + * address. + * \param addr Queried address. + * \return A country code string if the address is found in the mapping, + * and an empty string otherwise. + */ std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const { if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN()) @@ -185,34 +263,21 @@ std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(), addr.getIP()); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; + cc_code = output[0][0]; } return cc_code; } // ip2Country //----------------------------------------------------------------------------- +/** Queries the database's IPv6 mapping to determine the country code for an + * address. + * \param addr Queried address. + * \return A country code string if the address is found in the mapping, + * and an empty string otherwise. + */ std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const { if (!m_db || !m_ipv6_geolocation_table_exists) @@ -227,34 +292,16 @@ std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(), ipv6.c_str()); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; + cc_code = output[0][0]; } return cc_code; } // ipv62Country // ---------------------------------------------------------------------------- +/** A function invoked within SQLite */ void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc, sqlite3_value** argv) { @@ -274,6 +321,9 @@ void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc, } // ---------------------------------------------------------------------------- +/** A function that checks within SQLite whether an IPv6 address (argv[1]) + * is located within a specified block (argv[0]) of IPv6 addresses. + */ void DatabaseConnector::insideIPv6CIDRSQL(sqlite3_context* context, int argc, sqlite3_value** argv) { @@ -314,6 +364,10 @@ sqlite3_extension_init(sqlite3* db, char** pzErrMsg, */ //----------------------------------------------------------------------------- +/** When a peer disconnects from the server, this function saves to the + * database peer's disconnection time and statistics (ping and packet loss). + * \param peer Disconnecting peer. + */ void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer) { if (m_server_stats_table.empty()) @@ -328,7 +382,11 @@ void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer) } // writeDisconnectInfoTable //----------------------------------------------------------------------------- - +/** Creates necessary tables and views if they don't exist yet in the database. + * As the function is invoked during the server launch, it also updates rows + * related to players whose disconnection time wasn't written, and loads + * last used host id. + */ void DatabaseConnector::initServerStatsTable() { if (!ServerConfig::m_sql_management || !m_db) @@ -356,26 +414,10 @@ void DatabaseConnector::initServerStatsTable() " packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n" ") WITHOUT ROWID;"; std::string query = oss.str(); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - ret = sqlite3_finalize(stmt); - if (ret == SQLITE_OK) - m_server_stats_table = table_name; - else - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } + + if (easySQLQuery(query)) + m_server_stats_table = table_name; + if (m_server_stats_table.empty()) return; @@ -501,31 +543,22 @@ void DatabaseConnector::initServerStatsTable() uint32_t last_host_id = 0; query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;", m_server_stats_table.c_str()); - ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + + std::vector> output; + if (easySQLQuery(query, &output)) { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW && sqlite3_column_type(stmt, 0) != SQLITE_NULL) + if (!output.empty() && !output[0].empty() + && StringUtils::fromString(output[0][0], last_host_id)) { - last_host_id = (unsigned)sqlite3_column_int64(stmt, 0); Log::info("DatabaseConnector", "%u was last server session max host id.", last_host_id); } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - m_server_stats_table = ""; - } } else { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); m_server_stats_table = ""; } + STKHost::get()->setNextHostId(last_host_id); // Update disconnected time (if stk crashed it will not be written) @@ -537,25 +570,41 @@ void DatabaseConnector::initServerStatsTable() } // initServerStatsTable //----------------------------------------------------------------------------- -bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptr reporter_npp, - STKPeer* reporting, std::shared_ptr reporting_npp, - irr::core::stringw& info) +/** Writes a report of one player about another player. + * \param reporter Peer that sends the report. + * \param reporter_npp Player profile that sends the report. + * \param reporting Peer that is reported. + * \param reporting_npp Player profile that is reported. + * \param info The report message. + * \return True if the database query succeeded. + */ +bool DatabaseConnector::writeReport( + STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info) { std::string query; + + std::shared_ptr coll = std::make_shared(); if (ServerConfig::m_ipv6_connection) { query = StringUtils::insertValues( "INSERT INTO %s " "(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, " "info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) " - "VALUES (?, %u, \"%s\", %u, ?, ?, %u, \"%s\", %u, ?);", + "VALUES (%s, %u, \"%s\", %u, %s, %s, %u, \"%s\", %u, %s);", ServerConfig::m_player_reports_table.c_str(), + Binder(coll, ServerConfig::m_server_uid, "server_uid"), !reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0, reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "", reporter_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"), + Binder(coll, StringUtils::wideToUtf8(info), "info"), !reporting->getAddress().isIPv6() ? reporting->getAddress().getIP() : 0, reporting->getAddress().isIPv6() ? reporting->getAddress().toString(false) : "", - reporting_npp->getOnlineId()); + reporting_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name") + ); } else { @@ -563,46 +612,29 @@ bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptrgetAddress().getIP(), reporter_npp->getOnlineId(), - reporting->getAddress().getIP(), reporting_npp->getOnlineId()); + Binder(coll, ServerConfig::m_server_uid, "server_uid"), + reporter->getAddress().getIP(), + reporter_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"), + Binder(coll, StringUtils::wideToUtf8(info), "info"), + reporting->getAddress().getIP(), + reporting_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name") + ); } - return easySQLQuery(query, - [reporter_npp, reporting_npp, info](sqlite3_stmt* stmt) - { - // SQLITE_TRANSIENT to copy string - if (sqlite3_bind_text(stmt, 1, ServerConfig::m_server_uid.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ServerConfig::m_server_uid.c_str()); - } - if (sqlite3_bind_text(stmt, 2, - StringUtils::wideToUtf8(reporter_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporter_npp->getName()).c_str()); - } - if (sqlite3_bind_text(stmt, 3, - StringUtils::wideToUtf8(info).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(info).c_str()); - } - if (sqlite3_bind_text(stmt, 4, - StringUtils::wideToUtf8(reporting_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporting_npp->getName()).c_str()); - } - }); + return easySQLQuery(query, nullptr, coll->getBindFunction()); } // writeReport //----------------------------------------------------------------------------- +/** Gets the rows from IPv4 ban table, either all of them (for polling + * purposes), or those describing a certain address (if only one peer has to + * be checked). + * \param ip The IP address to check the database for. If zero, all rows + * will be given. + * \return A vector of rows in the form of IpBanTableData structures. + */ std::vector DatabaseConnector::getIpBanTableData(uint32_t ip) const { @@ -624,26 +656,32 @@ DatabaseConnector::getIpBanTableData(uint32_t ip) const oss << " LIMIT 1"; oss << ";"; std::string query = oss.str(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) - { - std::vector* vec = (std::vector*)ptr; - IpBanTableData element; - if (!StringUtils::fromString(data[0], element.row_id)) - return 0; - if (!StringUtils::fromString(data[1], element.ip_start)) - return 0; - if (!StringUtils::fromString(data[2], element.ip_end)) - return 0; - element.reason = std::string(data[3]); - element.description = std::string(data[4]); - vec->push_back(element); - return 0; - }, &result, NULL); + + std::vector> output; + easySQLQuery(query, &output); + + for (std::vector& row: output) + { + IpBanTableData element; + if (!StringUtils::fromString(row[0], element.row_id)) + continue; + if (!StringUtils::fromString(row[1], element.ip_start)) + continue; + if (!StringUtils::fromString(row[2], element.ip_end)) + continue; + element.reason = row[3]; + element.description = row[4]; + result.push_back(element); + } return result; } // getIpBanTableData //----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by IPv4, this function increases + * the trigger count. + * \param ip_start Start of IP ban range corresponding to peer. + * \param ip_end End of IP ban range corresponding to peer. + */ void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const { std::string query = StringUtils::insertValues( @@ -655,6 +693,13 @@ void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip } // getIpBanTableData //----------------------------------------------------------------------------- +/** Gets the rows from IPv6 ban table, either all of them (for polling + * purposes), or those describing a certain address (if only one peer has to + * be checked). + * \param ip The IPv6 address to check the database for. If empty, all rows + * will be given. + * \return A vector of rows in the form of Ipv6BanTableData structures. + */ std::vector DatabaseConnector::getIpv6BanTableData(std::string ipv6) const { @@ -664,88 +709,68 @@ DatabaseConnector::getIpv6BanTableData(std::string ipv6) const return result; } bool single_ip = !ipv6.empty(); - std::ostringstream oss; - oss << "SELECT rowid, ipv6_cidr, reason, description FROM "; - oss << (std::string)ServerConfig::m_ipv6_ban_table; - oss << " WHERE "; + std::string query; + std::shared_ptr coll = std::make_shared(); + + query = StringUtils::insertValues( + "SELECT rowid, ipv6_cidr, reason, description FROM %s WHERE ", + ServerConfig::m_ipv6_ban_table.c_str() + ); if (single_ip) - oss << "insideIPv6CIDR(ipv6_cidr, ?) = 1 AND "; - oss << "datetime('now') > datetime(starting_time) AND " + query += StringUtils::insertValues( + "insideIPv6CIDR(ipv6_cidr, %s) = 1 AND ", + Binder(coll, ipv6, "ipv6") + ); + + query += "datetime('now') > datetime(starting_time) AND " "(expired_days is NULL OR datetime" "(starting_time, '+'||expired_days||' days') > datetime('now'))"; - if (single_ip) - oss << " LIMIT 1"; - oss << ";"; - std::string query = oss.str(); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + if (single_ip) + query += " LIMIT 1;"; + + std::vector> output; + easySQLQuery(query, &output, coll->getBindFunction()); + + for (std::vector& row: output) { - if (single_ip) - { - if (sqlite3_bind_text(stmt, 1, - ipv6.c_str(), -1, SQLITE_TRANSIENT) - != SQLITE_OK) - { - Log::error("DatabaseConnector", "Error binding ipv6 addr for query: %s", - sqlite3_errmsg(m_db)); - return result; - } - } - ret = sqlite3_step(stmt); - while (ret == SQLITE_ROW) - { - const char* rowid_cstr = (char*)sqlite3_column_text(stmt, 0); - const char* ipv6cidr_cstr = (char*)sqlite3_column_text(stmt, 1); - const char* reason_cstr = (char*)sqlite3_column_text(stmt, 2); - const char* description_cstr = (char*)sqlite3_column_text(stmt, 3); - Ipv6BanTableData element; - if (StringUtils::fromString(rowid_cstr, element.row_id)) - { - element.ipv6_cidr = std::string(ipv6cidr_cstr); - element.reason = std::string(reason_cstr); - element.description = std::string(description_cstr); - result.push_back(element); - } - ret = sqlite3_step(stmt); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return result; + Ipv6BanTableData element; + if (!StringUtils::fromString(row[0], element.row_id)) + continue; + element.ipv6_cidr = row[1]; + element.reason = row[2]; + element.description = row[3]; + result.push_back(element); } return result; } // getIpv6BanTableData //----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by IPv6, this function increases + * the trigger count. + * \param ipv6_cidr Block of IPv6 addresses corresponding to the peer. + */ void DatabaseConnector::increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const { + std::shared_ptr coll = std::make_shared(); std::string query = StringUtils::insertValues( "UPDATE %s SET trigger_count = trigger_count + 1, " "last_trigger = datetime('now') " - "WHERE ipv6_cidr = ?;", ServerConfig::m_ipv6_ban_table.c_str()); - easySQLQuery(query, [ipv6_cidr](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, ipv6_cidr.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ipv6_cidr.c_str()); - } - }); + "WHERE ipv6_cidr = %s;", + ServerConfig::m_ipv6_ban_table.c_str(), + Binder(coll, ipv6_cidr, "ipv6_cidr") + ); + easySQLQuery(query, nullptr, coll->getBindFunction()); } // increaseIpv6BanTriggerCount //----------------------------------------------------------------------------- +/** Gets the rows from online id ban table, either all of them (for polling + * purposes), or those describing a certain online id (if only one peer has + * to be checked). + * \param online_id The online id to check the database for. If empty, all + * rows will be given. + * \return A vector of rows in the form of OnlineIdBanTableData structures. + */ std::vector DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const { @@ -786,6 +811,10 @@ DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const } // getOnlineIdBanTableData //----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by online id, this function + * increases the trigger count. + * \param online_id Online id of the peer. + */ void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) const { std::string query = StringUtils::insertValues( @@ -797,6 +826,9 @@ void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) cons } // increaseOnlineIdBanTriggerCount //----------------------------------------------------------------------------- +/** Clears reports that are older than a certain number of days + * (specified in the server config). + */ void DatabaseConnector::clearOldReports() { if (m_player_reports_table_exists && @@ -813,6 +845,10 @@ void DatabaseConnector::clearOldReports() } // clearOldReports //----------------------------------------------------------------------------- +/** Sets disconnection times for those peers that already left the server, but + * whose disconnection times wasn't set yet. + * \param present_hosts List of online ids of present peers. + */ void DatabaseConnector::setDisconnectionTimes(std::vector& present_hosts) { if (!hasServerStatsTable()) @@ -841,6 +877,10 @@ void DatabaseConnector::setDisconnectionTimes(std::vector& present_hos } // setDisconnectionTimes //----------------------------------------------------------------------------- +/** Adds a specified IP address to the IPv4 ban table. Usually invoked from + * network console. + * \param addr Address to ban. + */ void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr) { if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists) @@ -854,22 +894,39 @@ void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr) } // saveAddressToIpBanTable //----------------------------------------------------------------------------- +/** Called when the player joins the server, inserts player info into database. + * \param peer The peer that joins. + * \param online_id Player's online id. + * \param player_count Number of players joining using a single peer. + * \param country_code Country code deduced by global or local IP mapping. + */ void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr peer, uint32_t online_id, unsigned player_count, const std::string& country_code) { if (m_server_stats_table.empty() || peer->isAIPeer()) return; std::string query; + std::shared_ptr coll = std::make_shared(); + auto version_os = StringUtils::extractVersionOS(peer->getUserVersion()); if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6()) { query = StringUtils::insertValues( "INSERT INTO %s " - "(host_id, ip, ipv6 ,port, online_id, username, player_num, " + "(host_id, ip, ipv6, port, online_id, username, player_num, " "country_code, version, os, ping) " - "VALUES (%u, 0, \"%s\" ,%u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().toString(false), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); + "VALUES (%u, 0, \"%s\", %u, %u, %s, %u, %s, %s, %s, %u);", + m_server_stats_table.c_str(), + peer->getHostId(), + peer->getAddress().toString(false), + peer->getAddress().getPort(), + online_id, + Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"), + player_count, + Binder(coll, country_code, "country_code", true), + Binder(coll, version_os.first, "version"), + Binder(coll, version_os.second, "os"), + peer->getAveragePing() + ); } else { @@ -877,56 +934,25 @@ void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr peer, "INSERT INTO %s " "(host_id, ip, port, online_id, username, player_num, " "country_code, version, os, ping) " - "VALUES (%u, %u, %u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().getIP(), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); + "VALUES (%u, %u, %u, %u, %s, %u, %s, %s, %s, %u);", + m_server_stats_table.c_str(), + peer->getHostId(), + peer->getAddress().getIP(), + peer->getAddress().getPort(), + online_id, + Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"), + player_count, + Binder(coll, country_code, "country_code", true), + Binder(coll, version_os.first, "version"), + Binder(coll, version_os.second, "os"), + peer->getAveragePing() + ); } - easySQLQuery(query, [peer, country_code](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str()); - } - if (country_code.empty()) - { - if (sqlite3_bind_null(stmt, 2) != SQLITE_OK) - { - Log::error("easySQLQuery", - "Failed to bind NULL for country code."); - } - } - else - { - if (sqlite3_bind_text(stmt, 2, country_code.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind country: %s.", - country_code.c_str()); - } - } - auto version_os = - StringUtils::extractVersionOS(peer->getUserVersion()); - if (sqlite3_bind_text(stmt, 3, version_os.first.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.first.c_str()); - } - if (sqlite3_bind_text(stmt, 4, version_os.second.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.second.c_str()); - } - }); + easySQLQuery(query, nullptr, coll->getBindFunction()); } // onPlayerJoinQueries //----------------------------------------------------------------------------- +/** Prints all rows of the IPv4 ban table. Called from the network console. */ void DatabaseConnector::listBanTable() { if (!m_db) diff --git a/src/network/database_connector.hpp b/src/network/database_connector.hpp index 282fc8598..d22e588b8 100644 --- a/src/network/database_connector.hpp +++ b/src/network/database_connector.hpp @@ -24,18 +24,86 @@ #include "utils/string_utils.hpp" #include "utils/time.hpp" -#include -#include #include +#include #include -#include #include +#include +#include class SocketAddress; class STKPeer; class NetworkPlayerProfile; +/** The purpose of Binder and BinderCollection structures is to allow + * putting values to bind inside StringUtils::insertValues, which is commonly + * used for values that don't require binding (such as integers). + * Unlike previously used approach with separate query formation and binding, + * the arguments are written in the code in the order of query appearance + * (even though real binding still happens later). It also avoids repeated + * binding code. + * + * Syntax looks as follows: + * std::shared_ptr coll = std::make_shared...; + * std::string query_string = StringUtils::insertValues( + * "query contents with wildcards of type %d, %s, %u, ..." + * "where %s is put for values that will be bound later", + * values to insert, ..., Binder(coll, other parameters), ...); + * Then the bind function (e.g. for usage in easySQLQuery) should be + * coll->getBindFunction(). + */ +struct Binder; + +/** BinderCollection is a structure that collects Binder objects used in an + * SQL query formed with insertValues() (see above). For a single query, a + * single instance of BinderCollection should be used. After a query is + * formed, BinderCollection can produce bind function to use with sqlite3. + */ +struct BinderCollection +{ + std::vector> m_binders; + + std::function getBindFunction() const; +}; + +/** Binder is a wrapper for a string to be bound into an SQL query. See above + * for its usage in insertValues(). When it's printed to an output stream + * (in particular, this is done in insertValues implementation), this Binder + * is added to the query's BinderCollection, and the '?'-placeholder is added + * to the query string instead of %s. + * + * When using Binder, make sure that: + * - operator << is invoked on it exactly once; + * - operator << is invoked on several Binders in the order in which they go + * in the query; + * - before calling insertValues, there is a %-wildcard corresponding to the + * Binder in the query string (and not '?'). + * For example, when the query formed inside of a function depends on its + * arguments, it should be formed part by part, from left to right. + * Of course, you can choose the "default" way, binding values separately from + * insertValues() call. + */ +struct Binder +{ + std::weak_ptr m_collection; + std::string m_value; + std::string m_name; + bool m_use_null_if_empty; + + Binder(std::shared_ptr collection, std::string value, + std::string name = "", bool use_null_if_empty = false): + m_collection(collection), m_value(value), + m_name(name), m_use_null_if_empty(use_null_if_empty) {} +}; + +std::ostream& operator << (std::ostream& os, const Binder& binder); + +/** A class that manages the database operations needed for the server to work. + * The SQL queries are intended to be placed only within the implementation + * of this class, while the logic corresponding to those queries should not + * belong here. + */ class DatabaseConnector { private: @@ -50,6 +118,7 @@ private: uint64_t m_last_poll_db_time; public: + /** Corresponds to the row of IPv4 ban table. */ struct IpBanTableData { int row_id; @@ -58,13 +127,17 @@ public: std::string reason; std::string description; }; - struct Ipv6BanTableData { + /** Corresponds to the row of IPv6 ban table. */ + struct Ipv6BanTableData + { int row_id; std::string ipv6_cidr; std::string reason; std::string description; }; - struct OnlineIdBanTableData { + /** Corresponds to the row of online id ban table. */ + struct OnlineIdBanTableData + { int row_id; uint32_t online_id; std::string reason; @@ -74,7 +147,9 @@ public: void destroyDatabase(); bool easySQLQuery(const std::string& query, - std::function bind_function = nullptr) const; + std::vector>* output = nullptr, + std::function bind_function = nullptr, + std::string null_value = "") const; void checkTableExists(const std::string& table, bool& result); @@ -83,14 +158,15 @@ public: std::string ipv62Country(const SocketAddress& addr) const; static void upperIPv6SQL(sqlite3_context* context, int argc, - sqlite3_value** argv); + sqlite3_value** argv); static void insideIPv6CIDRSQL(sqlite3_context* context, int argc, - sqlite3_value** argv); + sqlite3_value** argv); void writeDisconnectInfoTable(STKPeer* peer); void initServerStatsTable(); - bool writeReport(STKPeer* reporter, std::shared_ptr reporter_npp, - STKPeer* reporting, std::shared_ptr reporting_npp, - irr::core::stringw& info); + bool writeReport( + STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info); bool hasDatabase() const { return m_db != nullptr; } bool hasServerStatsTable() const { return !m_server_stats_table.empty(); } bool hasPlayerReportsTable() const @@ -115,7 +191,5 @@ public: void listBanTable(); }; - - #endif // ifndef DATABASE_CONNECTOR_HPP #endif // ifdef ENABLE_SQLITE3 diff --git a/src/network/protocols/server_lobby.hpp b/src/network/protocols/server_lobby.hpp index a8e9245ad..fc18cfeca 100644 --- a/src/network/protocols/server_lobby.hpp +++ b/src/network/protocols/server_lobby.hpp @@ -34,10 +34,6 @@ #include #include -// #ifdef ENABLE_SQLITE3 -// #include -// #endif - class BareNetworkString; class DatabaseConnector; class NetworkItemManager;