From c1ffefb2e983761f577d23642670c8873a8f3604 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 14 Oct 2016 13:28:10 -0400 Subject: [PATCH 1/2] Optimize NetworkStream.CopyToAsync This commit overrides CopyToAsync on NetworkStream to provide an optimized implementation. Several optimizations: - Use ArrayPool for a pooled copy buffer rather than allocating a new one for each CopyToAsync operation - Uses a SocketAsyncEventArgs to avoid per-socket operation costs like pinning of the buffer - Uses a custom awaitable to avoid per-ReadAsync costs like the allocated Tasks and IAsyncResult objects involved --- .../src/Resources/Strings.resx | 12 ++ .../src/System.Net.Sockets.csproj | 6 + .../src/System/Net/Sockets/NetworkStream.cs | 169 ++++++++++++++++++ .../src/netcore50/project.json | 1 + src/System.Net.Sockets/src/project.json | 1 + src/System.Net.Sockets/src/win/project.json | 1 + .../FunctionalTests/NetworkStreamTest.cs | 48 +++++ 7 files changed, 238 insertions(+) diff --git a/src/System.Net.Sockets/src/Resources/Strings.resx b/src/System.Net.Sockets/src/Resources/Strings.resx index 7b7ab0350980..fb4340e0f086 100644 --- a/src/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/System.Net.Sockets/src/Resources/Strings.resx @@ -291,4 +291,16 @@ This platform does not support disconnecting a Socket. Instead, close the Socket and create a new one. + + Stream does not support reading. + + + Stream does not support writing. + + + Can not access a closed Stream. + + + Positive number required. + diff --git a/src/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/System.Net.Sockets/src/System.Net.Sockets.csproj index 062f4c66de27..22082c581800 100644 --- a/src/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -75,6 +75,12 @@ + + Common\System\IO\StreamHelpers.ArrayPoolCopy.cs + + + Common\System\IO\StreamHelpers.CopyValidation.cs + Common\System\Net\Logging\GlobalLog.cs diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs index dc2f8abb6fac..76a029066f2f 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs @@ -2,7 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Buffers; +using System.Diagnostics; using System.IO; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -1025,6 +1028,63 @@ public override Task WriteAsync(byte[] buffer, int offset, int size, Cancellatio this); } + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + // Validate arguments as would the base CopyToAsync + StreamHelpers.ValidateCopyToArgs(this, destination, bufferSize); + + // And bail early if cancellation has already been requested + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + // Then do additional checks as ReadAsync would. + + if (_cleanedUp) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + + Socket streamSocket = _streamSocket; + if (streamSocket == null) + { + throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed)); + } + + // Do the copy. We get a copy buffer from the shared pool, and we pass both it and the + // socket into the copy as part of the event args so as to avoid additional fields in + // the async method's state machine. + return CopyToAsyncCore( + destination, + new AwaitableSocketAsyncEventArgs(streamSocket, ArrayPool.Shared.Rent(bufferSize)), + cancellationToken); + } + + private static async Task CopyToAsyncCore(Stream destination, AwaitableSocketAsyncEventArgs ea, CancellationToken cancellationToken) + { + try + { + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + int bytesRead = await ea.ReceiveAsync(); + if (bytesRead == 0) + { + break; + } + + await destination.WriteAsync(ea.Buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(ea.Buffer, clearArray: true); + ea.Dispose(); + } + } + // Flushes data from the stream. This is meaningless for us, so it does nothing. public override void Flush() { @@ -1092,5 +1152,114 @@ internal void DebugMembers() _streamSocket.DebugMembers(); } } + + /// A SocketAsyncEventArgs that can be awaited to get the result of an operation. + internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion + { + /// Sentinal object used to indicate that the operation has completed prior to OnCompleted being called. + private static readonly Action s_completedSentinel = () => { }; + /// + /// null if the operation has not completed, if it has, and another object + /// if OnCompleted was called before the operation could complete, in which case it's the delegate to invoke + /// when the operation does complete. + /// + private Action _continuation; + + /// Initializes the event args. + /// The associated socket. + /// The buffer to use for all operations. + public AwaitableSocketAsyncEventArgs(Socket socket, byte[] buffer) + { + Debug.Assert(socket != null); + Debug.Assert(buffer != null && buffer.Length > 0); + + // Store the socket into the base's UserToken. This avoids the need for an extra field, at the expense + // of an object=>Socket cast when we need to access it, which is only once per operation. + UserToken = socket; + + // Store the buffer for use by all operations with this instance. + SetBuffer(buffer, 0, buffer.Length); + + // Hook up the completed event. + Completed += delegate + { + // When the operation completes, see if OnCompleted was already called to hook up a continuation. + // If it was, invoke the continuation. + Action c = _continuation; + if (c != null) + { + c(); + } + else + { + // We may be racing with OnCompleted, so check with synchronization, trying to swap in our + // completion sentinel. If we lose the race and OnCompleted did hook up a continuation, + // invoke it. Otherwise, there's nothing more to be done. + Interlocked.CompareExchange(ref _continuation, s_completedSentinel, null)?.Invoke(); + } + }; + } + + /// Initiates a receive operation on the associated socket. + /// This instance. + public AwaitableSocketAsyncEventArgs ReceiveAsync() + { + if (!Socket.ReceiveAsync(this)) + { + _continuation = s_completedSentinel; + } + return this; + } + + /// Gets this instance. + public AwaitableSocketAsyncEventArgs GetAwaiter() => this; + + /// Gets whether the operation has already completed. + /// + /// This is not a generically usable IsCompleted operation that suggests the whole operation has completed. + /// Rather, it's specifically used as part of the await pattern, and is only usable to determine whether the + /// operation has completed by the time the instance is awaited. + /// + public bool IsCompleted => _continuation != null; + + /// Same as + public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); + + /// Queues the provided continuation to be executed once the operation has completed. + public void OnCompleted(Action continuation) + { + if (_continuation == s_completedSentinel || Interlocked.CompareExchange(ref _continuation, continuation, null) == s_completedSentinel) + { + Task.Run(continuation); + } + } + + /// Gets the result of the completion operation. + /// Number of bytes transferred. + /// + /// Unlike Task's awaiter's GetResult, this does not block until the operation completes: it must only + /// be used once the operation has completed. This is handled implicitly by await. + /// + public int GetResult() + { + _continuation = null; + if (SocketError != SocketError.Success) + { + ThrowIOSocketException(); + } + return BytesTransferred; + } + + /// Gets the associated socket. + internal Socket Socket => (Socket)UserToken; // stored in the base's UserToken to avoid an extra field in the object + + /// Throws an IOException wrapping a SocketException using the current . + [MethodImpl(MethodImplOptions.NoInlining)] + private void ThrowIOSocketException() + { + var se = new SocketException((int)SocketError); + throw new IOException(SR.Format(SR.net_io_readfailure, se.Message), se); + } + } } } diff --git a/src/System.Net.Sockets/src/netcore50/project.json b/src/System.Net.Sockets/src/netcore50/project.json index 2c9fb38cba87..2e16c8f5aeba 100644 --- a/src/System.Net.Sockets/src/netcore50/project.json +++ b/src/System.Net.Sockets/src/netcore50/project.json @@ -5,6 +5,7 @@ "Microsoft.NETCore.Platforms": "1.0.1", "Microsoft.NETCore.Targets": "1.0.1", "Microsoft.TargetingPack.Private.WinRT": "1.0.1", + "System.Buffers": "4.0.0", "System.Collections": "4.0.0", "System.Diagnostics.Debug": "4.0.10", "System.Diagnostics.Tracing": "4.0.20", diff --git a/src/System.Net.Sockets/src/project.json b/src/System.Net.Sockets/src/project.json index 8732e9be4e01..5d7dd1829dcb 100644 --- a/src/System.Net.Sockets/src/project.json +++ b/src/System.Net.Sockets/src/project.json @@ -2,6 +2,7 @@ "frameworks": { "netstandard1.7": { "dependencies": { + "System.Buffers": "4.4.0-beta-24615-03", "System.Collections": "4.4.0-beta-24615-03", "System.Diagnostics.Debug": "4.4.0-beta-24615-03", "System.Diagnostics.Tracing": "4.4.0-beta-24615-03", diff --git a/src/System.Net.Sockets/src/win/project.json b/src/System.Net.Sockets/src/win/project.json index ab0790970fde..7cc3b92aa205 100644 --- a/src/System.Net.Sockets/src/win/project.json +++ b/src/System.Net.Sockets/src/win/project.json @@ -2,6 +2,7 @@ "frameworks": { "netstandard1.7": { "dependencies": { + "System.Buffers": "4.4.0-beta-24615-03", "System.Collections": "4.4.0-beta-24615-03", "System.Diagnostics.Debug": "4.4.0-beta-24615-03", "System.Diagnostics.Tracing": "4.4.0-beta-24615-03", diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs index 63806ac461fd..95a42c503656 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.IO; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -45,6 +46,53 @@ await RunWithConnectedNetworkStreamsAsync(async (server, client) => new object[] { new CancellationTokenSource().Token } // CanBeCanceled == true }; + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(1024)] + [InlineData(4096)] + [InlineData(4095)] + [InlineData(1024*1024)] + public async Task CopyToAsync_AllDataCopied(int byteCount) + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + var results = new MemoryStream(); + byte[] dataToCopy = new byte[byteCount]; + new Random().NextBytes(dataToCopy); + + Task copyTask = client.CopyToAsync(results); + await server.WriteAsync(dataToCopy, 0, dataToCopy.Length); + server.Dispose(); + await copyTask; + + Assert.Equal(dataToCopy, results.ToArray()); + }); + } + + [Fact] + public async Task CopyToAsync_InvalidArguments_Throws() + { + await RunWithConnectedNetworkStreamsAsync((stream, _) => + { + // Null destination + Assert.Throws("destination", () => { stream.CopyToAsync(null); }); + + // Buffer size out-of-range + Assert.Throws("bufferSize", () => { stream.CopyToAsync(new MemoryStream(), 0); }); + Assert.Throws("bufferSize", () => { stream.CopyToAsync(new MemoryStream(), -1, CancellationToken.None); }); + + // Copying to non-writable stream + Assert.Throws(() => { stream.CopyToAsync(new MemoryStream(new byte[0], writable: false)); }); + + // Copying after disposing the stream + stream.Dispose(); + Assert.Throws(() => { stream.CopyToAsync(new MemoryStream()); }); + + return Task.CompletedTask; + }); + } + /// /// Creates a pair of connected NetworkStreams and invokes the provided /// with them as arguments. From 81dbf51df90dd444a1e384117b55e3fc898b09fe Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 14 Oct 2016 16:25:11 -0400 Subject: [PATCH 2/2] Avoid allocating a SafeNativeOverlapped per SocketAsyncEventArgs operation Each operation ends up allocating a SafeNativeOverlapped, which results in a ton of allocations when trying to optimize code via SocketAsyncEventArgs. This commit reuses the same SafeHandle object for many / all of the operations on the event args instance. --- .../Windows/Winsock/SafeNativeOverlapped.cs | 62 +++++++------------ .../Sockets/SocketAsyncEventArgs.Windows.cs | 26 +++++--- .../SocketAsyncEventArgsTest.cs | 57 +++++++++++++++++ .../System.Net.Sockets.Tests.csproj | 1 + 4 files changed, 98 insertions(+), 48 deletions(-) create mode 100644 src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs diff --git a/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs b/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs index 17061bb006f0..5d880b412541 100644 --- a/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs +++ b/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs @@ -2,23 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.Win32.SafeHandles; - using System.Diagnostics; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; namespace System.Net.Sockets { - internal class SafeNativeOverlapped : SafeHandle + internal sealed class SafeNativeOverlapped : SafeHandle { - private static readonly SafeNativeOverlapped s_zero = new SafeNativeOverlapped(); - private SafeCloseSocket _safeCloseSocket; - - internal static SafeNativeOverlapped Zero { get { return s_zero; } } + internal static SafeNativeOverlapped Zero { get; } = new SafeNativeOverlapped(); - protected SafeNativeOverlapped() + private SafeNativeOverlapped() : this(IntPtr.Zero) { if (GlobalLog.IsEnabled) @@ -27,7 +21,7 @@ protected SafeNativeOverlapped() } } - protected SafeNativeOverlapped(IntPtr handle) + private SafeNativeOverlapped(IntPtr handle) : base(IntPtr.Zero, true) { SetHandle(handle); @@ -36,7 +30,7 @@ protected SafeNativeOverlapped(IntPtr handle) public unsafe SafeNativeOverlapped(SafeCloseSocket socketHandle, NativeOverlapped* handle) : this((IntPtr)handle) { - _safeCloseSocket = socketHandle; + SocketHandle = socketHandle; if (GlobalLog.IsEnabled) { @@ -44,24 +38,18 @@ public unsafe SafeNativeOverlapped(SafeCloseSocket socketHandle, NativeOverlappe } #if DEBUG - _safeCloseSocket.AddRef(); + SocketHandle.AddRef(); #endif } - protected override void Dispose(bool disposing) + internal unsafe void ReplaceHandle(NativeOverlapped* overlapped) { - if (disposing) - { - // It is important that the boundHandle is released immediately to allow new overlapped operations. - if (GlobalLog.IsEnabled) - { - GlobalLog.Print("SafeNativeOverlapped#" + LoggingHash.HashString(this) + "::Dispose(true)"); - } - - FreeNativeOverlapped(); - } + Debug.Assert(handle == IntPtr.Zero, "We should only be replacing the handle when it's already been freed."); + SetHandle((IntPtr)overlapped); } + internal SafeCloseSocket SocketHandle { get; } + public override bool IsInvalid { get { return handle == IntPtr.Zero; } @@ -75,37 +63,31 @@ protected override bool ReleaseHandle() } FreeNativeOverlapped(); + +#if DEBUG + SocketHandle.Release(); +#endif return true; } - private void FreeNativeOverlapped() + internal void FreeNativeOverlapped() { - IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero); - // Do not call free during AppDomain shutdown, there may be an outstanding operation. // Overlapped will take care calling free when the native callback completes. + IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero); if (oldHandle != IntPtr.Zero && !Environment.HasShutdownStarted) { unsafe { - Debug.Assert(_safeCloseSocket != null, "m_SafeCloseSocket is null."); - - ThreadPoolBoundHandle boundHandle = _safeCloseSocket.IOCPBoundHandle; - Debug.Assert(boundHandle != null, "SafeNativeOverlapped::ImmediatelyFreeNativeOverlapped - boundHandle is null"); + Debug.Assert(SocketHandle != null, "SocketHandle is null."); - if (boundHandle != null) - { - // FreeNativeOverlapped will be called even if boundHandle was previously disposed. - boundHandle.FreeNativeOverlapped((NativeOverlapped*)oldHandle); - } + ThreadPoolBoundHandle boundHandle = SocketHandle.IOCPBoundHandle; + Debug.Assert(boundHandle != null, "SafeNativeOverlapped::FreeNativeOverlapped - boundHandle is null"); -#if DEBUG - _safeCloseSocket.Release(); -#endif - _safeCloseSocket = null; + // FreeNativeOverlapped will be called even if boundHandle was previously disposed. + boundHandle?.FreeNativeOverlapped((NativeOverlapped*)oldHandle); } } - return; } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 376217c81c7c..0482bd57f945 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -155,9 +155,21 @@ private unsafe void PrepareIOCPOperation() "). Returned = " + ((IntPtr)overlapped).ToString("x")); } } - Debug.Assert(overlapped != null, "NativeOverlapped is null."); - _ptrNativeOverlapped = new SafeNativeOverlapped(_currentSocket.SafeHandle, overlapped); + + // If we already have a SafeNativeOverlapped SafeHandle and it's associated with the same + // socket (due to the last operation that used this SocketAsyncEventArgs using the same socket), + // then we can reuse the same SafeHandle object. Otherwise, this is either the first operation + // or the last operation was with a different socket, so create a new SafeHandle. + if (_ptrNativeOverlapped?.SocketHandle == _currentSocket.SafeHandle) + { + _ptrNativeOverlapped.ReplaceHandle(overlapped); + } + else + { + _ptrNativeOverlapped?.Dispose(); + _ptrNativeOverlapped = new SafeNativeOverlapped(_currentSocket.SafeHandle, overlapped); + } } private void CompleteIOCPOperation() @@ -169,12 +181,10 @@ private void CompleteIOCPOperation() // it is guaranteed that the IOCP operation will be completed in the callback even if Socket.Success was // returned by the Win32 API. - // Required to allow another IOCP operation for the same handle. - if (_ptrNativeOverlapped != null) - { - _ptrNativeOverlapped.Dispose(); - _ptrNativeOverlapped = null; - } + // Required to allow another IOCP operation for the same handle. We release the native overlapped + // in the safe handle, but keep the safe handle object around so as to be able to reuse it + // for other operations. + _ptrNativeOverlapped?.FreeNativeOverlapped(); } private void InnerStartOperationAccept(bool userSuppliedBuffer) diff --git a/src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs new file mode 100644 index 000000000000..cb423377b2ab --- /dev/null +++ b/src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Sockets.Tests +{ + public class SocketAsyncEventArgsTest + { + [Fact] + public async Task ReuseSocketAsyncEventArgs_SameInstance_MultipleSockets() + { + using (var listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listen.Listen(1); + + Task acceptTask = listen.AcceptAsync(); + await Task.WhenAll( + acceptTask, + client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listen.LocalEndPoint).Port))); + + using (Socket server = await acceptTask) + { + TaskCompletionSource tcs = null; + + var args = new SocketAsyncEventArgs(); + args.SetBuffer(new byte[1024], 0, 1024); + args.Completed += (_,__) => tcs.SetResult(true); + + for (int i = 1; i <= 10; i++) + { + tcs = new TaskCompletionSource(); + args.Buffer[0] = (byte)i; + args.SetBuffer(0, 1); + if (server.SendAsync(args)) + { + await tcs.Task; + } + + args.Buffer[0] = 0; + tcs = new TaskCompletionSource(); + if (client.ReceiveAsync(args)) + { + await tcs.Task; + } + Assert.Equal(1, args.BytesTransferred); + Assert.Equal(i, args.Buffer[0]); + } + } + } + } + } +} diff --git a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index b194c41b9d56..afa9cbeff632 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -39,6 +39,7 @@ +