diff --git a/SharpChat/ChatContext.cs b/SharpChat/ChatContext.cs index 6567165..51a7b12 100644 --- a/SharpChat/ChatContext.cs +++ b/SharpChat/ChatContext.cs @@ -4,7 +4,6 @@ using SharpChat.Packet; using System; using System.Collections.Generic; using System.Linq; -using System.Net; using System.Threading; namespace SharpChat { @@ -12,7 +11,7 @@ namespace SharpChat { public readonly SemaphoreSlim ContextAccess = new(1, 1); public ChannelsContext Channels { get; } = new(); - public List Connections { get; } = new(); + public ConnectionsContext Connections { get; } = new(); public UsersContext Users { get; } = new(); public IEventStorage Events { get; } public ChannelsUsersContext ChannelsUsers { get; } = new(); @@ -79,18 +78,16 @@ namespace SharpChat { } public void Update() { - foreach(ConnectionInfo conn in Connections) - if(!conn.IsDisposed && conn.HasTimedOut) { - conn.Dispose(); - Logger.Write($"Nuked connection {conn.Id} associated with {conn.User}."); - } + ConnectionInfo[] timedOut = Connections.GetTimedOut(); + foreach(ConnectionInfo conn in timedOut) { + Connections.Remove(conn); + conn.Close(1002); - int removed = Connections.RemoveAll(conn => conn.IsDisposed); - if(removed > 0) - Logger.Write($"Removed {removed} nuked connections from the list."); + Logger.Write($"<{conn.RemoteEndPoint}> Nuked timed out connection from user #{conn.UserId}."); + } foreach(UserInfo user in Users.All) - if(!Connections.Any(conn => conn.User == user)) { + if(!Connections.HasUser(user)) { HandleDisconnect(user, UserDisconnectReason.TimeOut); Logger.Write($"Timed out {user} (no more connections)."); } @@ -192,10 +189,13 @@ namespace SharpChat { } else SendTo(user, new ForceDisconnectPacket()); - foreach(ConnectionInfo conn in Connections) - if(conn.User == user) - conn.Dispose(); - Connections.RemoveAll(conn => conn.IsDisposed); + ConnectionInfo[] conns = Connections.GetUser(user); + foreach(ConnectionInfo conn in conns) { + Connections.Remove(conn); + conn.Close(1000); + + Logger.Write($"<{conn.RemoteEndPoint}> Nuked connection from banned user #{conn.UserId}."); + } HandleDisconnect(user, reason); } @@ -372,22 +372,17 @@ namespace SharpChat { } public void Send(SockChatS2CPacket packet) { - foreach(ConnectionInfo conn in Connections) - if(conn.IsAuthed) - conn.Send(packet); + Connections.WithAuthed(conn => conn.Send(packet)); } public void SendTo(UserInfo user, SockChatS2CPacket packet) { - foreach(ConnectionInfo conn in Connections) - if(conn.IsAuthed && conn.User!.UserId == user.UserId) - conn.Send(packet); + Connections.WithUser(user, conn => conn.Send(packet)); } public void SendTo(ChannelInfo channel, SockChatS2CPacket packet) { long[] userIds = ChannelsUsers.GetChannelUserIds(channel); - foreach(ConnectionInfo conn in Connections) - if(conn.IsAuthed && userIds.Contains(conn.User!.UserId)) - conn.Send(packet); + foreach(long userId in userIds) + Connections.WithUser(userId, conn => conn.Send(packet)); } public void SendToUserChannels(UserInfo user, SockChatS2CPacket packet) { @@ -396,10 +391,6 @@ namespace SharpChat { SendTo(chan, packet); } - public IPAddress[] GetRemoteAddresses(UserInfo user) { - return Connections.Where(c => c.IsAlive && c.User == user).Select(c => c.RemoteAddress).Distinct().ToArray(); - } - public void ForceChannel(UserInfo user, ChannelInfo? chan = null) { chan ??= Channels.Get(ChannelsUsers.GetUserLastChannel(user)); if(chan != null) diff --git a/SharpChat/Commands/KickBanCommand.cs b/SharpChat/Commands/KickBanCommand.cs index 2e31da2..d950715 100644 --- a/SharpChat/Commands/KickBanCommand.cs +++ b/SharpChat/Commands/KickBanCommand.cs @@ -61,19 +61,20 @@ namespace SharpChat.Commands { Task.Run(async () => { string userId = banUser.UserId.ToString(); - string userIp = ctx.Chat.GetRemoteAddresses(banUser).FirstOrDefault()?.ToString() ?? string.Empty; - - // obviously it makes no sense to only check for one ip address but that's current misuzu limitations - MisuzuBanInfo? fbi = await Misuzu.CheckBanAsync(userId, userIp); + MisuzuBanInfo? fbi = await Misuzu.CheckBanAsync(userId); if(fbi != null && fbi.IsBanned && !fbi.HasExpired) { ctx.Chat.SendTo(ctx.User, new KickBanNotAllowedErrorPacket(SockChatUtility.GetUserName(banUser))); return; } + string[] userRemoteAddrs = ctx.Chat.Connections.GetUserRemoteAddresses(banUser); + string userRemoteAddr = string.Format(", ", userRemoteAddrs); + + // Misuzu only stores the IP address in private comment and doesn't do any checking, so this is fine. await Misuzu.CreateBanAsync( - userId, userIp, - ctx.User.UserId.ToString(), ctx.Connection.RemoteAddress.ToString(), + userId, userRemoteAddr, + ctx.User.UserId.ToString(), ctx.Connection.RemoteAddress, duration, banReason ); diff --git a/SharpChat/Commands/ShutdownRestartCommand.cs b/SharpChat/Commands/ShutdownRestartCommand.cs index 05f6ab0..c8db58e 100644 --- a/SharpChat/Commands/ShutdownRestartCommand.cs +++ b/SharpChat/Commands/ShutdownRestartCommand.cs @@ -5,11 +5,17 @@ using System.Threading; namespace SharpChat.Commands { public class ShutdownRestartCommand : IUserCommand { private readonly ManualResetEvent WaitHandle; - private readonly Func ShutdownCheck; + private readonly Func ShuttingDown; + private readonly Action SetShutdown; - public ShutdownRestartCommand(ManualResetEvent waitHandle, Func shutdownCheck) { + public ShutdownRestartCommand( + ManualResetEvent waitHandle, + Func shuttingDown, + Action setShutdown + ) { WaitHandle = waitHandle; - ShutdownCheck = shutdownCheck; + ShuttingDown = shuttingDown; + SetShutdown = setShutdown; } public bool IsMatch(UserCommandContext ctx) { @@ -23,13 +29,10 @@ namespace SharpChat.Commands { return; } - if(!ShutdownCheck()) + if(ShuttingDown()) return; - if(ctx.NameEquals("restart")) - foreach(ConnectionInfo conn in ctx.Chat.Connections) - conn.PrepareForRestart(); - + SetShutdown(ctx.NameEquals("restart")); ctx.Chat.Update(); WaitHandle?.Set(); } diff --git a/SharpChat/Commands/WhoisCommand.cs b/SharpChat/Commands/WhoisCommand.cs index ea3b25a..14ee228 100644 --- a/SharpChat/Commands/WhoisCommand.cs +++ b/SharpChat/Commands/WhoisCommand.cs @@ -1,6 +1,5 @@ using SharpChat.Packet; using System.Linq; -using System.Net; namespace SharpChat.Commands { public class WhoisCommand : IUserCommand { @@ -24,8 +23,8 @@ namespace SharpChat.Commands { return; } - foreach(IPAddress ip in ctx.Chat.GetRemoteAddresses(ipUser)) - ctx.Chat.SendTo(ctx.User, new WhoisResponsePacket(ipUser.UserName, ip.ToString())); + foreach(string remoteAddr in ctx.Chat.Connections.GetUserRemoteAddresses(ipUser)) + ctx.Chat.SendTo(ctx.User, new WhoisResponsePacket(ipUser.UserName, remoteAddr)); } } } diff --git a/SharpChat/ConnectionInfo.cs b/SharpChat/ConnectionInfo.cs index 950497b..5c37fbf 100644 --- a/SharpChat/ConnectionInfo.cs +++ b/SharpChat/ConnectionInfo.cs @@ -3,45 +3,42 @@ using System; using System.Net; namespace SharpChat { - public class ConnectionInfo : IDisposable { - public const int ID_LENGTH = 20; - -#if DEBUG - public static TimeSpan SessionTimeOut { get; } = TimeSpan.FromMinutes(1); -#else - public static TimeSpan SessionTimeOut { get; } = TimeSpan.FromMinutes(5); -#endif - + public class ConnectionInfo { public IWebSocketConnection Socket { get; } + public DateTimeOffset LastPing { get; private set; } - public string Id { get; } - public bool IsDisposed { get; private set; } - public DateTimeOffset LastPing { get; set; } = DateTimeOffset.Now; - public UserInfo? User { get; set; } + public long UserId { get; private set; } = 0; - private int CloseCode { get; set; } = 1000; - - public IPAddress RemoteAddress { get; } + public string RemoteAddress { get; } public ushort RemotePort { get; } - - public bool IsAlive => !IsDisposed && !HasTimedOut; - - public bool IsAuthed => IsAlive && User is not null; + public string RemoteEndPoint { get; } public ConnectionInfo(IWebSocketConnection sock) { Socket = sock; - Id = RNG.SecureRandomString(ID_LENGTH); - if(!IPAddress.TryParse(sock.ConnectionInfo.ClientIpAddress, out IPAddress? addr)) - throw new Exception("Unable to parse remote address?????"); + BumpPing(); - if(IPAddress.IsLoopback(addr) + IPAddress remoteAddr = IPAddress.Parse(sock.ConnectionInfo.ClientIpAddress); + + if(IPAddress.IsLoopback(remoteAddr) && sock.ConnectionInfo.Headers.ContainsKey("X-Real-IP") && IPAddress.TryParse(sock.ConnectionInfo.Headers["X-Real-IP"], out IPAddress? realAddr)) - addr = realAddr; + remoteAddr = realAddr; - RemoteAddress = addr; + RemoteAddress = remoteAddr.ToString(); RemotePort = (ushort)sock.ConnectionInfo.ClientPort; + RemoteEndPoint = string.Format( + RemoteAddress.Contains(':') ? "[{0}]:{1}" : "{0}:{1}", + RemoteAddress, RemotePort + ); + } + + // please call this through ConnectionsContext + public void SetUserId(long userId) { + if(UserId > 0) + throw new InvalidOperationException("This connection is already associated with a user."); + + UserId = userId; } public void Send(SockChatS2CPacket packet) { @@ -50,39 +47,15 @@ namespace SharpChat { string data = packet.Pack(); if(!string.IsNullOrWhiteSpace(data)) - Socket.Send(data); + Socket.Send(data).Wait(); } public void BumpPing() { - LastPing = DateTimeOffset.Now; + LastPing = DateTimeOffset.UtcNow; } - public bool HasTimedOut - => DateTimeOffset.Now - LastPing > SessionTimeOut; - - public void PrepareForRestart() { - CloseCode = 1012; - } - - ~ConnectionInfo() { - DoDispose(); - } - - public void Dispose() { - DoDispose(); - GC.SuppressFinalize(this); - } - - private void DoDispose() { - if(IsDisposed) - return; - - IsDisposed = true; - Socket.Close(CloseCode); - } - - public override string ToString() { - return Id; + public void Close(int code) { + Socket.Close(code); } } } diff --git a/SharpChat/ConnectionsContext.cs b/SharpChat/ConnectionsContext.cs new file mode 100644 index 0000000..f1b53ea --- /dev/null +++ b/SharpChat/ConnectionsContext.cs @@ -0,0 +1,176 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace SharpChat { + public class ConnectionsContext { + public static readonly TimeSpan TimeOut = TimeSpan.FromMinutes(5); + + private readonly HashSet Connections = new(); + private readonly HashSet AuthedConnections = new(); + private readonly Dictionary> UserConnections = new(); + + public ConnectionInfo[] All => Connections.ToArray(); + public ConnectionInfo[] Authed => AuthedConnections.ToArray(); + + public void WithAll(Action body) { + foreach(ConnectionInfo conn in Connections) + body(conn); + } + + public void WithAuthed(Action body) { + foreach(ConnectionInfo conn in AuthedConnections) + body(conn); + } + + public ConnectionInfo[] GetTimedOut() { + List conns = new(); + + foreach(ConnectionInfo conn in Connections) + if(DateTimeOffset.UtcNow - conn.LastPing > TimeOut) + conns.Add(conn); + + return conns.ToArray(); + } + + public int GetCountForUser(long userId) { + return UserConnections.ContainsKey(userId) ? UserConnections.Count : 0; + } + + public int GetCountForUser(UserInfo userInfo) { + return GetCountForUser(userInfo.UserId); + } + + public bool HasUser(long userId) { + return GetCountForUser(userId) > 0; + } + + public bool HasUser(UserInfo userInfo) { + return HasUser(userInfo.UserId); + } + + public ConnectionInfo[] GetUser(long userId) { + if(!UserConnections.ContainsKey(userId)) + return Array.Empty(); + + return UserConnections[userId].ToArray(); + } + + public ConnectionInfo[] GetUser(UserInfo userInfo) { + return GetUser(userInfo.UserId); + } + + public void WithUser(long userId, Action body) { + if(!UserConnections.ContainsKey(userId)) + return; + + foreach(ConnectionInfo conn in UserConnections[userId]) + body(conn); + } + + public void WithUser(UserInfo userInfo, Action body) { + WithUser(userInfo.UserId, body); + } + + public string[] GetAllRemoteAddresses() { + HashSet addrs = new(); + + foreach(ConnectionInfo conn in Connections) + addrs.Add(conn.RemoteAddress); + + return addrs.ToArray(); + } + + public string[] GetAuthedRemoteAddresses() { + HashSet addrs = new(); + + foreach(ConnectionInfo conn in AuthedConnections) + addrs.Add(conn.RemoteAddress); + + return addrs.ToArray(); + } + + public string[] GetUserRemoteAddresses(long userId) { + if(!UserConnections.ContainsKey(userId)) + return Array.Empty(); + + HashSet addrs = new(); + + foreach(ConnectionInfo conn in UserConnections[userId]) + addrs.Add(conn.RemoteAddress); + + return addrs.ToArray(); + } + + public string[] GetUserRemoteAddresses(UserInfo userInfo) { + return GetUserRemoteAddresses(userInfo.UserId); + } + + public void Add(ConnectionInfo conn) { + if(Connections.Contains(conn)) + return; + + Connections.Add(conn); + + if(conn.UserId > 0) { + AuthedConnections.Add(conn); + + if(UserConnections.ContainsKey(conn.UserId)) + UserConnections[conn.UserId].Add(conn); + else + UserConnections.Add(conn.UserId, new() { conn }); + } + } + + public void Remove(ConnectionInfo conn) { + if(Connections.Contains(conn)) + Connections.Remove(conn); + + if(AuthedConnections.Contains(conn)) + AuthedConnections.Remove(conn); + + if(conn.UserId > 0 && UserConnections.ContainsKey(conn.UserId)) { + UserConnections[conn.UserId].Remove(conn); + + if(UserConnections[conn.UserId].Count < 1) + UserConnections.Remove(conn.UserId); + } + } + + public void SetUser(ConnectionInfo conn, long userId) { + if(!Connections.Contains(conn)) + return; + + // so yeah this is implemented but SetUserId will throw an exception if we're trying to re-auth a connection + // will just leave this here but i'm not sure how to go forward with this + if(conn.UserId > 0) { + if(UserConnections[conn.UserId].Contains(conn)) + UserConnections[conn.UserId].Remove(conn); + + if(UserConnections[conn.UserId].Count < 1) + UserConnections.Remove(conn.UserId); + } + + conn.SetUserId(userId); + + if(conn.UserId > 0) { + if(UserConnections.ContainsKey(conn.UserId)) + UserConnections[conn.UserId].Add(conn); + else + UserConnections.Add(conn.UserId, new() { conn }); + } + + if(conn.UserId > 0) { + if(!AuthedConnections.Contains(conn)) + AuthedConnections.Add(conn); + } else { + if(AuthedConnections.Contains(conn)) + AuthedConnections.Remove(conn); + } + } + + public void SetUser(ConnectionInfo conn, UserInfo userInfo) { + SetUser(conn, userInfo.UserId); + } + } +} diff --git a/SharpChat/PacketHandlers/AuthHandler.cs b/SharpChat/PacketHandlers/AuthHandler.cs index 38df9bf..22a59f5 100644 --- a/SharpChat/PacketHandlers/AuthHandler.cs +++ b/SharpChat/PacketHandlers/AuthHandler.cs @@ -41,14 +41,14 @@ namespace SharpChat.PacketHandlers { string? authMethod = args.ElementAtOrDefault(1); if(string.IsNullOrWhiteSpace(authMethod)) { ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } string? authToken = args.ElementAtOrDefault(2); if(string.IsNullOrWhiteSpace(authToken)) { ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } @@ -60,14 +60,14 @@ namespace SharpChat.PacketHandlers { Task.Run(async () => { MisuzuAuthInfo? fai; - string ipAddr = ctx.Connection.RemoteAddress.ToString(); + string ipAddr = ctx.Connection.RemoteAddress; try { fai = await Misuzu.AuthVerifyAsync(authMethod, authToken, ipAddr); } catch(Exception ex) { - Logger.Write($"<{ctx.Connection.Id}> Failed to authenticate: {ex}"); + Logger.Write($"<{ctx.Connection.RemoteEndPoint}> Failed to authenticate: {ex}"); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); #if DEBUG throw; #else @@ -76,16 +76,16 @@ namespace SharpChat.PacketHandlers { } if(fai == null) { - Logger.Debug($"<{ctx.Connection.Id}> Auth fail: "); + Logger.Debug($"<{ctx.Connection.RemoteEndPoint}> Auth fail: "); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.Null)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } if(!fai.Success) { - Logger.Debug($"<{ctx.Connection.Id}> Auth fail: {fai.Reason}"); + Logger.Debug($"<{ctx.Connection.RemoteEndPoint}> Auth fail: {fai.Reason}"); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } @@ -93,9 +93,9 @@ namespace SharpChat.PacketHandlers { try { fbi = await Misuzu.CheckBanAsync(fai.UserId.ToString(), ipAddr); } catch(Exception ex) { - Logger.Write($"<{ctx.Connection.Id}> Failed auth ban check: {ex}"); + Logger.Write($"<{ctx.Connection.RemoteEndPoint}> Failed auth ban check: {ex}"); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); #if DEBUG throw; #else @@ -104,16 +104,16 @@ namespace SharpChat.PacketHandlers { } if(fbi == null) { - Logger.Debug($"<{ctx.Connection.Id}> Ban check fail: "); + Logger.Debug($"<{ctx.Connection.RemoteEndPoint}> Ban check fail: "); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.Null)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } if(fbi.IsBanned && !fbi.HasExpired) { - Logger.Write($"<{ctx.Connection.Id}> User is banned."); + Logger.Write($"<{ctx.Connection.RemoteEndPoint}> User is banned."); ctx.Connection.Send(new AuthFailPacket(fbi.ExpiresAt)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } @@ -141,14 +141,14 @@ namespace SharpChat.PacketHandlers { ); // Enforce a maximum amount of connections per user - if(ctx.Chat.Connections.Count(conn => conn.User == user) >= MaxConnections) { + if(ctx.Chat.Connections.GetCountForUser(user) >= MaxConnections) { ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.MaxSessions)); - ctx.Connection.Dispose(); + ctx.Connection.Close(1000); return; } ctx.Connection.BumpPing(); - ctx.Connection.User = user; + ctx.Chat.Connections.SetUser(ctx.Connection, user); ctx.Connection.Send(new MOTDPacket(Started, $"Welcome to Flashii Chat, {user.UserName}!")); if(File.Exists(MOTD_FILE)) { diff --git a/SharpChat/PacketHandlers/PingHandler.cs b/SharpChat/PacketHandlers/PingHandler.cs index 46aca77..178ca93 100644 --- a/SharpChat/PacketHandlers/PingHandler.cs +++ b/SharpChat/PacketHandlers/PingHandler.cs @@ -1,6 +1,7 @@ using SharpChat.Misuzu; using SharpChat.Packet; using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -31,12 +32,20 @@ namespace SharpChat.PacketHandlers { ctx.Chat.ContextAccess.Wait(); try { if(LastBump < DateTimeOffset.UtcNow - BumpInterval) { - (string, string)[] bumpList = ctx.Chat.Users.All - .Where(u => u.Status == UserStatus.Online && ctx.Chat.Connections.Any(c => c.User == u)) - .Select(u => (u.UserId.ToString(), ctx.Chat.GetRemoteAddresses(u).FirstOrDefault()?.ToString() ?? string.Empty)) - .ToArray(); + List<(string, string)> bumpList = new(); - if(bumpList.Any()) + foreach(UserInfo userInfo in ctx.Chat.Users.All) { + if(userInfo.Status != UserStatus.Online) + continue; + + string[] remoteAddrs = ctx.Chat.Connections.GetUserRemoteAddresses(userInfo); + if(remoteAddrs.Length < 1) + continue; + + bumpList.Add((userInfo.UserId.ToString(), remoteAddrs[0])); + } + + if(bumpList.Count > 0) Task.Run(async () => { await Misuzu.BumpUsersOnlineAsync(bumpList); }).Wait(); diff --git a/SharpChat/PacketHandlers/SendMessageHandler.cs b/SharpChat/PacketHandlers/SendMessageHandler.cs index 2e0f64c..5cdea96 100644 --- a/SharpChat/PacketHandlers/SendMessageHandler.cs +++ b/SharpChat/PacketHandlers/SendMessageHandler.cs @@ -29,7 +29,7 @@ namespace SharpChat.PacketHandlers { public void Handle(PacketHandlerContext ctx) { string[] args = ctx.SplitText(3); - UserInfo? user = ctx.Connection.User; + UserInfo? user = ctx.Chat.Users.Get(ctx.Connection.UserId); // No longer concats everything after index 1 with \t, no previous implementation did that either string? messageText = args.ElementAtOrDefault(2); @@ -59,7 +59,7 @@ namespace SharpChat.PacketHandlers { messageText = messageText.Trim(); #if DEBUG - Logger.Write($"<{ctx.Connection.Id} {user.UserName}> {messageText}"); + Logger.Write($"<{user.UserId} {user.UserName}> {messageText}"); #endif if(messageText.StartsWith("/")) { diff --git a/SharpChat/SockChatServer.cs b/SharpChat/SockChatServer.cs index 35492c4..53fe732 100644 --- a/SharpChat/SockChatServer.cs +++ b/SharpChat/SockChatServer.cs @@ -17,7 +17,6 @@ namespace SharpChat { public const int DEFAULT_MSG_LENGTH_MAX = 5000; public const int DEFAULT_MAX_CONNECTIONS = 5; public const int DEFAULT_FLOOD_KICK_LENGTH = 30; - public const int DEFAULT_FLOOD_KICK_EXEMPT_RANK = 9; public IWebSocketServer Server { get; } public ChatContext Context { get; } @@ -28,13 +27,13 @@ namespace SharpChat { private readonly CachedValue MaxMessageLength; private readonly CachedValue MaxConnections; private readonly CachedValue FloodKickLength; - private readonly CachedValue FloodKickExemptRank; private readonly List GuestHandlers = new(); private readonly List AuthedHandlers = new(); private readonly SendMessageHandler SendMessageHandler; private bool IsShuttingDown = false; + private bool IsRestarting = false; public SockChatServer(HttpClient httpClient, MisuzuClient msz, IEventStorage evtStore, IConfig config) { Logger.Write("Initialising Sock Chat server..."); @@ -46,7 +45,6 @@ namespace SharpChat { MaxMessageLength = config.ReadCached("msgMaxLength", DEFAULT_MSG_LENGTH_MAX); MaxConnections = config.ReadCached("connMaxCount", DEFAULT_MAX_CONNECTIONS); FloodKickLength = config.ReadCached("floodKickLength", DEFAULT_FLOOD_KICK_LENGTH); - FloodKickExemptRank = config.ReadCached("floodKickExemptRank", DEFAULT_FLOOD_KICK_EXEMPT_RANK); Context = new ChatContext(evtStore); @@ -111,7 +109,14 @@ namespace SharpChat { public void Listen(ManualResetEvent waitHandle) { if(waitHandle != null) - SendMessageHandler.AddCommand(new ShutdownRestartCommand(waitHandle, () => !IsShuttingDown && (IsShuttingDown = true))); + SendMessageHandler.AddCommand(new ShutdownRestartCommand( + waitHandle, + () => IsShuttingDown, + restarting => { + IsShuttingDown = true; + IsRestarting = restarting; + } + )); Server.Start(sock => { if(IsShuttingDown) { @@ -132,24 +137,27 @@ namespace SharpChat { } private void OnOpen(ConnectionInfo conn) { - Logger.Write($"Connection opened from {conn.RemoteAddress}:{conn.RemotePort}"); + Logger.Write($"Connection opened from {conn.RemoteEndPoint}"); Context.SafeUpdate(); } private void OnError(ConnectionInfo conn, Exception ex) { - Logger.Write($"[{conn.Id} {conn.RemoteAddress}] {ex}"); + Logger.Write($"<{conn.RemoteEndPoint}> {ex}"); Context.SafeUpdate(); } private void OnClose(ConnectionInfo conn) { - Logger.Write($"Connection closed from {conn.RemoteAddress}:{conn.RemotePort}"); + Logger.Write($"Connection closed from {conn.RemoteEndPoint}"); Context.ContextAccess.Wait(); try { Context.Connections.Remove(conn); - if(conn.User != null && !Context.Connections.Any(c => c.User == conn.User)) - Context.HandleDisconnect(conn.User); + if(!Context.Connections.HasUser(conn.UserId)) { + UserInfo? userInfo = Context.Users.Get(conn.UserId); + if(userInfo != null) + Context.HandleDisconnect(userInfo); + } Context.Update(); } finally { @@ -161,15 +169,15 @@ namespace SharpChat { Context.SafeUpdate(); // this doesn't affect non-authed connections????? - if(conn.User is not null && conn.User.Rank < FloodKickExemptRank) { - UserInfo? banUser = null; + if(conn.UserId > 0) { + long banUserId = 0; string banAddr = string.Empty; TimeSpan banDuration = TimeSpan.MinValue; Context.ContextAccess.Wait(); try { - if(!Context.UserRateLimiters.TryGetValue(conn.User.UserId, out RateLimiter? rateLimiter)) - Context.UserRateLimiters.Add(conn.User.UserId, rateLimiter = new RateLimiter( + if(!Context.UserRateLimiters.TryGetValue(conn.UserId, out RateLimiter? rateLimiter)) + Context.UserRateLimiters.Add(conn.UserId, rateLimiter = new RateLimiter( UserInfo.DEFAULT_SIZE, UserInfo.DEFAULT_MINIMUM_DELAY, UserInfo.DEFAULT_RISKY_OFFSET @@ -179,27 +187,30 @@ namespace SharpChat { if(rateLimiter.IsExceeded) { banDuration = TimeSpan.FromSeconds(FloodKickLength); - banUser = conn.User; - banAddr = conn.RemoteAddress.ToString(); + banUserId = conn.UserId; + banAddr = conn.RemoteAddress; } else if(rateLimiter.IsRisky) { - banUser = conn.User; + banUserId = conn.UserId; } - if(banUser is not null) { - if(banDuration == TimeSpan.MinValue) { - Context.SendTo(conn.User, new FloodWarningPacket()); - } else { - Context.BanUser(conn.User, banDuration, UserDisconnectReason.Flood); + if(banUserId > 0) { + UserInfo? userInfo = Context.Users.Get(banUserId); + if(userInfo != null) { + if(banDuration == TimeSpan.MinValue) { + Context.SendTo(userInfo, new FloodWarningPacket()); + } else { + Context.BanUser(userInfo, banDuration, UserDisconnectReason.Flood); - if(banDuration > TimeSpan.Zero) - Misuzu.CreateBanAsync( - conn.User.UserId.ToString(), conn.RemoteAddress.ToString(), - string.Empty, "::1", - banDuration, - "Kicked from chat for flood protection." - ).Wait(); + if(banDuration > TimeSpan.Zero) + Misuzu.CreateBanAsync( + banUserId.ToString(), conn.RemoteAddress, + string.Empty, "::1", + banDuration, + "Kicked from chat for flood protection." + ).Wait(); - return; + return; + } } } } finally { @@ -208,9 +219,9 @@ namespace SharpChat { } PacketHandlerContext context = new(msg, Context, conn); - IPacketHandler? handler = conn.User is null - ? GuestHandlers.FirstOrDefault(h => h.IsMatch(context)) - : AuthedHandlers.FirstOrDefault(h => h.IsMatch(context)); + IPacketHandler? handler = conn.UserId > 0 + ? AuthedHandlers.FirstOrDefault(h => h.IsMatch(context)) + : GuestHandlers.FirstOrDefault(h => h.IsMatch(context)); handler?.Handle(context); } @@ -232,8 +243,7 @@ namespace SharpChat { IsDisposed = true; IsShuttingDown = true; - foreach(ConnectionInfo conn in Context.Connections) - conn.Dispose(); + Context.Connections.WithAll(conn => conn.Close(IsRestarting ? 1012 : 1001)); Server?.Dispose(); HttpClient?.Dispose();