diff --git a/src/System.Net.Sockets/src/Resources/Strings.resx b/src/System.Net.Sockets/src/Resources/Strings.resx index 585394e704a6..ae657c6e497e 100644 --- a/src/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/System.Net.Sockets/src/Resources/Strings.resx @@ -232,4 +232,7 @@ Positive number required. + + Unable to transfer data on the transport connection: {0}. + diff --git a/src/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/System.Net.Sockets/src/System.Net.Sockets.csproj index f7d1d639fc4c..1c86441ef071 100644 --- a/src/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -40,6 +40,7 @@ + diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs index 971d83ff9620..df6e757adcfb 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs @@ -16,7 +16,10 @@ namespace System.Net.Sockets public class NetworkStream : Stream { // Used by the class to hold the underlying socket the stream uses. - private Socket _streamSocket; + private readonly Socket _streamSocket; + + // Whether the stream should dispose of the socket when the stream is disposed + private readonly bool _ownsSocket; // Used by the class to indicate that the stream is m_Readable. private bool _readable; @@ -24,8 +27,6 @@ public class NetworkStream : Stream // Used by the class to indicate that the stream is writable. private bool _writeable; - private bool _ownsSocket; - // Creates a new instance of the System.Net.Sockets.NetworkStream class for the specified System.Net.Sockets.Socket. public NetworkStream(Socket socket) : this(socket, FileAccess.ReadWrite, ownsSocket: false) @@ -664,24 +665,48 @@ public void EndWrite(IAsyncResult asyncResult) // A Task representing the read. public override Task ReadAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) { -#if netcore50 + bool canRead = CanRead; // Prevent race with Dispose. + if (_cleanedUp) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + if (!canRead) + { + throw new InvalidOperationException(SR.net_writeonlystream); + } + + // Validate input parameters. + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException(nameof(size)); + } + if (cancellationToken.IsCancellationRequested) { return Task.FromCanceled(cancellationToken); } - return Task.Factory.FromAsync( - (bufferArg, offsetArg, sizeArg, callback, state) => ((NetworkStream)state).BeginRead(bufferArg, offsetArg, sizeArg, callback, state), - iar => ((NetworkStream)iar.AsyncState).EndRead(iar), - buffer, - offset, - size, - this); -#else - // Use optimized Stream.ReadAsync that's more efficient than - // Task.Factory.FromAsync when NetworkStream overrides Begin/EndRead. - return base.ReadAsync(buffer, offset, size, cancellationToken); -#endif + try + { + return _streamSocket.ReceiveAsync( + new ArraySegment(buffer, offset, size), + SocketFlags.None, + wrapExceptionsInIOExceptions: true); + } + catch (Exception exception) when (!(exception is OutOfMemoryException)) + { + // Some sort of error occurred on the socket call, + // set the SocketException as InnerException and throw. + throw new IOException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + } } // WriteAsync - provide async write functionality. @@ -701,24 +726,48 @@ public override Task ReadAsync(byte[] buffer, int offset, int size, Cancell // A Task representing the write. public override Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) { -#if netcore50 + bool canWrite = CanWrite; // Prevent race with Dispose. + if (_cleanedUp) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + if (!canWrite) + { + throw new InvalidOperationException(SR.net_readonlystream); + } + + // Validate input parameters. + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException(nameof(size)); + } + if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return Task.FromCanceled(cancellationToken); } - return Task.Factory.FromAsync( - (bufferArg, offsetArg, sizeArg, callback, state) => ((NetworkStream)state).BeginWrite(bufferArg, offsetArg, sizeArg, callback, state), - iar => ((NetworkStream)iar.AsyncState).EndWrite(iar), - buffer, - offset, - size, - this); -#else - // Use optimized Stream.WriteAsync that's more efficient than - // Task.Factory.FromAsync when NetworkStream overrides Begin/EndWrite. - return base.WriteAsync(buffer, offset, size, cancellationToken); -#endif + try + { + return _streamSocket.SendAsync( + new ArraySegment(buffer, offset, size), + SocketFlags.None, + wrapExceptionsInIOExceptions: true); + } + catch (Exception exception) when (!(exception is OutOfMemoryException)) + { + // Some sort of error occurred on the socket call, + // set the SocketException as InnerException and throw. + throw new IOException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + } } public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs new file mode 100644 index 000000000000..ff8a7cd1d112 --- /dev/null +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -0,0 +1,574 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Sockets +{ + // The Task-based APIs are currently wrappers over either the APM APIs (e.g. BeginConnect) + // or the SocketAsyncEventArgs APIs (e.g. ReceiveAsync(SocketAsyncEventArgs)). The latter + // are much more efficient when the SocketAsyncEventArg instances can be reused; as such, + // at present we use them for ReceiveAsync and Send{To}Async, caching an instance for each. + // In the future we could potentially maintain a global cache of instances used for accepts + // and connects, and potentially separate per-socket instances for Receive{Message}FromAsync, + // which would need different instances from ReceiveAsync due to having different results + // and thus different Completed logic. We also currently fall back to APM implementations + // when the single cached instance for each of send/receive is otherwise in use; we could + // potentially also employ a global pool from which to pull in such situations. + + public partial class Socket + { + /// + /// Sentinel that can be stored into one of the cached fields to indicate that an instance + /// was previously created but is currently being used by another concurrent operation. + /// + private static readonly Int32TaskSocketAsyncEventArgs s_rentedSentinel = new Int32TaskSocketAsyncEventArgs(); + /// Cached SocketAsyncEventArgs for Task-based ReceiveAsync APIs. + private Int32TaskSocketAsyncEventArgs _cachedReceiveEventArgs; + /// Cached SocketAsyncEventArgs for Task-based SendAsync APIs. + private Int32TaskSocketAsyncEventArgs _cachedSendEventArgs; + + internal Task AcceptAsync() => AcceptAsync((Socket)null); + + internal Task AcceptAsync(Socket acceptSocket) + { + var tcs = new TaskCompletionSource(this); + BeginAccept(acceptSocket, 0, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndAccept(iar)); } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task ConnectAsync(EndPoint remoteEP) + { + var tcs = new TaskCompletionSource(this); + BeginConnect(remoteEP, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try + { + ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); + innerTcs.TrySetResult(true); + } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task ConnectAsync(IPAddress address, int port) + { + var tcs = new TaskCompletionSource(this); + BeginConnect(address, port, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try + { + ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); + innerTcs.TrySetResult(true); + } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task ConnectAsync(IPAddress[] addresses, int port) + { + var tcs = new TaskCompletionSource(this); + BeginConnect(addresses, port, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try + { + ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); + innerTcs.TrySetResult(true); + } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task ConnectAsync(string host, int port) + { + var tcs = new TaskCompletionSource(this); + BeginConnect(host, port, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try + { + ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); + innerTcs.TrySetResult(true); + } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags, bool wrapExceptionsInIOExceptions) + { + // Validate the arguments. + ValidateBuffer(buffer); + + // Get the SocketAsyncEventArgs to use for the operation. + Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: true); + if (saea == null) + { + // We couldn't get a cached instance, which means there's already a receive operation + // happening on this socket. Fall back to wrapping APM. + var tcs = new TaskCompletionSource(this); + BeginReceive(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + // Configure the buffer. We don't clear the buffers when returning the SAEA to the pool, + // so as to minimize overhead if the same buffer is used for subsequent operations (which is likely). + // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer list + // if there is one before we set the desired buffer. + if (saea.BufferList != null) saea.BufferList = null; + saea.SetBuffer(buffer.Array, buffer.Offset, buffer.Count); + saea.SocketFlags = socketFlags; + saea.WrapExceptionsInIOExceptions = wrapExceptionsInIOExceptions; + + // Initiate the receive + Task t; + if (!ReceiveAsync(saea)) + { + // The operation completed synchronously. Get a task for it and return the SAEA for future use. + t = saea.SocketError == SocketError.Success ? + GetSuccessTask(saea) : + Task.FromException(GetException(saea.SocketError, wrapExceptionsInIOExceptions)); + ReturnSocketAsyncEventArgs(saea, isReceive: true); + } + else + { + // The operation completed asynchronously. Get the task for the operation, + // with appropriate synchronization to coordinate with the async callback + // that'll be completing the task. + t = saea.GetTaskSafe(); + } + return t; + } + + internal Task ReceiveAsync(IList> buffers, SocketFlags socketFlags) + { + // Validate the arguments. + ValidateBuffersList(buffers); + + // Get the SocketAsyncEventArgs instance to use for the operation. + Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: true); + if (saea == null) + { + // We couldn't get a cached instance, which means there's already a receive operation + // happening on this socket. Fall back to wrapping APM. + var tcs = new TaskCompletionSource(this); + BeginReceive(buffers, socketFlags, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + // Configure the buffer list. We don't clear the buffers when returning the SAEA to the pool, + // so as to minimize overhead if the same buffers are used for subsequent operations (which is likely). + // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer + // if there is one before we set the desired buffer list. + if (saea.Buffer != null) saea.SetBuffer(null, 0, 0); + saea.BufferList = buffers; + saea.SocketFlags = socketFlags; + + // Initiate the receive + Task t; + if (!ReceiveAsync(saea)) + { + // The operation completed synchronously. Get a task for it and return the SAEA for future use. + t = saea.SocketError == SocketError.Success ? + GetSuccessTask(saea) : + Task.FromException(new SocketException((int)saea.SocketError)); + ReturnSocketAsyncEventArgs(saea, isReceive: true); + } + else + { + // The operation completed asynchronously. Get the task for the operation, + // with appropriate synchronization to coordinate with the async callback + // that'll be completing the task. + t = saea.GetTaskSafe(); + } + return t; + } + + internal Task ReceiveFromAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) + { + var tcs = new StateTaskCompletionSource(this) { _field1 = remoteEndPoint }; + BeginReceiveFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field1, iar => + { + var innerTcs = (StateTaskCompletionSource)iar.AsyncState; + try + { + int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveFrom(iar, ref innerTcs._field1); + innerTcs.TrySetResult(new SocketReceiveFromResult + { + ReceivedBytes = receivedBytes, + RemoteEndPoint = innerTcs._field1 + }); + } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task ReceiveMessageFromAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) + { + var tcs = new StateTaskCompletionSource(this) { _field1 = socketFlags, _field2 = remoteEndPoint }; + BeginReceiveMessageFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field2, iar => + { + var innerTcs = (StateTaskCompletionSource)iar.AsyncState; + try + { + IPPacketInformation ipPacketInformation; + int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveMessageFrom(iar, ref innerTcs._field1, ref innerTcs._field2, out ipPacketInformation); + innerTcs.TrySetResult(new SocketReceiveMessageFromResult + { + ReceivedBytes = receivedBytes, + RemoteEndPoint = innerTcs._field2, + SocketFlags = innerTcs._field1, + PacketInformation = ipPacketInformation + }); + } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + internal Task SendAsync(ArraySegment buffer, SocketFlags socketFlags, bool wrapExceptionsInIOExceptions) + { + // Validate the arguments. + ValidateBuffer(buffer); + + // Get the SocketAsyncEventArgs instance to use for the operation. + Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: false); + if (saea == null) + { + // We couldn't get a cached instance, which means there's already a receive operation + // happening on this socket. Fall back to wrapping APM. + var tcs = new TaskCompletionSource(this); + BeginSend(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + // Configure the buffer. We don't clear the buffers when returning the SAEA to the pool, + // so as to minimize overhead if the same buffer is used for subsequent operations (which is likely). + // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer list + // if there is one before we set the desired buffer. + if (saea.BufferList != null) saea.BufferList = null; + saea.SetBuffer(buffer.Array, buffer.Offset, buffer.Count); + saea.SocketFlags = socketFlags; + saea.WrapExceptionsInIOExceptions = wrapExceptionsInIOExceptions; + + // Initiate the send + Task t; + if (!SendAsync(saea)) + { + // The operation completed synchronously. Get a task for it and return the SAEA for future use. + t = saea.SocketError == SocketError.Success ? + GetSuccessTask(saea) : + Task.FromException(GetException(saea.SocketError, wrapExceptionsInIOExceptions)); + ReturnSocketAsyncEventArgs(saea, isReceive: false); + } + else + { + // The operation completed asynchronously. Get the task for the operation, + // with appropriate synchronization to coordinate with the async callback + // that'll be completing the task. + t = saea.GetTaskSafe(); + } + return t; + } + + internal Task SendAsync(IList> buffers, SocketFlags socketFlags) + { + // Validate the arguments. + ValidateBuffersList(buffers); + + // Get the SocketAsyncEventArgs instance to use for the operation. + Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: false); + if (saea == null) + { + // We couldn't get a cached instance, which means there's already a receive operation + // happening on this socket. Fall back to wrapping APM. + var tcs = new TaskCompletionSource(this); + BeginSend(buffers, socketFlags, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + // Configure the buffer list. We don't clear the buffers when returning the SAEA to the pool, + // so as to minimize overhead if the same buffers are used for subsequent operations (which is likely). + // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer + // if there is one before we set the desired buffer list. + if (saea.Buffer != null) saea.SetBuffer(null, 0, 0); + saea.BufferList = buffers; + saea.SocketFlags = socketFlags; + + // Initiate the send + Task t; + if (!SendAsync(saea)) + { + // The operation completed synchronously. Get a task for it and return the SAEA for future use. + t = saea.SocketError == SocketError.Success ? + GetSuccessTask(saea) : + Task.FromException(new SocketException((int)saea.SocketError)); + ReturnSocketAsyncEventArgs(saea, isReceive: false); + } + else + { + // The operation completed asynchronously. Get the task for the operation, + // with appropriate synchronization to coordinate with the async callback + // that'll be completing the task. + t = saea.GetTaskSafe(); + } + return t; + } + + internal Task SendToAsync(ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP) + { + var tcs = new TaskCompletionSource(this); + BeginSendTo(buffer.Array, buffer.Offset, buffer.Count, socketFlags, remoteEP, iar => + { + var innerTcs = (TaskCompletionSource)iar.AsyncState; + try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSendTo(iar)); } + catch (Exception e) { innerTcs.TrySetException(e); } + }, tcs); + return tcs.Task; + } + + /// Validates the supplied array segment, throwing if its array or indices are null or out-of-bounds, respectively. + private static void ValidateBuffer(ArraySegment buffer) + { + if (buffer.Array == null) + { + throw new ArgumentNullException(nameof(buffer.Array)); + } + if (buffer.Offset < 0 || buffer.Offset > buffer.Array.Length) + { + throw new ArgumentOutOfRangeException(nameof(buffer.Offset)); + } + if (buffer.Count < 0 || buffer.Count > buffer.Array.Length - buffer.Offset) + { + throw new ArgumentOutOfRangeException(nameof(buffer.Count)); + } + } + + /// Validates the supplied buffer list, throwing if it's null or empty. + private static void ValidateBuffersList(IList> buffers) + { + if (buffers == null) + { + throw new ArgumentNullException(nameof(buffers)); + } + if (buffers.Count == 0) + { + throw new ArgumentException(SR.Format(SR.net_sockets_zerolist, nameof(buffers)), nameof(buffers)); + } + } + + /// Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool. + private static void CompleteSendReceive(Int32TaskSocketAsyncEventArgs saea, bool isReceive) + { + // Synchronize with the initiating thread accessing the task from the builder. + saea.GetTaskSafe(); + + // Pull the relevant state off of the SAEA and only then return it to the pool. + Socket s = (Socket)saea.UserToken; + AsyncTaskMethodBuilder builder = saea.Builder; + SocketError error = saea.SocketError; + int bytesTransferred = saea.BytesTransferred; + bool wrapExceptionsInIOExceptions = saea.WrapExceptionsInIOExceptions; + + s.ReturnSocketAsyncEventArgs(saea, isReceive); + + // Complete the builder/task with the results. + if (error == SocketError.Success) + { + builder.SetResult(bytesTransferred); + } + else + { + builder.SetException(GetException(error, wrapExceptionsInIOExceptions)); + } + } + + /// Gets a that represents the BytesTransferred from a successful send/receive. + private static Task GetSuccessTask(Int32TaskSocketAsyncEventArgs saea) + { + // Get the number of bytes successfully received/sent. + int bytesTransferred = saea.BytesTransferred; + + // And get any cached, successfully-completed cached task that may exist on this SAEA. + Task lastTask = saea.SuccessfullyCompletedTask; + Debug.Assert(lastTask == null || lastTask.Status == TaskStatus.RanToCompletion); + + // If there is a task and if it has the desired result, simply reuse it. + // Otherwise, create a new one for this result value, and in addition to returning it, + // also store it into the SAEA for potential future reuse. + return lastTask != null && lastTask.Result == bytesTransferred ? + lastTask : + (saea.SuccessfullyCompletedTask = Task.FromResult(bytesTransferred)); + } + + /// Gets a SocketException or an IOException wrapping a SocketException for the specified error. + private static Exception GetException(SocketError error, bool wrapExceptionsInIOExceptions = false) + { + Exception e = new SocketException((int)error); + return wrapExceptionsInIOExceptions ? + new IOException(SR.Format(SR.net_io_readwritefailure, e.Message), e) : + e; + } + + /// Rents a for immediate use. + /// true if this instance will be used for a receive; false if for sends. + private Int32TaskSocketAsyncEventArgs RentSocketAsyncEventArgs(bool isReceive) + { + // Get any cached SocketAsyncEventArg we may have. + Int32TaskSocketAsyncEventArgs saea = isReceive ? + Interlocked.Exchange(ref _cachedReceiveEventArgs, s_rentedSentinel) : + Interlocked.Exchange(ref _cachedSendEventArgs, s_rentedSentinel); + + if (saea == s_rentedSentinel) + { + // An instance was once created (or is currently being created elsewhere), but some other + // concurrent operation is using it. Since we can store at most one, and since an individual + // APM operation is less expensive than creating a new SAEA and using it only once, we simply + // return null, for a caller to fall back to using an APM implementation. + return null; + } + + if (saea == null) + { + // No instance has been created yet, so create one. + saea = new Int32TaskSocketAsyncEventArgs(); + var handler = isReceive ? // branch to avoid capturing isReceive on every call + new EventHandler((_, e) => CompleteSendReceive((Int32TaskSocketAsyncEventArgs)e, isReceive: true)) : + new EventHandler((_, e) => CompleteSendReceive((Int32TaskSocketAsyncEventArgs)e, isReceive: false)); + saea.Completed += handler; + } + + // We got an instance. Configure and return it. + saea.UserToken = this; + return saea; + } + + /// Returns a instance for reuse. + /// The instance to return. + /// true if this instance is used for receives; false if used for sends. + private void ReturnSocketAsyncEventArgs(Int32TaskSocketAsyncEventArgs saea, bool isReceive) + { + Debug.Assert(saea != s_rentedSentinel); + + // Reset state on the SAEA before returning it. But do not reset buffer state. That'll be done + // if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse + // and the costs associated with changing them. + saea.UserToken = null; + saea.Builder = default(AsyncTaskMethodBuilder); + saea.WrapExceptionsInIOExceptions = false; + + // Write this instance back as a cached instance. It should only ever be overwriting the sentinel, + // never null or another instance. + if (isReceive) + { + Debug.Assert(_cachedReceiveEventArgs == s_rentedSentinel); + Volatile.Write(ref _cachedReceiveEventArgs, saea); + } + else + { + Debug.Assert(_cachedSendEventArgs == s_rentedSentinel); + Volatile.Write(ref _cachedSendEventArgs, saea); + } + } + + /// Dispose of any cached instances. + private void DisposeCachedTaskSocketAsyncEventArgs() + { + Int32TaskSocketAsyncEventArgs e = Interlocked.Exchange(ref _cachedReceiveEventArgs, s_rentedSentinel); + if (e != s_rentedSentinel) e?.Dispose(); + + e = Interlocked.Exchange(ref _cachedSendEventArgs, s_rentedSentinel); + if (e != s_rentedSentinel) e?.Dispose(); + } + + /// A TaskCompletionSource that carries an extra field of strongly-typed state. + private class StateTaskCompletionSource : TaskCompletionSource + { + internal TField1 _field1; + public StateTaskCompletionSource(object baseState) : base(baseState) { } + } + + /// A TaskCompletionSource that carries several extra fields of strongly-typed state. + private class StateTaskCompletionSource : StateTaskCompletionSource + { + internal TField2 _field2; + public StateTaskCompletionSource(object baseState) : base(baseState) { } + } + + /// A SocketAsyncEventArgs with an associated async method builder. + internal sealed class Int32TaskSocketAsyncEventArgs : SocketAsyncEventArgs + { + /// A cached, successfully completed task. + internal Task SuccessfullyCompletedTask; + /// + /// The builder used to create the Task representing the result of the async operation. + /// This is a mutable struct. + /// + internal AsyncTaskMethodBuilder Builder; + /// Whether exceptions that emerge should be wrapped in IOExceptions. + internal bool WrapExceptionsInIOExceptions; + /// + /// The lock used to protect initialization fo the Builder's Task. AsyncTaskMethodBuilder + /// expects a particular access pattern as generated by the language compiler, such that + /// its Task property is always accessed in a serialized manner and no synchronization is + /// needed. As such, since in our pattern here the initiater of the async operation may race + /// with asynchronous completion to access the Task, we need to synchronize on its initial + /// access so that the same Task is published/accessed by both sides. + /// + private SpinLock BuilderTaskLock = new SpinLock(enableThreadOwnerTracking: false); + + /// Gets the builder's task with appropriate synchronization. + internal Task GetTaskSafe() + { + bool lockTaken = false; + try + { + BuilderTaskLock.Enter(ref lockTaken); + return Builder.Task; + } + finally + { + if (lockTaken) BuilderTaskLock.Exit(useMemoryBarrier: false); + } + } + } + } +} diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 7f177caca5da..76368800e62a 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -4879,6 +4879,9 @@ protected virtual void Dispose(bool disposing) { NetEventSource.Fail(this, $"handle:{_handle}, Closing the handle threw ObjectDisposedException."); } + + // Clean up any cached data + DisposeCachedTaskSocketAsyncEventArgs(); } public void Dispose() diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs index 52056df61159..f7ecaab70a91 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs @@ -9,226 +9,34 @@ namespace System.Net.Sockets { public static class SocketTaskExtensions { - public static Task AcceptAsync(this Socket socket) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginAccept(iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndAccept(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task AcceptAsync(this Socket socket, Socket acceptSocket) - { - const int ReceiveSize = 0; - var tcs = new TaskCompletionSource(socket); - socket.BeginAccept(acceptSocket, ReceiveSize, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndAccept(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ConnectAsync(this Socket socket, EndPoint remoteEP) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginConnect(remoteEP, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ConnectAsync(this Socket socket, IPAddress address, int port) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginConnect(address, port, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginConnect(addresses, port, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ConnectAsync(this Socket socket, string host, int port) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginConnect(host, port, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try - { - ((Socket)innerTcs.Task.AsyncState).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginReceive(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ReceiveAsync( - this Socket socket, - IList> buffers, - SocketFlags socketFlags) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginReceive(buffers, socketFlags, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndReceive(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ReceiveFromAsync( - this Socket socket, - ArraySegment buffer, - SocketFlags socketFlags, - EndPoint remoteEndPoint) - { - var tcs = new StateTaskCompletionSource(socket) { _field1 = remoteEndPoint }; - socket.BeginReceiveFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field1, iar => - { - var innerTcs = (StateTaskCompletionSource)iar.AsyncState; - try - { - int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveFrom(iar, ref innerTcs._field1); - innerTcs.TrySetResult(new SocketReceiveFromResult - { - ReceivedBytes = receivedBytes, - RemoteEndPoint = innerTcs._field1 - }); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task ReceiveMessageFromAsync( - this Socket socket, - ArraySegment buffer, - SocketFlags socketFlags, - EndPoint remoteEndPoint) - { - var tcs = new StateTaskCompletionSource(socket) { _field1 = socketFlags, _field2 = remoteEndPoint }; - socket.BeginReceiveMessageFrom(buffer.Array, buffer.Offset, buffer.Count, socketFlags, ref tcs._field2, iar => - { - var innerTcs = (StateTaskCompletionSource)iar.AsyncState; - try - { - IPPacketInformation ipPacketInformation; - int receivedBytes = ((Socket)innerTcs.Task.AsyncState).EndReceiveMessageFrom(iar, ref innerTcs._field1, ref innerTcs._field2, out ipPacketInformation); - innerTcs.TrySetResult(new SocketReceiveMessageFromResult - { - ReceivedBytes = receivedBytes, - RemoteEndPoint = innerTcs._field2, - SocketFlags = innerTcs._field1, - PacketInformation = ipPacketInformation - }); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task SendAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginSend(buffer.Array, buffer.Offset, buffer.Count, socketFlags, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task SendAsync( - this Socket socket, - IList> buffers, - SocketFlags socketFlags) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginSend(buffers, socketFlags, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSend(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - public static Task SendToAsync( - this Socket socket, - ArraySegment buffer, - SocketFlags socketFlags, - EndPoint remoteEP) - { - var tcs = new TaskCompletionSource(socket); - socket.BeginSendTo(buffer.Array, buffer.Offset, buffer.Count, socketFlags, remoteEP, iar => - { - var innerTcs = (TaskCompletionSource)iar.AsyncState; - try { innerTcs.TrySetResult(((Socket)innerTcs.Task.AsyncState).EndSendTo(iar)); } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; - } - - private class StateTaskCompletionSource : TaskCompletionSource - { - internal TField1 _field1; - public StateTaskCompletionSource(object baseState) : base(baseState) { } - } - - private class StateTaskCompletionSource : StateTaskCompletionSource - { - internal TField2 _field2; - public StateTaskCompletionSource(object baseState) : base(baseState) { } - } + public static Task AcceptAsync(this Socket socket) => + socket.AcceptAsync(); + public static Task AcceptAsync(this Socket socket, Socket acceptSocket) => + socket.AcceptAsync(acceptSocket); + + public static Task ConnectAsync(this Socket socket, EndPoint remoteEP) => + socket.ConnectAsync(remoteEP); + public static Task ConnectAsync(this Socket socket, IPAddress address, int port) => + socket.ConnectAsync(address, port); + public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port) => + socket.ConnectAsync(addresses, port); + public static Task ConnectAsync(this Socket socket, string host, int port) => + socket.ConnectAsync(host, port); + + public static Task ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) => + socket.ReceiveAsync(buffer, socketFlags, wrapExceptionsInIOExceptions: false); + public static Task ReceiveAsync(this Socket socket, IList> buffers, SocketFlags socketFlags) => + socket.ReceiveAsync(buffers, socketFlags); + public static Task ReceiveFromAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) => + socket.ReceiveFromAsync(buffer, socketFlags, remoteEndPoint); + public static Task ReceiveMessageFromAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEndPoint) => + socket.ReceiveMessageFromAsync(buffer, socketFlags, remoteEndPoint); + + public static Task SendAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) => + socket.SendAsync(buffer, socketFlags, wrapExceptionsInIOExceptions: false); + public static Task SendAsync(this Socket socket, IList> buffers, SocketFlags socketFlags) => + socket.SendAsync(buffers, socketFlags); + public static Task SendToAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP) => + socket.SendToAsync(buffer, socketFlags, remoteEP); } } diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs index 918a08cfc8ba..f080b91880e5 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs @@ -224,6 +224,9 @@ public async Task Ctor_SocketFileAccess_CanReadAndWrite() Assert.Equal(1, await clientStream.ReadAsync(buffer, 0, 1)); Assert.Equal('a', (char)buffer[0]); + Assert.Throws(() => { serverStream.BeginRead(buffer, 0, 1, null, null); }); + Assert.Throws(() => { clientStream.BeginWrite(buffer, 0, 1, null, null); }); + Assert.Throws(() => { serverStream.ReadAsync(buffer, 0, 1); }); Assert.Throws(() => { clientStream.WriteAsync(buffer, 0, 1); }); } @@ -363,6 +366,12 @@ await RunWithConnectedNetworkStreamsAsync((server, _) => Assert.Throws(() => server.BeginRead(new byte[1], 0, -1, null, null)); Assert.Throws(() => server.BeginRead(new byte[1], 0, 2, null, null)); + Assert.Throws(() => { server.ReadAsync(null, 0, 0); }); + Assert.Throws(() => { server.ReadAsync(new byte[1], -1, 0); }); + Assert.Throws(() => { server.ReadAsync(new byte[1], 2, 0); }); + Assert.Throws(() => { server.ReadAsync(new byte[1], 0, -1); }); + Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 2); }); + Assert.Throws(() => server.Write(null, 0, 0)); Assert.Throws(() => server.Write(new byte[1], -1, 0)); Assert.Throws(() => server.Write(new byte[1], 2, 0)); @@ -375,6 +384,12 @@ await RunWithConnectedNetworkStreamsAsync((server, _) => Assert.Throws(() => server.BeginWrite(new byte[1], 0, -1, null, null)); Assert.Throws(() => server.BeginWrite(new byte[1], 0, 2, null, null)); + Assert.Throws(() => { server.WriteAsync(null, 0, 0); }); + Assert.Throws(() => { server.WriteAsync(new byte[1], -1, 0); }); + Assert.Throws(() => { server.WriteAsync(new byte[1], 2, 0); }); + Assert.Throws(() => { server.WriteAsync(new byte[1], 0, -1); }); + Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 2); }); + Assert.Throws(() => server.EndRead(null)); Assert.Throws(() => server.EndWrite(null));