diff --git a/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs b/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs
index 17061bb006f0..5d880b412541 100644
--- a/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs
+++ b/src/Common/src/Interop/Windows/Winsock/SafeNativeOverlapped.cs
@@ -2,23 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.Win32.SafeHandles;
-
using System.Diagnostics;
-using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
namespace System.Net.Sockets
{
- internal class SafeNativeOverlapped : SafeHandle
+ internal sealed class SafeNativeOverlapped : SafeHandle
{
- private static readonly SafeNativeOverlapped s_zero = new SafeNativeOverlapped();
- private SafeCloseSocket _safeCloseSocket;
-
- internal static SafeNativeOverlapped Zero { get { return s_zero; } }
+ internal static SafeNativeOverlapped Zero { get; } = new SafeNativeOverlapped();
- protected SafeNativeOverlapped()
+ private SafeNativeOverlapped()
: this(IntPtr.Zero)
{
if (GlobalLog.IsEnabled)
@@ -27,7 +21,7 @@ protected SafeNativeOverlapped()
}
}
- protected SafeNativeOverlapped(IntPtr handle)
+ private SafeNativeOverlapped(IntPtr handle)
: base(IntPtr.Zero, true)
{
SetHandle(handle);
@@ -36,7 +30,7 @@ protected SafeNativeOverlapped(IntPtr handle)
public unsafe SafeNativeOverlapped(SafeCloseSocket socketHandle, NativeOverlapped* handle)
: this((IntPtr)handle)
{
- _safeCloseSocket = socketHandle;
+ SocketHandle = socketHandle;
if (GlobalLog.IsEnabled)
{
@@ -44,24 +38,18 @@ public unsafe SafeNativeOverlapped(SafeCloseSocket socketHandle, NativeOverlappe
}
#if DEBUG
- _safeCloseSocket.AddRef();
+ SocketHandle.AddRef();
#endif
}
- protected override void Dispose(bool disposing)
+ internal unsafe void ReplaceHandle(NativeOverlapped* overlapped)
{
- if (disposing)
- {
- // It is important that the boundHandle is released immediately to allow new overlapped operations.
- if (GlobalLog.IsEnabled)
- {
- GlobalLog.Print("SafeNativeOverlapped#" + LoggingHash.HashString(this) + "::Dispose(true)");
- }
-
- FreeNativeOverlapped();
- }
+ Debug.Assert(handle == IntPtr.Zero, "We should only be replacing the handle when it's already been freed.");
+ SetHandle((IntPtr)overlapped);
}
+ internal SafeCloseSocket SocketHandle { get; }
+
public override bool IsInvalid
{
get { return handle == IntPtr.Zero; }
@@ -75,37 +63,31 @@ protected override bool ReleaseHandle()
}
FreeNativeOverlapped();
+
+#if DEBUG
+ SocketHandle.Release();
+#endif
return true;
}
- private void FreeNativeOverlapped()
+ internal void FreeNativeOverlapped()
{
- IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero);
-
// Do not call free during AppDomain shutdown, there may be an outstanding operation.
// Overlapped will take care calling free when the native callback completes.
+ IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero);
if (oldHandle != IntPtr.Zero && !Environment.HasShutdownStarted)
{
unsafe
{
- Debug.Assert(_safeCloseSocket != null, "m_SafeCloseSocket is null.");
-
- ThreadPoolBoundHandle boundHandle = _safeCloseSocket.IOCPBoundHandle;
- Debug.Assert(boundHandle != null, "SafeNativeOverlapped::ImmediatelyFreeNativeOverlapped - boundHandle is null");
+ Debug.Assert(SocketHandle != null, "SocketHandle is null.");
- if (boundHandle != null)
- {
- // FreeNativeOverlapped will be called even if boundHandle was previously disposed.
- boundHandle.FreeNativeOverlapped((NativeOverlapped*)oldHandle);
- }
+ ThreadPoolBoundHandle boundHandle = SocketHandle.IOCPBoundHandle;
+ Debug.Assert(boundHandle != null, "SafeNativeOverlapped::FreeNativeOverlapped - boundHandle is null");
-#if DEBUG
- _safeCloseSocket.Release();
-#endif
- _safeCloseSocket = null;
+ // FreeNativeOverlapped will be called even if boundHandle was previously disposed.
+ boundHandle?.FreeNativeOverlapped((NativeOverlapped*)oldHandle);
}
}
- return;
}
}
}
diff --git a/src/System.Net.Sockets/src/Resources/Strings.resx b/src/System.Net.Sockets/src/Resources/Strings.resx
index 7b7ab0350980..fb4340e0f086 100644
--- a/src/System.Net.Sockets/src/Resources/Strings.resx
+++ b/src/System.Net.Sockets/src/Resources/Strings.resx
@@ -291,4 +291,16 @@
This platform does not support disconnecting a Socket. Instead, close the Socket and create a new one.
+
+ Stream does not support reading.
+
+
+ Stream does not support writing.
+
+
+ Can not access a closed Stream.
+
+
+ Positive number required.
+
diff --git a/src/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/System.Net.Sockets/src/System.Net.Sockets.csproj
index 062f4c66de27..22082c581800 100644
--- a/src/System.Net.Sockets/src/System.Net.Sockets.csproj
+++ b/src/System.Net.Sockets/src/System.Net.Sockets.csproj
@@ -75,6 +75,12 @@
+
+ Common\System\IO\StreamHelpers.ArrayPoolCopy.cs
+
+
+ Common\System\IO\StreamHelpers.CopyValidation.cs
+
Common\System\Net\Logging\GlobalLog.cs
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
index dc2f8abb6fac..76a029066f2f 100644
--- a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
@@ -2,7 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.Buffers;
+using System.Diagnostics;
using System.IO;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
@@ -1025,6 +1028,63 @@ public override Task WriteAsync(byte[] buffer, int offset, int size, Cancellatio
this);
}
+ public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
+ {
+ // Validate arguments as would the base CopyToAsync
+ StreamHelpers.ValidateCopyToArgs(this, destination, bufferSize);
+
+ // And bail early if cancellation has already been requested
+ if (cancellationToken.IsCancellationRequested)
+ {
+ return Task.FromCanceled(cancellationToken);
+ }
+
+ // Then do additional checks as ReadAsync would.
+
+ if (_cleanedUp)
+ {
+ throw new ObjectDisposedException(this.GetType().FullName);
+ }
+
+ Socket streamSocket = _streamSocket;
+ if (streamSocket == null)
+ {
+ throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
+ }
+
+ // Do the copy. We get a copy buffer from the shared pool, and we pass both it and the
+ // socket into the copy as part of the event args so as to avoid additional fields in
+ // the async method's state machine.
+ return CopyToAsyncCore(
+ destination,
+ new AwaitableSocketAsyncEventArgs(streamSocket, ArrayPool.Shared.Rent(bufferSize)),
+ cancellationToken);
+ }
+
+ private static async Task CopyToAsyncCore(Stream destination, AwaitableSocketAsyncEventArgs ea, CancellationToken cancellationToken)
+ {
+ try
+ {
+ while (true)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ int bytesRead = await ea.ReceiveAsync();
+ if (bytesRead == 0)
+ {
+ break;
+ }
+
+ await destination.WriteAsync(ea.Buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(ea.Buffer, clearArray: true);
+ ea.Dispose();
+ }
+ }
+
// Flushes data from the stream. This is meaningless for us, so it does nothing.
public override void Flush()
{
@@ -1092,5 +1152,114 @@ internal void DebugMembers()
_streamSocket.DebugMembers();
}
}
+
+ /// A SocketAsyncEventArgs that can be awaited to get the result of an operation.
+ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion
+ {
+ /// Sentinal object used to indicate that the operation has completed prior to OnCompleted being called.
+ private static readonly Action s_completedSentinel = () => { };
+ ///
+ /// null if the operation has not completed, if it has, and another object
+ /// if OnCompleted was called before the operation could complete, in which case it's the delegate to invoke
+ /// when the operation does complete.
+ ///
+ private Action _continuation;
+
+ /// Initializes the event args.
+ /// The associated socket.
+ /// The buffer to use for all operations.
+ public AwaitableSocketAsyncEventArgs(Socket socket, byte[] buffer)
+ {
+ Debug.Assert(socket != null);
+ Debug.Assert(buffer != null && buffer.Length > 0);
+
+ // Store the socket into the base's UserToken. This avoids the need for an extra field, at the expense
+ // of an object=>Socket cast when we need to access it, which is only once per operation.
+ UserToken = socket;
+
+ // Store the buffer for use by all operations with this instance.
+ SetBuffer(buffer, 0, buffer.Length);
+
+ // Hook up the completed event.
+ Completed += delegate
+ {
+ // When the operation completes, see if OnCompleted was already called to hook up a continuation.
+ // If it was, invoke the continuation.
+ Action c = _continuation;
+ if (c != null)
+ {
+ c();
+ }
+ else
+ {
+ // We may be racing with OnCompleted, so check with synchronization, trying to swap in our
+ // completion sentinel. If we lose the race and OnCompleted did hook up a continuation,
+ // invoke it. Otherwise, there's nothing more to be done.
+ Interlocked.CompareExchange(ref _continuation, s_completedSentinel, null)?.Invoke();
+ }
+ };
+ }
+
+ /// Initiates a receive operation on the associated socket.
+ /// This instance.
+ public AwaitableSocketAsyncEventArgs ReceiveAsync()
+ {
+ if (!Socket.ReceiveAsync(this))
+ {
+ _continuation = s_completedSentinel;
+ }
+ return this;
+ }
+
+ /// Gets this instance.
+ public AwaitableSocketAsyncEventArgs GetAwaiter() => this;
+
+ /// Gets whether the operation has already completed.
+ ///
+ /// This is not a generically usable IsCompleted operation that suggests the whole operation has completed.
+ /// Rather, it's specifically used as part of the await pattern, and is only usable to determine whether the
+ /// operation has completed by the time the instance is awaited.
+ ///
+ public bool IsCompleted => _continuation != null;
+
+ /// Same as
+ public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation);
+
+ /// Queues the provided continuation to be executed once the operation has completed.
+ public void OnCompleted(Action continuation)
+ {
+ if (_continuation == s_completedSentinel || Interlocked.CompareExchange(ref _continuation, continuation, null) == s_completedSentinel)
+ {
+ Task.Run(continuation);
+ }
+ }
+
+ /// Gets the result of the completion operation.
+ /// Number of bytes transferred.
+ ///
+ /// Unlike Task's awaiter's GetResult, this does not block until the operation completes: it must only
+ /// be used once the operation has completed. This is handled implicitly by await.
+ ///
+ public int GetResult()
+ {
+ _continuation = null;
+ if (SocketError != SocketError.Success)
+ {
+ ThrowIOSocketException();
+ }
+ return BytesTransferred;
+ }
+
+ /// Gets the associated socket.
+ internal Socket Socket => (Socket)UserToken; // stored in the base's UserToken to avoid an extra field in the object
+
+ /// Throws an IOException wrapping a SocketException using the current .
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private void ThrowIOSocketException()
+ {
+ var se = new SocketException((int)SocketError);
+ throw new IOException(SR.Format(SR.net_io_readfailure, se.Message), se);
+ }
+ }
}
}
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
index 376217c81c7c..0482bd57f945 100644
--- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
@@ -155,9 +155,21 @@ private unsafe void PrepareIOCPOperation()
"). Returned = " + ((IntPtr)overlapped).ToString("x"));
}
}
-
Debug.Assert(overlapped != null, "NativeOverlapped is null.");
- _ptrNativeOverlapped = new SafeNativeOverlapped(_currentSocket.SafeHandle, overlapped);
+
+ // If we already have a SafeNativeOverlapped SafeHandle and it's associated with the same
+ // socket (due to the last operation that used this SocketAsyncEventArgs using the same socket),
+ // then we can reuse the same SafeHandle object. Otherwise, this is either the first operation
+ // or the last operation was with a different socket, so create a new SafeHandle.
+ if (_ptrNativeOverlapped?.SocketHandle == _currentSocket.SafeHandle)
+ {
+ _ptrNativeOverlapped.ReplaceHandle(overlapped);
+ }
+ else
+ {
+ _ptrNativeOverlapped?.Dispose();
+ _ptrNativeOverlapped = new SafeNativeOverlapped(_currentSocket.SafeHandle, overlapped);
+ }
}
private void CompleteIOCPOperation()
@@ -169,12 +181,10 @@ private void CompleteIOCPOperation()
// it is guaranteed that the IOCP operation will be completed in the callback even if Socket.Success was
// returned by the Win32 API.
- // Required to allow another IOCP operation for the same handle.
- if (_ptrNativeOverlapped != null)
- {
- _ptrNativeOverlapped.Dispose();
- _ptrNativeOverlapped = null;
- }
+ // Required to allow another IOCP operation for the same handle. We release the native overlapped
+ // in the safe handle, but keep the safe handle object around so as to be able to reuse it
+ // for other operations.
+ _ptrNativeOverlapped?.FreeNativeOverlapped();
}
private void InnerStartOperationAccept(bool userSuppliedBuffer)
diff --git a/src/System.Net.Sockets/src/netcore50/project.json b/src/System.Net.Sockets/src/netcore50/project.json
index 2c9fb38cba87..2e16c8f5aeba 100644
--- a/src/System.Net.Sockets/src/netcore50/project.json
+++ b/src/System.Net.Sockets/src/netcore50/project.json
@@ -5,6 +5,7 @@
"Microsoft.NETCore.Platforms": "1.0.1",
"Microsoft.NETCore.Targets": "1.0.1",
"Microsoft.TargetingPack.Private.WinRT": "1.0.1",
+ "System.Buffers": "4.0.0",
"System.Collections": "4.0.0",
"System.Diagnostics.Debug": "4.0.10",
"System.Diagnostics.Tracing": "4.0.20",
diff --git a/src/System.Net.Sockets/src/project.json b/src/System.Net.Sockets/src/project.json
index 8732e9be4e01..5d7dd1829dcb 100644
--- a/src/System.Net.Sockets/src/project.json
+++ b/src/System.Net.Sockets/src/project.json
@@ -2,6 +2,7 @@
"frameworks": {
"netstandard1.7": {
"dependencies": {
+ "System.Buffers": "4.4.0-beta-24615-03",
"System.Collections": "4.4.0-beta-24615-03",
"System.Diagnostics.Debug": "4.4.0-beta-24615-03",
"System.Diagnostics.Tracing": "4.4.0-beta-24615-03",
diff --git a/src/System.Net.Sockets/src/win/project.json b/src/System.Net.Sockets/src/win/project.json
index ab0790970fde..7cc3b92aa205 100644
--- a/src/System.Net.Sockets/src/win/project.json
+++ b/src/System.Net.Sockets/src/win/project.json
@@ -2,6 +2,7 @@
"frameworks": {
"netstandard1.7": {
"dependencies": {
+ "System.Buffers": "4.4.0-beta-24615-03",
"System.Collections": "4.4.0-beta-24615-03",
"System.Diagnostics.Debug": "4.4.0-beta-24615-03",
"System.Diagnostics.Tracing": "4.4.0-beta-24615-03",
diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs
index 63806ac461fd..95a42c503656 100644
--- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs
+++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
@@ -45,6 +46,53 @@ await RunWithConnectedNetworkStreamsAsync(async (server, client) =>
new object[] { new CancellationTokenSource().Token } // CanBeCanceled == true
};
+ [Theory]
+ [InlineData(0)]
+ [InlineData(1)]
+ [InlineData(1024)]
+ [InlineData(4096)]
+ [InlineData(4095)]
+ [InlineData(1024*1024)]
+ public async Task CopyToAsync_AllDataCopied(int byteCount)
+ {
+ await RunWithConnectedNetworkStreamsAsync(async (server, client) =>
+ {
+ var results = new MemoryStream();
+ byte[] dataToCopy = new byte[byteCount];
+ new Random().NextBytes(dataToCopy);
+
+ Task copyTask = client.CopyToAsync(results);
+ await server.WriteAsync(dataToCopy, 0, dataToCopy.Length);
+ server.Dispose();
+ await copyTask;
+
+ Assert.Equal(dataToCopy, results.ToArray());
+ });
+ }
+
+ [Fact]
+ public async Task CopyToAsync_InvalidArguments_Throws()
+ {
+ await RunWithConnectedNetworkStreamsAsync((stream, _) =>
+ {
+ // Null destination
+ Assert.Throws("destination", () => { stream.CopyToAsync(null); });
+
+ // Buffer size out-of-range
+ Assert.Throws("bufferSize", () => { stream.CopyToAsync(new MemoryStream(), 0); });
+ Assert.Throws("bufferSize", () => { stream.CopyToAsync(new MemoryStream(), -1, CancellationToken.None); });
+
+ // Copying to non-writable stream
+ Assert.Throws(() => { stream.CopyToAsync(new MemoryStream(new byte[0], writable: false)); });
+
+ // Copying after disposing the stream
+ stream.Dispose();
+ Assert.Throws(() => { stream.CopyToAsync(new MemoryStream()); });
+
+ return Task.CompletedTask;
+ });
+ }
+
///
/// Creates a pair of connected NetworkStreams and invokes the provided
/// with them as arguments.
diff --git a/src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs
new file mode 100644
index 000000000000..cb423377b2ab
--- /dev/null
+++ b/src/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs
@@ -0,0 +1,57 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.Sockets.Tests
+{
+ public class SocketAsyncEventArgsTest
+ {
+ [Fact]
+ public async Task ReuseSocketAsyncEventArgs_SameInstance_MultipleSockets()
+ {
+ using (var listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ listen.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+ listen.Listen(1);
+
+ Task acceptTask = listen.AcceptAsync();
+ await Task.WhenAll(
+ acceptTask,
+ client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listen.LocalEndPoint).Port)));
+
+ using (Socket server = await acceptTask)
+ {
+ TaskCompletionSource tcs = null;
+
+ var args = new SocketAsyncEventArgs();
+ args.SetBuffer(new byte[1024], 0, 1024);
+ args.Completed += (_,__) => tcs.SetResult(true);
+
+ for (int i = 1; i <= 10; i++)
+ {
+ tcs = new TaskCompletionSource();
+ args.Buffer[0] = (byte)i;
+ args.SetBuffer(0, 1);
+ if (server.SendAsync(args))
+ {
+ await tcs.Task;
+ }
+
+ args.Buffer[0] = 0;
+ tcs = new TaskCompletionSource();
+ if (client.ReceiveAsync(args))
+ {
+ await tcs.Task;
+ }
+ Assert.Equal(1, args.BytesTransferred);
+ Assert.Equal(i, args.Buffer[0]);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj
index b194c41b9d56..afa9cbeff632 100644
--- a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj
+++ b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj
@@ -39,6 +39,7 @@
+