diff --git a/src/EasyWebSockets/EasyWebSockets.csproj b/src/EasyWebSockets/EasyWebSockets.csproj index cc8b5e3..e7ea557 100755 --- a/src/EasyWebSockets/EasyWebSockets.csproj +++ b/src/EasyWebSockets/EasyWebSockets.csproj @@ -1,7 +1,10 @@ - + netstandard2.0 + 9 + Enable + nullable EasyWebSockets 1.0.0 EasyWebSockets @@ -12,8 +15,8 @@ - - + + diff --git a/src/EasyWebSockets/WebSocketConnectionManager.cs b/src/EasyWebSockets/WebSocketConnectionManager.cs index 223381d..248de55 100644 --- a/src/EasyWebSockets/WebSocketConnectionManager.cs +++ b/src/EasyWebSockets/WebSocketConnectionManager.cs @@ -9,11 +9,10 @@ namespace EasyWebSockets { internal class WebSocketConnectionManager { - private ConcurrentDictionary _sockets = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _sockets = new ConcurrentDictionary(); public WebSocket GetSocketById(string id) => _sockets.FirstOrDefault(p => p.Key == id).Value; - public ConcurrentDictionary GetAll() => _sockets; @@ -21,16 +20,15 @@ public WebSocket GetSocketById(string id) => public void AddSocket(WebSocket socket) => _sockets.TryAdd(CreateConnectionId(), socket); - public async Task RemoveSocket(string id) + public Task RemoveSocket(string id, CancellationToken cancellationToken = default) { - WebSocket socket; - _sockets.TryRemove(id, out socket); + _sockets.TryRemove(id, out WebSocket socket); - await socket.CloseAsync(closeStatus: WebSocketCloseStatus.NormalClosure, + return socket.CloseAsync(closeStatus: WebSocketCloseStatus.NormalClosure, statusDescription: "Closed by the WebSocketManager", - cancellationToken: CancellationToken.None); + cancellationToken: cancellationToken); } - private string CreateConnectionId() => Guid.NewGuid().ToString(); + private static string CreateConnectionId() => Guid.NewGuid().ToString(); } } \ No newline at end of file diff --git a/src/EasyWebSockets/WebSocketHandler.cs b/src/EasyWebSockets/WebSocketHandler.cs index 9741b8b..7c0ef6e 100644 --- a/src/EasyWebSockets/WebSocketHandler.cs +++ b/src/EasyWebSockets/WebSocketHandler.cs @@ -9,16 +9,16 @@ namespace EasyWebSockets { - public interface IWebSocketPublisher + public interface IWebSocketPublisher { - Task SendMessageToAllAsync(object message); + Task SendMessageToAllAsync(object message, CancellationToken cancellationToken = default); } internal class WebSocketHandler : IWebSocketPublisher { private readonly WebSocketConnectionManager _webSocketConnectionManager; - private JsonSerializerSettings _jsonSerializerSettings = new JsonSerializerSettings() + private readonly JsonSerializerSettings _jsonSerializerSettings = new JsonSerializerSettings() { ContractResolver = new CamelCasePropertyNamesContractResolver() }; @@ -26,33 +26,36 @@ internal class WebSocketHandler : IWebSocketPublisher public WebSocketHandler(WebSocketConnectionManager webSocketConnectionManager) => _webSocketConnectionManager = webSocketConnectionManager; - public async Task SendMessageToAllAsync(object message) => - await Task.WhenAll( - _webSocketConnectionManager.GetAll() + public Task SendMessageToAllAsync(object message, CancellationToken cancellationToken = default) => + Task.WhenAll(_webSocketConnectionManager.GetAll() .Where(pair => pair.Value.State == WebSocketState.Open) - .Select(pair => SendMessageAsync(pair.Value, message))); - - public async Task OnConnected(WebSocket socket) + .Select(pair => SendMessageAsync(pair.Value, message, cancellationToken))); + + public Task OnConnected(WebSocket socket, CancellationToken cancellationToken = default) { _webSocketConnectionManager.AddSocket(socket); - await SendMessageAsync(socket, $"Connected with Id: ${_webSocketConnectionManager.GetId(socket)}"); + return SendMessageAsync(socket, $"Connected with Id: ${_webSocketConnectionManager.GetId(socket)}", cancellationToken); } - public async Task OnDisconnected(WebSocket socket) => - await _webSocketConnectionManager.RemoveSocket(_webSocketConnectionManager.GetId(socket)); + public Task OnDisconnected(WebSocket socket, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + return _webSocketConnectionManager.RemoveSocket(_webSocketConnectionManager.GetId(socket), cancellationToken); + } - private async Task SendMessageAsync(WebSocket socket, object message) + private async Task SendMessageAsync(WebSocket socket, object message, CancellationToken cancellationToken = default) { if (socket.State != WebSocketState.Open) return; - var serializedMessage = JsonConvert.SerializeObject(message, _jsonSerializerSettings); + string? serializedMessage = JsonConvert.SerializeObject(message, _jsonSerializerSettings); await socket.SendAsync(buffer: new ArraySegment( - array: Encoding.ASCII.GetBytes(serializedMessage),offset: 0, + array: Encoding.ASCII.GetBytes(serializedMessage), + offset: 0, count: serializedMessage.Length), messageType: WebSocketMessageType.Text, endOfMessage: true, - cancellationToken: CancellationToken.None); + cancellationToken: cancellationToken); } } } \ No newline at end of file diff --git a/src/EasyWebSockets/WebSocketManagerExtensions.cs b/src/EasyWebSockets/WebSocketManagerExtensions.cs index 64b7544..eb53334 100644 --- a/src/EasyWebSockets/WebSocketManagerExtensions.cs +++ b/src/EasyWebSockets/WebSocketManagerExtensions.cs @@ -10,14 +10,16 @@ public static IServiceCollection AddEasyWebSockets(this IServiceCollection servi { services.AddTransient(); services.AddSingleton(); + return services; } public static IApplicationBuilder UseEasyWebSockets(this IApplicationBuilder app, string path = "/ws") { app.UseWebSockets(); - var wsHandler = app.ApplicationServices.GetService(typeof(IWebSocketPublisher)); - return app.Map(new PathString(path), (_app) => _app.UseMiddleware(wsHandler)); + + object wsHandler = app.ApplicationServices.GetRequiredService(); + return app.Map(new PathString(path), builder => builder.UseMiddleware(wsHandler)); } } } \ No newline at end of file diff --git a/src/EasyWebSockets/WebSocketManagerMiddleware.cs b/src/EasyWebSockets/WebSocketManagerMiddleware.cs index 71c7741..3e7018a 100644 --- a/src/EasyWebSockets/WebSocketManagerMiddleware.cs +++ b/src/EasyWebSockets/WebSocketManagerMiddleware.cs @@ -11,7 +11,7 @@ namespace EasyWebSockets internal class WebSocketManagerMiddleware { private readonly RequestDelegate _next; - private WebSocketHandler _webSocketHandler { get; set; } + private readonly WebSocketHandler _webSocketHandler; public WebSocketManagerMiddleware(RequestDelegate next, WebSocketHandler webSocketHandler) { @@ -24,55 +24,44 @@ public async Task Invoke(HttpContext context) if (!context.WebSockets.IsWebSocketRequest) return; - var socket = await context.WebSockets.AcceptWebSocketAsync(); - await _webSocketHandler.OnConnected(socket); + WebSocket? socket = await context.WebSockets.AcceptWebSocketAsync(); + await _webSocketHandler.OnConnected(socket, context.RequestAborted); await Receive(socket, async (result, serializedInvocationDescriptor) => { - if (result.MessageType == WebSocketMessageType.Text) + switch (result.MessageType) { - // await _webSocketHandler.ReceiveAsync(socket, result, serializedInvocationDescriptor); - return; - } - - else if (result.MessageType == WebSocketMessageType.Close) - { - try - { - await _webSocketHandler.OnDisconnected(socket); - } + case WebSocketMessageType.Text: + // await _webSocketHandler.ReceiveAsync(socket, result, serializedInvocationDescriptor); + return; - catch (WebSocketException) - { - throw; //let's not swallow any exception for now - } - - return; + case WebSocketMessageType.Close: + await _webSocketHandler.OnDisconnected(socket, context.RequestAborted); + return; } - }); + }, context.RequestAborted); } - private async Task Receive(WebSocket socket, Action handleMessage) + private static async Task Receive(WebSocket socket, Action handleMessage, CancellationToken cancellationToken = default) { while (socket.State == WebSocketState.Open) { - ArraySegment buffer = new ArraySegment(new Byte[1024 * 4]); - string serializedInvocationDescriptor = null; - WebSocketReceiveResult result = null; + var buffer = new ArraySegment(new byte[1024 * 4]); + string serializedInvocationDescriptor; + WebSocketReceiveResult result; + using (var ms = new MemoryStream()) { do { - result = await socket.ReceiveAsync(buffer, CancellationToken.None); - ms.Write(buffer.Array, buffer.Offset, result.Count); + result = await socket.ReceiveAsync(buffer, cancellationToken); + await ms.WriteAsync(buffer.Array, buffer.Offset, result.Count, cancellationToken); } while (!result.EndOfMessage); ms.Seek(0, SeekOrigin.Begin); - using (var reader = new StreamReader(ms, Encoding.UTF8)) - { - serializedInvocationDescriptor = await reader.ReadToEndAsync(); - } + using var reader = new StreamReader(ms, Encoding.UTF8); + serializedInvocationDescriptor = await reader.ReadToEndAsync(); } handleMessage(result, serializedInvocationDescriptor);