Get rid of repeated code, add comments

easySQLQuery was generalized to return output rows if requested,
which allowed to shorten the code for many queries.

The code for binding values to sqlite statement was also repeated
many times. Two auxiliary structures were introduced, so that it's
possible to provide at the same time both those parameters which
require and those which don't require binding, in a single
StringUtils::insertValues() call.
This commit is contained in:
kimden 2024-07-03 23:13:40 +04:00
parent a17a2d5024
commit 9fb7448abe
3 changed files with 390 additions and 294 deletions

View File

@ -1,6 +1,6 @@
//
// SuperTuxKart - a fun racing game with go-kart
// Copyright (C) 2013-2015 SuperTuxKart-Team
// Copyright (C) 2024 SuperTuxKart-Team
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
@ -29,6 +29,61 @@
#include "utils/log.hpp"
//-----------------------------------------------------------------------------
/** Prints "?" to the output stream and saves the Binder object to the
* corresponding BinderCollection so that it can produce bind function later
* When we invoke StringUtils::insertValues with a Binder argument, the
* implementation of insertValues ensures that this function is invoked for
* all Binder arguments from left to right.
*/
std::ostream& operator << (std::ostream& os, const Binder& binder)
{
os << "?";
binder.m_collection.lock()->m_binders.emplace_back(std::make_shared<Binder>(binder));
return os;
} // operator << (Binder)
//-----------------------------------------------------------------------------
/** Returns a bind function that should be used inside an easySQLQuery. As the
* Binder objects are already ordered in a correct way, the indices just go
* from 1 upwards. Depending on a particular Binder, we can also bind NULL
* instead of a string.
*/
std::function<void(sqlite3_stmt* stmt)> BinderCollection::getBindFunction() const
{
auto binders = m_binders;
return [binders](sqlite3_stmt* stmt)
{
int idx = 1;
for (std::shared_ptr<Binder> binder: binders)
{
if (binder)
{
// SQLITE_TRANSIENT to copy string
if (binder->m_use_null_if_empty && binder->m_value.empty())
{
if (sqlite3_bind_null(stmt, idx) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind NULL for %s.",
binder->m_name.c_str());
}
}
else
{
if (sqlite3_bind_text(stmt, idx, binder->m_value.c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s as %s.",
binder->m_value.c_str(), binder->m_name.c_str());
}
}
}
++idx;
}
};
} // BinderCollection::getBindFunction
//-----------------------------------------------------------------------------
/** Opens the database, sets its busy handler and variables related to it. */
void DatabaseConnector::initDatabase()
{
m_last_poll_db_time = StkTime::getMonoTimeMs();
@ -83,6 +138,7 @@ void DatabaseConnector::initDatabase()
} // initDatabase
//-----------------------------------------------------------------------------
/** Closes the database. */
void DatabaseConnector::destroyDatabase()
{
auto peers = STKHost::get()->getPeers();
@ -93,12 +149,18 @@ void DatabaseConnector::destroyDatabase()
} // destroyDatabase
//-----------------------------------------------------------------------------
/** Run simple query with write lock waiting and optional function, this
* function has no callback for the return (if any) by the query.
* Return true if no error occurs
/** Runs simple query with optional bind function. If output vector pointer is
* not (default) nullptr, then the output is written there.
* \param query The SQL query with '?'-placeholders for values to bind.
* \param output The 2D vector for output rows. If nullptr, the query output
* is ignored.
* \param bind_function The function for binding missing values.
* \return True if no error occurs.
*/
bool DatabaseConnector::easySQLQuery(const std::string& query,
std::function<void(sqlite3_stmt* stmt)> bind_function) const
bool DatabaseConnector::easySQLQuery(
const std::string& query, std::vector<std::vector<std::string>>* output,
std::function<void(sqlite3_stmt* stmt)> bind_function,
std::string null_value) const
{
if (!m_db)
return false;
@ -109,6 +171,24 @@ bool DatabaseConnector::easySQLQuery(const std::string& query,
if (bind_function)
bind_function(stmt);
ret = sqlite3_step(stmt);
if (output)
{
output->clear();
while (ret == SQLITE_ROW)
{
output->emplace_back();
int columns = sqlite3_column_count(stmt);
for (int i = 0; i < columns; ++i)
{
const char* value = (char*)sqlite3_column_text(stmt, i);
if (value == nullptr)
output->back().push_back(null_value);
else
output->back().push_back(std::string(value));
}
ret = sqlite3_step(stmt);
}
}
ret = sqlite3_finalize(stmt);
if (ret != SQLITE_OK)
{
@ -129,38 +209,30 @@ bool DatabaseConnector::easySQLQuery(const std::string& query,
} // easySQLQuery
//-----------------------------------------------------------------------------
/* Write true to result if table name exists in database. */
/** Performs a query to determine if a certain table exists.
* \param table The searched name.
* \param result The output value.
*/
void DatabaseConnector::checkTableExists(const std::string& table, bool& result)
{
if (!m_db)
return;
sqlite3_stmt* stmt = NULL;
result = false;
if (!table.empty())
{
std::string query = StringUtils::insertValues(
"SELECT count(type) FROM sqlite_master "
"WHERE type='table' AND name='%s';", table.c_str());
int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0);
if (ret == SQLITE_OK)
std::vector<std::vector<std::string>> output;
if (easySQLQuery(query, &output) && !output.empty())
{
ret = sqlite3_step(stmt);
if (ret == SQLITE_ROW)
int number;
if (StringUtils::fromString(output[0][0], number) && number == 1)
{
int number = sqlite3_column_int(stmt, 0);
if (number == 1)
{
Log::info("DatabaseConnector", "Table named %s will be used.",
table.c_str());
result = true;
}
}
ret = sqlite3_finalize(stmt);
if (ret != SQLITE_OK)
{
Log::error("DatabaseConnector",
"Error finalize database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
Log::info("DatabaseConnector", "Table named %s will be used.",
table.c_str());
result = true;
}
}
}
@ -172,6 +244,12 @@ void DatabaseConnector::checkTableExists(const std::string& table, bool& result)
} // checkTableExists
//-----------------------------------------------------------------------------
/** Queries the database's IP mapping to determine the country code for an
* address.
* \param addr Queried address.
* \return A country code string if the address is found in the mapping,
* and an empty string otherwise.
*/
std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const
{
if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN())
@ -185,34 +263,21 @@ std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const
ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(),
addr.getIP());
sqlite3_stmt* stmt = NULL;
int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0);
if (ret == SQLITE_OK)
std::vector<std::vector<std::string>> output;
if (easySQLQuery(query, &output) && !output.empty())
{
ret = sqlite3_step(stmt);
if (ret == SQLITE_ROW)
{
const char* country_code = (char*)sqlite3_column_text(stmt, 0);
cc_code = country_code;
}
ret = sqlite3_finalize(stmt);
if (ret != SQLITE_OK)
{
Log::error("DatabaseConnector",
"Error finalize database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
}
}
else
{
Log::error("DatabaseConnector", "Error preparing database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
return "";
cc_code = output[0][0];
}
return cc_code;
} // ip2Country
//-----------------------------------------------------------------------------
/** Queries the database's IPv6 mapping to determine the country code for an
* address.
* \param addr Queried address.
* \return A country code string if the address is found in the mapping,
* and an empty string otherwise.
*/
std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const
{
if (!m_db || !m_ipv6_geolocation_table_exists)
@ -227,34 +292,16 @@ std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const
ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(),
ipv6.c_str());
sqlite3_stmt* stmt = NULL;
int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0);
if (ret == SQLITE_OK)
std::vector<std::vector<std::string>> output;
if (easySQLQuery(query, &output) && !output.empty())
{
ret = sqlite3_step(stmt);
if (ret == SQLITE_ROW)
{
const char* country_code = (char*)sqlite3_column_text(stmt, 0);
cc_code = country_code;
}
ret = sqlite3_finalize(stmt);
if (ret != SQLITE_OK)
{
Log::error("DatabaseConnector",
"Error finalize database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
}
}
else
{
Log::error("DatabaseConnector", "Error preparing database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
return "";
cc_code = output[0][0];
}
return cc_code;
} // ipv62Country
// ----------------------------------------------------------------------------
/** A function invoked within SQLite */
void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc,
sqlite3_value** argv)
{
@ -274,6 +321,9 @@ void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc,
}
// ----------------------------------------------------------------------------
/** A function that checks within SQLite whether an IPv6 address (argv[1])
* is located within a specified block (argv[0]) of IPv6 addresses.
*/
void DatabaseConnector::insideIPv6CIDRSQL(sqlite3_context* context, int argc,
sqlite3_value** argv)
{
@ -314,6 +364,10 @@ sqlite3_extension_init(sqlite3* db, char** pzErrMsg,
*/
//-----------------------------------------------------------------------------
/** When a peer disconnects from the server, this function saves to the
* database peer's disconnection time and statistics (ping and packet loss).
* \param peer Disconnecting peer.
*/
void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer)
{
if (m_server_stats_table.empty())
@ -328,7 +382,11 @@ void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer)
} // writeDisconnectInfoTable
//-----------------------------------------------------------------------------
/** Creates necessary tables and views if they don't exist yet in the database.
* As the function is invoked during the server launch, it also updates rows
* related to players whose disconnection time wasn't written, and loads
* last used host id.
*/
void DatabaseConnector::initServerStatsTable()
{
if (!ServerConfig::m_sql_management || !m_db)
@ -356,26 +414,10 @@ void DatabaseConnector::initServerStatsTable()
" packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n"
") WITHOUT ROWID;";
std::string query = oss.str();
sqlite3_stmt* stmt = NULL;
int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0);
if (ret == SQLITE_OK)
{
ret = sqlite3_step(stmt);
ret = sqlite3_finalize(stmt);
if (ret == SQLITE_OK)
m_server_stats_table = table_name;
else
{
Log::error("DatabaseConnector",
"Error finalize database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
}
}
else
{
Log::error("DatabaseConnector", "Error preparing database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
}
if (easySQLQuery(query))
m_server_stats_table = table_name;
if (m_server_stats_table.empty())
return;
@ -501,31 +543,22 @@ void DatabaseConnector::initServerStatsTable()
uint32_t last_host_id = 0;
query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;",
m_server_stats_table.c_str());
ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0);
if (ret == SQLITE_OK)
std::vector<std::vector<std::string>> output;
if (easySQLQuery(query, &output))
{
ret = sqlite3_step(stmt);
if (ret == SQLITE_ROW && sqlite3_column_type(stmt, 0) != SQLITE_NULL)
if (!output.empty() && !output[0].empty()
&& StringUtils::fromString(output[0][0], last_host_id))
{
last_host_id = (unsigned)sqlite3_column_int64(stmt, 0);
Log::info("DatabaseConnector", "%u was last server session max host id.",
last_host_id);
}
ret = sqlite3_finalize(stmt);
if (ret != SQLITE_OK)
{
Log::error("DatabaseConnector",
"Error finalize database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
m_server_stats_table = "";
}
}
else
{
Log::error("DatabaseConnector", "Error preparing database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
m_server_stats_table = "";
}
STKHost::get()->setNextHostId(last_host_id);
// Update disconnected time (if stk crashed it will not be written)
@ -537,25 +570,41 @@ void DatabaseConnector::initServerStatsTable()
} // initServerStatsTable
//-----------------------------------------------------------------------------
bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptr<NetworkPlayerProfile> reporter_npp,
STKPeer* reporting, std::shared_ptr<NetworkPlayerProfile> reporting_npp,
irr::core::stringw& info)
/** Writes a report of one player about another player.
* \param reporter Peer that sends the report.
* \param reporter_npp Player profile that sends the report.
* \param reporting Peer that is reported.
* \param reporting_npp Player profile that is reported.
* \param info The report message.
* \return True if the database query succeeded.
*/
bool DatabaseConnector::writeReport(
STKPeer* reporter, std::shared_ptr<NetworkPlayerProfile> reporter_npp,
STKPeer* reporting, std::shared_ptr<NetworkPlayerProfile> reporting_npp,
irr::core::stringw& info)
{
std::string query;
std::shared_ptr<BinderCollection> coll = std::make_shared<BinderCollection>();
if (ServerConfig::m_ipv6_connection)
{
query = StringUtils::insertValues(
"INSERT INTO %s "
"(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, "
"info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) "
"VALUES (?, %u, \"%s\", %u, ?, ?, %u, \"%s\", %u, ?);",
"VALUES (%s, %u, \"%s\", %u, %s, %s, %u, \"%s\", %u, %s);",
ServerConfig::m_player_reports_table.c_str(),
Binder(coll, ServerConfig::m_server_uid, "server_uid"),
!reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0,
reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "",
reporter_npp->getOnlineId(),
Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"),
Binder(coll, StringUtils::wideToUtf8(info), "info"),
!reporting->getAddress().isIPv6() ? reporting->getAddress().getIP() : 0,
reporting->getAddress().isIPv6() ? reporting->getAddress().toString(false) : "",
reporting_npp->getOnlineId());
reporting_npp->getOnlineId(),
Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name")
);
}
else
{
@ -563,46 +612,29 @@ bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptr<NetworkPl
"INSERT INTO %s "
"(server_uid, reporter_ip, reporter_online_id, reporter_username, "
"info, reporting_ip, reporting_online_id, reporting_username) "
"VALUES (?, %u, %u, ?, ?, %u, %u, ?);",
"VALUES (%s, %u, %u, %s, %s, %u, %u, %s);",
ServerConfig::m_player_reports_table.c_str(),
reporter->getAddress().getIP(), reporter_npp->getOnlineId(),
reporting->getAddress().getIP(), reporting_npp->getOnlineId());
Binder(coll, ServerConfig::m_server_uid, "server_uid"),
reporter->getAddress().getIP(),
reporter_npp->getOnlineId(),
Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"),
Binder(coll, StringUtils::wideToUtf8(info), "info"),
reporting->getAddress().getIP(),
reporting_npp->getOnlineId(),
Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name")
);
}
return easySQLQuery(query,
[reporter_npp, reporting_npp, info](sqlite3_stmt* stmt)
{
// SQLITE_TRANSIENT to copy string
if (sqlite3_bind_text(stmt, 1, ServerConfig::m_server_uid.c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
ServerConfig::m_server_uid.c_str());
}
if (sqlite3_bind_text(stmt, 2,
StringUtils::wideToUtf8(reporter_npp->getName()).c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
StringUtils::wideToUtf8(reporter_npp->getName()).c_str());
}
if (sqlite3_bind_text(stmt, 3,
StringUtils::wideToUtf8(info).c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
StringUtils::wideToUtf8(info).c_str());
}
if (sqlite3_bind_text(stmt, 4,
StringUtils::wideToUtf8(reporting_npp->getName()).c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
StringUtils::wideToUtf8(reporting_npp->getName()).c_str());
}
});
return easySQLQuery(query, nullptr, coll->getBindFunction());
} // writeReport
//-----------------------------------------------------------------------------
/** Gets the rows from IPv4 ban table, either all of them (for polling
* purposes), or those describing a certain address (if only one peer has to
* be checked).
* \param ip The IP address to check the database for. If zero, all rows
* will be given.
* \return A vector of rows in the form of IpBanTableData structures.
*/
std::vector<DatabaseConnector::IpBanTableData>
DatabaseConnector::getIpBanTableData(uint32_t ip) const
{
@ -624,26 +656,32 @@ DatabaseConnector::getIpBanTableData(uint32_t ip) const
oss << " LIMIT 1";
oss << ";";
std::string query = oss.str();
sqlite3_exec(m_db, query.c_str(),
[](void* ptr, int count, char** data, char** columns)
{
std::vector<IpBanTableData>* vec = (std::vector<IpBanTableData>*)ptr;
IpBanTableData element;
if (!StringUtils::fromString(data[0], element.row_id))
return 0;
if (!StringUtils::fromString(data[1], element.ip_start))
return 0;
if (!StringUtils::fromString(data[2], element.ip_end))
return 0;
element.reason = std::string(data[3]);
element.description = std::string(data[4]);
vec->push_back(element);
return 0;
}, &result, NULL);
std::vector<std::vector<std::string>> output;
easySQLQuery(query, &output);
for (std::vector<std::string>& row: output)
{
IpBanTableData element;
if (!StringUtils::fromString(row[0], element.row_id))
continue;
if (!StringUtils::fromString(row[1], element.ip_start))
continue;
if (!StringUtils::fromString(row[2], element.ip_end))
continue;
element.reason = row[3];
element.description = row[4];
result.push_back(element);
}
return result;
} // getIpBanTableData
//-----------------------------------------------------------------------------
/** For a peer that turned out to be banned by IPv4, this function increases
* the trigger count.
* \param ip_start Start of IP ban range corresponding to peer.
* \param ip_end End of IP ban range corresponding to peer.
*/
void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const
{
std::string query = StringUtils::insertValues(
@ -655,6 +693,13 @@ void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip
} // getIpBanTableData
//-----------------------------------------------------------------------------
/** Gets the rows from IPv6 ban table, either all of them (for polling
* purposes), or those describing a certain address (if only one peer has to
* be checked).
* \param ip The IPv6 address to check the database for. If empty, all rows
* will be given.
* \return A vector of rows in the form of Ipv6BanTableData structures.
*/
std::vector<DatabaseConnector::Ipv6BanTableData>
DatabaseConnector::getIpv6BanTableData(std::string ipv6) const
{
@ -664,88 +709,68 @@ DatabaseConnector::getIpv6BanTableData(std::string ipv6) const
return result;
}
bool single_ip = !ipv6.empty();
std::ostringstream oss;
oss << "SELECT rowid, ipv6_cidr, reason, description FROM ";
oss << (std::string)ServerConfig::m_ipv6_ban_table;
oss << " WHERE ";
std::string query;
std::shared_ptr<BinderCollection> coll = std::make_shared<BinderCollection>();
query = StringUtils::insertValues(
"SELECT rowid, ipv6_cidr, reason, description FROM %s WHERE ",
ServerConfig::m_ipv6_ban_table.c_str()
);
if (single_ip)
oss << "insideIPv6CIDR(ipv6_cidr, ?) = 1 AND ";
oss << "datetime('now') > datetime(starting_time) AND "
query += StringUtils::insertValues(
"insideIPv6CIDR(ipv6_cidr, %s) = 1 AND ",
Binder(coll, ipv6, "ipv6")
);
query += "datetime('now') > datetime(starting_time) AND "
"(expired_days is NULL OR datetime"
"(starting_time, '+'||expired_days||' days') > datetime('now'))";
if (single_ip)
oss << " LIMIT 1";
oss << ";";
std::string query = oss.str();
sqlite3_stmt* stmt = NULL;
int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0);
if (ret == SQLITE_OK)
if (single_ip)
query += " LIMIT 1;";
std::vector<std::vector<std::string>> output;
easySQLQuery(query, &output, coll->getBindFunction());
for (std::vector<std::string>& row: output)
{
if (single_ip)
{
if (sqlite3_bind_text(stmt, 1,
ipv6.c_str(), -1, SQLITE_TRANSIENT)
!= SQLITE_OK)
{
Log::error("DatabaseConnector", "Error binding ipv6 addr for query: %s",
sqlite3_errmsg(m_db));
return result;
}
}
ret = sqlite3_step(stmt);
while (ret == SQLITE_ROW)
{
const char* rowid_cstr = (char*)sqlite3_column_text(stmt, 0);
const char* ipv6cidr_cstr = (char*)sqlite3_column_text(stmt, 1);
const char* reason_cstr = (char*)sqlite3_column_text(stmt, 2);
const char* description_cstr = (char*)sqlite3_column_text(stmt, 3);
Ipv6BanTableData element;
if (StringUtils::fromString(rowid_cstr, element.row_id))
{
element.ipv6_cidr = std::string(ipv6cidr_cstr);
element.reason = std::string(reason_cstr);
element.description = std::string(description_cstr);
result.push_back(element);
}
ret = sqlite3_step(stmt);
}
ret = sqlite3_finalize(stmt);
if (ret != SQLITE_OK)
{
Log::error("DatabaseConnector",
"Error finalize database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
}
}
else
{
Log::error("DatabaseConnector", "Error preparing database for query %s: %s",
query.c_str(), sqlite3_errmsg(m_db));
return result;
Ipv6BanTableData element;
if (!StringUtils::fromString(row[0], element.row_id))
continue;
element.ipv6_cidr = row[1];
element.reason = row[2];
element.description = row[3];
result.push_back(element);
}
return result;
} // getIpv6BanTableData
//-----------------------------------------------------------------------------
/** For a peer that turned out to be banned by IPv6, this function increases
* the trigger count.
* \param ipv6_cidr Block of IPv6 addresses corresponding to the peer.
*/
void DatabaseConnector::increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const
{
std::shared_ptr<BinderCollection> coll = std::make_shared<BinderCollection>();
std::string query = StringUtils::insertValues(
"UPDATE %s SET trigger_count = trigger_count + 1, "
"last_trigger = datetime('now') "
"WHERE ipv6_cidr = ?;", ServerConfig::m_ipv6_ban_table.c_str());
easySQLQuery(query, [ipv6_cidr](sqlite3_stmt* stmt)
{
if (sqlite3_bind_text(stmt, 1, ipv6_cidr.c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
ipv6_cidr.c_str());
}
});
"WHERE ipv6_cidr = %s;",
ServerConfig::m_ipv6_ban_table.c_str(),
Binder(coll, ipv6_cidr, "ipv6_cidr")
);
easySQLQuery(query, nullptr, coll->getBindFunction());
} // increaseIpv6BanTriggerCount
//-----------------------------------------------------------------------------
/** Gets the rows from online id ban table, either all of them (for polling
* purposes), or those describing a certain online id (if only one peer has
* to be checked).
* \param online_id The online id to check the database for. If empty, all
* rows will be given.
* \return A vector of rows in the form of OnlineIdBanTableData structures.
*/
std::vector<DatabaseConnector::OnlineIdBanTableData>
DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const
{
@ -786,6 +811,10 @@ DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const
} // getOnlineIdBanTableData
//-----------------------------------------------------------------------------
/** For a peer that turned out to be banned by online id, this function
* increases the trigger count.
* \param online_id Online id of the peer.
*/
void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) const
{
std::string query = StringUtils::insertValues(
@ -797,6 +826,9 @@ void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) cons
} // increaseOnlineIdBanTriggerCount
//-----------------------------------------------------------------------------
/** Clears reports that are older than a certain number of days
* (specified in the server config).
*/
void DatabaseConnector::clearOldReports()
{
if (m_player_reports_table_exists &&
@ -813,6 +845,10 @@ void DatabaseConnector::clearOldReports()
} // clearOldReports
//-----------------------------------------------------------------------------
/** Sets disconnection times for those peers that already left the server, but
* whose disconnection times wasn't set yet.
* \param present_hosts List of online ids of present peers.
*/
void DatabaseConnector::setDisconnectionTimes(std::vector<uint32_t>& present_hosts)
{
if (!hasServerStatsTable())
@ -841,6 +877,10 @@ void DatabaseConnector::setDisconnectionTimes(std::vector<uint32_t>& present_hos
} // setDisconnectionTimes
//-----------------------------------------------------------------------------
/** Adds a specified IP address to the IPv4 ban table. Usually invoked from
* network console.
* \param addr Address to ban.
*/
void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr)
{
if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists)
@ -854,22 +894,39 @@ void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr)
} // saveAddressToIpBanTable
//-----------------------------------------------------------------------------
/** Called when the player joins the server, inserts player info into database.
* \param peer The peer that joins.
* \param online_id Player's online id.
* \param player_count Number of players joining using a single peer.
* \param country_code Country code deduced by global or local IP mapping.
*/
void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr<STKPeer> peer,
uint32_t online_id, unsigned player_count, const std::string& country_code)
{
if (m_server_stats_table.empty() || peer->isAIPeer())
return;
std::string query;
std::shared_ptr<BinderCollection> coll = std::make_shared<BinderCollection>();
auto version_os = StringUtils::extractVersionOS(peer->getUserVersion());
if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6())
{
query = StringUtils::insertValues(
"INSERT INTO %s "
"(host_id, ip, ipv6 ,port, online_id, username, player_num, "
"(host_id, ip, ipv6, port, online_id, username, player_num, "
"country_code, version, os, ping) "
"VALUES (%u, 0, \"%s\" ,%u, %u, ?, %u, ?, ?, ?, %u);",
m_server_stats_table.c_str(), peer->getHostId(),
peer->getAddress().toString(false), peer->getAddress().getPort(),
online_id, player_count, peer->getAveragePing());
"VALUES (%u, 0, \"%s\", %u, %u, %s, %u, %s, %s, %s, %u);",
m_server_stats_table.c_str(),
peer->getHostId(),
peer->getAddress().toString(false),
peer->getAddress().getPort(),
online_id,
Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"),
player_count,
Binder(coll, country_code, "country_code", true),
Binder(coll, version_os.first, "version"),
Binder(coll, version_os.second, "os"),
peer->getAveragePing()
);
}
else
{
@ -877,56 +934,25 @@ void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr<STKPeer> peer,
"INSERT INTO %s "
"(host_id, ip, port, online_id, username, player_num, "
"country_code, version, os, ping) "
"VALUES (%u, %u, %u, %u, ?, %u, ?, ?, ?, %u);",
m_server_stats_table.c_str(), peer->getHostId(),
peer->getAddress().getIP(), peer->getAddress().getPort(),
online_id, player_count, peer->getAveragePing());
"VALUES (%u, %u, %u, %u, %s, %u, %s, %s, %s, %u);",
m_server_stats_table.c_str(),
peer->getHostId(),
peer->getAddress().getIP(),
peer->getAddress().getPort(),
online_id,
Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"),
player_count,
Binder(coll, country_code, "country_code", true),
Binder(coll, version_os.first, "version"),
Binder(coll, version_os.second, "os"),
peer->getAveragePing()
);
}
easySQLQuery(query, [peer, country_code](sqlite3_stmt* stmt)
{
if (sqlite3_bind_text(stmt, 1, StringUtils::wideToUtf8(
peer->getPlayerProfiles()[0]->getName()).c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
StringUtils::wideToUtf8(
peer->getPlayerProfiles()[0]->getName()).c_str());
}
if (country_code.empty())
{
if (sqlite3_bind_null(stmt, 2) != SQLITE_OK)
{
Log::error("easySQLQuery",
"Failed to bind NULL for country code.");
}
}
else
{
if (sqlite3_bind_text(stmt, 2, country_code.c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind country: %s.",
country_code.c_str());
}
}
auto version_os =
StringUtils::extractVersionOS(peer->getUserVersion());
if (sqlite3_bind_text(stmt, 3, version_os.first.c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
version_os.first.c_str());
}
if (sqlite3_bind_text(stmt, 4, version_os.second.c_str(),
-1, SQLITE_TRANSIENT) != SQLITE_OK)
{
Log::error("easySQLQuery", "Failed to bind %s.",
version_os.second.c_str());
}
});
easySQLQuery(query, nullptr, coll->getBindFunction());
} // onPlayerJoinQueries
//-----------------------------------------------------------------------------
/** Prints all rows of the IPv4 ban table. Called from the network console. */
void DatabaseConnector::listBanTable()
{
if (!m_db)

View File

@ -24,18 +24,86 @@
#include "utils/string_utils.hpp"
#include "utils/time.hpp"
#include <vector>
#include <iostream>
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <sqlite3.h>
#include <string>
#include <vector>
class SocketAddress;
class STKPeer;
class NetworkPlayerProfile;
/** The purpose of Binder and BinderCollection structures is to allow
* putting values to bind inside StringUtils::insertValues, which is commonly
* used for values that don't require binding (such as integers).
* Unlike previously used approach with separate query formation and binding,
* the arguments are written in the code in the order of query appearance
* (even though real binding still happens later). It also avoids repeated
* binding code.
*
* Syntax looks as follows:
* std::shared_ptr<BinderCollection> coll = std::make_shared...;
* std::string query_string = StringUtils::insertValues(
* "query contents with wildcards of type %d, %s, %u, ..."
* "where %s is put for values that will be bound later",
* values to insert, ..., Binder(coll, other parameters), ...);
* Then the bind function (e.g. for usage in easySQLQuery) should be
* coll->getBindFunction().
*/
struct Binder;
/** BinderCollection is a structure that collects Binder objects used in an
* SQL query formed with insertValues() (see above). For a single query, a
* single instance of BinderCollection should be used. After a query is
* formed, BinderCollection can produce bind function to use with sqlite3.
*/
struct BinderCollection
{
std::vector<std::shared_ptr<Binder>> m_binders;
std::function<void(sqlite3_stmt* stmt)> getBindFunction() const;
};
/** Binder is a wrapper for a string to be bound into an SQL query. See above
* for its usage in insertValues(). When it's printed to an output stream
* (in particular, this is done in insertValues implementation), this Binder
* is added to the query's BinderCollection, and the '?'-placeholder is added
* to the query string instead of %s.
*
* When using Binder, make sure that:
* - operator << is invoked on it exactly once;
* - operator << is invoked on several Binders in the order in which they go
* in the query;
* - before calling insertValues, there is a %-wildcard corresponding to the
* Binder in the query string (and not '?').
* For example, when the query formed inside of a function depends on its
* arguments, it should be formed part by part, from left to right.
* Of course, you can choose the "default" way, binding values separately from
* insertValues() call.
*/
struct Binder
{
std::weak_ptr<BinderCollection> m_collection;
std::string m_value;
std::string m_name;
bool m_use_null_if_empty;
Binder(std::shared_ptr<BinderCollection> collection, std::string value,
std::string name = "", bool use_null_if_empty = false):
m_collection(collection), m_value(value),
m_name(name), m_use_null_if_empty(use_null_if_empty) {}
};
std::ostream& operator << (std::ostream& os, const Binder& binder);
/** A class that manages the database operations needed for the server to work.
* The SQL queries are intended to be placed only within the implementation
* of this class, while the logic corresponding to those queries should not
* belong here.
*/
class DatabaseConnector
{
private:
@ -50,6 +118,7 @@ private:
uint64_t m_last_poll_db_time;
public:
/** Corresponds to the row of IPv4 ban table. */
struct IpBanTableData
{
int row_id;
@ -58,13 +127,17 @@ public:
std::string reason;
std::string description;
};
struct Ipv6BanTableData {
/** Corresponds to the row of IPv6 ban table. */
struct Ipv6BanTableData
{
int row_id;
std::string ipv6_cidr;
std::string reason;
std::string description;
};
struct OnlineIdBanTableData {
/** Corresponds to the row of online id ban table. */
struct OnlineIdBanTableData
{
int row_id;
uint32_t online_id;
std::string reason;
@ -74,7 +147,9 @@ public:
void destroyDatabase();
bool easySQLQuery(const std::string& query,
std::function<void(sqlite3_stmt* stmt)> bind_function = nullptr) const;
std::vector<std::vector<std::string>>* output = nullptr,
std::function<void(sqlite3_stmt* stmt)> bind_function = nullptr,
std::string null_value = "") const;
void checkTableExists(const std::string& table, bool& result);
@ -83,14 +158,15 @@ public:
std::string ipv62Country(const SocketAddress& addr) const;
static void upperIPv6SQL(sqlite3_context* context, int argc,
sqlite3_value** argv);
sqlite3_value** argv);
static void insideIPv6CIDRSQL(sqlite3_context* context, int argc,
sqlite3_value** argv);
sqlite3_value** argv);
void writeDisconnectInfoTable(STKPeer* peer);
void initServerStatsTable();
bool writeReport(STKPeer* reporter, std::shared_ptr<NetworkPlayerProfile> reporter_npp,
STKPeer* reporting, std::shared_ptr<NetworkPlayerProfile> reporting_npp,
irr::core::stringw& info);
bool writeReport(
STKPeer* reporter, std::shared_ptr<NetworkPlayerProfile> reporter_npp,
STKPeer* reporting, std::shared_ptr<NetworkPlayerProfile> reporting_npp,
irr::core::stringw& info);
bool hasDatabase() const { return m_db != nullptr; }
bool hasServerStatsTable() const { return !m_server_stats_table.empty(); }
bool hasPlayerReportsTable() const
@ -115,7 +191,5 @@ public:
void listBanTable();
};
#endif // ifndef DATABASE_CONNECTOR_HPP
#endif // ifdef ENABLE_SQLITE3

View File

@ -34,10 +34,6 @@
#include <mutex>
#include <set>
// #ifdef ENABLE_SQLITE3
// #include <sqlite3.h>
// #endif
class BareNetworkString;
class DatabaseConnector;
class NetworkItemManager;