sharp-chat/SharpChat/Context.cs

407 lines
16 KiB
C#

using Microsoft.Extensions.Logging;
using SharpChat.Auth;
using SharpChat.Bans;
using SharpChat.Channels;
using SharpChat.Configuration;
using SharpChat.Events;
using SharpChat.Messages;
using SharpChat.Snowflake;
using SharpChat.SockChat;
using SharpChat.SockChat.S2CPackets;
using System.Dynamic;
using System.Net;
using ZLogger;
namespace SharpChat;
public class Context {
public const int DEFAULT_MSG_LENGTH_MAX = 5000;
public const int DEFAULT_MAX_CONNECTIONS = 5;
public const int DEFAULT_FLOOD_KICK_LENGTH = 30;
public const int DEFAULT_FLOOD_KICK_EXEMPT_RANK = 9;
public record ChannelUserAssoc(string UserId, string ChannelName);
public readonly SemaphoreSlim ContextAccess = new(1, 1);
public ILoggerFactory LoggerFactory { get; }
public Config Config { get; }
public MessageStorage Messages { get; }
public AuthClient Auth { get; }
public BansClient Bans { get; }
public CachedValue<int> MaxMessageLength { get; }
public CachedValue<int> MaxConnections { get; }
public CachedValue<int> FloodKickLength { get; }
public CachedValue<int> FloodKickExemptRank { get; }
private readonly ILogger Logger;
public SnowflakeGenerator SnowflakeGenerator { get; } = new();
public RandomSnowflake RandomSnowflake { get; }
public ChannelsContext Channels { get; } = new();
public HashSet<Connection> Connections { get; } = [];
public HashSet<User> Users { get; } = [];
public HashSet<ChannelUserAssoc> ChannelUsers { get; } = [];
public Dictionary<string, RateLimiter> UserRateLimiters { get; } = [];
public Dictionary<string, Channel> UserLastChannel { get; } = [];
public Context(
ILoggerFactory logFactory,
Config config,
Storage storage,
AuthClient authClient,
BansClient bansClient
) {
LoggerFactory = logFactory;
Logger = logFactory.CreateLogger("ctx");
Config = config;
Messages = storage.CreateMessageStorage();
Auth = authClient;
Bans = bansClient;
RandomSnowflake = new(SnowflakeGenerator);
Logger.ZLogDebug($"Reading cached config values...");
MaxMessageLength = config.ReadCached("msgMaxLength", DEFAULT_MSG_LENGTH_MAX);
MaxConnections = config.ReadCached("connMaxCount", DEFAULT_MAX_CONNECTIONS);
FloodKickLength = config.ReadCached("floodKickLength", DEFAULT_FLOOD_KICK_LENGTH);
FloodKickExemptRank = config.ReadCached("floodKickExemptRank", DEFAULT_FLOOD_KICK_EXEMPT_RANK);
Logger.ZLogDebug($"Creating channel list...");
string[] channelNames = config.ReadValue<string[]>("channels") ?? ["lounge"];
if(channelNames is not null)
foreach(string channelName in channelNames) {
Config channelCfg = config.ScopeTo($"channels:{channelName}");
string name = channelCfg.SafeReadValue("name", string.Empty)!;
if(string.IsNullOrWhiteSpace(name))
name = channelName;
Channels.CreateChannel(
name,
channelCfg.SafeReadValue("password", string.Empty)!,
rank: channelCfg.SafeReadValue("minRank", 0)
);
}
}
public async Task DispatchEvent(ChatEvent eventInfo) {
if(eventInfo is MessageCreateEvent mce) {
if(mce.IsBroadcast) {
await Send(new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.BROADCAST, false, mce.MessageText));
} else if(mce.IsPrivate) {
// The channel name returned by GetDMChannelName should not be exposed to the user, instead @<Target User> should be displayed
// e.g. nook sees @Arysil and Arysil sees @nook
// this entire routine is garbage, channels should probably in the db
if(!mce.ChannelName.StartsWith('@'))
return;
IEnumerable<string> uids = mce.ChannelName[1..].Split('-', 3).Select(u => (long.TryParse(u, out long up) ? up : -1).ToString());
if(uids.Count() != 2)
return;
IEnumerable<User> users = Users.Where(u => uids.Any(uid => uid == u.UserId));
User? target = users.FirstOrDefault(u => u.UserId != mce.SenderId);
if(target == null)
return;
foreach(User user in users)
await SendTo(user, new ChatMessageAddS2CPacket(
mce.MessageId,
DateTimeOffset.Now,
mce.SenderId,
mce.SenderId == user.UserId ? $"{target.LegacyName} {mce.MessageText}" : mce.MessageText,
mce.IsAction,
true
));
} else {
Channel? channel = Channels.GetChannel(mce.ChannelName);
if(channel is not null)
await SendTo(channel, new ChatMessageAddS2CPacket(
mce.MessageId,
DateTimeOffset.Now,
mce.SenderId,
mce.MessageText,
mce.IsAction,
false
));
}
dynamic data = new ExpandoObject();
data.text = mce.MessageText;
if(mce.IsAction)
data.act = true;
await Messages.LogMessage(mce.MessageId, "msg:add", mce.ChannelName, mce.SenderId, mce.SenderName, mce.SenderColour, mce.SenderRank, mce.SenderNickName, mce.SenderPerms, data);
return;
}
}
public async Task Update() {
foreach(Connection conn in Connections)
if(!conn.IsDisposed && conn.HasTimedOut) {
conn.Logger.ZLogInformation($"Nuking connection associated with user {conn.User?.UserId ?? "no-one"}");
conn.Dispose();
}
Connections.RemoveWhere(conn => conn.IsDisposed);
foreach(User user in Users)
if(!Connections.Any(conn => conn.User == user)) {
Logger.ZLogInformation($"Timing out user {user.UserId} (no more connections).");
await HandleDisconnect(user, UserDisconnectS2CPacket.Reason.TimeOut);
}
}
public async Task SafeUpdate() {
ContextAccess.Wait();
try {
await Update();
} finally {
ContextAccess.Release();
}
}
public bool IsInChannel(User user, Channel channel) {
return ChannelUsers.Contains(new ChannelUserAssoc(user.UserId, channel.Name));
}
public string[] GetUserChannelNames(User user) {
return [.. ChannelUsers.Where(cu => cu.UserId == user.UserId).Select(cu => cu.ChannelName)];
}
public Channel[] GetUserChannels(User user) {
return [.. Channels.GetChannels(GetUserChannelNames(user))];
}
public string[] GetChannelUserIds(Channel channel) {
return [.. ChannelUsers.Where(cu => channel.NameEquals(cu.ChannelName)).Select(cu => cu.UserId)];
}
public User[] GetChannelUsers(Channel channel) {
string[] ids = GetChannelUserIds(channel);
return [.. Users.Where(u => ids.Contains(u.UserId))];
}
public async Task UpdateUser(
User user,
string? userName = null,
string? nickName = null,
ColourInheritable? colour = null,
UserStatus? status = null,
string? statusText = null,
int? rank = null,
UserPermissions? perms = null,
bool silent = false
) {
bool hasChanged = false;
string previousName = string.Empty;
if(userName != null && !user.UserName.Equals(userName)) {
user.UserName = userName;
hasChanged = true;
}
if(nickName != null && !user.NickName.Equals(nickName)) {
if(!silent)
previousName = user.LegacyName;
user.NickName = nickName;
hasChanged = true;
}
if(colour.HasValue && user.Colour != colour.Value) {
user.Colour = colour.Value;
hasChanged = true;
}
if(status.HasValue && user.Status != status.Value) {
user.Status = status.Value;
hasChanged = true;
}
if(statusText != null && !user.StatusText.Equals(statusText)) {
user.StatusText = statusText;
hasChanged = true;
}
if(rank != null && user.Rank != rank) {
user.Rank = (int)rank;
hasChanged = true;
}
if(perms.HasValue && user.Permissions != perms) {
user.Permissions = perms.Value;
hasChanged = true;
}
if(hasChanged) {
if(!string.IsNullOrWhiteSpace(previousName))
await SendToUserChannels(user, new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.NICKNAME_CHANGE, false, previousName, user.LegacyNameWithStatus));
await SendToUserChannels(user, new UserUpdateS2CPacket(user.UserId, user.LegacyNameWithStatus, user.Colour, user.Rank, user.Permissions));
}
}
public async Task BanUser(User user, TimeSpan duration, UserDisconnectS2CPacket.Reason reason = UserDisconnectS2CPacket.Reason.Kicked) {
if(duration > TimeSpan.Zero) {
DateTimeOffset expires = duration >= TimeSpan.MaxValue ? DateTimeOffset.MaxValue : DateTimeOffset.Now + duration;
await SendTo(user, new ForceDisconnectS2CPacket(expires));
} else
await SendTo(user, new ForceDisconnectS2CPacket());
foreach(Connection conn in Connections)
if(conn.User == user)
conn.Dispose();
Connections.RemoveWhere(conn => conn.IsDisposed);
await HandleDisconnect(user, reason);
}
public async Task HandleDisconnect(User user, UserDisconnectS2CPacket.Reason reason = UserDisconnectS2CPacket.Reason.Leave) {
await UpdateUser(user, status: UserStatus.Offline);
Users.Remove(user);
UserLastChannel.Remove(user.UserId);
Channel[] channels = GetUserChannels(user);
foreach(Channel chan in channels) {
ChannelUsers.Remove(new ChannelUserAssoc(user.UserId, chan.Name));
long msgId = RandomSnowflake.Next();
await SendTo(chan, new UserDisconnectS2CPacket(msgId, DateTimeOffset.Now, user.UserId, user.LegacyNameWithStatus, reason));
await Messages.LogMessage(msgId, "user:disconnect", chan.Name, user.UserId, user.UserName, user.Colour, user.Rank, user.NickName, user.Permissions, new { reason = (int)reason });
if(chan.IsTemporary && chan.IsOwner(user.UserId))
await RemoveChannel(chan);
}
}
public async Task SwitchChannel(User user, Channel chan, string password) {
if(UserLastChannel.TryGetValue(user.UserId, out Channel? ulc) && chan == ulc) {
await ForceChannel(user);
return;
}
if(!user.Permissions.HasFlag(UserPermissions.JoinAnyChannel) && chan.IsOwner(user.UserId)) {
if(chan.Rank > user.Rank) {
await SendTo(user, new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.CHANNEL_INSUFFICIENT_HIERARCHY, true, chan.Name));
await ForceChannel(user);
return;
}
if(!string.IsNullOrEmpty(chan.Password) && chan.Password != password) {
await SendTo(user, new CommandResponseS2CPacket(RandomSnowflake.Next(), LCR.CHANNEL_INVALID_PASSWORD, true, chan.Name));
await ForceChannel(user);
return;
}
}
Channel oldChan = UserLastChannel[user.UserId];
long leaveId = RandomSnowflake.Next();
await SendTo(oldChan, new UserChannelLeaveS2CPacket(leaveId, user.UserId));
await Messages.LogMessage(leaveId, "chan:leave", oldChan.Name, user.UserId, user.UserName, user.Colour, user.Rank, user.NickName, user.Permissions);
long joinId = RandomSnowflake.Next();
await SendTo(chan, new UserChannelJoinS2CPacket(joinId, user.UserId, user.LegacyNameWithStatus, user.Colour, user.Rank, user.Permissions));
await Messages.LogMessage(joinId, "chan:join", chan.Name, user.UserId, user.LegacyName, user.Colour, user.Rank, user.NickName, user.Permissions);
await SendTo(user, new ContextClearS2CPacket(ContextClearS2CPacket.Mode.MessagesUsers));
await SendTo(user, new ContextUsersS2CPacket(
GetChannelUsers(chan).Except([user]).OrderByDescending(u => u.Rank)
.Select(u => new ContextUsersS2CPacket.Entry(
u.UserId,
u.LegacyNameWithStatus,
u.Colour,
u.Rank,
u.Permissions,
true
))
));
IEnumerable<Message> msgs = await Messages.GetMessages(chan.Name);
foreach(Message msg in msgs)
await SendTo(user, new ContextMessageS2CPacket(msg));
await ForceChannel(user, chan);
ChannelUsers.Remove(new ChannelUserAssoc(user.UserId, oldChan.Name));
ChannelUsers.Add(new ChannelUserAssoc(user.UserId, chan.Name));
UserLastChannel[user.UserId] = chan;
if(oldChan.IsTemporary && oldChan.IsOwner(user.UserId))
await RemoveChannel(oldChan);
}
public async Task Send(S2CPacket packet) {
foreach(Connection conn in Connections)
if(conn.IsAlive && conn.User is not null)
await conn.Send(packet);
}
public async Task SendTo(User user, S2CPacket packet) {
foreach(Connection conn in Connections)
if(conn.IsAlive && conn.User == user)
await conn.Send(packet);
}
public async Task SendTo(Channel channel, S2CPacket packet) {
// might be faster to grab the users first and then cascade into that SendTo
IEnumerable<Connection> conns = Connections.Where(c => c.IsAlive && c.User is not null && IsInChannel(c.User, channel));
foreach(Connection conn in conns)
await conn.Send(packet);
}
public async Task SendToUserChannels(User user, S2CPacket packet) {
IEnumerable<Channel> chans = Channels.GetChannels(c => IsInChannel(user, c));
IEnumerable<Connection> conns = Connections.Where(conn => conn.IsAlive && conn.User is not null && ChannelUsers.Any(cu => cu.UserId == conn.User.UserId && chans.Any(chan => chan.NameEquals(cu.ChannelName))));
foreach(Connection conn in conns)
await conn.Send(packet);
}
public IPAddress[] GetRemoteAddresses(User user) {
return [.. Connections.Where(c => c.IsAlive && c.User == user).Select(c => c.RemoteEndPoint.Address).Distinct()];
}
public async Task ForceChannel(User user, Channel? chan = null) {
if(chan == null && !UserLastChannel.TryGetValue(user.UserId, out chan))
throw new ArgumentException("no channel???");
await SendTo(user, new UserChannelForceJoinS2CPacket(chan.Name));
}
public async Task UpdateChannel(
Channel channel,
bool? temporary = null,
int? rank = null,
string? password = null
) {
Channels.UpdateChannel(
channel,
temporary: temporary,
rank: rank,
password: password
);
// TODO: Users that no longer have access to the channel/gained access to the channel by the hierarchy change should receive delete and create packets respectively
foreach(User user in Users.Where(u => u.Rank >= channel.Rank))
await SendTo(user, new ChannelUpdateS2CPacket(channel.Name, channel.Name, channel.HasPassword, channel.IsTemporary));
}
public async Task RemoveChannel(Channel channel) {
// Remove channel from the listing
Channels.RemoveChannel(channel.Name);
// Move all users back to the main channel
// TODO: Replace this with a kick. SCv2 supports being in 0 channels, SCv1 should force the user back to DefaultChannel.
foreach(User user in GetChannelUsers(channel))
await SwitchChannel(user, Channels.DefaultChannel, string.Empty);
// Broadcast deletion of channel
foreach(User user in Users.Where(u => u.Rank >= channel.Rank))
await SendTo(user, new ChannelDeleteS2CPacket(channel.Name));
}
}