Reintroduces separate contexts for users, channels, connections (now split into sessions and connections) and user-channel associations. It builds which is as much assurance as I can give about the stability of this commit, but its also the bare minimum of what i like to commit sooooo A lot of things still need to be broadcast through events throughout the application in order to keep states consistent but we'll cross that bridge when we get to it. I really need to stop using that phrase thingy, I'm overusing it.
295 lines
11 KiB
C#
295 lines
11 KiB
C#
using Microsoft.Extensions.Logging;
|
|
using SharpChat.Connections;
|
|
using SharpChat.Snowflake;
|
|
using SharpChat.Users;
|
|
using System.Net;
|
|
using ZLogger;
|
|
|
|
namespace SharpChat.Sessions;
|
|
|
|
public class SessionsContext(ILoggerFactory loggerFactory, RandomSnowflake snowflake) {
|
|
public readonly TimeSpan TimeOutInterval = TimeSpan.FromMinutes(5);
|
|
|
|
private readonly Dictionary<long, Session> Sessions = [];
|
|
private readonly Dictionary<string, List<Session>> UserSessions = [];
|
|
private readonly Dictionary<long, Session> SuspendedSessions = [];
|
|
private readonly Dictionary<Connection, Session> ConnectionSession = [];
|
|
private readonly Lock @lock = new();
|
|
|
|
public Session CreateSession(User user, Connection conn)
|
|
=> CreateSession(user.UserId, conn);
|
|
|
|
public Session CreateSession(string userId, Connection conn) {
|
|
lock(@lock) {
|
|
if(ConnectionSession.ContainsKey(conn))
|
|
throw new SessionConnectionAlreadyInUseException(nameof(conn));
|
|
|
|
long sessId = snowflake.Next();
|
|
string secret = RNG.SecureRandomString(20);
|
|
ILogger logger = loggerFactory.CreateLogger($"session:({sessId})");
|
|
Session sess = new(sessId, secret, userId, logger, conn);
|
|
|
|
Sessions.Add(sessId, sess);
|
|
ConnectionSession.Add(conn, sess);
|
|
|
|
if(!UserSessions.TryGetValue(userId, out var userSessions))
|
|
UserSessions.Add(userId, userSessions = []);
|
|
|
|
userSessions.Add(sess);
|
|
|
|
logger.ZLogInformation($"Session created for #{userId}.");
|
|
|
|
return sess;
|
|
}
|
|
}
|
|
|
|
public Session? ResumeSession(long sessId, string secret, Connection conn) {
|
|
lock(@lock) {
|
|
if(!SuspendedSessions.TryGetValue(sessId, out Session? sess)
|
|
|| !NullConnection.IsNull(sess.Connection)
|
|
|| !sess.Secret.SlowUtf8Equals(secret))
|
|
return null;
|
|
|
|
ConnectionSession.Add(conn, sess);
|
|
SuspendedSessions.Remove(sessId);
|
|
sess.Connection = conn;
|
|
return sess;
|
|
}
|
|
}
|
|
|
|
public void SuspendSession(Connection conn) {
|
|
lock(@lock)
|
|
if(ConnectionSession.TryGetValue(conn, out Session? sess))
|
|
SuspendSessionInternal(sess);
|
|
}
|
|
|
|
public void SuspendSession(Session sess)
|
|
=> SuspendSession(sess.Id);
|
|
|
|
public void SuspendSession(long sessId) {
|
|
lock(@lock)
|
|
if(Sessions.TryGetValue(sessId, out Session? sess))
|
|
SuspendSessionInternal(sess);
|
|
}
|
|
|
|
private void SuspendSessionInternal(Session sess) {
|
|
if(SuspendedSessions.ContainsValue(sess) || NullConnection.IsNull(sess.Connection))
|
|
return;
|
|
|
|
ConnectionSession.Remove(sess.Connection);
|
|
SuspendedSessions.Add(sess.Id, sess);
|
|
sess.Connection = NullConnection.Instance;
|
|
}
|
|
|
|
public void DestroySession(Connection conn) {
|
|
lock(@lock)
|
|
if(ConnectionSession.TryGetValue(conn, out Session? sess))
|
|
DestroySessionInternal(sess);
|
|
}
|
|
|
|
public void DestroySession(Session sess)
|
|
=> DestroySession(sess.Id);
|
|
|
|
public void DestroySession(long sessId) {
|
|
lock(@lock)
|
|
if(Sessions.TryGetValue(sessId, out Session? sess))
|
|
DestroySessionInternal(sess);
|
|
}
|
|
|
|
private void DestroySessionInternal(Session sess) {
|
|
if(UserSessions.TryGetValue(sess.UserId, out var userSessions)) {
|
|
userSessions.Remove(sess);
|
|
|
|
if(userSessions.Count < 1)
|
|
UserSessions.Remove(sess.UserId);
|
|
}
|
|
|
|
Sessions.Remove(sess.Id);
|
|
SuspendedSessions.Remove(sess.Id);
|
|
ConnectionSession.Remove(sess.Connection);
|
|
}
|
|
|
|
public bool IsSuspendedSession(Session sess)
|
|
=> IsSuspendedSession(sess.Id);
|
|
|
|
public bool IsSuspendedSession(long sessId) {
|
|
lock(@lock)
|
|
return SuspendedSessions.ContainsKey(sessId);
|
|
}
|
|
|
|
public Session? GetSession(long sessId) {
|
|
lock(@lock)
|
|
return Sessions.TryGetValue(sessId, out Session? sess) ? sess : null;
|
|
}
|
|
|
|
public Session? GetSession(Connection conn) {
|
|
lock(@lock)
|
|
return ConnectionSession.TryGetValue(conn, out Session? sess) ? sess : null;
|
|
}
|
|
|
|
public Session? GetSession(Func<Session, bool> predicate) {
|
|
lock(@lock)
|
|
return Sessions.Values.FirstOrDefault(predicate);
|
|
}
|
|
|
|
public int CountSessions() {
|
|
lock(@lock)
|
|
return Sessions.Count;
|
|
}
|
|
|
|
public int CountSessions(Func<Session, bool> predicate) {
|
|
lock(@lock)
|
|
return Sessions.Values.Count(predicate);
|
|
}
|
|
|
|
public int CountSessions(User user)
|
|
=> CountSessions(user.UserId);
|
|
|
|
public int CountSessions(string userId) {
|
|
lock(@lock)
|
|
return UserSessions.TryGetValue(userId, out var userSessions) ? userSessions.Count : 0;
|
|
}
|
|
|
|
public int CountSessions(User user, Func<Session, bool> predicate)
|
|
=> CountSessions(user.UserId, predicate);
|
|
|
|
public int CountSessions(string userId, Func<Session, bool> predicate) {
|
|
lock(@lock)
|
|
return UserSessions.TryGetValue(userId, out var userSessions) ? userSessions.Count(predicate) : 0;
|
|
}
|
|
|
|
public int CountActiveSessions(User user)
|
|
=> CountActiveSessions(user.UserId);
|
|
|
|
public int CountActiveSessions(string userId) {
|
|
return CountSessions(userId, s => !IsTimedOut(s));
|
|
}
|
|
|
|
public int CountNonSuspendedActiveSessions(User user)
|
|
=> CountNonSuspendedActiveSessions(user.UserId);
|
|
|
|
public int CountNonSuspendedActiveSessions(string userId) {
|
|
return CountSessions(userId, s => !s.IsSuspended /* <-- might create a loophole */ && !IsTimedOut(s));
|
|
}
|
|
|
|
public IEnumerable<Session> GetSessions() {
|
|
lock(@lock)
|
|
return [.. Sessions.Values];
|
|
}
|
|
|
|
public IEnumerable<Session> GetSessions(Func<Session, bool> predicate) {
|
|
lock(@lock)
|
|
return [.. Sessions.Values.Where(predicate)];
|
|
}
|
|
|
|
public IEnumerable<Session> GetSessions(User user)
|
|
=> GetSessions(user.UserId);
|
|
|
|
public IEnumerable<Session> GetSessions(string userId) {
|
|
lock(@lock)
|
|
return UserSessions.TryGetValue(userId, out var userSessions) ? [.. userSessions] : [];
|
|
}
|
|
|
|
public IEnumerable<Session> GetSessions(User user, Func<Session, bool> predicate)
|
|
=> GetSessions(user.UserId, predicate);
|
|
|
|
public IEnumerable<Session> GetSessions(string userId, Func<Session, bool> predicate) {
|
|
lock(@lock)
|
|
return UserSessions.TryGetValue(userId, out var userSessions) ? [.. userSessions.Where(predicate)] : [];
|
|
}
|
|
|
|
public IEnumerable<Session> GetSessions(IEnumerable<long> ids) {
|
|
return ids.Select(GetSession).Where(s => s is not null).Cast<Session>();
|
|
}
|
|
|
|
public IEnumerable<Session> GetTimedOutSessions(DateTimeOffset? now = null) {
|
|
now ??= DateTimeOffset.UtcNow;
|
|
return GetSessions(s => IsTimedOut(s, now.Value));
|
|
}
|
|
|
|
public IEnumerable<Session> GetActiveSessions(DateTimeOffset? now = null) {
|
|
now ??= DateTimeOffset.UtcNow;
|
|
return GetSessions(s => !IsTimedOut(s, now.Value));
|
|
}
|
|
|
|
public IEnumerable<Session> GetActiveSessions(User user, DateTimeOffset? now = null)
|
|
=> GetActiveSessions(user.UserId, now);
|
|
|
|
public IEnumerable<Session> GetActiveSessions(string userId, DateTimeOffset? now = null) {
|
|
now ??= DateTimeOffset.UtcNow;
|
|
return GetSessions(userId, s => !IsTimedOut(s, now.Value));
|
|
}
|
|
|
|
public IEnumerable<Session> GetActiveSessions(User user, Func<Session, bool> predicate, DateTimeOffset? now = null)
|
|
=> GetActiveSessions(user.UserId, predicate, now);
|
|
|
|
public IEnumerable<Session> GetActiveSessions(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) {
|
|
now ??= DateTimeOffset.UtcNow;
|
|
return GetSessions(userId, s => !IsTimedOut(s, now.Value) && predicate(s));
|
|
}
|
|
|
|
public bool IsTimedOut(Session session)
|
|
=> IsTimedOut(session, DateTimeOffset.UtcNow);
|
|
|
|
public bool IsTimedOut(Session session, DateTimeOffset now)
|
|
=> now - session.LastHeartbeat > TimeOutInterval;
|
|
|
|
public IEnumerable<IPEndPoint> GetRemoteEndPoints() {
|
|
return [.. GetSessions(s => !s.IsSuspended).Select(s => s.Connection.RemoteEndPoint).Distinct()];
|
|
}
|
|
|
|
public IEnumerable<IPEndPoint> GetRemoteEndPoints(User user)
|
|
=> GetRemoteEndPoints(user.UserId);
|
|
|
|
public IEnumerable<IPEndPoint> GetRemoteEndPoints(string userId) {
|
|
return [.. GetSessions(userId, s => !s.IsSuspended).Select(s => s.Connection.RemoteEndPoint).Distinct()];
|
|
}
|
|
|
|
public IEnumerable<Connection> GetConnections(DateTimeOffset? now = null) {
|
|
now ??= DateTimeOffset.UtcNow;
|
|
return ConnectionSession.Where(kvp => !IsTimedOut(kvp.Value)).Select(kvp => kvp.Key);
|
|
}
|
|
|
|
public IEnumerable<Connection> GetConnections(Func<Session, bool> predicate, DateTimeOffset? now = null) {
|
|
now ??= DateTimeOffset.UtcNow;
|
|
return ConnectionSession.Where(kvp => !IsTimedOut(kvp.Value) && predicate(kvp.Value)).Select(kvp => kvp.Key);
|
|
}
|
|
|
|
public IEnumerable<T> GetConnections<T>(DateTimeOffset? now = null) where T : Connection {
|
|
return GetConnections(now).Where(c => c is T).Cast<T>();
|
|
}
|
|
|
|
public IEnumerable<T> GetConnections<T>(Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
|
|
return GetConnections(predicate, now).Where(c => c is T).Cast<T>();
|
|
}
|
|
|
|
public IEnumerable<Connection> GetConnections(User user, DateTimeOffset? now = null)
|
|
=> GetConnections(user.UserId, now);
|
|
|
|
public IEnumerable<Connection> GetConnections(string userId, DateTimeOffset? now = null) {
|
|
return GetActiveSessions(userId, now).Select(s => s.Connection);
|
|
}
|
|
|
|
public IEnumerable<T> GetConnections<T>(User user, DateTimeOffset? now = null) where T : Connection {
|
|
return GetConnections(user.UserId, now).Where(c => c is T).Cast<T>();
|
|
}
|
|
|
|
public IEnumerable<T> GetConnections<T>(string userId, DateTimeOffset? now = null) where T : Connection {
|
|
return GetConnections(userId, now).Where(c => c is T).Cast<T>();
|
|
}
|
|
|
|
public IEnumerable<Connection> GetConnections(User user, Func<Session, bool> predicate, DateTimeOffset? now = null)
|
|
=> GetConnections(user.UserId, predicate, now);
|
|
|
|
public IEnumerable<Connection> GetConnections(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) {
|
|
return GetActiveSessions(userId, predicate, now).Select(s => s.Connection);
|
|
}
|
|
|
|
public IEnumerable<T> GetConnections<T>(User user, Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
|
|
return GetConnections(user.UserId, predicate, now).Where(c => c is T).Cast<T>();
|
|
}
|
|
|
|
public IEnumerable<T> GetConnections<T>(string userId, Func<Session, bool> predicate, DateTimeOffset? now = null) where T : Connection {
|
|
return GetConnections(userId, predicate, now).Where(c => c is T).Cast<T>();
|
|
}
|
|
}
|