diff --git a/protocol.md b/protocol.md index c7238fd..ee4ac60 100644 --- a/protocol.md +++ b/protocol.md @@ -185,8 +185,18 @@ Communication between the master server and clients will be done over a WebSocke 1 - Client Key - Big Int + Server Name + Text + + + 2 + License + Text + + + 3 + Secret + Blob (512b) @@ -472,9 +482,11 @@ TODO: MAKE THIS SECTION NOT LOOK LIKE SHIT 0x02: COULD NOT PARSE KEY #### M -> S (ID 4) -0x01: LICENSE DATA INCORRECT +0x01: MAX AUTH ATTEMPTS REACHED -0x02: LICENSE LIMIT REACHED +0x02: LICENSE DATA INCORRECT + +0x03: LICENSE LIMIT REACHED ### Master / Client diff --git a/server/src/db/database.hpp b/server/src/db/database.hpp index ae33682..000cd94 100644 --- a/server/src/db/database.hpp +++ b/server/src/db/database.hpp @@ -14,6 +14,7 @@ namespace sosc { namespace db { class Query; +typedef std::vector QueryList; class ResultSet { public: diff --git a/server/src/hosts/master.hpp b/server/src/hosts/master.hpp index 1c75cdd..ef4a46b 100644 --- a/server/src/hosts/master.hpp +++ b/server/src/hosts/master.hpp @@ -10,6 +10,8 @@ #include "../db/database.hpp" +#include + namespace sosc { /** MASTER -> CLIENT **/ @@ -36,7 +38,7 @@ protected: class MasterIntra { public: explicit MasterIntra(const IntraClient& client); - bool Process(); + bool Process(const db::QueryList* queries); bool Close(); bool Close(const Packet& message); @@ -45,6 +47,9 @@ private: bool Authentication(Packet& pck); bool StatusUpdate(Packet& pck); + bool AuthenticationFailure + (const std::string& packetId, uint16_t errorCode); + enum SlaveToMasterId { kInitAttempt = 1, kAuthentication, @@ -66,13 +71,21 @@ private: int auth_attempts; const int MAX_AUTH_ATTEMPTS = 3; std::string license; + + const db::QueryList* queries; }; class MasterIntraPool : public Pool { +public: + MasterIntraPool(); protected: bool ProcessClient(MasterIntra& client) override { - return client.Process(); + return client.Process(&this->queries); } + + void Stop() override; +private: + db::QueryList queries; }; } diff --git a/server/src/hosts/master_intra.cpp b/server/src/hosts/master_intra.cpp index f128880..038d31b 100644 --- a/server/src/hosts/master_intra.cpp +++ b/server/src/hosts/master_intra.cpp @@ -1,5 +1,43 @@ #include "master.hpp" #include "../db/database.hpp" +#include + +static struct { + std::mutex license_check_mtx; +} _ctx; + +/** MASTERINTRAPOOL CODE **/ + +sosc::MasterIntraPool::MasterIntraPool() { +#define QRY_LICENSE_CHECK 0 + this->queries.push_back(new db::Query( + "SELECT COUNT(*) FROM SERVER_LICENSES " + "WHERE KEY_ID = ? AND SECRET = ?" + )); + +#define QRY_LICENSE_LIMIT 1 + this->queries.push_back(new db::Query( + "SELECT ALLOWANCE FROM SERVER_LICENSES WHERE KEY_ID = ?" + )); + +#define QRY_LICENSE_ACTIVE_COUNT 2 + this->queries.push_back(new db::Query( + "SELECT COUNT(*) FROM SERVER_LIST WHERE LICENSE = ?" + , DB_USE_MEMORY)); + +#define QRY_LICENSE_ +} + +void sosc::MasterIntraPool::Stop() { + Pool::Stop(); + + for(auto& query : this->queries) { + query->Close(); + delete query; + } +} + +/** MASTERINTRA CODE **/ sosc::MasterIntra::MasterIntra(const IntraClient& client) { this->sock = client; @@ -7,7 +45,7 @@ sosc::MasterIntra::MasterIntra(const IntraClient& client) { this->auth_attempts = 0; } -bool sosc::MasterIntra::Process() { +bool sosc::MasterIntra::Process(const db::QueryList* queries) { Packet pck; int status = this->sock.Receive(&pck); if(status == PCK_ERR) @@ -15,6 +53,7 @@ bool sosc::MasterIntra::Process() { else if(status == PCK_MORE) return true; + this->queries = queries; switch(pck.GetId()) { case kInitAttempt: return this->InitAttempt(pck); @@ -27,7 +66,8 @@ bool sosc::MasterIntra::Process() { } } -bool sosc::MasterIntra::InitAttempt(sosc::Packet &pck) { +bool sosc::MasterIntra::InitAttempt(sosc::Packet &pck) +{ if(!pck.Check(1, key.key_size_bytes)) return this->Close(Packet(kEncryptionError, { "\x01" })); @@ -38,24 +78,62 @@ bool sosc::MasterIntra::InitAttempt(sosc::Packet &pck) { this->sock.Send(response); } -bool sosc::MasterIntra::Authentication(sosc::Packet &pck) { +bool sosc::MasterIntra::Authentication(sosc::Packet &pck) +{ if(this->authed) return true; - if(!pck.Check(2, PCK_ANY, 512)) - return this->Close(Packet(kNegativeAck, { "\x01" })); + std::string packetId = BYTESTR(kAuthentication); + if(!pck.Check(3, PCK_ANY, PCK_ANY, 512)) + return this->Close(); - db::Query = db::Query::ScalarInt32( - "SELECT COUNT(*) FROM SERVER_LICENSES " - "WHERE KEY_ID = ? AND SECRET = ?" - ); + db::Query* query = this->queries->at(QRY_LICENSE_CHECK); + query->Reset(); + query->BindText(pck[1], 0); + query->BindBlob(pck[2], 1); + if(query->ScalarInt32() == 0) + return AuthenticationFailure(packetId, 2); - if(isValid > 0) { + _ctx.license_check_mtx.lock(); + int limit; + query = this->queries->at(QRY_LICENSE_LIMIT); + query->Reset(); + query->BindText(pck[1], 0); + if((limit = query->ScalarInt32()) != 0) { + query = this->queries->at(QRY_LICENSE_ACTIVE_COUNT); + query->Reset(); + query->BindText(pck[1], 0); + if(query->ScalarInt32() < limit) { + _ctx.license_check_mtx.unlock(); + return AuthenticationFailure(packetId, 3); + } + } + + + _ctx.license_check_mtx.unlock(); + + this->authed = true; + return true; +} + +bool sosc::MasterIntra::AuthenticationFailure + (const std::string& packetId, uint16_t errorCode) +{ + if(++this->auth_attempts < MAX_AUTH_ATTEMPTS) { + this->sock.Send( + Packet(kNegativeAck, { packetId , net::htonv(errorCode) }) + ); + return true; + } else { + return this->Close( + Packet(kNegativeAck, { packetId, net::htonv(1) }) + ); } } -bool sosc::MasterIntra::StatusUpdate(sosc::Packet &pck) { +bool sosc::MasterIntra::StatusUpdate(sosc::Packet &pck) +{ } @@ -67,4 +145,4 @@ bool sosc::MasterIntra::Close() { bool sosc::MasterIntra::Close(const Packet &message) { this->sock.Send(message); this->Close(); -} +} \ No newline at end of file diff --git a/server/src/sock/pool.hpp b/server/src/sock/pool.hpp index 89cac5c..fb8c756 100644 --- a/server/src/sock/pool.hpp +++ b/server/src/sock/pool.hpp @@ -40,7 +40,7 @@ public: return this->is_open; } - void Stop(); + virtual void Stop(); protected: virtual bool ProcessClient(T& client) = 0; private: diff --git a/server/src/utils/string.hpp b/server/src/utils/string.hpp index 86f6528..8a279c1 100644 --- a/server/src/utils/string.hpp +++ b/server/src/utils/string.hpp @@ -10,6 +10,10 @@ #undef TOSTR #define TOSTR(X) std::to_string(X) +#undef bytestr +#undef BYTESTR +#define BYTESTR(X) std::string(1, X) + namespace sosc { namespace str { std::string trim (std::string str);