diff --git a/PROTOCOL.md b/PROTOCOL.md index 56c74ee..91f798b 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -181,8 +181,8 @@ Communication between the master server and clients will be done over a WebSocke 2 - Message - String + Error Code + Packed Unsigned Short ¬R1 @@ -213,8 +213,8 @@ Communication between the master server and clients will be done over a WebSocke 2 - Message - String + Error Code + Packed Unsigned Short @@ -354,12 +354,7 @@ TODO: MAKE THIS SECTION NOT LOOK LIKE SHIT ### Master / Slave -#### M -> S (ID 2) -0x100: KEY SIZE WAS INCORRECT - -0x101: COULD NOT PARSE KEY - -#### M -> S (ID 4) +#### M -> S (ID 1) 0x100: MAX AUTH ATTEMPTS REACHED 0x101: LICENSE DATA INCORRECT @@ -370,4 +365,22 @@ TODO: MAKE THIS SECTION NOT LOOK LIKE SHIT ### Master / Client +#### M -> C (ID 0) + +0x100: MAX AUTH ATTEMPTS REACHED + +0x101: USERNAME DOES NOT EXIST + +0x102: PASSWORD INCORRECT + +#### M -> C (ID 1) + +0x000: OK + +0x100: USERNAME TAKEN + +0x101: EMAIL TAKEN + +0x102: PASSWORD TOO WEAK + ### Slave / Client \ No newline at end of file diff --git a/src/common/sock/packet.cpp b/src/common/sock/packet.cpp index 7cdfad8..d261693 100644 --- a/src/common/sock/packet.cpp +++ b/src/common/sock/packet.cpp @@ -148,6 +148,15 @@ bool sosc::Packet::Check(int region_count, ...) const { return true; } +void sosc::Packet::TrimRegions(const std::vector& ids) { + if(ids.size() == 0) + for(uint32_t id = 0; id < this->regions.size(); ++id) + str::trim(&this->regions[id]); + else + for(auto id : ids) + str::trim(&this->regions[id]); +} + std::string* sosc::Packet::ToString(std::string* packet) const { *packet = std::string(8, 0); (*packet)[0] = 0xB0; diff --git a/src/common/sock/packet.hpp b/src/common/sock/packet.hpp index 0552c62..583cc61 100644 --- a/src/common/sock/packet.hpp +++ b/src/common/sock/packet.hpp @@ -35,6 +35,7 @@ public: int Parse(const std::string& data, std::string* extra = nullptr); bool Check(int region_count, ...) const; + void TrimRegions(const std::vector& ids = {}); inline void SetId(uint8_t id) { this->id = id; diff --git a/src/common/utils/ini.cpp b/src/common/utils/ini.cpp index 03606c0..26273c6 100644 --- a/src/common/utils/ini.cpp +++ b/src/common/utils/ini.cpp @@ -1,20 +1,20 @@ #include "ini.hpp" using namespace sosc::ini; -bool Field::Test() const { +bool Field::Test(const std::string& value) const { try { switch(type) { case INT32: - (int32_t)File::Proxy(name); + (int32_t)File::Proxy(value); break; case UINT32: - (uint32_t)File::Proxy(name); + (uint32_t)File::Proxy(value); break; case DOUBLE: - (double)File::Proxy(name); + (double)File::Proxy(value); break; case BOOL: - (bool)File::Proxy(name); + (bool)File::Proxy(value); break; } @@ -100,13 +100,14 @@ File* File::Open for(auto& section : (*ini)[rule.name].sections) { for(auto &field : rule.required_fields) { - if(section.values.count(str::tolower(field.name)) == 0) + std::string field_name = str::tolower(field.name); + if(section.values.count(field_name) == 0) throw LoadError(ini, -1, str::join({ "Required field '", field.name, "' in section '", rule.name, "' not found." })); - if(!field.Test()) + if(!field.Test(section.values.at(field_name))) throw LoadError(ini, -1, str::join({ "Field '", field.name, "' in section '",rule.name, "' " "cannot be casted to requested type." @@ -121,15 +122,16 @@ File* File::Open std::runtime_error File::LoadError (File* file, int line, const std::string &error) { - delete file; + std::string msg; if(line > 0) - return std::runtime_error(str::join( + msg = str::join( {"LOAD ERROR IN '", file->filename, "' L", TOSTR(line), ": ", error} - )); + ); else - return std::runtime_error(str::join( - {"LOAD ERROR IN '", file->filename, "': ", error} - )); + msg = str::join({"LOAD ERROR IN '", file->filename, "': ", error}); + + delete file; + throw std::runtime_error(msg); } bool File::HasSection(std::string name) const { diff --git a/src/common/utils/ini.hpp b/src/common/utils/ini.hpp index 7ad6efb..a06f385 100644 --- a/src/common/utils/ini.hpp +++ b/src/common/utils/ini.hpp @@ -7,6 +7,8 @@ #include #include +#include + #include "string.hpp" namespace sosc { @@ -23,7 +25,7 @@ struct Field { explicit Field (const std::string& name, kType type = STRING) : name(name), type(type) {} - bool Test() const; + bool Test(const std::string& value) const; std::string name; kType type; diff --git a/src/common/utils/net.hpp b/src/common/utils/net.hpp index c9e0d0d..99cf201 100644 --- a/src/common/utils/net.hpp +++ b/src/common/utils/net.hpp @@ -23,19 +23,19 @@ #undef NTOHLL #undef NTOHULL -#define HTONS (X) sosc::net::htonv(X) +#define HTONS(X) sosc::net::htonv(X) #define HTONUS(X) sosc::net::htonv(X) -#define NTOHS (X) sosc::net::ntohv(X) +#define NTOHS(X) sosc::net::ntohv(X) #define NTOHUS(X) sosc::net::ntohv(X) -#define HTONL (X) sosc::net::htonv(X) +#define HTONL(X) sosc::net::htonv(X) #define HTONUL(X) sosc::net::htonv(X) -#define NTOHL (X) sosc::net::ntohv(X, 0) +#define NTOHL(X) sosc::net::ntohv(X, 0) #define NTOHUL(X) sosc::net::ntohv(X, 0) -#define HTONLL (X) sosc::net::htonv(X) +#define HTONLL(X) sosc::net::htonv(X) #define HTONULL(X) sosc::net::htonv(X) -#define NTOHLL (X) sosc::net::ntohv(X, 0) +#define NTOHLL(X) sosc::net::ntohv(X, 0) #define NTOHULL(X) sosc::net::ntohv(X, 0) namespace sosc { diff --git a/src/common/utils/string.cpp b/src/common/utils/string.cpp index 434937e..f73b7d5 100644 --- a/src/common/utils/string.cpp +++ b/src/common/utils/string.cpp @@ -88,15 +88,15 @@ std::string sosc::str::join(const std::vector& parts, std::string sosc::str::join(const std::vector& parts, std::string delimiter, int count) { - std::string assembled; + std::stringstream ss; int bounds = (count == -1) ? parts.size() : std::min(count, parts.size()); for(int i = 0; i < bounds; ++i) - assembled += (i == 0 ? "" : delimiter) + parts[i]; + ss << (i == 0 ? "" : delimiter) + parts[i]; - return assembled; + return ss.str(); } bool sosc::str::starts diff --git a/src/server/db/database.cpp b/src/server/db/database.cpp index fe11b03..37237a6 100644 --- a/src/server/db/database.cpp +++ b/src/server/db/database.cpp @@ -11,11 +11,12 @@ static struct { bool sosc::db::init_databases(std::string* error) { if(_ctx.ready) return true; + _ctx.ready = true; sqlite3_open(":memory:", &_ctx.mem_db); sqlite3_exec(_ctx.mem_db, _mem_db_sql, nullptr, nullptr, nullptr); - sqlite3_open("scape.db", &_ctx.hard_db); + sqlite3_open(SOSC_RESC("scape.db").c_str(), &_ctx.hard_db); int32_t migrationsExist = db::Query::ScalarInt32( "SELECT COUNT(*) FROM SQLITE_MASTER WHERE TBL_NAME = 'MIGRATIONS'" @@ -27,13 +28,14 @@ bool sosc::db::init_databases(std::string* error) { if(lastMig > _hard_db_sql.size()) { if(error != nullptr) *error = "HARD DB: RECORDED MIGRATION COUNT TOO HIGH"; + _ctx.ready = false; return false; } int id; Query insertMigration( "INSERT INTO MIGRATIONS (ID, SQL_HASH, DATE_RAN) " - "VALUES (?, ?, NOW())" + "VALUES (?, ?, DATETIME('NOW'))" ); Query getMigration("SELECT SQL_HASH FROM MIGRATIONS WHERE ID = ?"); for(id = 0; id < _hard_db_sql.size(); ++id) { @@ -44,6 +46,7 @@ bool sosc::db::init_databases(std::string* error) { if(id < lastMig) { if(error != nullptr) *error = "HARD DB: MIGRATION RECORDS NOT CONTINUOUS"; + _ctx.ready = false; return false; } @@ -56,6 +59,7 @@ bool sosc::db::init_databases(std::string* error) { if(hash != cgc::sha1(_hard_db_sql[id])) { if(error != nullptr) *error = "HARD DB: MIGRATION SQL HASH MISMATCH"; + _ctx.ready = false; return false; } } @@ -106,7 +110,7 @@ sosc::db::Query::Query(const std::string& query, int db) : results(this) { void sosc::db::Query::SetQuery(const std::string &query, int db) { if(!_ctx.ready) return; - if(!this->open) + if(this->open) this->Close(); this->database = db == DB_USE_MEMORY ? _ctx.mem_db : _ctx.hard_db; @@ -120,6 +124,8 @@ void sosc::db::Query::SetQuery(const std::string &query, int db) { if(status == SQLITE_OK) this->open = true; + else + throw std::runtime_error(sqlite3_errmsg(this->database)); } void sosc::db::Query::BindDouble(double value, int i) { diff --git a/src/server/db/database.hpp b/src/server/db/database.hpp index b4b5c4b..d24726f 100644 --- a/src/server/db/database.hpp +++ b/src/server/db/database.hpp @@ -4,6 +4,7 @@ #include "sqlite/sqlite3.h" #include "utils/time.hpp" #include "crypto/sha1.hpp" +#include "common.hpp" #include #include @@ -86,7 +87,7 @@ private: }; // THE FOLLOWING ARE NOT THREAD SAFE !! -// CALL THEM ONLY WHEN MASTER POOL IS INACTIVE +// CALL THEM ONLY WHEN MASTER POOLS ARE INACTIVE bool init_databases(std::string* error); void close_databases(); }} diff --git a/src/server/hosts/master.hpp b/src/server/hosts/master.hpp index 2bdb780..4e1b51b 100644 --- a/src/server/hosts/master.hpp +++ b/src/server/hosts/master.hpp @@ -5,8 +5,8 @@ #include "sock/scapesock.hpp" #include "sock/pool.hpp" +#include "crypto/bcrypt.hpp" #include "db/database.hpp" - #include "ctx/master.hpp" #include @@ -27,7 +27,11 @@ public: ~MasterClient() { this->Close(); }; private: bool ProcessLogin(Packet& pck); + bool LoginError(uint16_t error_code); + bool ProcessRegistration(Packet& pck); + bool RegistrationError(uint16_t error_code); + bool ListServers(Packet& pck); enum MasterToClientId { @@ -79,9 +83,9 @@ private: bool StatusUpdate(Packet& pck); bool AuthenticationFailure - (const std::string& packetId, uint16_t errorCode); + (const std::string& packet_id, uint16_t error_code); - bool NotAuthorized(const std::string& packetId); + bool NotAuthorized(const std::string& packet_id); enum SlaveToMasterId { kAuthentication = 0, diff --git a/src/server/hosts/master_client.cpp b/src/server/hosts/master_client.cpp index 989a74a..3371015 100644 --- a/src/server/hosts/master_client.cpp +++ b/src/server/hosts/master_client.cpp @@ -7,44 +7,50 @@ static struct { /** MASTERCLIENTPOOL CODE **/ void sosc::MasterClientPool::SetupQueries(db::Queries *queries) { -#define QRY_USER_REG_CHECK 0 +#define QRY_USER_NAME_REG_CHECK 0 queries->push_back(new db::Query( "SELECT COUNT(*) FROM `USERS` " - "WHERE `USERNAME` = ? OR `EMAIL` = ?" + "WHERE `USERNAME` = ?" )); -#define QRY_USER_REGISTER 1 +#define QRY_USER_MAIL_REG_CHECK 1 + queries->push_back(new db::Query( + "SELECT COUNT(*) FROM `USERS` " + "WHERE `USERNAME` = ?" + )); + +#define QRY_USER_REGISTER 2 queries->push_back(new db::Query( "INSERT INTO `USERS` " "(`USERNAME`, `PASS_HASH`, `EMAIL`, `ACTIVATED`, `JOINED`) " "VALUES (?, ?, ?, 0, CURRENT_TIMESTAMP)" )); -#define QRY_USER_NAME_EXISTS 2 +#define QRY_USER_NAME_EXISTS 3 queries->push_back(new db::Query( "SELECT COUNT(*) FROM `USERS` " "WHERE LOWER(`USERNAME`) = LOWER(?)" )); -#define QRY_USER_GET_PWD_HASH 3 +#define QRY_USER_GET_PWD_HASH 4 queries->push_back(new db::Query( "SELECT `ID`, `PASS_HASH` FROM `USERS` " "WHERE LOWER(`USERNAME`) = LOWER(?)" )); -#define QRY_USER_GENERATE_KEY 4 +#define QRY_USER_GENERATE_KEY 5 queries->push_back(new db::Query( "INSERT OR IGNORE INTO `USER_KEYS` " "(`ID`, `SECRET`) VALUES (?, RANDOMBLOB(128))" )); -#define QRY_USER_GET_KEY 5 +#define QRY_USER_GET_KEY 6 queries->push_back(new db::Query( "SELECT `SECRET` FROM `USER_KEYS` " "WHERE `ID` = ?" )); -#define QRY_USER_CHECK_KEY 6 +#define QRY_USER_CHECK_KEY 7 queries->push_back(new db::Query( "SELECT COUNT(*) FROM `USER_KEYS` " "WHERE `ID` = ? AND `SECRET` = ?" @@ -88,14 +94,85 @@ bool sosc::MasterClient::ProcessLogin(Packet &pck) { return true; if(!pck.Check(2, PCK_ANY, PCK_ANY)) return false; + pck.TrimRegions(); + db::ResultSet* results = nullptr; db::Query* query = this->queries->at(QRY_USER_NAME_EXISTS); + query->Reset(); + query->BindText(pck[0], 0); + if(query->ScalarInt32() == 0) + return LoginError(0x101); + query = this->queries->at(QRY_USER_GET_PWD_HASH); + query->Reset(); + query->BindText(pck[0], 0); + results = query->GetResults(); + results->Step(); + + int64_t user_id = results->GetInt64(0); + if(!cgc::bcrypt_check(pck[1], results->GetText(1))) + return LoginError(0x102); + + query = this->queries->at(QRY_USER_GENERATE_KEY); + query->Reset(); + query->BindInt64(user_id, 0); + query->NonQuery(); + + query = this->queries->at(QRY_USER_GET_KEY); + query->Reset(); + query->BindInt64(user_id, 0); + auto secret = query->ScalarBlob(); + + this->sock.Send(Packet(kLoginResponse, {"\1", secret})); + this->authed = true; return true; } -bool sosc::MasterClient::ProcessRegistration(Packet &pck) { +bool sosc::MasterClient::LoginError(uint16_t error_code) { + if(++this->auth_attempts < MAX_AUTH_ATTEMPTS) { + this->sock.Send( + Packet(kLoginResponse, {"\0", HTONUS(error_code)}) + ); + return true; + } else { + return this->Close( + Packet(kLoginResponse, {"\0", HTONUS(0x100)}) + ); + } +} +bool sosc::MasterClient::ProcessRegistration(Packet &pck) { + if(!pck.Check(3, PCK_ANY, PCK_ANY, PCK_ANY)) + return false; + pck.TrimRegions(); + + db::Query* query = this->queries->at(QRY_USER_NAME_REG_CHECK); + query->Reset(); + query->BindText(pck[0], 0); + if(query->ScalarInt32() > 0) + return RegistrationError(0x100); + + query = this->queries->at(QRY_USER_MAIL_REG_CHECK); + query->Reset(); + query->BindText(pck[2], 0); + if(query->ScalarInt32() > 0) + return RegistrationError(0x101); + + query = this->queries->at(QRY_USER_REGISTER); + query->Reset(); + query->BindText(pck[0], 0); + query->BindText(cgc::bcrypt_hash(pck[1]), 1); + query->BindText(pck[2], 2); + query->NonQuery(); + + this->sock.Send(Packet(kRegisterResponse, {"\1", 0x000})); + return true; +} + +bool sosc::MasterClient::RegistrationError(uint16_t error_code) { + this->sock.Send( + Packet(kRegisterResponse, {"\0", HTONUS(error_code)}) + ); return true; } diff --git a/src/server/hosts/master_intra.cpp b/src/server/hosts/master_intra.cpp index 694145f..59fbcf2 100644 --- a/src/server/hosts/master_intra.cpp +++ b/src/server/hosts/master_intra.cpp @@ -152,23 +152,23 @@ bool sosc::MasterIntra::Authentication(sosc::Packet& pck) { } bool sosc::MasterIntra::AuthenticationFailure - (const std::string& packetId, uint16_t errorCode) + (const std::string& packet_id, uint16_t error_code) { if(++this->auth_attempts < MAX_AUTH_ATTEMPTS) { this->sock.Send( - Packet(kNegativeAck, { packetId , net::htonv(errorCode) }) + Packet(kNegativeAck, { packet_id , net::htonv(error_code) }) ); return true; } else { return this->Close( - Packet(kNegativeAck, { packetId, net::htonv(0x100) }) + Packet(kNegativeAck, { packet_id, net::htonv(0x100) }) ); } } -bool sosc::MasterIntra::NotAuthorized(const std::string& packetId) { +bool sosc::MasterIntra::NotAuthorized(const std::string& packet_id) { return this->Close( - Packet(kNegativeAck, { packetId, net::htonv(0x200) }) + Packet(kNegativeAck, { packet_id, net::htonv(0x200) }) ); } diff --git a/src/server/main.cpp b/src/server/main.cpp index 22897aa..7d2535a 100644 --- a/src/server/main.cpp +++ b/src/server/main.cpp @@ -65,8 +65,6 @@ void configure_poolinfo(sosc::poolinfo_t* info, int main(int argc, char **argv) { using namespace sosc; - if(argc < 2) - return -1; ini::File* config; try { @@ -105,7 +103,7 @@ int main(int argc, char **argv) { poolinfo_t info; configure_poolinfo(&_ctx.default_info, (*config)["defaults"][0]); - if((*config)["master"]["run master"]) { + if((*config)["general"]["run master"]) { if(!config->HasSection("master to client") || !config->HasSection("master to slave")) { @@ -122,13 +120,13 @@ int main(int argc, char **argv) { configure_poolinfo(&info, (*config)["master to slave"][0]); _ctx.master_intra = new master_intra_ctx; master_intra_start( - (uint16_t)(*config)["master"]["intra port"], info + (uint16_t)(*config)["master to slave"]["port"], info ); configure_poolinfo(&info, (*config)["master to client"][0]); _ctx.master_client = new master_client_ctx; master_client_start( - (uint16_t)(*config)["master"]["client port"], info + (uint16_t)(*config)["master to client"]["port"], info ); }