diff --git a/src/System.Net.Sockets/src/Resources/Strings.resx b/src/System.Net.Sockets/src/Resources/Strings.resx
index 585394e704a6..ae657c6e497e 100644
--- a/src/System.Net.Sockets/src/Resources/Strings.resx
+++ b/src/System.Net.Sockets/src/Resources/Strings.resx
@@ -232,4 +232,7 @@
Positive number required.
+
+ Unable to transfer data on the transport connection: {0}.
+
diff --git a/src/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/System.Net.Sockets/src/System.Net.Sockets.csproj
index f7d1d639fc4c..1c86441ef071 100644
--- a/src/System.Net.Sockets/src/System.Net.Sockets.csproj
+++ b/src/System.Net.Sockets/src/System.Net.Sockets.csproj
@@ -40,6 +40,7 @@
+
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
index 971d83ff9620..df6e757adcfb 100644
--- a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs
@@ -16,7 +16,10 @@ namespace System.Net.Sockets
public class NetworkStream : Stream
{
// Used by the class to hold the underlying socket the stream uses.
- private Socket _streamSocket;
+ private readonly Socket _streamSocket;
+
+ // Whether the stream should dispose of the socket when the stream is disposed
+ private readonly bool _ownsSocket;
// Used by the class to indicate that the stream is m_Readable.
private bool _readable;
@@ -24,8 +27,6 @@ public class NetworkStream : Stream
// Used by the class to indicate that the stream is writable.
private bool _writeable;
- private bool _ownsSocket;
-
// Creates a new instance of the System.Net.Sockets.NetworkStream class for the specified System.Net.Sockets.Socket.
public NetworkStream(Socket socket)
: this(socket, FileAccess.ReadWrite, ownsSocket: false)
@@ -664,24 +665,48 @@ public void EndWrite(IAsyncResult asyncResult)
// A Task representing the read.
public override Task ReadAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken)
{
-#if netcore50
+ bool canRead = CanRead; // Prevent race with Dispose.
+ if (_cleanedUp)
+ {
+ throw new ObjectDisposedException(this.GetType().FullName);
+ }
+ if (!canRead)
+ {
+ throw new InvalidOperationException(SR.net_writeonlystream);
+ }
+
+ // Validate input parameters.
+ if (buffer == null)
+ {
+ throw new ArgumentNullException(nameof(buffer));
+ }
+ if (offset < 0 || offset > buffer.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(offset));
+ }
+ if (size < 0 || size > buffer.Length - offset)
+ {
+ throw new ArgumentOutOfRangeException(nameof(size));
+ }
+
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled(cancellationToken);
}
- return Task.Factory.FromAsync(
- (bufferArg, offsetArg, sizeArg, callback, state) => ((NetworkStream)state).BeginRead(bufferArg, offsetArg, sizeArg, callback, state),
- iar => ((NetworkStream)iar.AsyncState).EndRead(iar),
- buffer,
- offset,
- size,
- this);
-#else
- // Use optimized Stream.ReadAsync that's more efficient than
- // Task.Factory.FromAsync when NetworkStream overrides Begin/EndRead.
- return base.ReadAsync(buffer, offset, size, cancellationToken);
-#endif
+ try
+ {
+ return _streamSocket.ReceiveAsync(
+ new ArraySegment(buffer, offset, size),
+ SocketFlags.None,
+ wrapExceptionsInIOExceptions: true);
+ }
+ catch (Exception exception) when (!(exception is OutOfMemoryException))
+ {
+ // Some sort of error occurred on the socket call,
+ // set the SocketException as InnerException and throw.
+ throw new IOException(SR.Format(SR.net_io_readfailure, exception.Message), exception);
+ }
}
// WriteAsync - provide async write functionality.
@@ -701,24 +726,48 @@ public override Task ReadAsync(byte[] buffer, int offset, int size, Cancell
// A Task representing the write.
public override Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken)
{
-#if netcore50
+ bool canWrite = CanWrite; // Prevent race with Dispose.
+ if (_cleanedUp)
+ {
+ throw new ObjectDisposedException(this.GetType().FullName);
+ }
+ if (!canWrite)
+ {
+ throw new InvalidOperationException(SR.net_readonlystream);
+ }
+
+ // Validate input parameters.
+ if (buffer == null)
+ {
+ throw new ArgumentNullException(nameof(buffer));
+ }
+ if (offset < 0 || offset > buffer.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(offset));
+ }
+ if (size < 0 || size > buffer.Length - offset)
+ {
+ throw new ArgumentOutOfRangeException(nameof(size));
+ }
+
if (cancellationToken.IsCancellationRequested)
{
- return Task.FromCanceled(cancellationToken);
+ return Task.FromCanceled(cancellationToken);
}
- return Task.Factory.FromAsync(
- (bufferArg, offsetArg, sizeArg, callback, state) => ((NetworkStream)state).BeginWrite(bufferArg, offsetArg, sizeArg, callback, state),
- iar => ((NetworkStream)iar.AsyncState).EndWrite(iar),
- buffer,
- offset,
- size,
- this);
-#else
- // Use optimized Stream.WriteAsync that's more efficient than
- // Task.Factory.FromAsync when NetworkStream overrides Begin/EndWrite.
- return base.WriteAsync(buffer, offset, size, cancellationToken);
-#endif
+ try
+ {
+ return _streamSocket.SendAsync(
+ new ArraySegment(buffer, offset, size),
+ SocketFlags.None,
+ wrapExceptionsInIOExceptions: true);
+ }
+ catch (Exception exception) when (!(exception is OutOfMemoryException))
+ {
+ // Some sort of error occurred on the socket call,
+ // set the SocketException as InnerException and throw.
+ throw new IOException(SR.Format(SR.net_io_writefailure, exception.Message), exception);
+ }
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
new file mode 100644
index 000000000000..ff8a7cd1d112
--- /dev/null
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
@@ -0,0 +1,574 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Net.Sockets
+{
+ // The Task-based APIs are currently wrappers over either the APM APIs (e.g. BeginConnect)
+ // or the SocketAsyncEventArgs APIs (e.g. ReceiveAsync(SocketAsyncEventArgs)). The latter
+ // are much more efficient when the SocketAsyncEventArg instances can be reused; as such,
+ // at present we use them for ReceiveAsync and Send{To}Async, caching an instance for each.
+ // In the future we could potentially maintain a global cache of instances used for accepts
+ // and connects, and potentially separate per-socket instances for Receive{Message}FromAsync,
+ // which would need different instances from ReceiveAsync due to having different results
+ // and thus different Completed logic. We also currently fall back to APM implementations
+ // when the single cached instance for each of send/receive is otherwise in use; we could
+ // potentially also employ a global pool from which to pull in such situations.
+
+ public partial class Socket
+ {
+ ///
+ /// Sentinel that can be stored into one of the cached fields to indicate that an instance
+ /// was previously created but is currently being used by another concurrent operation.
+ ///
+ private static readonly Int32TaskSocketAsyncEventArgs s_rentedSentinel = new Int32TaskSocketAsyncEventArgs();
+ /// Cached SocketAsyncEventArgs for Task-based ReceiveAsync APIs.
+ private Int32TaskSocketAsyncEventArgs _cachedReceiveEventArgs;
+ /// Cached SocketAsyncEventArgs for Task-based SendAsync APIs.
+ private Int32TaskSocketAsyncEventArgs _cachedSendEventArgs;
+
+ internal Task AcceptAsync() => AcceptAsync((Socket)null);
+
+ internal Task AcceptAsync(Socket acceptSocket)
+ {
+ var tcs = new TaskCompletionSource(this);
+ BeginAccept(acceptSocket, 0, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndAccept(iar)); }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task ConnectAsync(EndPoint remoteEP)
+ {
+ var tcs = new TaskCompletionSource(this);
+ BeginConnect(remoteEP, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try
+ {
+ ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
+ innerTcs.TrySetResult(true);
+ }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task ConnectAsync(IPAddress address, int port)
+ {
+ var tcs = new TaskCompletionSource(this);
+ BeginConnect(address, port, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try
+ {
+ ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
+ innerTcs.TrySetResult(true);
+ }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task ConnectAsync(IPAddress[] addresses, int port)
+ {
+ var tcs = new TaskCompletionSource(this);
+ BeginConnect(addresses, port, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try
+ {
+ ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
+ innerTcs.TrySetResult(true);
+ }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task ConnectAsync(string host, int port)
+ {
+ var tcs = new TaskCompletionSource(this);
+ BeginConnect(host, port, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try
+ {
+ ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
+ innerTcs.TrySetResult(true);
+ }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags, bool wrapExceptionsInIOExceptions)
+ {
+ // Validate the arguments.
+ ValidateBuffer(buffer);
+
+ // Get the SocketAsyncEventArgs to use for the operation.
+ Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: true);
+ if (saea == null)
+ {
+ // We couldn't get a cached instance, which means there's already a receive operation
+ // happening on this socket. Fall back to wrapping APM.
+ var tcs = new TaskCompletionSource(this);
+ BeginReceive(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ // Configure the buffer. We don't clear the buffers when returning the SAEA to the pool,
+ // so as to minimize overhead if the same buffer is used for subsequent operations (which is likely).
+ // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer list
+ // if there is one before we set the desired buffer.
+ if (saea.BufferList != null) saea.BufferList = null;
+ saea.SetBuffer(buffer.Array, buffer.Offset, buffer.Count);
+ saea.SocketFlags = socketFlags;
+ saea.WrapExceptionsInIOExceptions = wrapExceptionsInIOExceptions;
+
+ // Initiate the receive
+ Task t;
+ if (!ReceiveAsync(saea))
+ {
+ // The operation completed synchronously. Get a task for it and return the SAEA for future use.
+ t = saea.SocketError == SocketError.Success ?
+ GetSuccessTask(saea) :
+ Task.FromException(GetException(saea.SocketError, wrapExceptionsInIOExceptions));
+ ReturnSocketAsyncEventArgs(saea, isReceive: true);
+ }
+ else
+ {
+ // The operation completed asynchronously. Get the task for the operation,
+ // with appropriate synchronization to coordinate with the async callback
+ // that'll be completing the task.
+ t = saea.GetTaskSafe();
+ }
+ return t;
+ }
+
+ internal Task ReceiveAsync(IList> buffers, SocketFlags socketFlags)
+ {
+ // Validate the arguments.
+ ValidateBuffersList(buffers);
+
+ // Get the SocketAsyncEventArgs instance to use for the operation.
+ Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: true);
+ if (saea == null)
+ {
+ // We couldn't get a cached instance, which means there's already a receive operation
+ // happening on this socket. Fall back to wrapping APM.
+ var tcs = new TaskCompletionSource(this);
+ BeginReceive(buffers, socketFlags, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ // Configure the buffer list. We don't clear the buffers when returning the SAEA to the pool,
+ // so as to minimize overhead if the same buffers are used for subsequent operations (which is likely).
+ // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer
+ // if there is one before we set the desired buffer list.
+ if (saea.Buffer != null) saea.SetBuffer(null, 0, 0);
+ saea.BufferList = buffers;
+ saea.SocketFlags = socketFlags;
+
+ // Initiate the receive
+ Task t;
+ if (!ReceiveAsync(saea))
+ {
+ // The operation completed synchronously. Get a task for it and return the SAEA for future use.
+ t = saea.SocketError == SocketError.Success ?
+ GetSuccessTask(saea) :
+ Task.FromException(new SocketException((int)saea.SocketError));
+ ReturnSocketAsyncEventArgs(saea, isReceive: true);
+ }
+ else
+ {
+ // The operation completed asynchronously. Get the task for the operation,
+ // with appropriate synchronization to coordinate with the async callback
+ // that'll be completing the task.
+ t = saea.GetTaskSafe();
+ }
+ return t;
+ }
+
+ internal Task ReceiveFromAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint)
+ {
+ var tcs = new StateTaskCompletionSource(this) { _field1 = remoteEndPoint };
+ BeginReceiveFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field1, iar =>
+ {
+ var innerTcs = (StateTaskCompletionSource)iar.AsyncState;
+ try
+ {
+ int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveFrom(iar, ref innerTcs._field1);
+ innerTcs.TrySetResult(new SocketReceiveFromResult
+ {
+ ReceivedBytes = receivedBytes,
+ RemoteEndPoint = innerTcs._field1
+ });
+ }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task ReceiveMessageFromAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint)
+ {
+ var tcs = new StateTaskCompletionSource(this) { _field1 = socketFlags, _field2 = remoteEndPoint };
+ BeginReceiveMessageFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field2, iar =>
+ {
+ var innerTcs = (StateTaskCompletionSource)iar.AsyncState;
+ try
+ {
+ IPPacketInformation ipPacketInformation;
+ int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveMessageFrom(iar, ref innerTcs._field1, ref innerTcs._field2, out ipPacketInformation);
+ innerTcs.TrySetResult(new SocketReceiveMessageFromResult
+ {
+ ReceivedBytes = receivedBytes,
+ RemoteEndPoint = innerTcs._field2,
+ SocketFlags = innerTcs._field1,
+ PacketInformation = ipPacketInformation
+ });
+ }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ internal Task SendAsync(ArraySegment buffer, SocketFlags socketFlags, bool wrapExceptionsInIOExceptions)
+ {
+ // Validate the arguments.
+ ValidateBuffer(buffer);
+
+ // Get the SocketAsyncEventArgs instance to use for the operation.
+ Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: false);
+ if (saea == null)
+ {
+ // We couldn't get a cached instance, which means there's already a receive operation
+ // happening on this socket. Fall back to wrapping APM.
+ var tcs = new TaskCompletionSource(this);
+ BeginSend(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ // Configure the buffer. We don't clear the buffers when returning the SAEA to the pool,
+ // so as to minimize overhead if the same buffer is used for subsequent operations (which is likely).
+ // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer list
+ // if there is one before we set the desired buffer.
+ if (saea.BufferList != null) saea.BufferList = null;
+ saea.SetBuffer(buffer.Array, buffer.Offset, buffer.Count);
+ saea.SocketFlags = socketFlags;
+ saea.WrapExceptionsInIOExceptions = wrapExceptionsInIOExceptions;
+
+ // Initiate the send
+ Task t;
+ if (!SendAsync(saea))
+ {
+ // The operation completed synchronously. Get a task for it and return the SAEA for future use.
+ t = saea.SocketError == SocketError.Success ?
+ GetSuccessTask(saea) :
+ Task.FromException(GetException(saea.SocketError, wrapExceptionsInIOExceptions));
+ ReturnSocketAsyncEventArgs(saea, isReceive: false);
+ }
+ else
+ {
+ // The operation completed asynchronously. Get the task for the operation,
+ // with appropriate synchronization to coordinate with the async callback
+ // that'll be completing the task.
+ t = saea.GetTaskSafe();
+ }
+ return t;
+ }
+
+ internal Task SendAsync(IList> buffers, SocketFlags socketFlags)
+ {
+ // Validate the arguments.
+ ValidateBuffersList(buffers);
+
+ // Get the SocketAsyncEventArgs instance to use for the operation.
+ Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: false);
+ if (saea == null)
+ {
+ // We couldn't get a cached instance, which means there's already a receive operation
+ // happening on this socket. Fall back to wrapping APM.
+ var tcs = new TaskCompletionSource(this);
+ BeginSend(buffers, socketFlags, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ // Configure the buffer list. We don't clear the buffers when returning the SAEA to the pool,
+ // so as to minimize overhead if the same buffers are used for subsequent operations (which is likely).
+ // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer
+ // if there is one before we set the desired buffer list.
+ if (saea.Buffer != null) saea.SetBuffer(null, 0, 0);
+ saea.BufferList = buffers;
+ saea.SocketFlags = socketFlags;
+
+ // Initiate the send
+ Task t;
+ if (!SendAsync(saea))
+ {
+ // The operation completed synchronously. Get a task for it and return the SAEA for future use.
+ t = saea.SocketError == SocketError.Success ?
+ GetSuccessTask(saea) :
+ Task.FromException(new SocketException((int)saea.SocketError));
+ ReturnSocketAsyncEventArgs(saea, isReceive: false);
+ }
+ else
+ {
+ // The operation completed asynchronously. Get the task for the operation,
+ // with appropriate synchronization to coordinate with the async callback
+ // that'll be completing the task.
+ t = saea.GetTaskSafe();
+ }
+ return t;
+ }
+
+ internal Task SendToAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP)
+ {
+ var tcs = new TaskCompletionSource(this);
+ BeginSendTo(buffer.Array, buffer.Offset, buffer.Count, socketFlags, remoteEP, iar =>
+ {
+ var innerTcs = (TaskCompletionSource)iar.AsyncState;
+ try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSendTo(iar)); }
+ catch (Exception e) { innerTcs.TrySetException(e); }
+ }, tcs);
+ return tcs.Task;
+ }
+
+ /// Validates the supplied array segment, throwing if its array or indices are null or out-of-bounds, respectively.
+ private static void ValidateBuffer(ArraySegment buffer)
+ {
+ if (buffer.Array == null)
+ {
+ throw new ArgumentNullException(nameof(buffer.Array));
+ }
+ if (buffer.Offset < 0 || buffer.Offset > buffer.Array.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(buffer.Offset));
+ }
+ if (buffer.Count < 0 || buffer.Count > buffer.Array.Length - buffer.Offset)
+ {
+ throw new ArgumentOutOfRangeException(nameof(buffer.Count));
+ }
+ }
+
+ /// Validates the supplied buffer list, throwing if it's null or empty.
+ private static void ValidateBuffersList(IList> buffers)
+ {
+ if (buffers == null)
+ {
+ throw new ArgumentNullException(nameof(buffers));
+ }
+ if (buffers.Count == 0)
+ {
+ throw new ArgumentException(SR.Format(SR.net_sockets_zerolist, nameof(buffers)), nameof(buffers));
+ }
+ }
+
+ /// Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.
+ private static void CompleteSendReceive(Int32TaskSocketAsyncEventArgs saea, bool isReceive)
+ {
+ // Synchronize with the initiating thread accessing the task from the builder.
+ saea.GetTaskSafe();
+
+ // Pull the relevant state off of the SAEA and only then return it to the pool.
+ Socket s = (Socket)saea.UserToken;
+ AsyncTaskMethodBuilder builder = saea.Builder;
+ SocketError error = saea.SocketError;
+ int bytesTransferred = saea.BytesTransferred;
+ bool wrapExceptionsInIOExceptions = saea.WrapExceptionsInIOExceptions;
+
+ s.ReturnSocketAsyncEventArgs(saea, isReceive);
+
+ // Complete the builder/task with the results.
+ if (error == SocketError.Success)
+ {
+ builder.SetResult(bytesTransferred);
+ }
+ else
+ {
+ builder.SetException(GetException(error, wrapExceptionsInIOExceptions));
+ }
+ }
+
+ /// Gets a that represents the BytesTransferred from a successful send/receive.
+ private static Task GetSuccessTask(Int32TaskSocketAsyncEventArgs saea)
+ {
+ // Get the number of bytes successfully received/sent.
+ int bytesTransferred = saea.BytesTransferred;
+
+ // And get any cached, successfully-completed cached task that may exist on this SAEA.
+ Task lastTask = saea.SuccessfullyCompletedTask;
+ Debug.Assert(lastTask == null || lastTask.Status == TaskStatus.RanToCompletion);
+
+ // If there is a task and if it has the desired result, simply reuse it.
+ // Otherwise, create a new one for this result value, and in addition to returning it,
+ // also store it into the SAEA for potential future reuse.
+ return lastTask != null && lastTask.Result == bytesTransferred ?
+ lastTask :
+ (saea.SuccessfullyCompletedTask = Task.FromResult(bytesTransferred));
+ }
+
+ /// Gets a SocketException or an IOException wrapping a SocketException for the specified error.
+ private static Exception GetException(SocketError error, bool wrapExceptionsInIOExceptions = false)
+ {
+ Exception e = new SocketException((int)error);
+ return wrapExceptionsInIOExceptions ?
+ new IOException(SR.Format(SR.net_io_readwritefailure, e.Message), e) :
+ e;
+ }
+
+ /// Rents a for immediate use.
+ /// true if this instance will be used for a receive; false if for sends.
+ private Int32TaskSocketAsyncEventArgs RentSocketAsyncEventArgs(bool isReceive)
+ {
+ // Get any cached SocketAsyncEventArg we may have.
+ Int32TaskSocketAsyncEventArgs saea = isReceive ?
+ Interlocked.Exchange(ref _cachedReceiveEventArgs, s_rentedSentinel) :
+ Interlocked.Exchange(ref _cachedSendEventArgs, s_rentedSentinel);
+
+ if (saea == s_rentedSentinel)
+ {
+ // An instance was once created (or is currently being created elsewhere), but some other
+ // concurrent operation is using it. Since we can store at most one, and since an individual
+ // APM operation is less expensive than creating a new SAEA and using it only once, we simply
+ // return null, for a caller to fall back to using an APM implementation.
+ return null;
+ }
+
+ if (saea == null)
+ {
+ // No instance has been created yet, so create one.
+ saea = new Int32TaskSocketAsyncEventArgs();
+ var handler = isReceive ? // branch to avoid capturing isReceive on every call
+ new EventHandler((_, e) => CompleteSendReceive((Int32TaskSocketAsyncEventArgs)e, isReceive: true)) :
+ new EventHandler((_, e) => CompleteSendReceive((Int32TaskSocketAsyncEventArgs)e, isReceive: false));
+ saea.Completed += handler;
+ }
+
+ // We got an instance. Configure and return it.
+ saea.UserToken = this;
+ return saea;
+ }
+
+ /// Returns a instance for reuse.
+ /// The instance to return.
+ /// true if this instance is used for receives; false if used for sends.
+ private void ReturnSocketAsyncEventArgs(Int32TaskSocketAsyncEventArgs saea, bool isReceive)
+ {
+ Debug.Assert(saea != s_rentedSentinel);
+
+ // Reset state on the SAEA before returning it. But do not reset buffer state. That'll be done
+ // if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse
+ // and the costs associated with changing them.
+ saea.UserToken = null;
+ saea.Builder = default(AsyncTaskMethodBuilder);
+ saea.WrapExceptionsInIOExceptions = false;
+
+ // Write this instance back as a cached instance. It should only ever be overwriting the sentinel,
+ // never null or another instance.
+ if (isReceive)
+ {
+ Debug.Assert(_cachedReceiveEventArgs == s_rentedSentinel);
+ Volatile.Write(ref _cachedReceiveEventArgs, saea);
+ }
+ else
+ {
+ Debug.Assert(_cachedSendEventArgs == s_rentedSentinel);
+ Volatile.Write(ref _cachedSendEventArgs, saea);
+ }
+ }
+
+ /// Dispose of any cached instances.
+ private void DisposeCachedTaskSocketAsyncEventArgs()
+ {
+ Int32TaskSocketAsyncEventArgs e = Interlocked.Exchange(ref _cachedReceiveEventArgs, s_rentedSentinel);
+ if (e != s_rentedSentinel) e?.Dispose();
+
+ e = Interlocked.Exchange(ref _cachedSendEventArgs, s_rentedSentinel);
+ if (e != s_rentedSentinel) e?.Dispose();
+ }
+
+ /// A TaskCompletionSource that carries an extra field of strongly-typed state.
+ private class StateTaskCompletionSource : TaskCompletionSource
+ {
+ internal TField1 _field1;
+ public StateTaskCompletionSource(object baseState) : base(baseState) { }
+ }
+
+ /// A TaskCompletionSource that carries several extra fields of strongly-typed state.
+ private class StateTaskCompletionSource : StateTaskCompletionSource
+ {
+ internal TField2 _field2;
+ public StateTaskCompletionSource(object baseState) : base(baseState) { }
+ }
+
+ /// A SocketAsyncEventArgs with an associated async method builder.
+ internal sealed class Int32TaskSocketAsyncEventArgs : SocketAsyncEventArgs
+ {
+ /// A cached, successfully completed task.
+ internal Task SuccessfullyCompletedTask;
+ ///
+ /// The builder used to create the Task representing the result of the async operation.
+ /// This is a mutable struct.
+ ///
+ internal AsyncTaskMethodBuilder Builder;
+ /// Whether exceptions that emerge should be wrapped in IOExceptions.
+ internal bool WrapExceptionsInIOExceptions;
+ ///
+ /// The lock used to protect initialization fo the Builder's Task. AsyncTaskMethodBuilder
+ /// expects a particular access pattern as generated by the language compiler, such that
+ /// its Task property is always accessed in a serialized manner and no synchronization is
+ /// needed. As such, since in our pattern here the initiater of the async operation may race
+ /// with asynchronous completion to access the Task, we need to synchronize on its initial
+ /// access so that the same Task is published/accessed by both sides.
+ ///
+ private SpinLock BuilderTaskLock = new SpinLock(enableThreadOwnerTracking: false);
+
+ /// Gets the builder's task with appropriate synchronization.
+ internal Task GetTaskSafe()
+ {
+ bool lockTaken = false;
+ try
+ {
+ BuilderTaskLock.Enter(ref lockTaken);
+ return Builder.Task;
+ }
+ finally
+ {
+ if (lockTaken) BuilderTaskLock.Exit(useMemoryBarrier: false);
+ }
+ }
+ }
+ }
+}
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
index 7f177caca5da..76368800e62a 100644
--- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
@@ -4879,6 +4879,9 @@ protected virtual void Dispose(bool disposing)
{
NetEventSource.Fail(this, $"handle:{_handle}, Closing the handle threw ObjectDisposedException.");
}
+
+ // Clean up any cached data
+ DisposeCachedTaskSocketAsyncEventArgs();
}
public void Dispose()
diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs
index 52056df61159..f7ecaab70a91 100644
--- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs
+++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs
@@ -9,226 +9,34 @@ namespace System.Net.Sockets
{
public static class SocketTaskExtensions
{
- public static Task AcceptAsync(this Socket socket)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginAccept(iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndAccept(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task AcceptAsync(this Socket socket, Socket acceptSocket)
- {
- const int ReceiveSize = 0;
- var tcs = new TaskCompletionSource(socket);
- socket.BeginAccept(acceptSocket, ReceiveSize, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndAccept(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ConnectAsync(this Socket socket, EndPoint remoteEP)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginConnect(remoteEP, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try
- {
- ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
- innerTcs.TrySetResult(true);
- }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ConnectAsync(this Socket socket, IPAddress address, int port)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginConnect(address, port, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try
- {
- ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
- innerTcs.TrySetResult(true);
- }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginConnect(addresses, port, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try
- {
- ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
- innerTcs.TrySetResult(true);
- }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ConnectAsync(this Socket socket, string host, int port)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginConnect(host, port, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try
- {
- ((Socket)innerTcs.Task.AsyncState).EndConnect(iar);
- innerTcs.TrySetResult(true);
- }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginReceive(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ReceiveAsync(
- this Socket socket,
- IList> buffers,
- SocketFlags socketFlags)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginReceive(buffers, socketFlags, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ReceiveFromAsync(
- this Socket socket,
- ArraySegment buffer,
- SocketFlags socketFlags,
- EndPoint remoteEndPoint)
- {
- var tcs = new StateTaskCompletionSource(socket) { _field1 = remoteEndPoint };
- socket.BeginReceiveFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field1, iar =>
- {
- var innerTcs = (StateTaskCompletionSource)iar.AsyncState;
- try
- {
- int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveFrom(iar, ref innerTcs._field1);
- innerTcs.TrySetResult(new SocketReceiveFromResult
- {
- ReceivedBytes = receivedBytes,
- RemoteEndPoint = innerTcs._field1
- });
- }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task ReceiveMessageFromAsync(
- this Socket socket,
- ArraySegment buffer,
- SocketFlags socketFlags,
- EndPoint remoteEndPoint)
- {
- var tcs = new StateTaskCompletionSource(socket) { _field1 = socketFlags, _field2 = remoteEndPoint };
- socket.BeginReceiveMessageFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field2, iar =>
- {
- var innerTcs = (StateTaskCompletionSource)iar.AsyncState;
- try
- {
- IPPacketInformation ipPacketInformation;
- int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveMessageFrom(iar, ref innerTcs._field1, ref innerTcs._field2, out ipPacketInformation);
- innerTcs.TrySetResult(new SocketReceiveMessageFromResult
- {
- ReceivedBytes = receivedBytes,
- RemoteEndPoint = innerTcs._field2,
- SocketFlags = innerTcs._field1,
- PacketInformation = ipPacketInformation
- });
- }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task SendAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginSend(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task SendAsync(
- this Socket socket,
- IList> buffers,
- SocketFlags socketFlags)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginSend(buffers, socketFlags, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- public static Task SendToAsync(
- this Socket socket,
- ArraySegment buffer,
- SocketFlags socketFlags,
- EndPoint remoteEP)
- {
- var tcs = new TaskCompletionSource(socket);
- socket.BeginSendTo(buffer.Array, buffer.Offset, buffer.Count, socketFlags, remoteEP, iar =>
- {
- var innerTcs = (TaskCompletionSource)iar.AsyncState;
- try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSendTo(iar)); }
- catch (Exception e) { innerTcs.TrySetException(e); }
- }, tcs);
- return tcs.Task;
- }
-
- private class StateTaskCompletionSource : TaskCompletionSource
- {
- internal TField1 _field1;
- public StateTaskCompletionSource(object baseState) : base(baseState) { }
- }
-
- private class StateTaskCompletionSource : StateTaskCompletionSource
- {
- internal TField2 _field2;
- public StateTaskCompletionSource(object baseState) : base(baseState) { }
- }
+ public static Task AcceptAsync(this Socket socket) =>
+ socket.AcceptAsync();
+ public static Task AcceptAsync(this Socket socket, Socket acceptSocket) =>
+ socket.AcceptAsync(acceptSocket);
+
+ public static Task ConnectAsync(this Socket socket, EndPoint remoteEP) =>
+ socket.ConnectAsync(remoteEP);
+ public static Task ConnectAsync(this Socket socket, IPAddress address, int port) =>
+ socket.ConnectAsync(address, port);
+ public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port) =>
+ socket.ConnectAsync(addresses, port);
+ public static Task ConnectAsync(this Socket socket, string host, int port) =>
+ socket.ConnectAsync(host, port);
+
+ public static Task ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) =>
+ socket.ReceiveAsync(buffer, socketFlags, wrapExceptionsInIOExceptions: false);
+ public static Task ReceiveAsync(this Socket socket, IList> buffers, SocketFlags socketFlags) =>
+ socket.ReceiveAsync(buffers, socketFlags);
+ public static Task ReceiveFromAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) =>
+ socket.ReceiveFromAsync(buffer, socketFlags, remoteEndPoint);
+ public static Task ReceiveMessageFromAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) =>
+ socket.ReceiveMessageFromAsync(buffer, socketFlags, remoteEndPoint);
+
+ public static Task SendAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) =>
+ socket.SendAsync(buffer, socketFlags, wrapExceptionsInIOExceptions: false);
+ public static Task SendAsync(this Socket socket, IList> buffers, SocketFlags socketFlags) =>
+ socket.SendAsync(buffers, socketFlags);
+ public static Task SendToAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP) =>
+ socket.SendToAsync(buffer, socketFlags, remoteEP);
}
}
diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs
index 918a08cfc8ba..f080b91880e5 100644
--- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs
+++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs
@@ -224,6 +224,9 @@ public async Task Ctor_SocketFileAccess_CanReadAndWrite()
Assert.Equal(1, await clientStream.ReadAsync(buffer, 0, 1));
Assert.Equal('a', (char)buffer[0]);
+ Assert.Throws(() => { serverStream.BeginRead(buffer, 0, 1, null, null); });
+ Assert.Throws(() => { clientStream.BeginWrite(buffer, 0, 1, null, null); });
+
Assert.Throws(() => { serverStream.ReadAsync(buffer, 0, 1); });
Assert.Throws(() => { clientStream.WriteAsync(buffer, 0, 1); });
}
@@ -363,6 +366,12 @@ await RunWithConnectedNetworkStreamsAsync((server, _) =>
Assert.Throws(() => server.BeginRead(new byte[1], 0, -1, null, null));
Assert.Throws(() => server.BeginRead(new byte[1], 0, 2, null, null));
+ Assert.Throws(() => { server.ReadAsync(null, 0, 0); });
+ Assert.Throws(() => { server.ReadAsync(new byte[1], -1, 0); });
+ Assert.Throws(() => { server.ReadAsync(new byte[1], 2, 0); });
+ Assert.Throws(() => { server.ReadAsync(new byte[1], 0, -1); });
+ Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 2); });
+
Assert.Throws(() => server.Write(null, 0, 0));
Assert.Throws(() => server.Write(new byte[1], -1, 0));
Assert.Throws(() => server.Write(new byte[1], 2, 0));
@@ -375,6 +384,12 @@ await RunWithConnectedNetworkStreamsAsync((server, _) =>
Assert.Throws(() => server.BeginWrite(new byte[1], 0, -1, null, null));
Assert.Throws(() => server.BeginWrite(new byte[1], 0, 2, null, null));
+ Assert.Throws(() => { server.WriteAsync(null, 0, 0); });
+ Assert.Throws(() => { server.WriteAsync(new byte[1], -1, 0); });
+ Assert.Throws(() => { server.WriteAsync(new byte[1], 2, 0); });
+ Assert.Throws(() => { server.WriteAsync(new byte[1], 0, -1); });
+ Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 2); });
+
Assert.Throws(() => server.EndRead(null));
Assert.Throws(() => server.EndWrite(null));