using Newtonsoft.Json; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Maki.Rest { public class WebRequest : IDisposable { private const string USER_AGENT = @"DiscordBot (https://github.com/flashwave/maki, 1.0.0.0)"; private const string GENERIC_CONTENT_TYPE = @"application/octet-stream"; private const string JSON_CONTENT_TYPE = @"application/json"; private const int MAX_RETRIES = 1; private const int TIMEOUT = 10000; public int Timeout { get; set; } = TIMEOUT; public event Action Started; public event Action Finished; public event Action Failed; public event Action DownloadProgress; public event Action UploadProgress; public bool IsAborted { get; private set; } public string Accept { get; set; } private bool PrivateCompleted; public bool IsCompleted { get => PrivateCompleted; private set { PrivateCompleted = value; if (!PrivateCompleted) return; Started = null; Finished = null; DownloadProgress = null; UploadProgress = null; } } private string PrivateUrl; private string Url { get => PrivateUrl; set { if (!value.StartsWith(@"http://") && !value.StartsWith(@"https://")) value = RestEndpoints.BASE_URL + RestEndpoints.BASE_PATH + value; PrivateUrl = value; } } private const int BUFFER_SIZE = 4096; private byte[] Buffer; private static HttpClient HttpClient; public readonly HttpMethod Method; private bool HasBody => Method == HttpMethod.PUT || Method == HttpMethod.POST || Method == HttpMethod.PATCH; public string ContentType { get; set; } = GENERIC_CONTENT_TYPE; [Obsolete] public long ContentLength => Response.Content.Headers.ContentLength ?? BytesDownloaded; [Obsolete] internal static string Authorisation { get; set; } private readonly Dictionary Headers = new Dictionary(); private readonly Dictionary Parameters = new Dictionary(); private readonly Dictionary Files = new Dictionary(); private long BytesUploaded = 0; private long BytesDownloaded = 0; private HttpResponseMessage Response; private MemoryStream RawRequestBody; private Stream RequestStream; private Stream ResponseStream; private CancellationTokenSource AbortToken; private CancellationTokenSource TimeoutToken; public int Retries { get; private set; } = 0; private byte[] PrivateResponseBytes; public byte[] ResponseBytes { get { if (PrivateResponseBytes == null) using (MemoryStream ms = new MemoryStream()) { byte[] bytes = new byte[4096]; int read = 0; while ((read = ResponseStream.Read(bytes, 0, bytes.Length)) > 0) ms.Write(bytes, 0, read); ms.Seek(0, SeekOrigin.Begin); PrivateResponseBytes = new byte[ms.Length]; ms.Read(PrivateResponseBytes, 0, PrivateResponseBytes.Length); } return PrivateResponseBytes; } } private string PrivateResponseString = string.Empty; public string ResponseString { get { if (string.IsNullOrEmpty(PrivateResponseString)) PrivateResponseString = Encoding.UTF8.GetString(ResponseBytes); return PrivateResponseString; } } public T ResponseJson() => JsonConvert.DeserializeObject(ResponseString); [Obsolete] public short Status => (short)Response?.StatusCode; static WebRequest() => CreateHttpClientInstance(); public WebRequest(HttpMethod method, string url) { Method = method; Url = url; } private static void CreateHttpClientInstance() { HttpClient?.Dispose(); HttpClient = new HttpClient(new HttpClientHandler { AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate }); HttpClient.DefaultRequestHeaders.UserAgent.ParseAdd(USER_AGENT); HttpClient.DefaultRequestHeaders.ExpectContinue = true; HttpClient.Timeout = new TimeSpan(0, 0, 0, 0, System.Threading.Timeout.Infinite); } public void AddRaw(Stream stream) { if (stream == null) throw new ArgumentNullException(nameof(stream)); RawRequestBody?.Dispose(); RawRequestBody = new MemoryStream(); stream.CopyTo(RawRequestBody); } public void AddRaw(byte[] bytes) { using (MemoryStream ms = new MemoryStream(bytes)) AddRaw(ms); } public void AddRaw(string str) => AddRaw(Encoding.UTF8.GetBytes(str)); public void AddJson(object obj) { ContentType = JSON_CONTENT_TYPE; AddRaw(JsonConvert.SerializeObject(obj)); } public void AddParam(string name, string contents) => Parameters.Add(name, contents); public void AddFile(string name, byte[] bytes) => Files.Add(name, bytes); public void AddHeader(string name, string value) { if (string.IsNullOrEmpty(name)) throw new ArgumentNullException(nameof(name)); if (value == null) throw new ArgumentNullException(nameof(value)); if (Headers.ContainsKey(name)) Headers[name] = value; else Headers.Add(name, value); } public void Abort() { IsAborted = true; IsCompleted = true; try { AbortToken?.Cancel(); } catch (ObjectDisposedException) { // just do nothign in this case } } private System.Net.Http.HttpMethod FromInternalHttpMethod(HttpMethod method) { switch (method) { case HttpMethod.GET: return System.Net.Http.HttpMethod.Get; case HttpMethod.DELETE: return System.Net.Http.HttpMethod.Delete; case HttpMethod.POST: return System.Net.Http.HttpMethod.Post; case HttpMethod.PATCH: return new System.Net.Http.HttpMethod(@"PATCH"); case HttpMethod.PUT: return System.Net.Http.HttpMethod.Put; } throw new InvalidOperationException($"Unsupported HTTP method {method}."); } private void PrivatePerform() { using (AbortToken = new CancellationTokenSource()) using (TimeoutToken = new CancellationTokenSource()) using (CancellationTokenSource linkedToken = CancellationTokenSource.CreateLinkedTokenSource(AbortToken.Token, TimeoutToken.Token)) { try { string requestUri = Url; HttpRequestMessage request = new HttpRequestMessage(FromInternalHttpMethod(Method), requestUri); foreach (KeyValuePair h in Headers) request.Headers.Add(h.Key, h.Value); if (!string.IsNullOrEmpty(Accept)) request.Headers.Accept.TryParseAdd(Accept); if (HasBody) { Stream bodyContent; if (RawRequestBody == null) { MultipartFormDataContent formData = new MultipartFormDataContent(); foreach (KeyValuePair p in Parameters) formData.Add(new StringContent(p.Value), p.Key); foreach (KeyValuePair f in Files) { ByteArrayContent bac = new ByteArrayContent(f.Value); bac.Headers.Add("Content-Type", GENERIC_CONTENT_TYPE); formData.Add(bac, f.Key, f.Key); } bodyContent = formData.ReadAsStreamAsync().Result; } else { if (Parameters.Count > 0 || Files.Count > 0) throw new InvalidOperationException($"You cannot use {nameof(AddRaw)} at the same time as {nameof(AddParam)} or {nameof(AddFile)}"); bodyContent = new MemoryStream(); RawRequestBody.Seek(0, SeekOrigin.Begin); RawRequestBody.CopyTo(bodyContent); bodyContent.Seek(0, SeekOrigin.Begin); } request.Content = new StreamContent(RequestStream); if (!string.IsNullOrEmpty(ContentType)) request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(ContentType); } else { if (Parameters.Count > 1) { StringBuilder urlBuilder = new StringBuilder(); urlBuilder.Append(Url); if (!Url.Contains('?')) urlBuilder.Append('?'); foreach (KeyValuePair param in Parameters) { urlBuilder.Append(param.Key); urlBuilder.Append('='); urlBuilder.Append(param.Value); urlBuilder.Append('&'); } urlBuilder.Length -= 1; requestUri = urlBuilder.ToString(); } } ReportProgress(); using (request) Response = HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, linkedToken.Token).Result; ResponseStream = new MemoryStream(); if (HasBody) { ReportProgress(); UploadProgress?.Invoke(0, BytesUploaded); } HandleResponse(linkedToken.Token); } catch (Exception) when (AbortToken.IsCancellationRequested) { Complete(new WebException(string.Format("Request to {0} was aborted by the user.", Url), WebExceptionStatus.RequestCanceled)); } catch (Exception) when (TimeoutToken.IsCancellationRequested) { Complete(new WebException(string.Format("Request to {0} timed out after {1:N0} seconds idle (read {2:N0} bytes).", Url, TimeSinceLastAction / 1000, BytesDownloaded), WebExceptionStatus.Timeout)); } catch (Exception ex) { if (IsCompleted) throw; Complete(ex); } } } public async Task PerformAsync() { if (IsCompleted) throw new InvalidOperationException($"{nameof(WebRequest)} has already been run, you can't reuse WebRequest objects."); try { await Task.Factory.StartNew(PrivatePerform, TaskCreationOptions.LongRunning); } catch (AggregateException ex) { if (ex.InnerExceptions.Count != 1) throw ex; while (ex.InnerExceptions.Count == 1) { AggregateException innerEx = ex.InnerException as AggregateException; ex = innerEx ?? throw innerEx.InnerException; } throw ex; } } public void Perform() => PerformAsync().Wait(); private void HandleResponse(CancellationToken cancellationToken) { using (Stream responseStream = Response.Content.ReadAsStreamAsync().Result) { Started?.Invoke(); Buffer = new byte[BUFFER_SIZE]; while (true) { cancellationToken.ThrowIfCancellationRequested(); int read = responseStream.Read(Buffer, 0, BUFFER_SIZE); ReportProgress(); if (read > 0) { ResponseStream.Write(Buffer, 0, read); BytesDownloaded += read; DownloadProgress?.Invoke(BytesDownloaded, Response.Content.Headers.ContentLength ?? BytesDownloaded); } else { ResponseStream.Seek(0, SeekOrigin.Begin); break; } } } } private void Complete(Exception exception = null) { if (IsAborted || IsCompleted) return; bool allowRetry = true; if (exception != null) { allowRetry = exception is WebException && (exception as WebException)?.Status == WebExceptionStatus.Timeout; } else if (!Response.IsSuccessStatusCode) { exception = new WebException($@"HTTP {Response.StatusCode}"); switch (Response.StatusCode) { case HttpStatusCode.NotFound: case HttpStatusCode.MethodNotAllowed: case HttpStatusCode.Forbidden: case HttpStatusCode.Unauthorized: allowRetry = false; break; } } if (exception != null) if (allowRetry && Retries < MAX_RETRIES && BytesDownloaded < 1) { ++Retries; PrivatePerform(); } try { // process } catch (Exception ex) { exception = exception == null ? ex : new AggregateException(exception, ex); } IsCompleted = true; if (exception != null) { IsAborted = true; Failed?.Invoke(exception); throw exception; } Finished?.Invoke(); } #region Timeout private long LastReportedAction = 0; private long TimeSinceLastAction => (DateTime.Now.Ticks - LastReportedAction) / TimeSpan.TicksPerMillisecond; private void ReportProgress() { LastReportedAction = DateTime.Now.Ticks; TimeoutToken.CancelAfter(Timeout); } #endregion #region Disposal public bool IsDisposed { get; private set; } = false; private void Dispose(bool disposing) { if (IsDisposed) return; IsDisposed = true; // TODO: reimplement disposal if (disposing) GC.SuppressFinalize(this); } ~WebRequest() => Dispose(false); public void Dispose() => Dispose(true); #endregion } }