Connection handling rewrite.

This commit is contained in:
flash 2024-05-20 16:16:32 +00:00
parent fa8c416b77
commit 610f9ab142
10 changed files with 320 additions and 158 deletions

View file

@ -4,7 +4,6 @@ using SharpChat.Packet;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net;
using System.Threading; using System.Threading;
namespace SharpChat { namespace SharpChat {
@ -12,7 +11,7 @@ namespace SharpChat {
public readonly SemaphoreSlim ContextAccess = new(1, 1); public readonly SemaphoreSlim ContextAccess = new(1, 1);
public ChannelsContext Channels { get; } = new(); public ChannelsContext Channels { get; } = new();
public List<ConnectionInfo> Connections { get; } = new(); public ConnectionsContext Connections { get; } = new();
public UsersContext Users { get; } = new(); public UsersContext Users { get; } = new();
public IEventStorage Events { get; } public IEventStorage Events { get; }
public ChannelsUsersContext ChannelsUsers { get; } = new(); public ChannelsUsersContext ChannelsUsers { get; } = new();
@ -79,18 +78,16 @@ namespace SharpChat {
} }
public void Update() { public void Update() {
foreach(ConnectionInfo conn in Connections) ConnectionInfo[] timedOut = Connections.GetTimedOut();
if(!conn.IsDisposed && conn.HasTimedOut) { foreach(ConnectionInfo conn in timedOut) {
conn.Dispose(); Connections.Remove(conn);
Logger.Write($"Nuked connection {conn.Id} associated with {conn.User}."); conn.Close(1002);
}
int removed = Connections.RemoveAll(conn => conn.IsDisposed); Logger.Write($"<{conn.RemoteEndPoint}> Nuked timed out connection from user #{conn.UserId}.");
if(removed > 0) }
Logger.Write($"Removed {removed} nuked connections from the list.");
foreach(UserInfo user in Users.All) foreach(UserInfo user in Users.All)
if(!Connections.Any(conn => conn.User == user)) { if(!Connections.HasUser(user)) {
HandleDisconnect(user, UserDisconnectReason.TimeOut); HandleDisconnect(user, UserDisconnectReason.TimeOut);
Logger.Write($"Timed out {user} (no more connections)."); Logger.Write($"Timed out {user} (no more connections).");
} }
@ -192,10 +189,13 @@ namespace SharpChat {
} else } else
SendTo(user, new ForceDisconnectPacket()); SendTo(user, new ForceDisconnectPacket());
foreach(ConnectionInfo conn in Connections) ConnectionInfo[] conns = Connections.GetUser(user);
if(conn.User == user) foreach(ConnectionInfo conn in conns) {
conn.Dispose(); Connections.Remove(conn);
Connections.RemoveAll(conn => conn.IsDisposed); conn.Close(1000);
Logger.Write($"<{conn.RemoteEndPoint}> Nuked connection from banned user #{conn.UserId}.");
}
HandleDisconnect(user, reason); HandleDisconnect(user, reason);
} }
@ -372,22 +372,17 @@ namespace SharpChat {
} }
public void Send(SockChatS2CPacket packet) { public void Send(SockChatS2CPacket packet) {
foreach(ConnectionInfo conn in Connections) Connections.WithAuthed(conn => conn.Send(packet));
if(conn.IsAuthed)
conn.Send(packet);
} }
public void SendTo(UserInfo user, SockChatS2CPacket packet) { public void SendTo(UserInfo user, SockChatS2CPacket packet) {
foreach(ConnectionInfo conn in Connections) Connections.WithUser(user, conn => conn.Send(packet));
if(conn.IsAuthed && conn.User!.UserId == user.UserId)
conn.Send(packet);
} }
public void SendTo(ChannelInfo channel, SockChatS2CPacket packet) { public void SendTo(ChannelInfo channel, SockChatS2CPacket packet) {
long[] userIds = ChannelsUsers.GetChannelUserIds(channel); long[] userIds = ChannelsUsers.GetChannelUserIds(channel);
foreach(ConnectionInfo conn in Connections) foreach(long userId in userIds)
if(conn.IsAuthed && userIds.Contains(conn.User!.UserId)) Connections.WithUser(userId, conn => conn.Send(packet));
conn.Send(packet);
} }
public void SendToUserChannels(UserInfo user, SockChatS2CPacket packet) { public void SendToUserChannels(UserInfo user, SockChatS2CPacket packet) {
@ -396,10 +391,6 @@ namespace SharpChat {
SendTo(chan, packet); SendTo(chan, packet);
} }
public IPAddress[] GetRemoteAddresses(UserInfo user) {
return Connections.Where(c => c.IsAlive && c.User == user).Select(c => c.RemoteAddress).Distinct().ToArray();
}
public void ForceChannel(UserInfo user, ChannelInfo? chan = null) { public void ForceChannel(UserInfo user, ChannelInfo? chan = null) {
chan ??= Channels.Get(ChannelsUsers.GetUserLastChannel(user)); chan ??= Channels.Get(ChannelsUsers.GetUserLastChannel(user));
if(chan != null) if(chan != null)

View file

@ -61,19 +61,20 @@ namespace SharpChat.Commands {
Task.Run(async () => { Task.Run(async () => {
string userId = banUser.UserId.ToString(); string userId = banUser.UserId.ToString();
string userIp = ctx.Chat.GetRemoteAddresses(banUser).FirstOrDefault()?.ToString() ?? string.Empty;
// obviously it makes no sense to only check for one ip address but that's current misuzu limitations
MisuzuBanInfo? fbi = await Misuzu.CheckBanAsync(userId, userIp);
MisuzuBanInfo? fbi = await Misuzu.CheckBanAsync(userId);
if(fbi != null && fbi.IsBanned && !fbi.HasExpired) { if(fbi != null && fbi.IsBanned && !fbi.HasExpired) {
ctx.Chat.SendTo(ctx.User, new KickBanNotAllowedErrorPacket(SockChatUtility.GetUserName(banUser))); ctx.Chat.SendTo(ctx.User, new KickBanNotAllowedErrorPacket(SockChatUtility.GetUserName(banUser)));
return; return;
} }
string[] userRemoteAddrs = ctx.Chat.Connections.GetUserRemoteAddresses(banUser);
string userRemoteAddr = string.Format(", ", userRemoteAddrs);
// Misuzu only stores the IP address in private comment and doesn't do any checking, so this is fine.
await Misuzu.CreateBanAsync( await Misuzu.CreateBanAsync(
userId, userIp, userId, userRemoteAddr,
ctx.User.UserId.ToString(), ctx.Connection.RemoteAddress.ToString(), ctx.User.UserId.ToString(), ctx.Connection.RemoteAddress,
duration, banReason duration, banReason
); );

View file

@ -5,11 +5,17 @@ using System.Threading;
namespace SharpChat.Commands { namespace SharpChat.Commands {
public class ShutdownRestartCommand : IUserCommand { public class ShutdownRestartCommand : IUserCommand {
private readonly ManualResetEvent WaitHandle; private readonly ManualResetEvent WaitHandle;
private readonly Func<bool> ShutdownCheck; private readonly Func<bool> ShuttingDown;
private readonly Action<bool> SetShutdown;
public ShutdownRestartCommand(ManualResetEvent waitHandle, Func<bool> shutdownCheck) { public ShutdownRestartCommand(
ManualResetEvent waitHandle,
Func<bool> shuttingDown,
Action<bool> setShutdown
) {
WaitHandle = waitHandle; WaitHandle = waitHandle;
ShutdownCheck = shutdownCheck; ShuttingDown = shuttingDown;
SetShutdown = setShutdown;
} }
public bool IsMatch(UserCommandContext ctx) { public bool IsMatch(UserCommandContext ctx) {
@ -23,13 +29,10 @@ namespace SharpChat.Commands {
return; return;
} }
if(!ShutdownCheck()) if(ShuttingDown())
return; return;
if(ctx.NameEquals("restart")) SetShutdown(ctx.NameEquals("restart"));
foreach(ConnectionInfo conn in ctx.Chat.Connections)
conn.PrepareForRestart();
ctx.Chat.Update(); ctx.Chat.Update();
WaitHandle?.Set(); WaitHandle?.Set();
} }

View file

@ -1,6 +1,5 @@
using SharpChat.Packet; using SharpChat.Packet;
using System.Linq; using System.Linq;
using System.Net;
namespace SharpChat.Commands { namespace SharpChat.Commands {
public class WhoisCommand : IUserCommand { public class WhoisCommand : IUserCommand {
@ -24,8 +23,8 @@ namespace SharpChat.Commands {
return; return;
} }
foreach(IPAddress ip in ctx.Chat.GetRemoteAddresses(ipUser)) foreach(string remoteAddr in ctx.Chat.Connections.GetUserRemoteAddresses(ipUser))
ctx.Chat.SendTo(ctx.User, new WhoisResponsePacket(ipUser.UserName, ip.ToString())); ctx.Chat.SendTo(ctx.User, new WhoisResponsePacket(ipUser.UserName, remoteAddr));
} }
} }
} }

View file

@ -3,45 +3,42 @@ using System;
using System.Net; using System.Net;
namespace SharpChat { namespace SharpChat {
public class ConnectionInfo : IDisposable { public class ConnectionInfo {
public const int ID_LENGTH = 20;
#if DEBUG
public static TimeSpan SessionTimeOut { get; } = TimeSpan.FromMinutes(1);
#else
public static TimeSpan SessionTimeOut { get; } = TimeSpan.FromMinutes(5);
#endif
public IWebSocketConnection Socket { get; } public IWebSocketConnection Socket { get; }
public DateTimeOffset LastPing { get; private set; }
public string Id { get; } public long UserId { get; private set; } = 0;
public bool IsDisposed { get; private set; }
public DateTimeOffset LastPing { get; set; } = DateTimeOffset.Now;
public UserInfo? User { get; set; }
private int CloseCode { get; set; } = 1000; public string RemoteAddress { get; }
public IPAddress RemoteAddress { get; }
public ushort RemotePort { get; } public ushort RemotePort { get; }
public string RemoteEndPoint { get; }
public bool IsAlive => !IsDisposed && !HasTimedOut;
public bool IsAuthed => IsAlive && User is not null;
public ConnectionInfo(IWebSocketConnection sock) { public ConnectionInfo(IWebSocketConnection sock) {
Socket = sock; Socket = sock;
Id = RNG.SecureRandomString(ID_LENGTH);
if(!IPAddress.TryParse(sock.ConnectionInfo.ClientIpAddress, out IPAddress? addr)) BumpPing();
throw new Exception("Unable to parse remote address?????");
if(IPAddress.IsLoopback(addr) IPAddress remoteAddr = IPAddress.Parse(sock.ConnectionInfo.ClientIpAddress);
if(IPAddress.IsLoopback(remoteAddr)
&& sock.ConnectionInfo.Headers.ContainsKey("X-Real-IP") && sock.ConnectionInfo.Headers.ContainsKey("X-Real-IP")
&& IPAddress.TryParse(sock.ConnectionInfo.Headers["X-Real-IP"], out IPAddress? realAddr)) && IPAddress.TryParse(sock.ConnectionInfo.Headers["X-Real-IP"], out IPAddress? realAddr))
addr = realAddr; remoteAddr = realAddr;
RemoteAddress = addr; RemoteAddress = remoteAddr.ToString();
RemotePort = (ushort)sock.ConnectionInfo.ClientPort; RemotePort = (ushort)sock.ConnectionInfo.ClientPort;
RemoteEndPoint = string.Format(
RemoteAddress.Contains(':') ? "[{0}]:{1}" : "{0}:{1}",
RemoteAddress, RemotePort
);
}
// please call this through ConnectionsContext
public void SetUserId(long userId) {
if(UserId > 0)
throw new InvalidOperationException("This connection is already associated with a user.");
UserId = userId;
} }
public void Send(SockChatS2CPacket packet) { public void Send(SockChatS2CPacket packet) {
@ -50,39 +47,15 @@ namespace SharpChat {
string data = packet.Pack(); string data = packet.Pack();
if(!string.IsNullOrWhiteSpace(data)) if(!string.IsNullOrWhiteSpace(data))
Socket.Send(data); Socket.Send(data).Wait();
} }
public void BumpPing() { public void BumpPing() {
LastPing = DateTimeOffset.Now; LastPing = DateTimeOffset.UtcNow;
} }
public bool HasTimedOut public void Close(int code) {
=> DateTimeOffset.Now - LastPing > SessionTimeOut; Socket.Close(code);
public void PrepareForRestart() {
CloseCode = 1012;
}
~ConnectionInfo() {
DoDispose();
}
public void Dispose() {
DoDispose();
GC.SuppressFinalize(this);
}
private void DoDispose() {
if(IsDisposed)
return;
IsDisposed = true;
Socket.Close(CloseCode);
}
public override string ToString() {
return Id;
} }
} }
} }

View file

@ -0,0 +1,176 @@
using System;
using System.Collections.Generic;
using System.Linq;
namespace SharpChat {
public class ConnectionsContext {
public static readonly TimeSpan TimeOut = TimeSpan.FromMinutes(5);
private readonly HashSet<ConnectionInfo> Connections = new();
private readonly HashSet<ConnectionInfo> AuthedConnections = new();
private readonly Dictionary<long, HashSet<ConnectionInfo>> UserConnections = new();
public ConnectionInfo[] All => Connections.ToArray();
public ConnectionInfo[] Authed => AuthedConnections.ToArray();
public void WithAll(Action<ConnectionInfo> body) {
foreach(ConnectionInfo conn in Connections)
body(conn);
}
public void WithAuthed(Action<ConnectionInfo> body) {
foreach(ConnectionInfo conn in AuthedConnections)
body(conn);
}
public ConnectionInfo[] GetTimedOut() {
List<ConnectionInfo> conns = new();
foreach(ConnectionInfo conn in Connections)
if(DateTimeOffset.UtcNow - conn.LastPing > TimeOut)
conns.Add(conn);
return conns.ToArray();
}
public int GetCountForUser(long userId) {
return UserConnections.ContainsKey(userId) ? UserConnections.Count : 0;
}
public int GetCountForUser(UserInfo userInfo) {
return GetCountForUser(userInfo.UserId);
}
public bool HasUser(long userId) {
return GetCountForUser(userId) > 0;
}
public bool HasUser(UserInfo userInfo) {
return HasUser(userInfo.UserId);
}
public ConnectionInfo[] GetUser(long userId) {
if(!UserConnections.ContainsKey(userId))
return Array.Empty<ConnectionInfo>();
return UserConnections[userId].ToArray();
}
public ConnectionInfo[] GetUser(UserInfo userInfo) {
return GetUser(userInfo.UserId);
}
public void WithUser(long userId, Action<ConnectionInfo> body) {
if(!UserConnections.ContainsKey(userId))
return;
foreach(ConnectionInfo conn in UserConnections[userId])
body(conn);
}
public void WithUser(UserInfo userInfo, Action<ConnectionInfo> body) {
WithUser(userInfo.UserId, body);
}
public string[] GetAllRemoteAddresses() {
HashSet<string> addrs = new();
foreach(ConnectionInfo conn in Connections)
addrs.Add(conn.RemoteAddress);
return addrs.ToArray();
}
public string[] GetAuthedRemoteAddresses() {
HashSet<string> addrs = new();
foreach(ConnectionInfo conn in AuthedConnections)
addrs.Add(conn.RemoteAddress);
return addrs.ToArray();
}
public string[] GetUserRemoteAddresses(long userId) {
if(!UserConnections.ContainsKey(userId))
return Array.Empty<string>();
HashSet<string> addrs = new();
foreach(ConnectionInfo conn in UserConnections[userId])
addrs.Add(conn.RemoteAddress);
return addrs.ToArray();
}
public string[] GetUserRemoteAddresses(UserInfo userInfo) {
return GetUserRemoteAddresses(userInfo.UserId);
}
public void Add(ConnectionInfo conn) {
if(Connections.Contains(conn))
return;
Connections.Add(conn);
if(conn.UserId > 0) {
AuthedConnections.Add(conn);
if(UserConnections.ContainsKey(conn.UserId))
UserConnections[conn.UserId].Add(conn);
else
UserConnections.Add(conn.UserId, new() { conn });
}
}
public void Remove(ConnectionInfo conn) {
if(Connections.Contains(conn))
Connections.Remove(conn);
if(AuthedConnections.Contains(conn))
AuthedConnections.Remove(conn);
if(conn.UserId > 0 && UserConnections.ContainsKey(conn.UserId)) {
UserConnections[conn.UserId].Remove(conn);
if(UserConnections[conn.UserId].Count < 1)
UserConnections.Remove(conn.UserId);
}
}
public void SetUser(ConnectionInfo conn, long userId) {
if(!Connections.Contains(conn))
return;
// so yeah this is implemented but SetUserId will throw an exception if we're trying to re-auth a connection
// will just leave this here but i'm not sure how to go forward with this
if(conn.UserId > 0) {
if(UserConnections[conn.UserId].Contains(conn))
UserConnections[conn.UserId].Remove(conn);
if(UserConnections[conn.UserId].Count < 1)
UserConnections.Remove(conn.UserId);
}
conn.SetUserId(userId);
if(conn.UserId > 0) {
if(UserConnections.ContainsKey(conn.UserId))
UserConnections[conn.UserId].Add(conn);
else
UserConnections.Add(conn.UserId, new() { conn });
}
if(conn.UserId > 0) {
if(!AuthedConnections.Contains(conn))
AuthedConnections.Add(conn);
} else {
if(AuthedConnections.Contains(conn))
AuthedConnections.Remove(conn);
}
}
public void SetUser(ConnectionInfo conn, UserInfo userInfo) {
SetUser(conn, userInfo.UserId);
}
}
}

View file

@ -41,14 +41,14 @@ namespace SharpChat.PacketHandlers {
string? authMethod = args.ElementAtOrDefault(1); string? authMethod = args.ElementAtOrDefault(1);
if(string.IsNullOrWhiteSpace(authMethod)) { if(string.IsNullOrWhiteSpace(authMethod)) {
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
string? authToken = args.ElementAtOrDefault(2); string? authToken = args.ElementAtOrDefault(2);
if(string.IsNullOrWhiteSpace(authToken)) { if(string.IsNullOrWhiteSpace(authToken)) {
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
@ -60,14 +60,14 @@ namespace SharpChat.PacketHandlers {
Task.Run(async () => { Task.Run(async () => {
MisuzuAuthInfo? fai; MisuzuAuthInfo? fai;
string ipAddr = ctx.Connection.RemoteAddress.ToString(); string ipAddr = ctx.Connection.RemoteAddress;
try { try {
fai = await Misuzu.AuthVerifyAsync(authMethod, authToken, ipAddr); fai = await Misuzu.AuthVerifyAsync(authMethod, authToken, ipAddr);
} catch(Exception ex) { } catch(Exception ex) {
Logger.Write($"<{ctx.Connection.Id}> Failed to authenticate: {ex}"); Logger.Write($"<{ctx.Connection.RemoteEndPoint}> Failed to authenticate: {ex}");
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
#if DEBUG #if DEBUG
throw; throw;
#else #else
@ -76,16 +76,16 @@ namespace SharpChat.PacketHandlers {
} }
if(fai == null) { if(fai == null) {
Logger.Debug($"<{ctx.Connection.Id}> Auth fail: <null>"); Logger.Debug($"<{ctx.Connection.RemoteEndPoint}> Auth fail: <null>");
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.Null)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.Null));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
if(!fai.Success) { if(!fai.Success) {
Logger.Debug($"<{ctx.Connection.Id}> Auth fail: {fai.Reason}"); Logger.Debug($"<{ctx.Connection.RemoteEndPoint}> Auth fail: {fai.Reason}");
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
@ -93,9 +93,9 @@ namespace SharpChat.PacketHandlers {
try { try {
fbi = await Misuzu.CheckBanAsync(fai.UserId.ToString(), ipAddr); fbi = await Misuzu.CheckBanAsync(fai.UserId.ToString(), ipAddr);
} catch(Exception ex) { } catch(Exception ex) {
Logger.Write($"<{ctx.Connection.Id}> Failed auth ban check: {ex}"); Logger.Write($"<{ctx.Connection.RemoteEndPoint}> Failed auth ban check: {ex}");
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.AuthInvalid));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
#if DEBUG #if DEBUG
throw; throw;
#else #else
@ -104,16 +104,16 @@ namespace SharpChat.PacketHandlers {
} }
if(fbi == null) { if(fbi == null) {
Logger.Debug($"<{ctx.Connection.Id}> Ban check fail: <null>"); Logger.Debug($"<{ctx.Connection.RemoteEndPoint}> Ban check fail: <null>");
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.Null)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.Null));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
if(fbi.IsBanned && !fbi.HasExpired) { if(fbi.IsBanned && !fbi.HasExpired) {
Logger.Write($"<{ctx.Connection.Id}> User is banned."); Logger.Write($"<{ctx.Connection.RemoteEndPoint}> User is banned.");
ctx.Connection.Send(new AuthFailPacket(fbi.ExpiresAt)); ctx.Connection.Send(new AuthFailPacket(fbi.ExpiresAt));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
@ -141,14 +141,14 @@ namespace SharpChat.PacketHandlers {
); );
// Enforce a maximum amount of connections per user // Enforce a maximum amount of connections per user
if(ctx.Chat.Connections.Count(conn => conn.User == user) >= MaxConnections) { if(ctx.Chat.Connections.GetCountForUser(user) >= MaxConnections) {
ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.MaxSessions)); ctx.Connection.Send(new AuthFailPacket(AuthFailPacket.FailReason.MaxSessions));
ctx.Connection.Dispose(); ctx.Connection.Close(1000);
return; return;
} }
ctx.Connection.BumpPing(); ctx.Connection.BumpPing();
ctx.Connection.User = user; ctx.Chat.Connections.SetUser(ctx.Connection, user);
ctx.Connection.Send(new MOTDPacket(Started, $"Welcome to Flashii Chat, {user.UserName}!")); ctx.Connection.Send(new MOTDPacket(Started, $"Welcome to Flashii Chat, {user.UserName}!"));
if(File.Exists(MOTD_FILE)) { if(File.Exists(MOTD_FILE)) {

View file

@ -1,6 +1,7 @@
using SharpChat.Misuzu; using SharpChat.Misuzu;
using SharpChat.Packet; using SharpChat.Packet;
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -31,12 +32,20 @@ namespace SharpChat.PacketHandlers {
ctx.Chat.ContextAccess.Wait(); ctx.Chat.ContextAccess.Wait();
try { try {
if(LastBump < DateTimeOffset.UtcNow - BumpInterval) { if(LastBump < DateTimeOffset.UtcNow - BumpInterval) {
(string, string)[] bumpList = ctx.Chat.Users.All List<(string, string)> bumpList = new();
.Where(u => u.Status == UserStatus.Online && ctx.Chat.Connections.Any(c => c.User == u))
.Select(u => (u.UserId.ToString(), ctx.Chat.GetRemoteAddresses(u).FirstOrDefault()?.ToString() ?? string.Empty))
.ToArray();
if(bumpList.Any()) foreach(UserInfo userInfo in ctx.Chat.Users.All) {
if(userInfo.Status != UserStatus.Online)
continue;
string[] remoteAddrs = ctx.Chat.Connections.GetUserRemoteAddresses(userInfo);
if(remoteAddrs.Length < 1)
continue;
bumpList.Add((userInfo.UserId.ToString(), remoteAddrs[0]));
}
if(bumpList.Count > 0)
Task.Run(async () => { Task.Run(async () => {
await Misuzu.BumpUsersOnlineAsync(bumpList); await Misuzu.BumpUsersOnlineAsync(bumpList);
}).Wait(); }).Wait();

View file

@ -29,7 +29,7 @@ namespace SharpChat.PacketHandlers {
public void Handle(PacketHandlerContext ctx) { public void Handle(PacketHandlerContext ctx) {
string[] args = ctx.SplitText(3); string[] args = ctx.SplitText(3);
UserInfo? user = ctx.Connection.User; UserInfo? user = ctx.Chat.Users.Get(ctx.Connection.UserId);
// No longer concats everything after index 1 with \t, no previous implementation did that either // No longer concats everything after index 1 with \t, no previous implementation did that either
string? messageText = args.ElementAtOrDefault(2); string? messageText = args.ElementAtOrDefault(2);
@ -59,7 +59,7 @@ namespace SharpChat.PacketHandlers {
messageText = messageText.Trim(); messageText = messageText.Trim();
#if DEBUG #if DEBUG
Logger.Write($"<{ctx.Connection.Id} {user.UserName}> {messageText}"); Logger.Write($"<{user.UserId} {user.UserName}> {messageText}");
#endif #endif
if(messageText.StartsWith("/")) { if(messageText.StartsWith("/")) {

View file

@ -17,7 +17,6 @@ namespace SharpChat {
public const int DEFAULT_MSG_LENGTH_MAX = 5000; public const int DEFAULT_MSG_LENGTH_MAX = 5000;
public const int DEFAULT_MAX_CONNECTIONS = 5; public const int DEFAULT_MAX_CONNECTIONS = 5;
public const int DEFAULT_FLOOD_KICK_LENGTH = 30; public const int DEFAULT_FLOOD_KICK_LENGTH = 30;
public const int DEFAULT_FLOOD_KICK_EXEMPT_RANK = 9;
public IWebSocketServer Server { get; } public IWebSocketServer Server { get; }
public ChatContext Context { get; } public ChatContext Context { get; }
@ -28,13 +27,13 @@ namespace SharpChat {
private readonly CachedValue<int> MaxMessageLength; private readonly CachedValue<int> MaxMessageLength;
private readonly CachedValue<int> MaxConnections; private readonly CachedValue<int> MaxConnections;
private readonly CachedValue<int> FloodKickLength; private readonly CachedValue<int> FloodKickLength;
private readonly CachedValue<int> FloodKickExemptRank;
private readonly List<IPacketHandler> GuestHandlers = new(); private readonly List<IPacketHandler> GuestHandlers = new();
private readonly List<IPacketHandler> AuthedHandlers = new(); private readonly List<IPacketHandler> AuthedHandlers = new();
private readonly SendMessageHandler SendMessageHandler; private readonly SendMessageHandler SendMessageHandler;
private bool IsShuttingDown = false; private bool IsShuttingDown = false;
private bool IsRestarting = false;
public SockChatServer(HttpClient httpClient, MisuzuClient msz, IEventStorage evtStore, IConfig config) { public SockChatServer(HttpClient httpClient, MisuzuClient msz, IEventStorage evtStore, IConfig config) {
Logger.Write("Initialising Sock Chat server..."); Logger.Write("Initialising Sock Chat server...");
@ -46,7 +45,6 @@ namespace SharpChat {
MaxMessageLength = config.ReadCached("msgMaxLength", DEFAULT_MSG_LENGTH_MAX); MaxMessageLength = config.ReadCached("msgMaxLength", DEFAULT_MSG_LENGTH_MAX);
MaxConnections = config.ReadCached("connMaxCount", DEFAULT_MAX_CONNECTIONS); MaxConnections = config.ReadCached("connMaxCount", DEFAULT_MAX_CONNECTIONS);
FloodKickLength = config.ReadCached("floodKickLength", DEFAULT_FLOOD_KICK_LENGTH); FloodKickLength = config.ReadCached("floodKickLength", DEFAULT_FLOOD_KICK_LENGTH);
FloodKickExemptRank = config.ReadCached("floodKickExemptRank", DEFAULT_FLOOD_KICK_EXEMPT_RANK);
Context = new ChatContext(evtStore); Context = new ChatContext(evtStore);
@ -111,7 +109,14 @@ namespace SharpChat {
public void Listen(ManualResetEvent waitHandle) { public void Listen(ManualResetEvent waitHandle) {
if(waitHandle != null) if(waitHandle != null)
SendMessageHandler.AddCommand(new ShutdownRestartCommand(waitHandle, () => !IsShuttingDown && (IsShuttingDown = true))); SendMessageHandler.AddCommand(new ShutdownRestartCommand(
waitHandle,
() => IsShuttingDown,
restarting => {
IsShuttingDown = true;
IsRestarting = restarting;
}
));
Server.Start(sock => { Server.Start(sock => {
if(IsShuttingDown) { if(IsShuttingDown) {
@ -132,24 +137,27 @@ namespace SharpChat {
} }
private void OnOpen(ConnectionInfo conn) { private void OnOpen(ConnectionInfo conn) {
Logger.Write($"Connection opened from {conn.RemoteAddress}:{conn.RemotePort}"); Logger.Write($"Connection opened from {conn.RemoteEndPoint}");
Context.SafeUpdate(); Context.SafeUpdate();
} }
private void OnError(ConnectionInfo conn, Exception ex) { private void OnError(ConnectionInfo conn, Exception ex) {
Logger.Write($"[{conn.Id} {conn.RemoteAddress}] {ex}"); Logger.Write($"<{conn.RemoteEndPoint}> {ex}");
Context.SafeUpdate(); Context.SafeUpdate();
} }
private void OnClose(ConnectionInfo conn) { private void OnClose(ConnectionInfo conn) {
Logger.Write($"Connection closed from {conn.RemoteAddress}:{conn.RemotePort}"); Logger.Write($"Connection closed from {conn.RemoteEndPoint}");
Context.ContextAccess.Wait(); Context.ContextAccess.Wait();
try { try {
Context.Connections.Remove(conn); Context.Connections.Remove(conn);
if(conn.User != null && !Context.Connections.Any(c => c.User == conn.User)) if(!Context.Connections.HasUser(conn.UserId)) {
Context.HandleDisconnect(conn.User); UserInfo? userInfo = Context.Users.Get(conn.UserId);
if(userInfo != null)
Context.HandleDisconnect(userInfo);
}
Context.Update(); Context.Update();
} finally { } finally {
@ -161,15 +169,15 @@ namespace SharpChat {
Context.SafeUpdate(); Context.SafeUpdate();
// this doesn't affect non-authed connections????? // this doesn't affect non-authed connections?????
if(conn.User is not null && conn.User.Rank < FloodKickExemptRank) { if(conn.UserId > 0) {
UserInfo? banUser = null; long banUserId = 0;
string banAddr = string.Empty; string banAddr = string.Empty;
TimeSpan banDuration = TimeSpan.MinValue; TimeSpan banDuration = TimeSpan.MinValue;
Context.ContextAccess.Wait(); Context.ContextAccess.Wait();
try { try {
if(!Context.UserRateLimiters.TryGetValue(conn.User.UserId, out RateLimiter? rateLimiter)) if(!Context.UserRateLimiters.TryGetValue(conn.UserId, out RateLimiter? rateLimiter))
Context.UserRateLimiters.Add(conn.User.UserId, rateLimiter = new RateLimiter( Context.UserRateLimiters.Add(conn.UserId, rateLimiter = new RateLimiter(
UserInfo.DEFAULT_SIZE, UserInfo.DEFAULT_SIZE,
UserInfo.DEFAULT_MINIMUM_DELAY, UserInfo.DEFAULT_MINIMUM_DELAY,
UserInfo.DEFAULT_RISKY_OFFSET UserInfo.DEFAULT_RISKY_OFFSET
@ -179,27 +187,30 @@ namespace SharpChat {
if(rateLimiter.IsExceeded) { if(rateLimiter.IsExceeded) {
banDuration = TimeSpan.FromSeconds(FloodKickLength); banDuration = TimeSpan.FromSeconds(FloodKickLength);
banUser = conn.User; banUserId = conn.UserId;
banAddr = conn.RemoteAddress.ToString(); banAddr = conn.RemoteAddress;
} else if(rateLimiter.IsRisky) { } else if(rateLimiter.IsRisky) {
banUser = conn.User; banUserId = conn.UserId;
} }
if(banUser is not null) { if(banUserId > 0) {
if(banDuration == TimeSpan.MinValue) { UserInfo? userInfo = Context.Users.Get(banUserId);
Context.SendTo(conn.User, new FloodWarningPacket()); if(userInfo != null) {
} else { if(banDuration == TimeSpan.MinValue) {
Context.BanUser(conn.User, banDuration, UserDisconnectReason.Flood); Context.SendTo(userInfo, new FloodWarningPacket());
} else {
Context.BanUser(userInfo, banDuration, UserDisconnectReason.Flood);
if(banDuration > TimeSpan.Zero) if(banDuration > TimeSpan.Zero)
Misuzu.CreateBanAsync( Misuzu.CreateBanAsync(
conn.User.UserId.ToString(), conn.RemoteAddress.ToString(), banUserId.ToString(), conn.RemoteAddress,
string.Empty, "::1", string.Empty, "::1",
banDuration, banDuration,
"Kicked from chat for flood protection." "Kicked from chat for flood protection."
).Wait(); ).Wait();
return; return;
}
} }
} }
} finally { } finally {
@ -208,9 +219,9 @@ namespace SharpChat {
} }
PacketHandlerContext context = new(msg, Context, conn); PacketHandlerContext context = new(msg, Context, conn);
IPacketHandler? handler = conn.User is null IPacketHandler? handler = conn.UserId > 0
? GuestHandlers.FirstOrDefault(h => h.IsMatch(context)) ? AuthedHandlers.FirstOrDefault(h => h.IsMatch(context))
: AuthedHandlers.FirstOrDefault(h => h.IsMatch(context)); : GuestHandlers.FirstOrDefault(h => h.IsMatch(context));
handler?.Handle(context); handler?.Handle(context);
} }
@ -232,8 +243,7 @@ namespace SharpChat {
IsDisposed = true; IsDisposed = true;
IsShuttingDown = true; IsShuttingDown = true;
foreach(ConnectionInfo conn in Context.Connections) Context.Connections.WithAll(conn => conn.Close(IsRestarting ? 1012 : 1001));
conn.Dispose();
Server?.Dispose(); Server?.Dispose();
HttpClient?.Dispose(); HttpClient?.Dispose();