sharp-chat/SharpChat/SockChatServer.cs

252 lines
10 KiB
C#

using Microsoft.Extensions.Logging;
using SharpChat.Auth;
using SharpChat.Bans;
using SharpChat.C2SPacketHandlers;
using SharpChat.ClientCommands;
using SharpChat.Configuration;
using SharpChat.Messages;
using SharpChat.SockChat.S2CPackets;
using System.Net;
using ZLogger;
namespace SharpChat;
public class SockChatServer {
public const ushort DEFAULT_PORT = 6770;
public const int DEFAULT_MSG_LENGTH_MAX = 5000;
public const int DEFAULT_MAX_CONNECTIONS = 5;
public const int DEFAULT_FLOOD_KICK_LENGTH = 30;
public const int DEFAULT_FLOOD_KICK_EXEMPT_RANK = 9;
public Context Context { get; }
private readonly ILogger Logger;
private readonly BansClient BansClient;
private readonly CachedValue<ushort> Port;
private readonly CachedValue<int> MaxMessageLength;
private readonly CachedValue<int> MaxConnections;
private readonly CachedValue<int> FloodKickLength;
private readonly CachedValue<int> FloodKickExemptRank;
private readonly List<C2SPacketHandler> GuestHandlers = [];
private readonly List<C2SPacketHandler> AuthedHandlers = [];
private readonly SendMessageC2SPacketHandler SendMessageHandler;
private static readonly string[] DEFAULT_CHANNELS = ["lounge"];
public SockChatServer(
ILoggerFactory logFactory,
CancellationTokenSource cancellationTokenSource,
AuthClient authClient,
BansClient bansClient,
MessageStorage msgStorage,
Config config
) {
Logger = logFactory.CreateLogger("sockchat");
Logger.ZLogInformation($"Initialising Sock Chat server...");
BansClient = bansClient;
Logger.ZLogDebug($"Fetching configuration values...");
Port = config.ReadCached("port", DEFAULT_PORT);
MaxMessageLength = config.ReadCached("msgMaxLength", DEFAULT_MSG_LENGTH_MAX);
MaxConnections = config.ReadCached("connMaxCount", DEFAULT_MAX_CONNECTIONS);
FloodKickLength = config.ReadCached("floodKickLength", DEFAULT_FLOOD_KICK_LENGTH);
FloodKickExemptRank = config.ReadCached("floodKickExemptRank", DEFAULT_FLOOD_KICK_EXEMPT_RANK);
Logger.ZLogDebug($"Creating context...");
Context = new Context(logFactory, msgStorage ?? throw new ArgumentNullException(nameof(msgStorage)));
Logger.ZLogDebug($"Creating channel list...");
string[]? channelNames = config.ReadValue("channels", DEFAULT_CHANNELS);
if(channelNames is not null)
foreach(string channelName in channelNames) {
Config channelCfg = config.ScopeTo($"channels:{channelName}");
string name = channelCfg.SafeReadValue("name", string.Empty)!;
if(string.IsNullOrWhiteSpace(name))
name = channelName;
Context.Channels.CreateChannel(
name,
channelCfg.SafeReadValue("password", string.Empty)!,
rank: channelCfg.SafeReadValue("minRank", 0)
);
}
Logger.ZLogDebug($"Registering unauthenticated handlers...");
GuestHandlers.Add(new AuthC2SPacketHandler(
authClient,
bansClient,
Context.Channels,
Context.RandomSnowflake,
MaxMessageLength,
MaxConnections
));
Logger.ZLogDebug($"Registering authenticated handlers...");
AuthedHandlers.AddRange([
new PingC2SPacketHandler(authClient),
SendMessageHandler = new SendMessageC2SPacketHandler(Context.RandomSnowflake, MaxMessageLength),
]);
Logger.ZLogDebug($"Registering client commands...");
SendMessageHandler.AddCommands([
new AFKClientCommand(),
new NickClientCommand(),
new WhisperClientCommand(),
new ActionClientCommand(),
new WhoClientCommand(),
new JoinChannelClientCommand(),
new CreateChannelClientCommand(),
new DeleteChannelClientCommand(),
new PasswordChannelClientCommand(),
new RankChannelClientCommand(),
new BroadcastClientCommand(),
new DeleteMessageClientCommand(),
new KickBanClientCommand(bansClient),
new PardonUserClientCommand(bansClient),
new PardonAddressClientCommand(bansClient),
new BanListClientCommand(bansClient),
new RemoteAddressClientCommand(),
new ShutdownRestartClientCommand(cancellationTokenSource)
]);
}
public async Task Listen(CancellationToken cancellationToken) {
using SharpChatWebSocketServer server = new(Context.LoggerFactory.CreateLogger("sockchat:server"), $"ws://0.0.0.0:{Port}");
server.Start(sock => {
if(!IPAddress.TryParse(sock.ConnectionInfo.ClientIpAddress, out IPAddress? addr)) {
Logger.ZLogError($@"A client attempted to connect with an invalid IP address: ""{sock.ConnectionInfo.ClientIpAddress}""");
sock.Close(1011);
return;
}
if(IPAddress.IsLoopback(addr) && sock.ConnectionInfo.Headers.TryGetValue("X-Real-IP", out string? addrStr)) {
if(IPAddress.TryParse(addrStr, out IPAddress? realAddr))
addr = realAddr;
else
Logger.ZLogWarning($@"Connection originated from loopback and supplied an X-Real-IP header, but it could not be parsed: ""{addrStr}""");
}
IPEndPoint endPoint = new(addr, sock.ConnectionInfo.ClientPort);
if(cancellationToken.IsCancellationRequested) {
Logger.ZLogInformation($"{endPoint} attepted to connect after shutdown was requested. Connection will be dropped.");
sock.Close(1013);
return;
}
ILogger logger = Context.LoggerFactory.CreateLogger($"sockchat:({endPoint})");
Connection conn = new(logger, sock, endPoint);
Context.Connections.Add(conn);
sock.OnOpen = () => OnOpen(conn).Wait();
sock.OnClose = () => OnClose(conn).Wait();
sock.OnError = err => OnError(conn, err).Wait();
sock.OnMessage = msg => OnMessage(conn, msg).Wait();
});
Logger.ZLogInformation($"Listening...");
await Task.Delay(Timeout.Infinite, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
Logger.ZLogDebug($"Disposing all clients...");
foreach(Connection conn in Context.Connections)
conn.Dispose();
}
private async Task OnOpen(Connection conn) {
conn.Logger.ZLogInformation($"Connection opened.");
await Context.SafeUpdate();
}
private async Task OnError(Connection conn, Exception ex) {
conn.Logger.ZLogError($"Error: {ex.Message}");
conn.Logger.ZLogDebug($"{ex}");
await Context.SafeUpdate();
}
private async Task OnClose(Connection conn) {
conn.Logger.ZLogInformation($"Connection closed.");
Context.ContextAccess.Wait();
try {
Context.Connections.Remove(conn);
if(conn.User != null && !Context.Connections.Any(c => c.User == conn.User))
await Context.HandleDisconnect(conn.User);
await Context.Update();
} finally {
Context.ContextAccess.Release();
}
}
private async Task OnMessage(Connection conn, string msg) {
conn.Logger.ZLogTrace($"Received: {msg}");
await Context.SafeUpdate();
// this doesn't affect non-authed connections?????
if(conn.User is not null && conn.User.Rank < FloodKickExemptRank) {
User? banUser = null;
string banAddr = string.Empty;
TimeSpan banDuration = TimeSpan.MinValue;
Context.ContextAccess.Wait();
try {
if(!Context.UserRateLimiters.TryGetValue(conn.User.UserId, out RateLimiter? rateLimiter))
Context.UserRateLimiters.Add(conn.User.UserId, rateLimiter = new RateLimiter(
User.DEFAULT_SIZE,
User.DEFAULT_MINIMUM_DELAY,
User.DEFAULT_RISKY_OFFSET
));
rateLimiter.Update();
if(rateLimiter.IsExceeded) {
banDuration = TimeSpan.FromSeconds(FloodKickLength);
banUser = conn.User;
banAddr = conn.RemoteAddress.ToString();
conn.Logger.ZLogWarning($"Exceeded flood limit! Issuing ban with duration {banDuration} on {banAddr}/{banUser.UserId}...");
} else if(rateLimiter.IsRisky) {
banUser = conn.User;
banAddr = conn.RemoteAddress.ToString();
conn.Logger.ZLogWarning($"About to exceed flood limit! Issueing warning to {banAddr}/{banUser.UserId}...");
}
if(banUser is not null) {
if(banDuration == TimeSpan.MinValue) {
await Context.SendTo(conn.User, new CommandResponseS2CPacket(Context.RandomSnowflake.Next(), LCR.FLOOD_WARN, false));
} else {
await Context.BanUser(conn.User, banDuration, UserDisconnectS2CPacket.Reason.Flood);
if(banDuration > TimeSpan.Zero)
await BansClient.BanCreate(
BanKind.User,
banDuration,
conn.RemoteAddress,
conn.User.UserId,
"Kicked from chat for flood protection.",
IPAddress.IPv6Loopback
);
return;
}
}
} finally {
Context.ContextAccess.Release();
}
}
C2SPacketHandlerContext context = new(msg, Context, conn);
C2SPacketHandler? handler = conn.User is null
? GuestHandlers.FirstOrDefault(h => h.IsMatch(context))
: AuthedHandlers.FirstOrDefault(h => h.IsMatch(context));
if(handler is not null)
await handler.Handle(context);
}
}