First bits of the Context overhaul.

Reintroduces separate contexts for users, channels, connections (now split into sessions and connections) and user-channel associations.
It builds which is as much assurance as I can give about the stability of this commit, but its also the bare minimum of what i like to commit sooooo
A lot of things still need to be broadcast through events throughout the application in order to keep states consistent but we'll cross that bridge when we get to it.
I really need to stop using that phrase thingy, I'm overusing it.
This commit is contained in:
flash 2025-05-03 02:49:51 +00:00
parent f41ca7fb7f
commit 5a7756894b
Signed by: flash
GPG key ID: 2C9C2C574D47FE3E
76 changed files with 1595 additions and 520 deletions
SharpChat.Flashii
SharpChat.MariaDB
SharpChat.SQLite
SharpChat.SockChat
SharpChat
SharpChatCommon

View file

@ -1,4 +1,5 @@
using SharpChat.Auth;
using SharpChat.Users;
using System.Text.Json.Serialization;
namespace SharpChat.Flashii;

View file

@ -2,6 +2,7 @@ using Microsoft.Extensions.Logging;
using MySqlConnector;
using SharpChat.Data;
using SharpChat.Messages;
using SharpChat.Users;
using System.Data.Common;
using System.Text;
using System.Text.Json;

View file

@ -2,11 +2,12 @@ using Microsoft.Extensions.Logging;
using MySqlConnector;
using SharpChat.Configuration;
using SharpChat.Messages;
using SharpChat.Storage;
using ZLogger;
namespace SharpChat.MariaDB;
public class MariaDBStorage(ILogger logger, string connString) : Storage {
public class MariaDBStorage(ILogger logger, string connString) : StorageBackend {
public async Task<MariaDBConnection> CreateConnection() {
MySqlConnection conn = new(connString);
await conn.OpenAsync();

View file

@ -1,3 +1,5 @@
using SharpChat.Users;
namespace SharpChat.MariaDB;
public static class MariaDBUserPermissionsConverter {

View file

@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using SharpChat.Data;
using SharpChat.Messages;
using SharpChat.Users;
using System.Data;
using System.Data.Common;
using System.Data.SQLite;

View file

@ -1,13 +1,14 @@
using Microsoft.Extensions.Logging;
using SharpChat.Configuration;
using SharpChat.Messages;
using SharpChat.Storage;
using System.Data.SQLite;
using ZLogger;
using NativeSQLiteConnection = System.Data.SQLite.SQLiteConnection;
namespace SharpChat.SQLite;
public class SQLiteStorage(ILogger logger, string connString) : Storage, IDisposable {
public class SQLiteStorage(ILogger logger, string connString) : StorageBackend, IDisposable {
public const string MEMORY = "file::memory:?cache=shared";
public const string DEFAULT = "sharpchat.db";

View file

@ -1,3 +1,5 @@
using SharpChat.Users;
namespace SharpChat.SQLite;
public static class SQLiteUserPermissionsConverter {

View file

@ -0,0 +1,9 @@
using Fleck;
namespace SharpChat.SockChat;
public static class IWebSocketConnectionExtensions {
public static void Close(this IWebSocketConnection conn, WebSocketCloseCode closeCode) {
conn.Close((int)closeCode);
}
}

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Text;
namespace SharpChat.SockChat.S2CPackets;

View file

@ -1,4 +1,5 @@
using SharpChat.Messages;
using SharpChat.Users;
using System.Text;
using System.Text.Json;

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Text;
namespace SharpChat.SockChat.S2CPackets;

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Text;
namespace SharpChat.SockChat.S2CPackets;

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Text;
namespace SharpChat.SockChat.S2CPackets;

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Text;
namespace SharpChat.SockChat.S2CPackets;

View file

@ -6,6 +6,10 @@
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Fleck" Version="1.2.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\SharpChatCommon\SharpChatCommon.csproj" />
</ItemGroup>

View file

@ -1,4 +1,4 @@
#nullable disable
#nullable disable
using Fleck;
using Microsoft.Extensions.Logging;
@ -14,7 +14,7 @@ using ZLogger;
// Fleck's Socket wrapper doesn't provide any way to do this with the normally provided APIs
// https://github.com/statianzo/Fleck/blob/1.1.0/src/Fleck/WebSocketServer.cs
namespace SharpChat;
namespace SharpChat.SockChat;
public class SharpChatWebSocketServer : IWebSocketServer {
private readonly ILogger Logger;
@ -149,7 +149,7 @@ public class SharpChatWebSocketServer : IWebSocketServer {
string responseBody = File.Exists("http-motd.txt") ? File.ReadAllText("http-motd.txt") : "SharpChat";
clientSocket.Stream.Write(Encoding.UTF8.GetBytes(string.Format(
responseMsg, DateTimeOffset.Now.ToString("r"), Encoding.UTF8.GetByteCount(responseBody), responseBody
responseMsg, DateTimeOffset.Now.ToString("r"), responseBody.CountUtf8Bytes(), responseBody
)));
clientSocket.Close();
return null;

View file

@ -1,14 +1,15 @@
namespace SharpChat;
using Microsoft.Extensions.Logging;
using SharpChat.Sessions;
public class C2SPacketHandlerContext(
string text,
Context chat,
Connection connection
namespace SharpChat;
public record class C2SPacketHandlerContext(
string Text,
Context Chat,
SockChatConnection Connection,
Session? Session,
ILogger Logger
) {
public string Text { get; } = text ?? throw new ArgumentNullException(nameof(text));
public Context Chat { get; } = chat ?? throw new ArgumentNullException(nameof(chat));
public Connection Connection { get; } = connection ?? throw new ArgumentNullException(nameof(connection));
public bool CheckPacketId(string packetId) {
return Text == packetId || Text.StartsWith(packetId + '\t');
}

View file

@ -2,9 +2,11 @@ using SharpChat.Auth;
using SharpChat.Bans;
using SharpChat.Channels;
using SharpChat.Configuration;
using SharpChat.Connections;
using SharpChat.Messages;
using SharpChat.Snowflake;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using ZLogger;
namespace SharpChat.C2SPacketHandlers;
@ -22,14 +24,18 @@ public class AuthC2SPacketHandler(
}
public async Task Handle(C2SPacketHandlerContext ctx) {
if(ctx.Session is not null)
return;
string[] args = ctx.SplitText(3);
string? authMethod = args.ElementAtOrDefault(1);
string? authToken = args.ElementAtOrDefault(2);
if(string.IsNullOrWhiteSpace(authMethod) || string.IsNullOrWhiteSpace(authToken)) {
ctx.Logger.ZLogInformation($"Received empty authentication information.");
await ctx.Connection.Send(new AuthFailS2CPacket(AuthFailS2CPacket.Reason.AuthInvalid));
ctx.Connection.Dispose();
ctx.Connection.Close(ConnectionCloseReason.Unauthorized);
return;
}
@ -48,42 +54,26 @@ public class AuthC2SPacketHandler(
BanInfo? banInfo = await bansClient.BanGet(authResult.UserId, ctx.Connection.RemoteEndPoint.Address);
if(banInfo is not null) {
ctx.Connection.Logger.ZLogInformation($"User {authResult.UserId} is banned.");
ctx.Logger.ZLogInformation($"User {authResult.UserId} is banned.");
await ctx.Connection.Send(new AuthFailS2CPacket(AuthFailS2CPacket.Reason.Banned, banInfo.IsPermanent ? DateTimeOffset.MaxValue : banInfo.ExpiresAt));
ctx.Connection.Dispose();
ctx.Connection.Close(ConnectionCloseReason.AccessDenied);
return;
}
await ctx.Chat.ContextAccess.WaitAsync();
try {
User? user = ctx.Chat.Users.FirstOrDefault(u => u.UserId == authResult.UserId);
if(user == null)
user = new User(
authResult.UserId,
authResult.UserName ?? $"({authResult.UserId})",
authResult.UserColour,
authResult.UserRank,
authResult.UserPermissions
);
else
await ctx.Chat.UpdateUser(
user,
userName: authResult.UserName ?? $"({authResult.UserId})",
colour: authResult.UserColour,
rank: authResult.UserRank,
perms: authResult.UserPermissions
);
User user = ctx.Chat.Users.CreateOrUpdateUser(authResult);
// Enforce a maximum amount of connections per user
if(ctx.Chat.Connections.Count(conn => conn.User == user) >= maxConns) {
if(ctx.Chat.Sessions.CountNonSuspendedActiveSessions(user) >= maxConns) {
ctx.Logger.ZLogInformation($"Too many active connections.");
await ctx.Connection.Send(new AuthFailS2CPacket(AuthFailS2CPacket.Reason.MaxSessions));
ctx.Connection.Dispose();
ctx.Connection.Close(ConnectionCloseReason.TooManyConnections);
return;
}
ctx.Connection.BumpPing();
ctx.Connection.User = user;
ctx.Chat.Sessions.CreateSession(user, ctx.Connection);
await ctx.Connection.Send(new CommandResponseS2CPacket(0, LCR.WELCOME, false, $"Welcome to Flashii Chat, {user.UserName}!"));
if(File.Exists("welcome.txt")) {
@ -94,17 +84,17 @@ public class AuthC2SPacketHandler(
await ctx.Connection.Send(new CommandResponseS2CPacket(0, LCR.WELCOME, false, line));
}
Channel channel = channelsCtx.DefaultChannel;
Channel channel = channelsCtx.GetDefaultChannel();
if(!ctx.Chat.IsInChannel(user, channel)) {
if(!ctx.Chat.ChannelsUsers.HasChannelUser(channel, user)) {
long msgId = snowflake.Next();
await ctx.Chat.SendTo(channel, new UserConnectS2CPacket(msgId, DateTimeOffset.Now, user.UserId, user.LegacyNameWithStatus, user.Colour, user.Rank, user.Permissions));
await ctx.Chat.SendTo(channel, new UserConnectS2CPacket(msgId, DateTimeOffset.Now, user.UserId, user.GetLegacyNameWithStatus(), user.Colour, user.Rank, user.Permissions));
await ctx.Chat.Messages.LogMessage(msgId, "user:connect", channel.Name, user.UserId, user.UserName, user.Colour, user.Rank, user.NickName, user.Permissions);
}
await ctx.Connection.Send(new AuthSuccessS2CPacket(
user.UserId,
user.LegacyNameWithStatus,
user.GetLegacyNameWithStatus(),
user.Colour,
user.Rank,
user.Permissions,
@ -112,10 +102,10 @@ public class AuthC2SPacketHandler(
maxMsgLength
));
await ctx.Connection.Send(new ContextUsersS2CPacket(
ctx.Chat.GetChannelUsers(channel).Except([user]).OrderByDescending(u => u.Rank)
ctx.Chat.ChannelsUsers.GetChannelUsers(channel).Except([user]).OrderByDescending(u => u.Rank)
.Select(u => new ContextUsersS2CPacket.Entry(
u.UserId,
u.LegacyNameWithStatus,
u.GetLegacyNameWithStatus(),
u.Colour,
u.Rank,
u.Permissions,
@ -132,22 +122,21 @@ public class AuthC2SPacketHandler(
.Select(c => new ContextChannelsS2CPacket.Entry(c.Name, c.HasPassword, c.IsTemporary))
));
ctx.Chat.Users.Add(user);
ctx.Chat.ChannelUsers.Add(new Context.ChannelUserAssoc(user.UserId, channel.Name));
ctx.Chat.UserLastChannel[user.UserId] = channel;
ctx.Chat.ChannelsUsers.AddChannelUser(channel, user);
} finally {
ctx.Chat.ContextAccess.Release();
}
} catch(AuthFailedException ex) {
ctx.Connection.Logger.ZLogWarning($"Failed to authenticate (expected): {ex}");
ctx.Chat.Sessions.DestroySession(ctx.Connection);
ctx.Logger.ZLogWarning($"Failed to authenticate (expected): {ex}");
await ctx.Connection.Send(new AuthFailS2CPacket(AuthFailS2CPacket.Reason.AuthInvalid));
ctx.Connection.Dispose();
ctx.Connection.Close(ConnectionCloseReason.Unauthorized);
throw;
} catch(Exception ex) {
ctx.Connection.Logger.ZLogError($"Failed to authenticate (unexpected): {ex}");
ctx.Chat.Sessions.DestroySession(ctx.Connection);
ctx.Logger.ZLogError($"Failed to authenticate (unexpected): {ex}");
await ctx.Connection.Send(new AuthFailS2CPacket(AuthFailS2CPacket.Reason.Exception));
ctx.Connection.Dispose();
ctx.Connection.Close(ConnectionCloseReason.Error);
throw;
}
}

View file

@ -1,5 +1,6 @@
using SharpChat.Auth;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Net;
namespace SharpChat.C2SPacketHandlers;
@ -13,20 +14,22 @@ public class PingC2SPacketHandler(AuthClient authClient) : C2SPacketHandler {
}
public async Task Handle(C2SPacketHandlerContext ctx) {
if(ctx.Session is null)
return;
string[] parts = ctx.SplitText(2);
if(!int.TryParse(parts.FirstOrDefault(), out int pTime))
return;
ctx.Connection.BumpPing();
ctx.Session.Heartbeat();
await ctx.Connection.Send(new PongS2CPacket());
ctx.Chat.ContextAccess.Wait();
try {
if(LastBump < DateTimeOffset.UtcNow - BumpInterval) {
(IPAddress, string)[] bumpList = [.. ctx.Chat.Users
.Where(u => u.Status == UserStatus.Online && ctx.Chat.Connections.Any(c => c.User == u))
.Select(u => (ctx.Chat.GetRemoteAddresses(u).FirstOrDefault() ?? IPAddress.None, u.UserId))];
(IPAddress, string)[] bumpList = [.. ctx.Chat.Users.GetUsersWithStatus(UserStatus.Online)
.Select(u => (ctx.Chat.Sessions.GetRemoteEndPoints(u).Select(e => e.Address).FirstOrDefault() ?? IPAddress.None, u.UserId))];
if(bumpList.Length > 0)
await authClient.AuthBumpUsersOnline(bumpList);

View file

@ -2,8 +2,7 @@ using SharpChat.Channels;
using SharpChat.Configuration;
using SharpChat.Events;
using SharpChat.Snowflake;
using System.Globalization;
using System.Text;
using SharpChat.Users;
namespace SharpChat.C2SPacketHandlers;
@ -28,9 +27,12 @@ public class SendMessageC2SPacketHandler(
}
public async Task Handle(C2SPacketHandlerContext ctx) {
if(ctx.Session is null)
return;
string[] args = ctx.SplitText(3);
User? user = ctx.Connection.User;
User? user = ctx.Chat.Users.GetUser(ctx.Session.UserId);
string? messageText = args.ElementAtOrDefault(2);
if(user?.Permissions.HasFlag(UserPermissions.SendMessage) != true
@ -43,23 +45,20 @@ public class SendMessageC2SPacketHandler(
ctx.Chat.ContextAccess.Wait();
try {
if(!ctx.Chat.UserLastChannel.TryGetValue(user.UserId, out Channel? channel)
&& (channel is null || !ctx.Chat.IsInChannel(user, channel)))
Channel? channel = ctx.Chat.ChannelsUsers.GetUserLastChannel(user);
if(channel is null)
return;
ctx.Chat.ChannelsUsers.RecordChannelUserActivity(channel, user);
if(user.Status != UserStatus.Online)
await ctx.Chat.UpdateUser(user, status: UserStatus.Online);
int maxMsgLength = MaxMessageLength;
StringInfo messageTextInfo = new(messageText);
if(Encoding.UTF8.GetByteCount(messageText) > (maxMsgLength * 10)
|| messageTextInfo.LengthInTextElements > maxMsgLength)
messageText = messageTextInfo.SubstringByTextElements(0, Math.Min(messageTextInfo.LengthInTextElements, maxMsgLength));
messageText = messageText.Trim();
messageText = messageText.TruncateIfTooLong(maxMsgLength, maxMsgLength * 10).Trim();
if(messageText.StartsWith('/')) {
ClientCommandContext context = new(messageText, ctx.Chat, user, ctx.Connection, channel);
ClientCommandContext context = new(messageText, ctx.Chat, user, ctx.Session, ctx.Connection, channel);
foreach(ClientCommand cmd in Commands)
if(cmd.IsMatch(context)) {
await cmd.Dispatch(context);

View file

@ -1,4 +1,7 @@
using SharpChat.Channels;
using Microsoft.Extensions.Logging;
using SharpChat.Channels;
using SharpChat.Sessions;
using SharpChat.Users;
namespace SharpChat;
@ -7,14 +10,17 @@ public class ClientCommandContext {
public string[] Args { get; }
public Context Chat { get; }
public User User { get; }
public Connection Connection { get; }
public Session Session { get; }
public SockChatConnection Connection { get; }
public Channel Channel { get; }
public ILogger Logger => Session.Logger;
public ClientCommandContext(
string text,
Context chat,
User user,
Connection connection,
Session session,
SockChatConnection connection,
Channel channel
) {
ArgumentNullException.ThrowIfNull(text);
@ -23,29 +29,14 @@ public class ClientCommandContext {
User = user ?? throw new ArgumentNullException(nameof(user));
Connection = connection ?? throw new ArgumentNullException(nameof(connection));
Channel = channel ?? throw new ArgumentNullException(nameof(channel));
Session = session ?? throw new ArgumentNullException(nameof(session));
string[] parts = text[1..].Split(' ');
Name = parts.First().Replace(".", string.Empty);
Args = [.. parts.Skip(1)];
}
public ClientCommandContext(
string name,
string[] args,
Context chat,
User user,
Connection connection,
Channel channel
) {
Name = name ?? throw new ArgumentNullException(nameof(name));
Args = args ?? throw new ArgumentNullException(nameof(args));
Chat = chat ?? throw new ArgumentNullException(nameof(chat));
User = user ?? throw new ArgumentNullException(nameof(user));
Connection = connection ?? throw new ArgumentNullException(nameof(connection));
Channel = channel ?? throw new ArgumentNullException(nameof(channel));
}
public bool NameEquals(string name) {
return Name.Equals(name, StringComparison.InvariantCultureIgnoreCase);
return Name.Equals(name, StringComparison.Ordinal);
}
}

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Globalization;
using System.Text;
@ -14,16 +15,7 @@ public class AFKClientCommand : ClientCommand {
public async Task Dispatch(ClientCommandContext ctx) {
string? statusText = ctx.Args.FirstOrDefault();
if(string.IsNullOrWhiteSpace(statusText))
statusText = DEFAULT;
else {
statusText = statusText.Trim();
StringInfo sti = new(statusText);
if(Encoding.UTF8.GetByteCount(statusText) > MAX_BYTES
|| sti.LengthInTextElements > MAX_GRAPHEMES)
statusText = sti.SubstringByTextElements(0, Math.Min(sti.LengthInTextElements, MAX_GRAPHEMES)).Trim();
}
statusText = string.IsNullOrWhiteSpace(statusText) ? DEFAULT : statusText.TruncateIfTooLong(MAX_GRAPHEMES, MAX_BYTES).Trim();
await ctx.Chat.UpdateUser(
ctx.User,

View file

@ -1,5 +1,6 @@
using SharpChat.Bans;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;

View file

@ -1,5 +1,6 @@
using SharpChat.Events;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;

View file

@ -1,5 +1,6 @@
using SharpChat.Channels;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;
@ -44,7 +45,7 @@ public class CreateChannelClientCommand : ClientCommand {
ownerId: ctx.User.UserId
);
foreach(User ccu in ctx.Chat.Users.Where(u => u.Rank >= ctx.Channel.Rank))
foreach(User ccu in ctx.Chat.Users.GetUsersOfMinimumRank(ctx.Channel.Rank))
await ctx.Chat.SendTo(ccu, new ChannelCreateS2CPacket(channel.Name, channel.HasPassword, channel.IsTemporary));
await ctx.Chat.SwitchChannel(ctx.User, channel, channel.Password);

View file

@ -1,5 +1,6 @@
using SharpChat.Channels;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;

View file

@ -1,5 +1,6 @@
using SharpChat.Messages;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;

View file

@ -1,5 +1,6 @@
using SharpChat.Bans;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Net;
namespace SharpChat.ClientCommands;
@ -22,15 +23,15 @@ public class KickBanClientCommand(BansClient bansClient) : ClientCommand {
string? banUserTarget = ctx.Args.ElementAtOrDefault(0);
string? banDurationStr = ctx.Args.ElementAtOrDefault(1);
int banReasonIndex = 1;
User? banUser = null;
User? banUser;
if(banUserTarget == null || (banUser = ctx.Chat.Users.FirstOrDefault(u => u.NameEquals(banUserTarget))) == null) {
if(banUserTarget == null || (banUser = ctx.Chat.Users.GetUserByLegacyName(banUserTarget)) == null) {
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.USER_NOT_FOUND, true, banUserTarget ?? "User"));
return;
}
if(banUser.Rank >= ctx.User.Rank && banUser != ctx.User) {
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.KICK_NOT_ALLOWED, true, banUser.LegacyName));
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.KICK_NOT_ALLOWED, true, banUser.GetLegacyName()));
return;
}
@ -54,14 +55,14 @@ public class KickBanClientCommand(BansClient bansClient) : ClientCommand {
BanInfo? banInfo = await bansClient.BanGet(banUser.UserId);
if(banInfo is not null) {
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.KICK_NOT_ALLOWED, true, banUser.LegacyName));
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.KICK_NOT_ALLOWED, true, banUser.GetLegacyName()));
return;
}
await bansClient.BanCreate(
BanKind.User,
duration,
ctx.Chat.GetRemoteAddresses(banUser).FirstOrDefault() ?? IPAddress.None,
ctx.Chat.Sessions.GetRemoteEndPoints(banUser).Select(e => e.Address).FirstOrDefault() ?? IPAddress.None,
banUser.UserId,
banReason,
ctx.Connection.RemoteEndPoint.Address,

View file

@ -1,4 +1,5 @@
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Globalization;
using System.Text;
@ -25,7 +26,7 @@ public class NickClientCommand : ClientCommand {
int offset = 0;
if(setOthersNick && long.TryParse(ctx.Args.FirstOrDefault(), out long targetUserId) && targetUserId > 0) {
targetUser = ctx.Chat.Users.FirstOrDefault(u => u.UserId == targetUserId.ToString());
targetUser = ctx.Chat.Users.GetUser(targetUserId.ToString());
++offset;
}
@ -41,18 +42,13 @@ public class NickClientCommand : ClientCommand {
.Replace("\f", string.Empty).Replace("\t", string.Empty)
.Replace(' ', '_').Trim();
if(nickStr == targetUser.UserName)
nickStr = string.Empty;
else if(string.IsNullOrEmpty(nickStr))
nickStr = string.Empty;
else {
StringInfo nsi = new(nickStr);
if(Encoding.UTF8.GetByteCount(nickStr) > MAX_BYTES
|| nsi.LengthInTextElements > MAX_GRAPHEMES)
nickStr = nsi.SubstringByTextElements(0, Math.Min(nsi.LengthInTextElements, MAX_GRAPHEMES)).Trim();
}
nickStr = nickStr == targetUser.UserName
? string.Empty
: (string.IsNullOrEmpty(nickStr)
? string.Empty
: nickStr.TruncateIfTooLong(MAX_GRAPHEMES, MAX_BYTES).Trim());
if(!string.IsNullOrWhiteSpace(nickStr) && ctx.Chat.Users.Any(u => u.NameEquals(nickStr))) {
if(!string.IsNullOrWhiteSpace(nickStr) && ctx.Chat.Users.UserWithLegacyNameExists(nickStr)) {
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.NAME_IN_USE, true, nickStr));
return;
}

View file

@ -1,5 +1,6 @@
using SharpChat.Bans;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Net;
namespace SharpChat.ClientCommands;

View file

@ -1,5 +1,6 @@
using SharpChat.Bans;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;
@ -24,9 +25,9 @@ public class PardonUserClientCommand(BansClient bansClient) : ClientCommand {
}
string unbanUserDisplay = unbanUserTarget;
User? unbanUser = ctx.Chat.Users.FirstOrDefault(u => u.NameEquals(unbanUserTarget));
User? unbanUser = ctx.Chat.Users.GetUserByLegacyName(unbanUserTarget);
if(unbanUser == null && long.TryParse(unbanUserTarget, out long unbanUserId))
unbanUser = ctx.Chat.Users.FirstOrDefault(u => u.UserId == unbanUserId.ToString());
unbanUser = ctx.Chat.Users.GetUser(unbanUserId.ToString());
if(unbanUser != null) {
unbanUserTarget = unbanUser.UserId;
unbanUserDisplay = unbanUser.UserName;

View file

@ -1,4 +1,5 @@
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;

View file

@ -1,4 +1,5 @@
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;

View file

@ -1,4 +1,5 @@
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Net;
namespace SharpChat.ClientCommands;
@ -18,14 +19,14 @@ public class RemoteAddressClientCommand : ClientCommand {
}
string? ipUserStr = ctx.Args.FirstOrDefault();
User? ipUser = null;
User? ipUser;
if(string.IsNullOrWhiteSpace(ipUserStr) || (ipUser = ctx.Chat.Users.FirstOrDefault(u => u.NameEquals(ipUserStr))) == null) {
if(string.IsNullOrWhiteSpace(ipUserStr) || (ipUser = ctx.Chat.Users.GetUserByLegacyName(ipUserStr)) == null) {
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.USER_NOT_FOUND, true, ipUserStr ?? "User"));
return;
}
foreach(IPAddress ip in ctx.Chat.GetRemoteAddresses(ipUser))
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.IP_ADDRESS, false, ipUser.UserName, ip));
foreach(IPEndPoint ep in ctx.Chat.Sessions.GetRemoteEndPoints(ipUser))
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.IP_ADDRESS, false, ipUser.UserName, ep.Address));
}
}

View file

@ -3,7 +3,7 @@ using ZLogger;
namespace SharpChat.ClientCommands;
public class ShutdownRestartClientCommand(CancellationTokenSource cancellationTokenSource) : ClientCommand {
public class ShutdownRestartClientCommand(SockChatServer server, CancellationTokenSource cancellationTokenSource) : ClientCommand {
public bool IsMatch(ClientCommandContext ctx) {
return ctx.NameEquals("shutdown")
|| ctx.NameEquals("restart");
@ -11,6 +11,7 @@ public class ShutdownRestartClientCommand(CancellationTokenSource cancellationTo
public async Task Dispatch(ClientCommandContext ctx) {
if(!ctx.User.UserId.Equals("1")) {
ctx.Logger.ZLogInformation($"{ctx.User.UserId}/{ctx.User.UserName} tried to issue /shutdown or /restart");
long msgId = ctx.Chat.RandomSnowflake.Next();
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.COMMAND_NOT_ALLOWED, true, $"/{ctx.Name}"));
return;
@ -19,11 +20,8 @@ public class ShutdownRestartClientCommand(CancellationTokenSource cancellationTo
if(cancellationTokenSource.IsCancellationRequested)
return;
ctx.Connection.Logger.ZLogInformation($"Shutdown requested through Sock Chat command...");
if(ctx.NameEquals("restart"))
foreach(Connection conn in ctx.Chat.Connections)
conn.PrepareForRestart();
server.IsRestarting = ctx.NameEquals("restart");
ctx.Logger.ZLogInformation($"{(server.IsRestarting ? "Restart" : "Shutdown")} requested through Sock Chat command...");
await ctx.Chat.Update();
await cancellationTokenSource.CancelAsync();

View file

@ -1,5 +1,6 @@
using SharpChat.Events;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
namespace SharpChat.ClientCommands;
@ -18,7 +19,7 @@ public class WhisperClientCommand : ClientCommand {
}
string whisperUserStr = ctx.Args.FirstOrDefault() ?? string.Empty;
User? whisperUser = ctx.Chat.Users.FirstOrDefault(u => u.NameEquals(whisperUserStr));
User? whisperUser = ctx.Chat.Users.GetUserByLegacyName(whisperUserStr);
if(whisperUser == null) {
await ctx.Chat.SendTo(ctx.User, new CommandResponseS2CPacket(msgId, LCR.USER_NOT_FOUND, true, whisperUserStr));
@ -30,7 +31,7 @@ public class WhisperClientCommand : ClientCommand {
await ctx.Chat.DispatchEvent(new MessageCreateEvent(
msgId,
User.GetDMChannelName(ctx.User, whisperUser),
ctx.User.GetDMChannelNameWith(whisperUser),
ctx.User.UserId,
ctx.User.UserName,
ctx.User.Colour,

View file

@ -1,5 +1,6 @@
using SharpChat.Channels;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Text;
namespace SharpChat.ClientCommands;
@ -15,14 +16,14 @@ public class WhoClientCommand : ClientCommand {
string? whoChanStr = ctx.Args.FirstOrDefault();
if(string.IsNullOrEmpty(whoChanStr)) {
foreach(User whoUser in ctx.Chat.Users) {
foreach(User whoUser in ctx.Chat.Users.GetUsers()) {
whoChanSB.Append(@"<a href=""javascript:void(0);"" onclick=""UI.InsertChatText(this.innerHTML);""");
if(whoUser == ctx.User)
whoChanSB.Append(@" style=""font-weight: bold;""");
whoChanSB.Append('>');
whoChanSB.Append(whoUser.LegacyName);
whoChanSB.Append(whoUser.GetLegacyNameWithStatus());
whoChanSB.Append("</a>, ");
}
@ -43,14 +44,14 @@ public class WhoClientCommand : ClientCommand {
return;
}
foreach(User whoUser in ctx.Chat.GetChannelUsers(whoChan)) {
foreach(User whoUser in ctx.Chat.ChannelsUsers.GetChannelUsers(whoChan)) {
whoChanSB.Append(@"<a href=""javascript:void(0);"" onclick=""UI.InsertChatText(this.innerHTML);""");
if(whoUser == ctx.User)
whoChanSB.Append(@" style=""font-weight: bold;""");
whoChanSB.Append('>');
whoChanSB.Append(whoUser.LegacyName);
whoChanSB.Append(whoUser.GetLegacyNameWithStatus());
whoChanSB.Append("</a>, ");
}

View file

@ -1,59 +0,0 @@
using Fleck;
using Microsoft.Extensions.Logging;
using SharpChat.SockChat;
using System.Net;
namespace SharpChat;
public class Connection(ILogger logger, IWebSocketConnection sock, IPEndPoint remoteEndPoint) : IDisposable {
public static readonly TimeSpan SessionTimeOut = TimeSpan.FromMinutes(5);
public ILogger Logger { get; } = logger;
public IWebSocketConnection Socket { get; } = sock;
public IPEndPoint RemoteEndPoint { get; } = remoteEndPoint;
public bool IsDisposed { get; private set; }
public DateTimeOffset LastPing { get; set; } = DateTimeOffset.Now;
public User? User { get; set; }
private int CloseCode { get; set; } = 1000;
public bool IsAlive => !IsDisposed && !HasTimedOut;
public async Task Send(S2CPacket packet) {
if(!Socket.IsAvailable)
return;
string data = packet.Pack();
if(!string.IsNullOrWhiteSpace(data))
await Socket.Send(data);
}
public void BumpPing() {
LastPing = DateTimeOffset.Now;
}
public bool HasTimedOut
=> DateTimeOffset.Now - LastPing > SessionTimeOut;
public void PrepareForRestart() {
CloseCode = 1012;
}
~Connection() {
DoDispose();
}
public void Dispose() {
DoDispose();
GC.SuppressFinalize(this);
}
private void DoDispose() {
if(IsDisposed)
return;
IsDisposed = true;
Socket.Close(CloseCode);
}
}

View file

@ -3,11 +3,15 @@ using SharpChat.Auth;
using SharpChat.Bans;
using SharpChat.Channels;
using SharpChat.Configuration;
using SharpChat.Connections;
using SharpChat.Events;
using SharpChat.Messages;
using SharpChat.Sessions;
using SharpChat.Snowflake;
using SharpChat.SockChat;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Storage;
using SharpChat.Users;
using System.Dynamic;
using System.Net;
using ZLogger;
@ -20,8 +24,6 @@ public class Context {
public const int DEFAULT_FLOOD_KICK_LENGTH = 30;
public const int DEFAULT_FLOOD_KICK_EXEMPT_RANK = 9;
public record ChannelUserAssoc(string UserId, string ChannelName);
public readonly SemaphoreSlim ContextAccess = new(1, 1);
public ILoggerFactory LoggerFactory { get; }
@ -40,27 +42,30 @@ public class Context {
public SnowflakeGenerator SnowflakeGenerator { get; } = new();
public RandomSnowflake RandomSnowflake { get; }
public ChannelsContext Channels { get; } = new();
public HashSet<Connection> Connections { get; } = [];
public HashSet<User> Users { get; } = [];
public HashSet<ChannelUserAssoc> ChannelUsers { get; } = [];
public UsersContext Users { get; } = new();
public SessionsContext Sessions { get; }
public ChannelsContext Channels { get; }
public ChannelsUsersContext ChannelsUsers { get; }
public Dictionary<string, RateLimiter> UserRateLimiters { get; } = [];
public Dictionary<string, Channel> UserLastChannel { get; } = [];
public Context(
ILoggerFactory logFactory,
ILoggerFactory loggerFactory,
Config config,
Storage storage,
StorageBackend storage,
AuthClient authClient,
BansClient bansClient
) {
LoggerFactory = logFactory;
Logger = logFactory.CreateLogger("ctx");
LoggerFactory = loggerFactory;
Logger = loggerFactory.CreateLogger("ctx");
Config = config;
Messages = storage.CreateMessageStorage();
Auth = authClient;
Bans = bansClient;
RandomSnowflake = new(SnowflakeGenerator);
Sessions = new(loggerFactory, RandomSnowflake);
Channels = new(RandomSnowflake);
ChannelsUsers = new(Channels, Users);
Logger.ZLogDebug($"Reading cached config values...");
MaxMessageLength = config.ReadCached("msgMaxLength", DEFAULT_MSG_LENGTH_MAX);
@ -102,8 +107,8 @@ public class Context {
if(uids.Count() != 2)
return;
IEnumerable<User> users = Users.Where(u => uids.Any(uid => uid == u.UserId));
User? target = users.FirstOrDefault(u => u.UserId != mce.SenderId);
IEnumerable<User> users = Users.GetUsers(uids);
User? target = users.FirstOrDefault(u => mce.SenderId.Equals(u.UserId, StringComparison.Ordinal));
if(target == null)
return;
@ -112,7 +117,7 @@ public class Context {
mce.MessageId,
DateTimeOffset.Now,
mce.SenderId,
mce.SenderId == user.UserId ? $"{target.LegacyName} {mce.MessageText}" : mce.MessageText,
mce.SenderId == user.UserId ? $"{target.GetLegacyName()} {mce.MessageText}" : mce.MessageText,
mce.IsAction,
true
));
@ -140,16 +145,14 @@ public class Context {
}
public async Task Update() {
foreach(Connection conn in Connections)
if(!conn.IsDisposed && conn.HasTimedOut) {
conn.Logger.ZLogInformation($"Nuking connection associated with user {conn.User?.UserId ?? "no-one"}");
conn.Dispose();
}
foreach(Session session in Sessions.GetTimedOutSessions()) {
session.Logger.ZLogInformation($"Nuking connection associated with user #{session.UserId}");
session.Connection.Close(ConnectionCloseReason.TimeOut);
Sessions.DestroySession(session);
}
Connections.RemoveWhere(conn => conn.IsDisposed);
foreach(User user in Users)
if(!Connections.Any(conn => conn.User == user)) {
foreach(User user in Users.GetUsers())
if(Sessions.CountActiveSessions(user) < 1) {
Logger.ZLogInformation($"Timing out user {user.UserId} (no more connections).");
await HandleDisconnect(user, UserDisconnectS2CPacket.Reason.TimeOut);
}
@ -164,27 +167,6 @@ public class Context {
}
}
public bool IsInChannel(User user, Channel channel) {
return ChannelUsers.Contains(new ChannelUserAssoc(user.UserId, channel.Name));
}
public string[] GetUserChannelNames(User user) {
return [.. ChannelUsers.Where(cu => cu.UserId == user.UserId).Select(cu => cu.ChannelName)];
}
public Channel[] GetUserChannels(User user) {
return [.. Channels.GetChannels(GetUserChannelNames(user))];
}
public string[] GetChannelUserIds(Channel channel) {
return [.. ChannelUsers.Where(cu => channel.NameEquals(cu.ChannelName)).Select(cu => cu.UserId)];
}
public User[] GetChannelUsers(Channel channel) {
string[] ids = GetChannelUserIds(channel);
return [.. Users.Where(u => ids.Contains(u.UserId))];
}
public async Task UpdateUser(
User user,
string? userName = null,
@ -196,52 +178,25 @@ public class Context {
UserPermissions? perms = null,
bool silent = false
) {
bool hasChanged = false;
string previousName = string.Empty;
string previousName = user.GetLegacyName();
UserDiff diff = Users.UpdateUser(
user,
userName,
colour,
rank,
perms,
nickName,
status,
statusText
);
if(userName != null && !user.UserName.Equals(userName)) {
user.UserName = userName;
hasChanged = true;
}
if(diff.Changed) {
string currentName = user.GetLegacyNameWithStatus();
if(nickName != null && !user.NickName.Equals(nickName)) {
if(!silent)
previousName = user.LegacyName;
if(!silent && diff.Nick.Changed)
await SendToUserChannels(user, new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.NICKNAME_CHANGE, false, previousName, currentName));
user.NickName = nickName;
hasChanged = true;
}
if(colour.HasValue && user.Colour != colour.Value) {
user.Colour = colour.Value;
hasChanged = true;
}
if(status.HasValue && user.Status != status.Value) {
user.Status = status.Value;
hasChanged = true;
}
if(statusText != null && !user.StatusText.Equals(statusText)) {
user.StatusText = statusText;
hasChanged = true;
}
if(rank != null && user.Rank != rank) {
user.Rank = (int)rank;
hasChanged = true;
}
if(perms.HasValue && user.Permissions != perms) {
user.Permissions = perms.Value;
hasChanged = true;
}
if(hasChanged) {
if(!string.IsNullOrWhiteSpace(previousName))
await SendToUserChannels(user, new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.NICKNAME_CHANGE, false, previousName, user.LegacyNameWithStatus));
await SendToUserChannels(user, new UserUpdateS2CPacket(user.UserId, user.LegacyNameWithStatus, user.Colour, user.Rank, user.Permissions));
await SendToUserChannels(user, new UserUpdateS2CPacket(diff.Id, currentName, diff.Colour.After, diff.Rank.After, diff.Permissions.After));
}
}
@ -252,35 +207,37 @@ public class Context {
} else
await SendTo(user, new ForceDisconnectS2CPacket());
foreach(Connection conn in Connections)
if(conn.User == user)
conn.Dispose();
Connections.RemoveWhere(conn => conn.IsDisposed);
foreach(SockChatConnection conn in Sessions.GetConnections<SockChatConnection>(user)) {
conn.Close(ConnectionCloseReason.Unauthorized);
Sessions.DestroySession(conn);
}
await Update();
await HandleDisconnect(user, reason);
}
public async Task HandleDisconnect(User user, UserDisconnectS2CPacket.Reason reason = UserDisconnectS2CPacket.Reason.Leave) {
await UpdateUser(user, status: UserStatus.Offline);
Users.Remove(user);
UserLastChannel.Remove(user.UserId);
Users.RemoveUser(user);
Channel[] channels = GetUserChannels(user);
foreach(Channel chan in channels) {
ChannelUsers.Remove(new ChannelUserAssoc(user.UserId, chan.Name));
foreach(Channel chan in ChannelsUsers.GetUserChannels(user)) {
ChannelsUsers.RemoveChannelUser(chan, user);
long msgId = RandomSnowflake.Next();
await SendTo(chan, new UserDisconnectS2CPacket(msgId, DateTimeOffset.Now, user.UserId, user.LegacyNameWithStatus, reason));
await SendTo(chan, new UserDisconnectS2CPacket(msgId, DateTimeOffset.Now, user.UserId, user.GetLegacyNameWithStatus(), reason));
await Messages.LogMessage(msgId, "user:disconnect", chan.Name, user.UserId, user.UserName, user.Colour, user.Rank, user.NickName, user.Permissions, new { reason = (int)reason });
if(chan.IsTemporary && chan.IsOwner(user.UserId))
await RemoveChannel(chan);
}
ChannelsUsers.RemoveUser(user);
}
public async Task SwitchChannel(User user, Channel chan, string password) {
if(UserLastChannel.TryGetValue(user.UserId, out Channel? ulc) && chan == ulc) {
Channel? oldChan = ChannelsUsers.GetUserLastChannel(user);
if(oldChan?.Id == chan.Id) {
await ForceChannel(user);
return;
}
@ -292,29 +249,33 @@ public class Context {
return;
}
if(!string.IsNullOrEmpty(chan.Password) && chan.Password != password) {
if(!string.IsNullOrEmpty(chan.Password) && chan.Password.SlowUtf8Equals(password)) {
await SendTo(user, new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.CHANNEL_INVALID_PASSWORD, true, chan.Name));
await ForceChannel(user);
return;
}
}
Channel oldChan = UserLastChannel[user.UserId];
if(oldChan is not null) {
long leaveId = RandomSnowflake.Next();
await SendTo(oldChan, new UserChannelLeaveS2CPacket(leaveId, user.UserId));
await Messages.LogMessage(leaveId, "chan:leave", oldChan.Name, user.UserId, user.UserName, user.Colour, user.Rank, user.NickName, user.Permissions);
ChannelsUsers.RemoveChannelUser(oldChan, user);
long leaveId = RandomSnowflake.Next();
await SendTo(oldChan, new UserChannelLeaveS2CPacket(leaveId, user.UserId));
await Messages.LogMessage(leaveId, "chan:leave", oldChan.Name, user.UserId, user.UserName, user.Colour, user.Rank, user.NickName, user.Permissions);
if(oldChan.IsTemporary && oldChan.IsOwner(user.UserId))
await RemoveChannel(oldChan);
}
long joinId = RandomSnowflake.Next();
await SendTo(chan, new UserChannelJoinS2CPacket(joinId, user.UserId, user.LegacyNameWithStatus, user.Colour, user.Rank, user.Permissions));
await Messages.LogMessage(joinId, "chan:join", chan.Name, user.UserId, user.LegacyName, user.Colour, user.Rank, user.NickName, user.Permissions);
await SendTo(chan, new UserChannelJoinS2CPacket(joinId, user.UserId, user.GetLegacyNameWithStatus(), user.Colour, user.Rank, user.Permissions));
await Messages.LogMessage(joinId, "chan:join", chan.Name, user.UserId, user.GetLegacyName(), user.Colour, user.Rank, user.NickName, user.Permissions);
await SendTo(user, new ContextClearS2CPacket(ContextClearS2CPacket.Mode.MessagesUsers));
await SendTo(user, new ContextUsersS2CPacket(
GetChannelUsers(chan).Except([user]).OrderByDescending(u => u.Rank)
ChannelsUsers.GetChannelUsers(chan).Except([user]).OrderByDescending(u => u.Rank)
.Select(u => new ContextUsersS2CPacket.Entry(
u.UserId,
u.LegacyNameWithStatus,
u.GetLegacyNameWithStatus(),
u.Colour,
u.Rank,
u.Permissions,
@ -327,49 +288,39 @@ public class Context {
await SendTo(user, new ContextMessageS2CPacket(msg));
await ForceChannel(user, chan);
ChannelUsers.Remove(new ChannelUserAssoc(user.UserId, oldChan.Name));
ChannelUsers.Add(new ChannelUserAssoc(user.UserId, chan.Name));
UserLastChannel[user.UserId] = chan;
if(oldChan.IsTemporary && oldChan.IsOwner(user.UserId))
await RemoveChannel(oldChan);
ChannelsUsers.AddChannelUser(chan, user);
}
public async Task Send(S2CPacket packet) {
foreach(Connection conn in Connections)
if(conn.IsAlive && conn.User is not null)
await conn.Send(packet);
foreach(SockChatConnection conn in Sessions.GetConnections<SockChatConnection>())
await conn.Send(packet);
}
public async Task SendTo(User user, S2CPacket packet) {
foreach(Connection conn in Connections)
if(conn.IsAlive && conn.User == user)
await conn.Send(packet);
foreach(SockChatConnection conn in Sessions.GetConnections<SockChatConnection>(user))
await conn.Send(packet);
}
public async Task SendTo(Channel channel, S2CPacket packet) {
// might be faster to grab the users first and then cascade into that SendTo
IEnumerable<Connection> conns = Connections.Where(c => c.IsAlive && c.User is not null && IsInChannel(c.User, channel));
foreach(Connection conn in conns)
IEnumerable<SockChatConnection> conns = Sessions.GetConnections<SockChatConnection>(
s => ChannelsUsers.HasChannelUser(channel, s.UserId)
);
foreach(SockChatConnection conn in conns)
await conn.Send(packet);
}
public async Task SendToUserChannels(User user, S2CPacket packet) {
IEnumerable<Channel> chans = Channels.GetChannels(c => IsInChannel(user, c));
IEnumerable<Connection> conns = Connections.Where(conn => conn.IsAlive && conn.User is not null && ChannelUsers.Any(cu => cu.UserId == conn.User.UserId && chans.Any(chan => chan.NameEquals(cu.ChannelName))));
foreach(Connection conn in conns)
IEnumerable<Channel> chans = ChannelsUsers.GetUserChannels(user);
IEnumerable<SockChatConnection> conns = Sessions.GetConnections<SockChatConnection>(
s => chans.Any(c => ChannelsUsers.HasChannelUser(c.Id, s.UserId))
);
foreach(SockChatConnection conn in conns)
await conn.Send(packet);
}
public IPAddress[] GetRemoteAddresses(User user) {
return [.. Connections.Where(c => c.IsAlive && c.User == user).Select(c => c.RemoteEndPoint.Address).Distinct()];
}
public async Task ForceChannel(User user, Channel? chan = null) {
if(chan == null && !UserLastChannel.TryGetValue(user.UserId, out chan))
throw new ArgumentException("no channel???");
chan ??= ChannelsUsers.GetUserLastChannel(user) ?? throw new ArgumentException("no channel???");
await SendTo(user, new UserChannelForceJoinS2CPacket(chan.Name));
}
@ -379,7 +330,7 @@ public class Context {
int? rank = null,
string? password = null
) {
Channels.UpdateChannel(
ChannelDiff diff = Channels.UpdateChannel(
channel,
temporary: temporary,
rank: rank,
@ -387,21 +338,24 @@ public class Context {
);
// TODO: Users that no longer have access to the channel/gained access to the channel by the hierarchy change should receive delete and create packets respectively
foreach(User user in Users.Where(u => u.Rank >= channel.Rank))
await SendTo(user, new ChannelUpdateS2CPacket(channel.Name, channel.Name, channel.HasPassword, channel.IsTemporary));
if(diff.Changed)
foreach(User user in Users.GetUsersOfMinimumRank(channel.Rank))
await SendTo(user, new ChannelUpdateS2CPacket(channel.Name, channel.Name, channel.HasPassword, channel.IsTemporary));
}
public async Task RemoveChannel(Channel channel) {
// Remove channel from the listing
Channels.RemoveChannel(channel.Name);
Channels.RemoveChannel(channel);
// Move all users back to the main channel
// TODO: Replace this with a kick. SCv2 supports being in 0 channels, SCv1 should force the user back to DefaultChannel.
foreach(User user in GetChannelUsers(channel))
await SwitchChannel(user, Channels.DefaultChannel, string.Empty);
foreach(User user in ChannelsUsers.GetChannelUsers(channel))
await SwitchChannel(user, Channels.GetDefaultChannel(), string.Empty);
// Broadcast deletion of channel
foreach(User user in Users.Where(u => u.Rank >= channel.Rank))
foreach(User user in Users.GetUsersOfMinimumRank(channel.Rank))
await SendTo(user, new ChannelDeleteS2CPacket(channel.Name));
ChannelsUsers.RemoveChannel(channel);
}
}

View file

@ -1,3 +1,5 @@
using SharpChat.Users;
namespace SharpChat.Events;
public class MessageCreateEvent(

View file

@ -4,6 +4,7 @@ using SharpChat.Configuration;
using SharpChat.Flashii;
using SharpChat.MariaDB;
using SharpChat.SQLite;
using SharpChat.Storage;
using System.Text;
using ZLogger;
using ZLogger.Providers;
@ -172,7 +173,7 @@ FlashiiClient flashii = new(logFactory.CreateLogger("flashii"), httpClient, conf
if(cts.IsCancellationRequested) return;
logger.ZLogInformation($"Initialising storage...");
Storage storage = string.IsNullOrWhiteSpace(config.SafeReadValue("mariadb:host", string.Empty))
StorageBackend storage = string.IsNullOrWhiteSpace(config.SafeReadValue("mariadb:host", string.Empty))
? new SQLiteStorage(logFactory.CreateLogger("sqlite"), SQLiteStorage.BuildConnectionString(config.ScopeTo("sqlite")))
: new MariaDBStorage(logFactory.CreateLogger("mariadb"), MariaDBStorage.BuildConnectionString(config.ScopeTo("mariadb")));

View file

@ -0,0 +1,36 @@
using Fleck;
using Microsoft.Extensions.Logging;
using SharpChat.Connections;
using SharpChat.SockChat;
using System.Net;
namespace SharpChat;
public class SockChatConnection(IWebSocketConnection sock, IPEndPoint remoteEndPoint, ILogger logger) : Connection {
public IWebSocketConnection Socket { get; } = sock;
public IPEndPoint RemoteEndPoint { get; } = remoteEndPoint;
public ILogger Logger { get; } = logger;
public async Task Send(S2CPacket packet) {
if(!Socket.IsAvailable)
return;
string data = packet.Pack();
if(!string.IsNullOrWhiteSpace(data))
await Socket.Send(data);
}
public void Close(ConnectionCloseReason reason = ConnectionCloseReason.Unexpected) {
Socket.Close(reason switch {
ConnectionCloseReason.ShuttingDown => WebSocketCloseCode.GoingAway,
ConnectionCloseReason.Error => WebSocketCloseCode.InternalError,
ConnectionCloseReason.Restarting => WebSocketCloseCode.ServiceRestart,
ConnectionCloseReason.Unavailable => WebSocketCloseCode.ServiceRestart,
ConnectionCloseReason.Unauthorized => WebSocketCloseCode.Unauthorized,
ConnectionCloseReason.TimeOut => WebSocketCloseCode.Timeout,
ConnectionCloseReason.AccessDenied => WebSocketCloseCode.Forbidden,
ConnectionCloseReason.TooManyConnections => WebSocketCloseCode.TryAgainLater,
_ => WebSocketCloseCode.NormalClosure,
});
}
}

View file

@ -3,7 +3,10 @@ using SharpChat.Bans;
using SharpChat.C2SPacketHandlers;
using SharpChat.ClientCommands;
using SharpChat.Configuration;
using SharpChat.Sessions;
using SharpChat.SockChat;
using SharpChat.SockChat.S2CPackets;
using SharpChat.Users;
using System.Net;
using ZLogger;
@ -14,10 +17,15 @@ public class SockChatServer {
public Context Context { get; }
public bool IsRestarting { get; set; }
private readonly ILogger Logger;
private readonly CachedValue<ushort> Port;
private readonly Lock ConnectionsLock = new();
private readonly HashSet<SockChatConnection> Connections = [];
private readonly List<C2SPacketHandler> GuestHandlers = [];
private readonly List<C2SPacketHandler> AuthedHandlers = [];
private readonly SendMessageC2SPacketHandler SendMessageHandler;
@ -70,16 +78,18 @@ public class SockChatServer {
new PardonAddressClientCommand(Context.Bans),
new BanListClientCommand(Context.Bans),
new RemoteAddressClientCommand(),
new ShutdownRestartClientCommand(cancellationTokenSource)
new ShutdownRestartClientCommand(this, cancellationTokenSource)
]);
}
public async Task Listen(CancellationToken cancellationToken) {
// TODO: protocol servers are now responsible of timing out unauthed connections by themselves
using SharpChatWebSocketServer server = new(Context.LoggerFactory.CreateLogger("sockchat:server"), $"ws://0.0.0.0:{Port}");
server.Start(sock => {
if(!IPAddress.TryParse(sock.ConnectionInfo.ClientIpAddress, out IPAddress? addr)) {
Logger.ZLogError($@"A client attempted to connect with an invalid IP address: ""{sock.ConnectionInfo.ClientIpAddress}""");
sock.Close(1011);
sock.Close(WebSocketCloseCode.InternalError);
return;
}
@ -91,51 +101,69 @@ public class SockChatServer {
}
IPEndPoint endPoint = new(addr, sock.ConnectionInfo.ClientPort);
ILogger logger = Context.LoggerFactory.CreateLogger($"sockchat:({endPoint})");
if(cancellationToken.IsCancellationRequested) {
Logger.ZLogInformation($"{endPoint} attepted to connect after shutdown was requested. Connection will be dropped.");
sock.Close(1013);
logger.ZLogInformation($"{endPoint} attempted to connect after shutdown was requested. Connection will be dropped.");
sock.Close(WebSocketCloseCode.TryAgainLater);
return;
}
ILogger logger = Context.LoggerFactory.CreateLogger($"sockchat:({endPoint})");
Connection conn = new(logger, sock, endPoint);
Context.Connections.Add(conn);
lock(ConnectionsLock) {
SockChatConnection conn = new(sock, endPoint, logger);
Connections.Add(conn);
sock.OnOpen = () => OnOpen(conn).Wait();
sock.OnClose = () => OnClose(conn).Wait();
sock.OnError = err => OnError(conn, err).Wait();
sock.OnMessage = msg => OnMessage(conn, msg).Wait();
sock.OnOpen = () => OnOpen(conn).Wait();
sock.OnClose = () => OnClose(conn).Wait();
sock.OnError = err => OnError(conn, err).Wait();
sock.OnMessage = msg => OnMessage(conn, msg).Wait();
}
});
Logger.ZLogInformation($"Listening...");
await Task.Delay(Timeout.Infinite, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
Logger.ZLogDebug($"Disposing all clients...");
foreach(Connection conn in Context.Connections)
conn.Dispose();
lock(ConnectionsLock) {
Logger.ZLogDebug($"Disposing all clients...");
foreach(SockChatConnection conn in Connections)
conn.Socket.Close(IsRestarting ? WebSocketCloseCode.ServiceRestart : WebSocketCloseCode.GoingAway);
}
}
private async Task OnOpen(Connection conn) {
private async Task OnOpen(SockChatConnection conn) {
conn.Logger.ZLogInformation($"Connection opened.");
await Context.SafeUpdate();
}
private async Task OnError(Connection conn, Exception ex) {
private async Task OnError(SockChatConnection conn, Exception ex) {
// TODO: detect timeouts and suspend the session
conn.Logger.ZLogError($"Error: {ex.Message}");
conn.Logger.ZLogDebug($"{ex}");
await Context.SafeUpdate();
}
private async Task OnClose(Connection conn) {
private async Task OnClose(SockChatConnection conn) {
conn.Logger.ZLogInformation($"Connection closed.");
User? noMoreSessionsUser = null;
lock(ConnectionsLock) {
Connections.Remove(conn);
Session? session = Context.Sessions.GetSession(conn);
if(session is not null) {
if(!Context.Sessions.IsSuspendedSession(session))
Context.Sessions.DestroySession(session);
if(Context.Sessions.CountActiveSessions(session.UserId) < 1)
noMoreSessionsUser = Context.Users.GetUser(session.UserId);
}
}
Context.ContextAccess.Wait();
try {
Context.Connections.Remove(conn);
if(conn.User != null && !Context.Connections.Any(c => c.User == conn.User))
await Context.HandleDisconnect(conn.User);
if(noMoreSessionsUser is not null)
await Context.HandleDisconnect(noMoreSessionsUser);
await Context.Update();
} finally {
@ -143,51 +171,50 @@ public class SockChatServer {
}
}
private async Task OnMessage(Connection conn, string msg) {
private async Task OnMessage(SockChatConnection conn, string msg) {
conn.Logger.ZLogTrace($"Received: {msg}");
await Context.SafeUpdate();
Session? session = Context.Sessions.GetSession(conn);
User? user = session is null ? null : Context.Users.GetUser(session.UserId);
// this doesn't affect non-authed connections?????
if(conn.User is not null && conn.User.Rank < Context.FloodKickExemptRank) {
User? banUser = null;
if(user is not null && user.Rank < Context.FloodKickExemptRank) {
bool rusticate = false;
string banAddr = string.Empty;
TimeSpan banDuration = TimeSpan.MinValue;
Context.ContextAccess.Wait();
try {
if(!Context.UserRateLimiters.TryGetValue(conn.User.UserId, out RateLimiter? rateLimiter))
Context.UserRateLimiters.Add(conn.User.UserId, rateLimiter = new RateLimiter(
User.DEFAULT_SIZE,
User.DEFAULT_MINIMUM_DELAY,
User.DEFAULT_RISKY_OFFSET
));
if(!Context.UserRateLimiters.TryGetValue(user.UserId, out RateLimiter? rateLimiter))
Context.UserRateLimiters.Add(user.UserId, rateLimiter = new RateLimiter());
rateLimiter.Update();
if(rateLimiter.IsExceeded) {
banDuration = TimeSpan.FromSeconds(Context.FloodKickLength);
banUser = conn.User;
rusticate = true;
banAddr = conn.RemoteEndPoint.Address.ToString();
conn.Logger.ZLogWarning($"Exceeded flood limit! Issuing ban with duration {banDuration} on {banAddr}/{banUser.UserId}...");
conn.Logger.ZLogWarning($"Exceeded flood limit! Issuing ban with duration {banDuration} on {banAddr}/{user.UserId}...");
} else if(rateLimiter.IsRisky) {
banUser = conn.User;
rusticate = true;
banAddr = conn.RemoteEndPoint.Address.ToString();
conn.Logger.ZLogWarning($"About to exceed flood limit! Issueing warning to {banAddr}/{banUser.UserId}...");
conn.Logger.ZLogWarning($"About to exceed flood limit! Issueing warning to {banAddr}/{user.UserId}...");
}
if(banUser is not null) {
if(rusticate) {
if(banDuration == TimeSpan.MinValue) {
await Context.SendTo(conn.User, new CommandResponseS2CPacket(Context.RandomSnowflake.Next(), LCR.FLOOD_WARN, false));
await Context.SendTo(user, new CommandResponseS2CPacket(Context.RandomSnowflake.Next(), LCR.FLOOD_WARN, false));
} else {
await Context.BanUser(conn.User, banDuration, UserDisconnectS2CPacket.Reason.Flood);
await Context.BanUser(user, banDuration, UserDisconnectS2CPacket.Reason.Flood);
if(banDuration > TimeSpan.Zero)
await Context.Bans.BanCreate(
BanKind.User,
banDuration,
conn.RemoteEndPoint.Address,
conn.User.UserId,
user.UserId,
"Kicked from chat for flood protection.",
IPAddress.IPv6Loopback
);
@ -200,8 +227,8 @@ public class SockChatServer {
}
}
C2SPacketHandlerContext context = new(msg, Context, conn);
C2SPacketHandler? handler = conn.User is null
C2SPacketHandlerContext context = new(msg, Context, conn, session, session?.Logger ?? conn.Logger);
C2SPacketHandler? handler = user is null
? GuestHandlers.FirstOrDefault(h => h.IsMatch(context))
: AuthedHandlers.FirstOrDefault(h => h.IsMatch(context));

View file

@ -1,68 +0,0 @@
using SharpChat.ClientCommands;
using System.Globalization;
using System.Text;
namespace SharpChat;
public class User(
string userId,
string userName,
ColourInheritable colour,
int rank,
UserPermissions perms,
string nickName = "",
UserStatus status = UserStatus.Online,
string statusText = ""
) {
public const int DEFAULT_SIZE = 30;
public const int DEFAULT_MINIMUM_DELAY = 10000;
public const int DEFAULT_RISKY_OFFSET = 5;
public string UserId { get; } = userId;
public string UserName { get; set; } = userName ?? throw new ArgumentNullException(nameof(userName));
public ColourInheritable Colour { get; set; } = colour;
public int Rank { get; set; } = rank;
public UserPermissions Permissions { get; set; } = perms;
public string NickName { get; set; } = nickName;
public UserStatus Status { get; set; } = status;
public string StatusText { get; set; } = statusText;
public string LegacyName => string.IsNullOrWhiteSpace(NickName) ? UserName : $"~{NickName}";
public string LegacyNameWithStatus {
get {
StringBuilder sb = new();
if(Status == UserStatus.Away) {
string statusText = StatusText.Trim();
StringInfo sti = new(statusText);
if(Encoding.UTF8.GetByteCount(statusText) > AFKClientCommand.MAX_BYTES
|| sti.LengthInTextElements > AFKClientCommand.MAX_GRAPHEMES)
statusText = sti.SubstringByTextElements(0, Math.Min(sti.LengthInTextElements, AFKClientCommand.MAX_GRAPHEMES)).Trim();
sb.AppendFormat("&lt;{0}&gt;_", statusText.ToUpperInvariant());
}
sb.Append(LegacyName);
return sb.ToString();
}
}
public bool NameEquals(string name) {
return string.Equals(name, UserName, StringComparison.InvariantCultureIgnoreCase)
|| string.Equals(name, NickName, StringComparison.InvariantCultureIgnoreCase)
|| string.Equals(name, LegacyName, StringComparison.InvariantCultureIgnoreCase)
|| string.Equals(name, LegacyNameWithStatus, StringComparison.InvariantCultureIgnoreCase);
}
public override int GetHashCode() {
return UserId.GetHashCode();
}
public static string GetDMChannelName(User user1, User user2) {
return string.Compare(user1.UserId, user2.UserId, StringComparison.InvariantCultureIgnoreCase) > 0
? $"@{user2.UserId}-{user1.UserId}"
: $"@{user1.UserId}-{user2.UserId}";
}
}

View file

@ -0,0 +1,23 @@
using SharpChat.ClientCommands;
using SharpChat.Users;
namespace SharpChat;
public static class UserExtensions {
public static string GetLegacyName(this User user) {
return string.IsNullOrWhiteSpace(user.NickName) ? user.UserName : '~' + user.NickName;
}
public static string GetLegacyNameWithStatus(this User user) {
return user.Status == UserStatus.Away ? string.Format(
"&lt;{0}&gt;_{1}",
user.StatusText.TruncateIfTooLong(AFKClientCommand.MAX_GRAPHEMES, AFKClientCommand.MAX_BYTES).Trim().ToUpperInvariant(),
user.GetLegacyName()
) : user.GetLegacyName();
}
public static bool LegacyNameEquals(this User user, string name) {
return user.NameEquals(name)
|| string.Equals(name, user.GetLegacyName(), StringComparison.OrdinalIgnoreCase)
|| string.Equals(name, user.GetLegacyNameWithStatus(), StringComparison.OrdinalIgnoreCase);
}
}

View file

@ -0,0 +1,13 @@
using SharpChat.Users;
namespace SharpChat;
public static class UsersContextExtensions {
public static bool UserWithLegacyNameExists(this UsersContext ctx, string name) {
return ctx.UserExists(u => u.LegacyNameEquals(name));
}
public static User? GetUserByLegacyName(this UsersContext ctx, string name) {
return ctx.GetUser(u => u.LegacyNameEquals(name));
}
}

View file

@ -1,3 +1,5 @@
using SharpChat.Users;
namespace SharpChat.Auth;
public interface AuthResult {

View file

@ -0,0 +1,22 @@
using System.Text;
namespace SharpChat;
public static class ByteArrayExtensions {
public static string GetUtf8String(this byte[] buffer) {
return Encoding.UTF8.GetString(buffer);
}
public static bool SlowEquals(this byte[] buffer, byte[] other) {
if(buffer.Length != other.Length)
return false;
int i = 0;
int result = 0;
while(i < buffer.Length) {
result |= buffer[i] ^ other[i];
++i;
}
return result == 0;
}
}

View file

@ -1,12 +1,21 @@
using SharpChat.Users;
namespace SharpChat.Channels;
public class Channel(
long id,
string name,
string password = "",
bool isTemporary = false,
int rank = 0,
string ownerId = ""
) {
/// <summary>
/// Ephemeral unique identifier.
/// ONLY use this to refer to the channel during runtime, not intended for long term storage!!!!!
/// </summary>
public long Id { get; } = id;
public string Name { get; internal set; } = name;
public string Password { get; internal set; } = password ?? string.Empty;
public bool IsTemporary { get; internal set; } = isTemporary;
@ -20,17 +29,13 @@ public class Channel(
=> !HasPassword && Rank < 1;
public bool NameEquals(string name) {
return string.Equals(name, Name, StringComparison.InvariantCultureIgnoreCase);
return string.Equals(name, Name, StringComparison.OrdinalIgnoreCase);
}
public bool IsOwner(string userId) {
return !string.IsNullOrEmpty(OwnerId)
&& !string.IsNullOrEmpty(userId)
&& OwnerId == userId;
}
public override int GetHashCode() {
return Name.GetHashCode();
&& OwnerId.Equals(userId, StringComparison.Ordinal);
}
public static bool CheckName(string name) {

View file

@ -0,0 +1,17 @@
namespace SharpChat.Channels;
public readonly record struct ChannelDiff(
Channel Channel,
StringDiff Name,
StringDiff Password,
ValueDiff<bool> IsTemporary,
ValueDiff<int> MinimumRank,
StringDiff OwnerId
) : Diff {
public bool Changed
=> Name.Changed
|| Password.Changed
|| IsTemporary.Changed
|| MinimumRank.Changed
|| OwnerId.Changed;
}

View file

@ -1,44 +1,62 @@
using SharpChat.Snowflake;
namespace SharpChat.Channels;
public class ChannelsContext {
private readonly List<Channel> Channels = [];
public int Count => Channels.Count;
private Channel? DefaultChannelValue;
public Channel DefaultChannel {
get {
if(DefaultChannelValue is not null) {
if(Channels.Contains(DefaultChannelValue))
return DefaultChannelValue;
DefaultChannelValue = null;
}
return GetChannel(c => c.IsPublic && !c.IsTemporary) ?? throw new NoDefaultChannelException();
}
set => DefaultChannelValue = value;
}
public class ChannelsContext(RandomSnowflake snowflake) {
private readonly Dictionary<long, Channel> Channels = [];
private Channel? DefaultChannel = null;
private readonly Lock @lock = new();
public bool ChannelExists(Func<Channel, bool> predicate) {
return Channels.Any(predicate);
lock(@lock)
return Channels.Values.Any(predicate);
}
public bool ChannelExists(string name) {
return ChannelExists(c => c.NameEquals(name));
}
public Channel GetDefaultChannel() {
lock(@lock) {
DefaultChannel ??= GetChannel(c => c.IsPublic && !c.IsTemporary) ?? throw new NoDefaultChannelException();
return DefaultChannel;
}
}
public void SetDefaultChannel(Channel channel)
=> SetDefaultChannel(channel.Name);
public void SetDefaultChannel(string channelName) {
lock(@lock)
DefaultChannel = GetChannel(channelName) ?? throw new ChannelNotFoundException(nameof(channelName));
}
public Channel? GetChannel(Func<Channel, bool> predicate) {
return Channels.FirstOrDefault(predicate);
lock(@lock)
return Channels.Values.FirstOrDefault(predicate);
}
public Channel? GetChannel(string name) {
return GetChannel(c => c.NameEquals(name));
}
public Channel? GetChannel(long id) {
lock(@lock)
return Channels.TryGetValue(id, out Channel? channel) ? channel : null;
}
public IEnumerable<Channel> GetChannels() {
lock(@lock)
return [.. Channels.Values];
}
public IEnumerable<Channel> GetChannels(IEnumerable<long> ids) {
return [.. ids.Select(GetChannel).Where(c => c is not null).Cast<Channel>()];
}
public IEnumerable<Channel> GetChannels(Func<Channel, bool> predicate) {
return Channels.Where(predicate);
lock(@lock)
return [.. Channels.Values.Where(predicate)];
}
public IEnumerable<Channel> GetChannels(IEnumerable<string> names) {
@ -54,73 +72,122 @@ public class ChannelsContext {
string password = "",
bool temporary = false,
int rank = 0,
string ownerId = ""
string ownerId = "",
long? id = null
) {
if(!Channel.CheckName(name))
throw new ChannelNameFormatException(nameof(name));
if(ChannelExists(name))
throw new ChannelExistsException(nameof(name));
Channel channel = new(
name,
password ?? string.Empty,
temporary,
rank,
ownerId ?? string.Empty
);
lock(@lock) {
if(ChannelExists(name))
throw new ChannelExistsException(nameof(name));
Channels.Add(channel);
id ??= snowflake.Next();
if(Channels.ContainsKey(id.Value))
throw new ChannelExistsException(nameof(id));
return channel;
Channel channel = new(id.Value, name, password, temporary, rank, ownerId);
Channels.Add(id.Value, channel);
return channel;
}
}
public Channel UpdateChannel(
Channel channel,
public ChannelDiff UpdateChannel(
long id,
string? name = null,
string? password = null,
bool? temporary = null,
int? rank = null,
string? ownerId = null
) => UpdateChannel(channel.Name, name, password, temporary, rank, ownerId);
) => UpdateChannelInternal(
GetChannel(id) ?? throw new ChannelNotFoundException(nameof(id)),
name, password, temporary, rank, ownerId
);
public Channel UpdateChannel(
public ChannelDiff UpdateChannel(
string currentName,
string? name = null,
string? password = null,
bool? temporary = null,
int? rank = null,
string? ownerId = null
) => UpdateChannelInternal(
GetChannel(currentName) ?? throw new ChannelNotFoundException(nameof(currentName)),
name, password, temporary, rank, ownerId
);
public ChannelDiff UpdateChannel(
Channel channel,
string? name = null,
string? password = null,
bool? temporary = null,
int? rank = null,
string? ownerId = null
) => UpdateChannel(channel.Id, name, password, temporary, rank, ownerId);
private ChannelDiff UpdateChannelInternal(
Channel channel,
string? name = null,
string? password = null,
bool? temporary = null,
int? rank = null,
string? ownerId = null
) {
Channel channel = GetChannel(currentName) ?? throw new ChannelNotFoundException(nameof(currentName));
lock(@lock) {
StringDiff nameDiff = new(channel.Name, name);
if(nameDiff.Changed) {
if(!Channel.CheckName(nameDiff.After))
throw new ChannelNameFormatException(nameof(name));
if(ChannelExists(nameDiff.After))
throw new ChannelExistsException(nameof(name));
if(name is not null && currentName != name) {
if(!Channel.CheckName(name))
throw new ChannelNameFormatException(nameof(name));
if(ChannelExists(name))
throw new ChannelExistsException(nameof(name));
channel.Name = nameDiff.After;
}
StringDiff passwordDiff = new(channel.Password, password);
if(passwordDiff.Changed)
channel.Password = passwordDiff.After;
ValueDiff<bool> temporaryDiff = new(channel.IsTemporary, temporary);
if(temporaryDiff.Changed)
channel.IsTemporary = temporaryDiff.After;
ValueDiff<int> rankDiff = new(channel.Rank, rank);
if(rankDiff.Changed)
channel.Rank = rankDiff.After;
StringDiff ownerIdDiff = new(channel.OwnerId, ownerId);
if(ownerIdDiff.Changed)
channel.OwnerId = ownerIdDiff.After;
return new(
channel,
nameDiff,
passwordDiff,
temporaryDiff,
rankDiff,
ownerIdDiff
);
}
if(name is not null)
channel.Name = name;
if(password is not null)
channel.Password = password;
if(temporary.HasValue)
channel.IsTemporary = temporary.Value;
if(rank.HasValue)
channel.Rank = rank.Value;
if(ownerId is not null)
channel.OwnerId = ownerId;
return channel;
}
public void RemoveChannel(string name) {
Channel channel = GetChannel(name) ?? throw new ChannelNotFoundException(nameof(name));
public void RemoveChannel(long id)
=> RemoveChannelInternal(GetChannel(id) ?? throw new ChannelNotFoundException(nameof(id)), nameof(id));
Channel defaultChannel = DefaultChannel;
if(channel == defaultChannel || defaultChannel.NameEquals(channel.Name))
throw new ChannelIsDefaultException(nameof(name));
public void RemoveChannel(Channel channel)
=> RemoveChannel(channel.Name);
Channels.Remove(channel);
public void RemoveChannel(string name)
=> RemoveChannelInternal(GetChannel(name) ?? throw new ChannelNotFoundException(nameof(name)), nameof(name));
private void RemoveChannelInternal(Channel channel, string argName) {
lock(@lock) {
Channel defaultChannel = GetDefaultChannel();
if(channel == defaultChannel || defaultChannel.NameEquals(channel.Name))
throw new ChannelIsDefaultException(argName);
Channels.Remove(channel.Id);
}
}
}

View file

@ -0,0 +1,176 @@
using SharpChat.Users;
namespace SharpChat.Channels;
public class ChannelsUsersContext(ChannelsContext channelsCtx, UsersContext usersCtx) {
private readonly Dictionary<string, Dictionary<long, DateTimeOffset>> UserChannels = [];
private readonly Dictionary<long, HashSet<string>> ChannelUsers = [];
private readonly Dictionary<string, long> UserLastChannel = [];
private readonly Lock @lock = new();
private void UpdateUserLastChannel(string userId, long? channelId = null) {
if(channelId.HasValue
&& UserLastChannel.TryGetValue(userId, out long lastChannelId)
&& channelId.Value != lastChannelId)
return;
if(UserChannels.TryGetValue(userId, out var userChannels) && userChannels.Count > 0)
UserLastChannel[userId] = userChannels.OrderByDescending(kvp => kvp.Value).First().Key;
else
UserLastChannel.Remove(userId);
}
public void RemoveChannel(Channel channel)
=> RemoveChannel(channel.Id);
public void RemoveChannel(long channelId) {
lock(@lock) {
if(!ChannelUsers.TryGetValue(channelId, out var channelUsers))
return;
ChannelUsers.Remove(channelId);
foreach(string userId in channelUsers)
if(UserChannels.TryGetValue(userId, out var userChannels)) {
userChannels.Remove(channelId);
UpdateUserLastChannel(userId, channelId);
}
}
}
public void RemoveUser(User user)
=> RemoveUser(user.UserId);
public void RemoveUser(string userId) {
lock(@lock) {
if(!UserChannels.TryGetValue(userId, out var userChannels))
return;
UserChannels.Remove(userId);
UserLastChannel.Remove(userId);
foreach(long channelId in userChannels.Keys)
if(ChannelUsers.TryGetValue(channelId, out var channelUsers))
channelUsers.Remove(userId);
}
}
public void RecordChannelUserActivity(Channel channel, User user)
=> RecordChannelUserActivity(channel.Id, user.UserId);
public void RecordChannelUserActivity(long channelId, User user)
=> RecordChannelUserActivity(channelId, user.UserId);
public void RecordChannelUserActivity(Channel channel, string userId)
=> RecordChannelUserActivity(channel.Id, userId);
public void RecordChannelUserActivity(long channelId, string userId) {
lock(@lock) {
if(!UserChannels.TryGetValue(userId, out var userChannels))
throw new ArgumentException("Attempted to register activity for non-existent user.", nameof(userId));
userChannels[channelId] = DateTimeOffset.UtcNow;
UserLastChannel[userId] = channelId;
}
}
public void AddChannelUser(Channel channel, User user)
=> AddChannelUser(channel.Id, user.UserId);
public void AddChannelUser(long channelId, string userId) {
lock(@lock) {
if(!ChannelUsers.TryGetValue(channelId, out var channelUsers))
ChannelUsers.Add(channelId, channelUsers = []);
if(!UserChannels.ContainsKey(userId))
UserChannels.Add(userId, []);
channelUsers.Add(userId);
RecordChannelUserActivity(channelId, userId);
}
}
public void RemoveChannelUser(Channel channel, User user)
=> RemoveChannelUser(channel.Id, user.UserId);
public void RemoveChannelUser(long channelId, User user)
=> RemoveChannelUser(channelId, user.UserId);
public void RemoveChannelUser(Channel channel, string userId)
=> RemoveChannelUser(channel.Id, userId);
public void RemoveChannelUser(long channelId, string userId) {
lock(@lock) {
ChannelUsers.Remove(channelId);
UserChannels.Remove(userId);
UpdateUserLastChannel(userId, channelId);
}
}
public bool HasChannelUser(Channel channel, User user)
=> HasChannelUser(channel.Id, user.UserId);
public bool HasChannelUser(long channelId, User user)
=> HasChannelUser(channelId, user.UserId);
public bool HasChannelUser(Channel channel, string userId)
=> HasChannelUser(channel.Id, userId);
public bool HasChannelUser(long channelId, string userId) {
lock(@lock)
return ChannelUsers.TryGetValue(channelId, out var channelUsers) && channelUsers.Contains(userId);
}
public long? GetUserLastChannelId(User user)
=> GetUserLastChannelId(user.UserId);
public long? GetUserLastChannelId(string userId) {
lock(@lock)
return UserLastChannel.TryGetValue(userId, out long channelId) ? channelId : null;
}
public Channel? GetUserLastChannel(User user)
=> GetUserLastChannel(user.UserId);
public Channel? GetUserLastChannel(string userId) {
lock(@lock) {
long? channelId = GetUserLastChannelId(userId);
return channelId.HasValue ? channelsCtx.GetChannel(channelId.Value) : null;
}
}
public IEnumerable<string> GetChannelUserIds(Channel channel)
=> GetChannelUserIds(channel.Id);
public IEnumerable<string> GetChannelUserIds(long channelId) {
lock(@lock)
return [.. GetChannelUserIdsInternal(channelId)];
}
private HashSet<string> GetChannelUserIdsInternal(long channelId) {
return ChannelUsers.TryGetValue(channelId, out var channelUsers) ? channelUsers : [];
}
public IEnumerable<User> GetChannelUsers(Channel channel) {
lock(@lock)
return usersCtx.GetUsers(GetChannelUserIdsInternal(channel.Id));
}
public IEnumerable<User> GetChannelUsers(long channelId) {
lock(@lock)
return usersCtx.GetUsers(GetChannelUserIdsInternal(channelId));
}
public IEnumerable<long> GetUserChannelIds(User user)
=> GetUserChannelIds(user.UserId);
public IEnumerable<long> GetUserChannelIds(string userId) {
lock(@lock)
return [.. GetUserChannelIdsInternal(userId)];
}
private IEnumerable<long> GetUserChannelIdsInternal(string userId) {
return UserChannels.TryGetValue(userId, out var userChannels) ? userChannels.Keys : [];
}
public IEnumerable<Channel> GetUserChannels(User user) {
lock(@lock)
return channelsCtx.GetChannels(GetUserChannelIdsInternal(user.UserId));
}
public IEnumerable<Channel> GetUserChannels(string userId) {
lock(@lock)
return channelsCtx.GetChannels(GetUserChannelIdsInternal(userId));
}
}

View file

@ -0,0 +1,8 @@
using System.Net;
namespace SharpChat.Connections;
public interface Connection {
IPEndPoint RemoteEndPoint { get; }
void Close(ConnectionCloseReason reason = ConnectionCloseReason.Unexpected);
}

View file

@ -0,0 +1,14 @@
namespace SharpChat.Connections;
public enum ConnectionCloseReason {
Unexpected,
Normal,
ShuttingDown,
Restarting,
TimeOut,
Unauthorized,
Unavailable,
AccessDenied,
TooManyConnections,
Error,
}

View file

@ -0,0 +1,40 @@
using SharpChat.Snowflake;
using System.Net;
namespace SharpChat.Connections;
public class ConnectionsContext {
private readonly Dictionary<long, Connection> Connections = [];
private readonly Lock @lock = new();
public Connection? GetConnection(long connId) {
lock(@lock)
return Connections.TryGetValue(connId, out Connection? conn) ? conn : null;
}
public IEnumerable<Connection> GetConnections() {
lock(@lock)
return [.. Connections.Values];
}
public IEnumerable<Connection> GetConnections(IEnumerable<long> ids) {
return [.. ids.Select(GetConnection).Where(c => c is not null).Cast<Connection>()];
}
public IEnumerable<Connection> GetConnections(Func<Connection, bool> predicate) {
lock(@lock)
return [.. Connections.Values.Where(predicate)];
}
public IEnumerable<T> GetConnectionsOfType<T>() where T : Connection {
return GetConnections(c => c is T).Cast<T>();
}
public IEnumerable<IPEndPoint> GetRemoteEndPoints() {
return GetConnections().Select(c => c.RemoteEndPoint).Distinct();
}
public IEnumerable<IPEndPoint> GetRemoteEndPoints(Func<Connection, bool> predicate) {
return GetConnections(predicate).Select(c => c.RemoteEndPoint).Distinct();
}
}

View file

@ -0,0 +1,15 @@
using System.Net;
namespace SharpChat.Connections;
public sealed class NullConnection : Connection {
public static readonly NullConnection Instance = new();
private NullConnection() { }
public IPEndPoint RemoteEndPoint { get; } = new(IPAddress.IPv6None, 0);
public void Close(ConnectionCloseReason reason = ConnectionCloseReason.Unexpected) {}
public static bool IsNull(Connection conn)
=> conn is NullConnection;
}

5
SharpChatCommon/Diff.cs Normal file
View file

@ -0,0 +1,5 @@
namespace SharpChat;
public interface Diff {
bool Changed { get; }
}

View file

@ -1,3 +1,4 @@
using SharpChat.Users;
using System.Text.Json;
namespace SharpChat.Messages;

View file

@ -1,3 +1,5 @@
using SharpChat.Users;
namespace SharpChat.Messages;
public interface MessageStorage {

View file

@ -1,12 +1,20 @@
namespace SharpChat;
public class RateLimiter {
public const int DEFAULT_SIZE = 30;
public const int DEFAULT_MINIMUM_DELAY = 10000;
public const int DEFAULT_RISKY_OFFSET = 5;
private readonly int Size;
private readonly int MinimumDelay;
private readonly int RiskyOffset;
private readonly long[] TimePoints;
public RateLimiter(int size, int minDelay, int riskyOffset = 0) {
public RateLimiter(
int size = DEFAULT_SIZE,
int minDelay = DEFAULT_MINIMUM_DELAY,
int riskyOffset = DEFAULT_RISKY_OFFSET
) {
if(size < 2)
throw new ArgumentException("Size is too small.", nameof(size));
if(minDelay < 1000)

View file

@ -0,0 +1,20 @@
using Microsoft.Extensions.Logging;
using SharpChat.Connections;
namespace SharpChat.Sessions;
public class Session(long id, string secret, string userId, ILogger logger, Connection conn) {
public long Id { get; } = id;
public string Secret { get; } = secret;
public string UserId { get; } = userId;
public ILogger Logger { get; } = logger;
public Connection Connection { get; internal set; } = conn;
public DateTimeOffset LastHeartbeat { get; private set; } = DateTimeOffset.UtcNow;
public bool IsSuspended
=> NullConnection.IsNull(Connection);
public void Heartbeat() {
LastHeartbeat = DateTimeOffset.UtcNow;
}
}

View file

@ -0,0 +1,4 @@
namespace SharpChat.Sessions;
public class SessionConnectionAlreadyInUseException(string argName)
: ArgumentException("Provided connection is already in use by another session.", argName) {}

View file

@ -0,0 +1,295 @@
using Microsoft.Extensions.Logging;
using SharpChat.Connections;
using SharpChat.Snowflake;
using SharpChat.Users;
using System.Net;
using ZLogger;
namespace SharpChat.Sessions;
public class SessionsContext(ILoggerFactory loggerFactory, RandomSnowflake snowflake) {
public readonly TimeSpan TimeOutInterval = TimeSpan.FromMinutes(5);
private readonly Dictionary<long, Session> Sessions = [];
private readonly Dictionary<string, List<Session>> UserSessions = [];
private readonly Dictionary<long, Session> SuspendedSessions = [];
private readonly Dictionary<Connection, Session> ConnectionSession = [];
private readonly Lock @lock = new();
public Session CreateSession(User user, Connection conn)
=> CreateSession(user.UserId, conn);
public Session CreateSession(string userId, Connection conn) {
lock(@lock) {
if(ConnectionSession.ContainsKey(conn))
throw new SessionConnectionAlreadyInUseException(nameof(conn));
long sessId = snowflake.Next();
string secret = RNG.SecureRandomString(20);
ILogger logger = loggerFactory.CreateLogger($"session:({sessId})");
Session sess = new(sessId, secret, userId, logger, conn);
Sessions.Add(sessId, sess);
ConnectionSession.Add(conn, sess);
if(!UserSessions.TryGetValue(userId, out var userSessions))
UserSessions.Add(userId, userSessions = []);
userSessions.Add(sess);
logger.ZLogInformation($"Session created for #{userId}.");
return sess;
}
}
public Session? ResumeSession(long sessId, string secret, Connection conn) {
lock(@lock) {
if(!SuspendedSessions.TryGetValue(sessId, out Session? sess)
|| !NullConnection.IsNull(sess.Connection)
|| !sess.Secret.SlowUtf8Equals(secret))
return null;
ConnectionSession.Add(conn, sess);
SuspendedSessions.Remove(sessId);
sess.Connection = conn;
return sess;
}
}
public void SuspendSession(Connection conn) {
lock(@lock)
if(ConnectionSession.TryGetValue(conn, out Session? sess))
SuspendSessionInternal(sess);
}
public void SuspendSession(Session sess)
=> SuspendSession(sess.Id);
public void SuspendSession(long sessId) {
lock(@lock)
if(Sessions.TryGetValue(sessId, out Session? sess))
SuspendSessionInternal(sess);
}
private void SuspendSessionInternal(Session sess) {
if(SuspendedSessions.ContainsValue(sess) || NullConnection.IsNull(sess.Connection))
return;
ConnectionSession.Remove(sess.Connection);
SuspendedSessions.Add(sess.Id, sess);
sess.Connection = NullConnection.Instance;
}
public void DestroySession(Connection conn) {
lock(@lock)
if(ConnectionSession.TryGetValue(conn, out Session? sess))
DestroySessionInternal(sess);
}
public void DestroySession(Session sess)
=> DestroySession(sess.Id);
public void DestroySession(long sessId) {
lock(@lock)
if(Sessions.TryGetValue(sessId, out Session? sess))
DestroySessionInternal(sess);
}
private void DestroySessionInternal(Session sess) {
if(UserSessions.TryGetValue(sess.UserId, out var userSessions)) {
userSessions.Remove(sess);
if(userSessions.Count < 1)
UserSessions.Remove(sess.UserId);
}
Sessions.Remove(sess.Id);
SuspendedSessions.Remove(sess.Id);
ConnectionSession.Remove(sess.Connection);
}
public bool IsSuspendedSession(Session sess)
=> IsSuspendedSession(sess.Id);
public bool IsSuspendedSession(long sessId) {
lock(@lock)
return SuspendedSessions.ContainsKey(sessId);
}
public Session? GetSession(long sessId) {
lock(@lock)
return Sessions.TryGetValue(sessId, out Session? sess) ? sess : null;
}
public Session? GetSession(Connection conn) {
lock(@lock)
return ConnectionSession.TryGetValue(conn, out Session? sess) ? sess : null;
}
public Session? GetSession(Func<Session, bool> predicate) {
lock(@lock)
return Sessions.Values.FirstOrDefault(predicate);
}
public int CountSessions() {
lock(@lock)
return Sessions.Count;
}
public int CountSessions(Func<Session, bool> predicate) {
lock(@lock)
return Sessions.Values.Count(predicate);
}
public int CountSessions(User user)
=> CountSessions(user.UserId);
public int CountSessions(string userId) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? userSessions.Count : 0;
}
public int CountSessions(User user, Func<Session, bool> predicate)
=> CountSessions(user.UserId, predicate);
public int CountSessions(string userId, Func<Session, bool> predicate) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? userSessions.Count(predicate) : 0;
}
public int CountActiveSessions(User user)
=> CountActiveSessions(user.UserId);
public int CountActiveSessions(string userId) {
return CountSessions(userId, s => !IsTimedOut(s));
}
public int CountNonSuspendedActiveSessions(User user)
=> CountNonSuspendedActiveSessions(user.UserId);
public int CountNonSuspendedActiveSessions(string userId) {
return CountSessions(userId, s => !s.IsSuspended /* <-- might create a loophole */ && !IsTimedOut(s));
}
public IEnumerable<Session> GetSessions() {
lock(@lock)
return [.. Sessions.Values];
}
public IEnumerable<Session> GetSessions(Func<Session, bool> predicate) {
lock(@lock)
return [.. Sessions.Values.Where(predicate)];
}
public IEnumerable<Session> GetSessions(User user)
=> GetSessions(user.UserId);
public IEnumerable<Session> GetSessions(string userId) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? [.. userSessions] : [];
}
public IEnumerable<Session> GetSessions(User user, Func<Session, bool> predicate)
=> GetSessions(user.UserId, predicate);
public IEnumerable<Session> GetSessions(string userId, Func<Session, bool> predicate) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? [.. userSessions.Where(predicate)] : [];
}
public IEnumerable<Session> GetSessions(IEnumerable<long> ids) {
return ids.Select(GetSession).Where(s => s is not null).Cast<Session>();
}
public IEnumerable<Session> GetTimedOutSessions(DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(s => IsTimedOut(s, now.Value));
}
public IEnumerable<Session> GetActiveSessions(DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(s => !IsTimedOut(s, now.Value));
}
public IEnumerable<Session> GetActiveSessions(User user, DateTimeOffset? now = null)
=> GetActiveSessions(user.UserId, now);
public IEnumerable<Session> GetActiveSessions(string userId, DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(userId, s => !IsTimedOut(s, now.Value));
}
public IEnumerable<Session> GetActiveSessions(User user, Func<Session, bool> predicate, DateTimeOffset? now = null)
=> GetActiveSessions(user.UserId, predicate, now);
public IEnumerable<Session> GetActiveSessions(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(userId, s => !IsTimedOut(s, now.Value) && predicate(s));
}
public bool IsTimedOut(Session session)
=> IsTimedOut(session, DateTimeOffset.UtcNow);
public bool IsTimedOut(Session session, DateTimeOffset now)
=> now - session.LastHeartbeat > TimeOutInterval;
public IEnumerable<IPEndPoint> GetRemoteEndPoints() {
return [.. GetSessions(s => !s.IsSuspended).Select(s => s.Connection.RemoteEndPoint).Distinct()];
}
public IEnumerable<IPEndPoint> GetRemoteEndPoints(User user)
=> GetRemoteEndPoints(user.UserId);
public IEnumerable<IPEndPoint> GetRemoteEndPoints(string userId) {
return [.. GetSessions(userId, s => !s.IsSuspended).Select(s => s.Connection.RemoteEndPoint).Distinct()];
}
public IEnumerable<Connection> GetConnections(DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return ConnectionSession.Where(kvp => !IsTimedOut(kvp.Value)).Select(kvp => kvp.Key);
}
public IEnumerable<Connection> GetConnections(Func<Session, bool> predicate, DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return ConnectionSession.Where(kvp => !IsTimedOut(kvp.Value) && predicate(kvp.Value)).Select(kvp => kvp.Key);
}
public IEnumerable<T> GetConnections<T>(DateTimeOffset? now = null) where T : Connection {
return GetConnections(now).Where(c => c is T).Cast<T>();
}
public IEnumerable<T> GetConnections<T>(Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
return GetConnections(predicate, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<Connection> GetConnections(User user, DateTimeOffset? now = null)
=> GetConnections(user.UserId, now);
public IEnumerable<Connection> GetConnections(string userId, DateTimeOffset? now = null) {
return GetActiveSessions(userId, now).Select(s => s.Connection);
}
public IEnumerable<T> GetConnections<T>(User user, DateTimeOffset? now = null) where T : Connection {
return GetConnections(user.UserId, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<T> GetConnections<T>(string userId, DateTimeOffset? now = null) where T : Connection {
return GetConnections(userId, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<Connection> GetConnections(User user, Func<Session, bool> predicate, DateTimeOffset? now = null)
=> GetConnections(user.UserId, predicate, now);
public IEnumerable<Connection> GetConnections(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) {
return GetActiveSessions(userId, predicate, now).Select(s => s.Connection);
}
public IEnumerable<T> GetConnections<T>(User user, Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
return GetConnections(user.UserId, predicate, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<T> GetConnections<T>(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
return GetConnections(userId, predicate, now).Where(c => c is T).Cast<T>();
}
}

View file

@ -1,8 +1,8 @@
using SharpChat.Messages;
namespace SharpChat;
namespace SharpChat.Storage;
public interface Storage {
public interface StorageBackend {
MessageStorage CreateMessageStorage();
Task UpgradeStorage();
}

View file

@ -2,9 +2,9 @@ using Microsoft.Extensions.Logging;
using SharpChat.Messages;
using ZLogger;
namespace SharpChat;
namespace SharpChat.Storage;
public class StorageMigrator(ILogger logger, Storage source, Storage target) {
public class StorageMigrator(ILogger logger, StorageBackend source, StorageBackend target) {
public async Task Migrate(CancellationToken cancellationToken) {
try {
logger.ZLogInformation($"Converting from {source.GetType().Name} to {target.GetType().Name}!");

View file

@ -0,0 +1,13 @@
namespace SharpChat;
public readonly struct StringDiff(
string before,
string? after,
StringComparison comparisonType = StringComparison.Ordinal
) : Diff {
public readonly string Before = before ?? throw new ArgumentNullException(nameof(before));
public readonly string After = after ?? before;
public readonly StringComparison ComparisonType = comparisonType;
public bool Changed => !Before.Equals(After, ComparisonType);
}

View file

@ -0,0 +1,53 @@
using System.Globalization;
using System.Text;
namespace SharpChat;
public static class StringExtensions {
public static byte[] GetUtf8Bytes(this string str) {
return Encoding.UTF8.GetBytes(str);
}
public static byte[] GetUtf8Bytes(this string str, int index, int count) {
return Encoding.UTF8.GetBytes(str, index, count);
}
public static int CountUtf8Bytes(this string str) {
return Encoding.UTF8.GetByteCount(str);
}
public static int CountUnicodeGraphemes(this string str) {
return new StringInfo(str).LengthInTextElements;
}
public static bool SlowUtf8Equals(this string str, string other) {
return str.GetUtf8Bytes().SlowEquals(other.GetUtf8Bytes());
}
public static string TruncateIfTooLong(this string str, int? maxGraphemes = null, int? maxBytes = null) {
StringInfo info = new(str);
if(maxGraphemes.HasValue) {
if(maxGraphemes.Value == 0)
return string.Empty;
if(maxGraphemes.Value < 0)
throw new ArgumentException("Maximum Unicode Grapheme Cluster count must be a positive integer.", nameof(maxGraphemes));
if(info.LengthInTextElements > maxGraphemes.Value)
return info.SubstringByTextElements(0, maxGraphemes.Value);
}
if(maxBytes.HasValue) {
if(maxBytes.Value == 0)
return string.Empty;
if(maxBytes.Value < 0)
throw new ArgumentException("Maximum bytes must be a positive integer.", nameof(maxBytes));
if(str.CountUtf8Bytes() > maxBytes.Value)
return maxGraphemes.HasValue
? info.SubstringByTextElements(0, Math.Min(info.LengthInTextElements, maxGraphemes.Value))
: str.GetUtf8Bytes(0, maxBytes.Value).GetUtf8String();
}
return str;
}
}

View file

@ -0,0 +1,32 @@
namespace SharpChat.Users;
public class User(
string userId,
string userName,
ColourInheritable colour,
int rank,
UserPermissions perms,
string nickName = "",
UserStatus status = UserStatus.Online,
string statusText = ""
) {
public string UserId { get; } = userId;
public string UserName { get; internal set; } = userName;
public ColourInheritable Colour { get; internal set; } = colour;
public int Rank { get; internal set; } = rank;
public UserPermissions Permissions { get; internal set; } = perms;
public string NickName { get; internal set; } = nickName;
public UserStatus Status { get; internal set; } = status;
public string StatusText { get; internal set; } = statusText;
public bool NameEquals(string name) {
return string.Equals(name, UserName, StringComparison.OrdinalIgnoreCase)
|| string.Equals(name, NickName, StringComparison.OrdinalIgnoreCase);
}
public string GetDMChannelNameWith(User other) {
return string.Compare(UserId, other.UserId, StringComparison.Ordinal) > 0
? $"@{other.UserId}-{UserId}"
: $"@{UserId}-{other.UserId}";
}
}

View file

@ -0,0 +1,22 @@
namespace SharpChat.Users;
public readonly record struct UserDiff(
User User,
string Id,
StringDiff Name,
ValueDiff<ColourInheritable> Colour,
ValueDiff<int> Rank,
ValueDiff<UserPermissions> Permissions,
StringDiff Nick,
ValueDiff<UserStatus> Status,
StringDiff StatusText
) : Diff {
public bool Changed
=> Name.Changed
|| Colour.Changed
|| Rank.Changed
|| Permissions.Changed
|| Nick.Changed
|| Status.Changed
|| StatusText.Changed;
}

View file

@ -0,0 +1,4 @@
namespace SharpChat.Users;
public class UserExistsException(string argName)
: ArgumentException("A user with that id already exists.", argName) { }

View file

@ -0,0 +1,4 @@
namespace SharpChat.Users;
public class UserNotFoundException(string argName)
: ArgumentException("A user with that id already exists.", argName) { }

View file

@ -1,4 +1,4 @@
namespace SharpChat;
namespace SharpChat.Users;
/// <summary>
/// User Permissions.

View file

@ -1,4 +1,4 @@
namespace SharpChat;
namespace SharpChat.Users;
public enum UserStatus {
Online,

View file

@ -0,0 +1,193 @@
using SharpChat.Auth;
namespace SharpChat.Users;
public class UsersContext {
private readonly Dictionary<string, User> Users = [];
private readonly Lock @lock = new();
public bool UserExists(string id) {
lock(@lock)
return Users.ContainsKey(id);
}
public bool UserExists(Func<User, bool> predicate) {
lock(@lock)
return Users.Values.Any(predicate);
}
public User? GetUser(string id) {
lock(@lock)
return Users.TryGetValue(id, out User? user) ? user : null;
}
public User? GetUser(Func<User, bool> predicate) {
lock(@lock)
return Users.Values.FirstOrDefault(predicate);
}
public IEnumerable<User> GetUsers() {
lock(@lock)
return [.. Users.Values];
}
public IEnumerable<User> GetUsers(IEnumerable<string> ids) {
return [.. ids.Select(GetUser).Where(u => u is not null).Cast<User>()];
}
public IEnumerable<User> GetUsers(Func<User, bool> predicate) {
lock(@lock)
return [.. Users.Values.Where(predicate)];
}
public IEnumerable<User> GetUsersWithStatus(UserStatus status) {
return GetUsers(u => u.Status == status);
}
public IEnumerable<User> GetUsersOfMinimumRank(int minRank) {
return GetUsers(u => u.Rank >= minRank);
}
public User CreateOrUpdateUser(AuthResult authResult) {
lock(@lock) {
User? user = GetUser(authResult.UserId);
return user is null ? CreateUserInternal(
authResult.UserId,
authResult.UserName,
authResult.UserColour,
authResult.UserRank,
authResult.UserPermissions
) : UpdateUserInternal(
user,
authResult.UserId,
authResult.UserName,
authResult.UserColour,
authResult.UserRank,
authResult.UserPermissions
).User;
}
}
public User CreateUser(
string id,
string name,
ColourInheritable colour,
int rank,
UserPermissions perms,
string nick = "",
UserStatus status = UserStatus.Online,
string statusText = ""
) {
lock(@lock)
return UserExists(id)
? throw new UserExistsException(nameof(id))
: CreateUserInternal(id, name, colour, rank, perms, nick, status, statusText);
}
private User CreateUserInternal(
string id,
string name,
ColourInheritable colour,
int rank,
UserPermissions perms,
string nick = "",
UserStatus status = UserStatus.Online,
string statusText = ""
) {
User user = new(id, name, colour, rank, perms, nick, status, statusText);
Users.Add(id, user);
return user;
}
public UserDiff UpdateUser(
User user,
string? name = null,
ColourInheritable? colour = null,
int? rank = null,
UserPermissions? perms = null,
string? nick = null,
UserStatus? status = null,
string? statusText = null
) => UpdateUser(user.UserId, name, colour, rank, perms, nick, status, statusText);
public UserDiff UpdateUser(
string id,
string? name = null,
ColourInheritable? colour = null,
int? rank = null,
UserPermissions? perms = null,
string? nick = null,
UserStatus? status = null,
string? statusText = null
) {
lock(@lock)
return UpdateUserInternal(
GetUser(id) ?? throw new UserNotFoundException(nameof(id)),
id, name, colour, rank, perms, nick, status, statusText
);
}
private static UserDiff UpdateUserInternal(
User user,
string id,
string? name = null,
ColourInheritable? colour = null,
int? rank = null,
UserPermissions? perms = null,
string? nick = null,
UserStatus? status = null,
string? statusText = null
) {
StringDiff nameDiff = new(user.UserName, name);
if(nameDiff.Changed)
user.UserName = nameDiff.After;
ValueDiff<ColourInheritable> colourDiff = new(user.Colour, colour);
if(colourDiff.Changed)
user.Colour = colourDiff.After;
ValueDiff<int> rankDiff = new(user.Rank, rank);
if(rankDiff.Changed)
user.Rank = rankDiff.After;
ValueDiff<UserPermissions> permsDiff = new(user.Permissions, perms);
if(permsDiff.Changed)
user.Permissions = permsDiff.After;
StringDiff nickDiff = new(user.NickName, nick);
if(nickDiff.Changed)
user.NickName = nickDiff.After;
ValueDiff<UserStatus> statusDiff = new(user.Status, status);
if(statusDiff.Changed)
user.Status = statusDiff.After;
StringDiff statusTextDiff = new(user.StatusText, statusText);
if(statusTextDiff.Changed)
user.StatusText = statusTextDiff.After;
return new(
user,
id,
nameDiff,
colourDiff,
rankDiff,
permsDiff,
nickDiff,
statusDiff,
statusTextDiff
);
}
public void RemoveUser(User user)
=> RemoveUserInternal(user.UserId, nameof(user));
public void RemoveUser(string id)
=> RemoveUserInternal(id, nameof(id));
private void RemoveUserInternal(string id, string argName) {
lock(@lock)
if(!Users.Remove(id))
throw new UserNotFoundException(argName);
}
}

View file

@ -0,0 +1,8 @@
namespace SharpChat;
public readonly struct ValueDiff<T>(T before, T? after) : Diff where T : struct {
public readonly T Before = before;
public readonly T After = after ?? before;
public readonly bool Changed => !Before.Equals(After);
}

View file

@ -0,0 +1,106 @@
namespace SharpChat;
/// <summary>
/// WebSocket Close Code Number Registry
/// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
/// </summary>
public enum WebSocketCloseCode : int {
/// <summary>
/// 1000 indicates a normal closure, meaning that the purpose for which the connection was established has been fulfilled.
/// </summary>
NormalClosure = 1000,
/// <summary>
/// 1001 indicates that an endpoint is "going away", such as a server going down or a browser having navigated away from a page.
/// </summary>
GoingAway = 1001,
/// <summary>
/// 1002 indicates that an endpoint is terminating the connection due to a protocol error.
/// </summary>
ProtocolError = 1002,
/// <summary>
/// 1003 indicates that an endpoint is terminating the connection because it has received a type of data it cannot accept (e.g., an endpoint that understands only text data MAY send this if it receives a binary message).
/// </summary>
UnsupportedData = 1003,
/// <summary>
/// 1005 is a reserved value and MUST NOT be set as a status code in a Close control frame by an endpoint.
/// It is designated for use in applications expecting a status code to indicate that no status code was actually present.
/// </summary>
NoStatusReceived = 1005,
/// <summary>
/// 1006 is a reserved value and MUST NOT be set as a status code in a Close control frame by an endpoint.
/// It is designated for use in applications expecting a status code to indicate that the connection was closed abnormally, e.g., without sending or receiving a Close control frame.
/// </summary>
AbnormalClosure = 1006,
/// <summary>
/// 1007 indicates that an endpoint is terminating the connection because it has received data within a message that was not consistent with the type of the message (e.g., non-UTF-8 [RFC3629] data within a text message).
/// </summary>
InvalidFramePayloadData = 1007,
/// <summary>
/// 1008 indicates that an endpoint is terminating the connection because it has received a message that violates its policy.
/// This is a generic status code that can be returned when there is no other more suitable status code (e.g., 1003 or 1009) or if there is a need to hide specific details about the policy.
/// </summary>
PolicyViolation = 1008,
/// <summary>
/// 1009 indicates that an endpoint is terminating the connection because it has received a message that is too big for it to process.
/// </summary>
MessageTooBig = 1009,
/// <summary>
/// 1010 indicates that an endpoint (client) is terminating the connection because it has expected the server to negotiate one or more extension, but the server didn't return them in the response message of the WebSocket handshake.
/// The list of extensions that are needed SHOULD appear in the /reason/ part of the Close frame.
/// Note that this status code is not used by the server, because it can fail the WebSocket handshake instead.
/// </summary>
MandatoryExtension = 1010,
/// <summary>
/// 1011 indicates that a server is terminating the connection because it encountered an unexpected condition that prevented it from fulfilling the request.
/// </summary>
InternalError = 1011,
/// <summary>
/// 1012 indicates that the service is restarted.
/// A client may reconnect, and if it choses to do, should reconnect using a randomized delay of 5 to 30 seconds.
/// </summary>
ServiceRestart = 1012,
/// <summary>
/// 1013 indicates that the service is experiencing overload.
/// A client should only connect to a different IP (when there are multiple for the target) or reconnect to the same IP upon user action.
/// </summary>
TryAgainLater = 1013,
/// <summary>
/// The server was acting as a gateway or proxy and received an invalid response from the upstream server.
/// This is similar to 502 HTTP Status Code.
/// </summary>
GatewayTimeout = 1014,
/// <summary>
/// 1015 is a reserved value and MUST NOT be set as a status code in a Close control frame by an endpoint.
/// It is designated for use in applications expecting a status code to indicate that the connection was closed due to a failure to perform a TLS handshake (e.g., the server certificate can't be verified).
/// </summary>
TlsHandshake = 1015,
/// <summary>
/// Unauthorized (HTTP 401)
/// </summary>
Unauthorized = 3000,
/// <summary>
/// Forbidden (HTTP 403)
/// </summary>
Forbidden = 3003,
/// <summary>
/// Timeout (HTTP 408)
/// </summary>
Timeout = 3008,
}