diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs index 487aab5594d813..4e55cd9575b8b4 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs @@ -135,6 +135,9 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => ConnectedWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) => + ConnectedWebSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) => ConnectedWebSocket.ReceiveAsync(buffer, cancellationToken); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs index 9836f31df16cb7..158aa0983a01b1 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -87,6 +87,91 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => }), new LoopbackServer.Options { WebSocketEndpoint = true }); } + [ConditionalFact(nameof(WebSocketsSupported))] + public async Task ThrowsWhenContinuationHasDifferentCompressionFlags() + { + var deflateOpt = new WebSocketDeflateOptions + { + ClientMaxWindowBits = 14, + ClientContextTakeover = true, + ServerMaxWindowBits = 14, + ServerContextTakeover = true + }; + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var cws = new ClientWebSocket(); + using var cts = new CancellationTokenSource(TimeOutMilliseconds); + + cws.Options.DangerousDeflateOptions = deflateOpt; + await ConnectAsync(cws, uri, cts.Token); + + + await cws.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); + Assert.Throws("messageFlags", () => + cws.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); + }, server => server.AcceptConnectionAsync(async connection => + { + string extensionsReply = CreateDeflateOptionsHeader(deflateOpt); + await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + [ConditionalFact(nameof(WebSocketsSupported))] + public async Task SendHelloWithDisableCompression() + { + byte[] bytes = "Hello"u8.ToArray(); + + int prefixLength = 2; + byte[] rawPrefix = new byte[] { 0x81, 0x85 }; // fin=1, rsv=0, opcode=text; mask=1, len=5 + int rawRemainingBytes = 9; // mask bytes (4) + payload bytes (5) + byte[] compressedPrefix = new byte[] { 0xc1, 0x87 }; // fin=1, rsv=compressed, opcode=text; mask=1, len=7 + int compressedRemainingBytes = 11; // mask bytes (4) + payload bytes (7) + + var deflateOpt = new WebSocketDeflateOptions + { + ClientMaxWindowBits = 14, + ClientContextTakeover = true, + ServerMaxWindowBits = 14, + ServerContextTakeover = true + }; + + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var cws = new ClientWebSocket(); + using var cts = new CancellationTokenSource(TimeOutMilliseconds); + + cws.Options.DangerousDeflateOptions = deflateOpt; + await ConnectAsync(cws, uri, cts.Token); + + await cws.SendAsync(bytes, WebSocketMessageType.Text, true, cts.Token); + + WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; + await cws.SendAsync(bytes, WebSocketMessageType.Text, flags, cts.Token); + }, server => server.AcceptConnectionAsync(async connection => + { + var buffer = new byte[compressedRemainingBytes]; + string extensionsReply = CreateDeflateOptionsHeader(deflateOpt); + await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + + // first message is compressed + await ReadExactAsync(buffer, prefixLength); + Assert.Equal(compressedPrefix, buffer[..prefixLength]); + // read rest of the frame + await ReadExactAsync(buffer, compressedRemainingBytes); + + // second message is not compressed + await ReadExactAsync(buffer, prefixLength); + Assert.Equal(rawPrefix, buffer[..prefixLength]); + // read rest of the frame + await ReadExactAsync(buffer, rawRemainingBytes); + + async Task ReadExactAsync(byte[] buf, int n) + { + await connection.Stream.ReadAtLeastAsync(buf.AsMemory(0, n), n); + } + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options) { var builder = new StringBuilder();