sharp-chat/SharpChatCommon/Sessions/SessionsContext.cs
flashwave 5a7756894b
First bits of the Context overhaul.
Reintroduces separate contexts for users, channels, connections (now split into sessions and connections) and user-channel associations.
It builds which is as much assurance as I can give about the stability of this commit, but its also the bare minimum of what i like to commit sooooo
A lot of things still need to be broadcast through events throughout the application in order to keep states consistent but we'll cross that bridge when we get to it.
I really need to stop using that phrase thingy, I'm overusing it.
2025-05-03 02:49:51 +00:00

295 lines
11 KiB
C#

using Microsoft.Extensions.Logging;
using SharpChat.Connections;
using SharpChat.Snowflake;
using SharpChat.Users;
using System.Net;
using ZLogger;
namespace SharpChat.Sessions;
public class SessionsContext(ILoggerFactory loggerFactory, RandomSnowflake snowflake) {
public readonly TimeSpan TimeOutInterval = TimeSpan.FromMinutes(5);
private readonly Dictionary<long, Session> Sessions = [];
private readonly Dictionary<string, List<Session>> UserSessions = [];
private readonly Dictionary<long, Session> SuspendedSessions = [];
private readonly Dictionary<Connection, Session> ConnectionSession = [];
private readonly Lock @lock = new();
public Session CreateSession(User user, Connection conn)
=> CreateSession(user.UserId, conn);
public Session CreateSession(string userId, Connection conn) {
lock(@lock) {
if(ConnectionSession.ContainsKey(conn))
throw new SessionConnectionAlreadyInUseException(nameof(conn));
long sessId = snowflake.Next();
string secret = RNG.SecureRandomString(20);
ILogger logger = loggerFactory.CreateLogger($"session:({sessId})");
Session sess = new(sessId, secret, userId, logger, conn);
Sessions.Add(sessId, sess);
ConnectionSession.Add(conn, sess);
if(!UserSessions.TryGetValue(userId, out var userSessions))
UserSessions.Add(userId, userSessions = []);
userSessions.Add(sess);
logger.ZLogInformation($"Session created for #{userId}.");
return sess;
}
}
public Session? ResumeSession(long sessId, string secret, Connection conn) {
lock(@lock) {
if(!SuspendedSessions.TryGetValue(sessId, out Session? sess)
|| !NullConnection.IsNull(sess.Connection)
|| !sess.Secret.SlowUtf8Equals(secret))
return null;
ConnectionSession.Add(conn, sess);
SuspendedSessions.Remove(sessId);
sess.Connection = conn;
return sess;
}
}
public void SuspendSession(Connection conn) {
lock(@lock)
if(ConnectionSession.TryGetValue(conn, out Session? sess))
SuspendSessionInternal(sess);
}
public void SuspendSession(Session sess)
=> SuspendSession(sess.Id);
public void SuspendSession(long sessId) {
lock(@lock)
if(Sessions.TryGetValue(sessId, out Session? sess))
SuspendSessionInternal(sess);
}
private void SuspendSessionInternal(Session sess) {
if(SuspendedSessions.ContainsValue(sess) || NullConnection.IsNull(sess.Connection))
return;
ConnectionSession.Remove(sess.Connection);
SuspendedSessions.Add(sess.Id, sess);
sess.Connection = NullConnection.Instance;
}
public void DestroySession(Connection conn) {
lock(@lock)
if(ConnectionSession.TryGetValue(conn, out Session? sess))
DestroySessionInternal(sess);
}
public void DestroySession(Session sess)
=> DestroySession(sess.Id);
public void DestroySession(long sessId) {
lock(@lock)
if(Sessions.TryGetValue(sessId, out Session? sess))
DestroySessionInternal(sess);
}
private void DestroySessionInternal(Session sess) {
if(UserSessions.TryGetValue(sess.UserId, out var userSessions)) {
userSessions.Remove(sess);
if(userSessions.Count < 1)
UserSessions.Remove(sess.UserId);
}
Sessions.Remove(sess.Id);
SuspendedSessions.Remove(sess.Id);
ConnectionSession.Remove(sess.Connection);
}
public bool IsSuspendedSession(Session sess)
=> IsSuspendedSession(sess.Id);
public bool IsSuspendedSession(long sessId) {
lock(@lock)
return SuspendedSessions.ContainsKey(sessId);
}
public Session? GetSession(long sessId) {
lock(@lock)
return Sessions.TryGetValue(sessId, out Session? sess) ? sess : null;
}
public Session? GetSession(Connection conn) {
lock(@lock)
return ConnectionSession.TryGetValue(conn, out Session? sess) ? sess : null;
}
public Session? GetSession(Func<Session, bool> predicate) {
lock(@lock)
return Sessions.Values.FirstOrDefault(predicate);
}
public int CountSessions() {
lock(@lock)
return Sessions.Count;
}
public int CountSessions(Func<Session, bool> predicate) {
lock(@lock)
return Sessions.Values.Count(predicate);
}
public int CountSessions(User user)
=> CountSessions(user.UserId);
public int CountSessions(string userId) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? userSessions.Count : 0;
}
public int CountSessions(User user, Func<Session, bool> predicate)
=> CountSessions(user.UserId, predicate);
public int CountSessions(string userId, Func<Session, bool> predicate) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? userSessions.Count(predicate) : 0;
}
public int CountActiveSessions(User user)
=> CountActiveSessions(user.UserId);
public int CountActiveSessions(string userId) {
return CountSessions(userId, s => !IsTimedOut(s));
}
public int CountNonSuspendedActiveSessions(User user)
=> CountNonSuspendedActiveSessions(user.UserId);
public int CountNonSuspendedActiveSessions(string userId) {
return CountSessions(userId, s => !s.IsSuspended /* <-- might create a loophole */ && !IsTimedOut(s));
}
public IEnumerable<Session> GetSessions() {
lock(@lock)
return [.. Sessions.Values];
}
public IEnumerable<Session> GetSessions(Func<Session, bool> predicate) {
lock(@lock)
return [.. Sessions.Values.Where(predicate)];
}
public IEnumerable<Session> GetSessions(User user)
=> GetSessions(user.UserId);
public IEnumerable<Session> GetSessions(string userId) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? [.. userSessions] : [];
}
public IEnumerable<Session> GetSessions(User user, Func<Session, bool> predicate)
=> GetSessions(user.UserId, predicate);
public IEnumerable<Session> GetSessions(string userId, Func<Session, bool> predicate) {
lock(@lock)
return UserSessions.TryGetValue(userId, out var userSessions) ? [.. userSessions.Where(predicate)] : [];
}
public IEnumerable<Session> GetSessions(IEnumerable<long> ids) {
return ids.Select(GetSession).Where(s => s is not null).Cast<Session>();
}
public IEnumerable<Session> GetTimedOutSessions(DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(s => IsTimedOut(s, now.Value));
}
public IEnumerable<Session> GetActiveSessions(DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(s => !IsTimedOut(s, now.Value));
}
public IEnumerable<Session> GetActiveSessions(User user, DateTimeOffset? now = null)
=> GetActiveSessions(user.UserId, now);
public IEnumerable<Session> GetActiveSessions(string userId, DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(userId, s => !IsTimedOut(s, now.Value));
}
public IEnumerable<Session> GetActiveSessions(User user, Func<Session, bool> predicate, DateTimeOffset? now = null)
=> GetActiveSessions(user.UserId, predicate, now);
public IEnumerable<Session> GetActiveSessions(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return GetSessions(userId, s => !IsTimedOut(s, now.Value) && predicate(s));
}
public bool IsTimedOut(Session session)
=> IsTimedOut(session, DateTimeOffset.UtcNow);
public bool IsTimedOut(Session session, DateTimeOffset now)
=> now - session.LastHeartbeat > TimeOutInterval;
public IEnumerable<IPEndPoint> GetRemoteEndPoints() {
return [.. GetSessions(s => !s.IsSuspended).Select(s => s.Connection.RemoteEndPoint).Distinct()];
}
public IEnumerable<IPEndPoint> GetRemoteEndPoints(User user)
=> GetRemoteEndPoints(user.UserId);
public IEnumerable<IPEndPoint> GetRemoteEndPoints(string userId) {
return [.. GetSessions(userId, s => !s.IsSuspended).Select(s => s.Connection.RemoteEndPoint).Distinct()];
}
public IEnumerable<Connection> GetConnections(DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return ConnectionSession.Where(kvp => !IsTimedOut(kvp.Value)).Select(kvp => kvp.Key);
}
public IEnumerable<Connection> GetConnections(Func<Session, bool> predicate, DateTimeOffset? now = null) {
now ??= DateTimeOffset.UtcNow;
return ConnectionSession.Where(kvp => !IsTimedOut(kvp.Value) && predicate(kvp.Value)).Select(kvp => kvp.Key);
}
public IEnumerable<T> GetConnections<T>(DateTimeOffset? now = null) where T : Connection {
return GetConnections(now).Where(c => c is T).Cast<T>();
}
public IEnumerable<T> GetConnections<T>(Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
return GetConnections(predicate, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<Connection> GetConnections(User user, DateTimeOffset? now = null)
=> GetConnections(user.UserId, now);
public IEnumerable<Connection> GetConnections(string userId, DateTimeOffset? now = null) {
return GetActiveSessions(userId, now).Select(s => s.Connection);
}
public IEnumerable<T> GetConnections<T>(User user, DateTimeOffset? now = null) where T : Connection {
return GetConnections(user.UserId, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<T> GetConnections<T>(string userId, DateTimeOffset? now = null) where T : Connection {
return GetConnections(userId, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<Connection> GetConnections(User user, Func<Session, bool> predicate, DateTimeOffset? now = null)
=> GetConnections(user.UserId, predicate, now);
public IEnumerable<Connection> GetConnections(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) {
return GetActiveSessions(userId, predicate, now).Select(s => s.Connection);
}
public IEnumerable<T> GetConnections<T>(User user, Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
return GetConnections(user.UserId, predicate, now).Where(c => c is T).Cast<T>();
}
public IEnumerable<T> GetConnections<T>(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
return GetConnections(userId, predicate, now).Where(c => c is T).Cast<T>();
}
}