diff --git a/src/Common/src/Interop/Windows/kernel32/Interop.SetFileCompletionNotificationModes.cs b/src/Common/src/Interop/Windows/kernel32/Interop.SetFileCompletionNotificationModes.cs new file mode 100644 index 000000000000..671800b7d251 --- /dev/null +++ b/src/Common/src/Interop/Windows/kernel32/Interop.SetFileCompletionNotificationModes.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static partial class Kernel32 + { + [Flags] + internal enum FileCompletionNotificationModes : byte + { + None = 0, + SkipCompletionPortOnSuccess = 1, + SkipSetEventOnHandle = 2 + } + + [DllImport(Libraries.Kernel32, SetLastError = true)] + internal static unsafe extern bool SetFileCompletionNotificationModes(SafeHandle handle, FileCompletionNotificationModes flags); + } +} diff --git a/src/Common/src/System/Net/CompletionPortHelper.Uap.cs b/src/Common/src/System/Net/CompletionPortHelper.Uap.cs new file mode 100644 index 000000000000..aac501ccb3ed --- /dev/null +++ b/src/Common/src/System/Net/CompletionPortHelper.Uap.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; + +namespace System.Net.Sockets +{ + internal static class CompletionPortHelper + { + internal static bool SkipCompletionPortOnSuccess(SafeHandle handle) + { + // SetFileCompletionNotificationModes is not supported on UAP. + return false; + } + + internal static readonly bool PlatformHasUdpIssue = false; + } +} diff --git a/src/Common/src/System/Net/CompletionPortHelper.Windows.cs b/src/Common/src/System/Net/CompletionPortHelper.Windows.cs new file mode 100644 index 000000000000..1a4ea844ff17 --- /dev/null +++ b/src/Common/src/System/Net/CompletionPortHelper.Windows.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.InteropServices; + +namespace System.Net.Sockets +{ + internal static class CompletionPortHelper + { + internal static bool SkipCompletionPortOnSuccess(SafeHandle handle) + { + return Interop.Kernel32.SetFileCompletionNotificationModes(handle, + Interop.Kernel32.FileCompletionNotificationModes.SkipCompletionPortOnSuccess | + Interop.Kernel32.FileCompletionNotificationModes.SkipSetEventOnHandle); + } + + // There's a bug with using SetFileCompletionNotificationModes with UDP on Windows 7 and before. + // This check tells us if the problem exists on the platform we're running on. + internal static readonly bool PlatformHasUdpIssue = CheckIfPlatformHasUdpIssue(); + + private static bool CheckIfPlatformHasUdpIssue() + { + Version osVersion = Environment.OSVersion.Version; + + // 6.1 == Windows 7 + return (osVersion.Major < 6 || + (osVersion.Major == 6 && osVersion.Minor <= 1)); + } + } +} diff --git a/src/Common/src/System/Net/SafeCloseSocket.Windows.cs b/src/Common/src/System/Net/SafeCloseSocket.Windows.cs index f90789fc1907..4f4615ecab4c 100644 --- a/src/Common/src/System/Net/SafeCloseSocket.Windows.cs +++ b/src/Common/src/System/Net/SafeCloseSocket.Windows.cs @@ -3,9 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.Win32.SafeHandles; - using System.Diagnostics; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; @@ -19,6 +17,7 @@ internal partial class SafeCloseSocket : #endif { private ThreadPoolBoundHandle _iocpBoundHandle; + private bool _skipCompletionPortOnSuccess; private object _iocpBindingLock = new object(); public ThreadPoolBoundHandle IOCPBoundHandle @@ -30,7 +29,7 @@ public ThreadPoolBoundHandle IOCPBoundHandle } // Binds the Socket Win32 Handle to the ThreadPool's CompletionPort. - public ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle() + public ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle(bool trySkipCompletionPortOnSuccess) { if (_released) { @@ -49,12 +48,13 @@ public ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle() // Bind the socket native _handle to the ThreadPool. if (NetEventSource.IsEnabled) NetEventSource.Info(this, "calling ThreadPool.BindHandle()"); + ThreadPoolBoundHandle boundHandle; try { // The handle (this) may have been already released: // E.g.: The socket has been disposed in the main thread. A completion callback may // attempt starting another operation. - _iocpBoundHandle = ThreadPoolBoundHandle.BindHandle(this); + boundHandle = ThreadPoolBoundHandle.BindHandle(this); } catch (Exception exception) { @@ -62,6 +62,16 @@ public ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle() CloseAsIs(); throw; } + + // Try to disable completions for synchronous success, if requested + if (trySkipCompletionPortOnSuccess && + CompletionPortHelper.SkipCompletionPortOnSuccess(boundHandle.Handle)) + { + _skipCompletionPortOnSuccess = true; + } + + // Don't set this until after we've configured the handle above (if we did) + _iocpBoundHandle = boundHandle; } } } @@ -69,6 +79,15 @@ public ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle() return _iocpBoundHandle; } + public bool SkipCompletionPortOnSuccess + { + get + { + Debug.Assert(_iocpBoundHandle != null); + return _skipCompletionPortOnSuccess; + } + } + internal static unsafe SafeCloseSocket CreateWSASocket(byte* pinnedBuffer) { return CreateSocket(InnerSafeCloseSocket.CreateWSASocket(pinnedBuffer)); diff --git a/src/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj b/src/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj index 582975a226e5..39ab948767eb 100644 --- a/src/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj +++ b/src/System.Net.NetworkInformation/src/System.Net.NetworkInformation.csproj @@ -138,6 +138,9 @@ Common\System\Net\ContextAwareResult.Windows.cs + + Common\System\Net\CompletionPortHelper.Windows.cs + Common\System\Net\LazyAsyncResult.cs @@ -197,6 +200,9 @@ Interop\Windows\kernel32\Interop.LocalFree.cs + + Interop\Windows\kernel32\Interop.SetFileCompletionNotificationModes.cs + Interop\Windows\Winsock\Interop.accept.cs diff --git a/src/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/System.Net.Sockets/src/System.Net.Sockets.csproj index 202817827e55..64853466bcdb 100644 --- a/src/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -282,15 +282,24 @@ Common\Interop\Windows\kernel32\Interop.LocalFree.cs + + Interop\Windows\kernel32\Interop.SetFileCompletionNotificationModes.cs + Common\System\Net\ContextAwareResult.Windows.cs + + Common\System\Net\CompletionPortHelper.Windows.cs + Common\System\Net\ContextAwareResult.Uap.cs + + Common\System\Net\CompletionPortHelper.Uap.cs + diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Unix.cs b/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Unix.cs index 7fc1d182b33a..d37ede5443fb 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Unix.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Unix.cs @@ -23,16 +23,10 @@ public BaseOverlappedAsyncResult(Socket socket, Object asyncState, AsyncCallback if (NetEventSource.IsEnabled) NetEventSource.Info(this, socket); } - public void CompletionCallback(int numBytes, SocketError errorCode) + protected void CompletionCallback(int numBytes, SocketError errorCode) { ErrorCode = (int)errorCode; InvokeCallback(PostCompletion(numBytes)); } - - private void ReleaseUnmanagedStructures() - { - // NOTE: this method needs to exist to conform to the contract expected by the - // platform-independent code in BaseOverlappedAsyncResult.CheckAsyncCallOverlappedResult. - } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Windows.cs index 9718acccd5b3..8178bee824ca 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Windows.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.Windows.cs @@ -30,14 +30,6 @@ internal BaseOverlappedAsyncResult(Socket socket, Object asyncState, AsyncCallba if (NetEventSource.IsEnabled) NetEventSource.Info(this, socket); } - internal SafeNativeOverlapped NativeOverlapped - { - get - { - return _nativeOverlapped; - } - } - // SetUnmanagedStructures // // This needs to be called for overlapped IO to function properly. @@ -59,14 +51,15 @@ internal void SetUnmanagedStructures(object objectsToPin) throw new ObjectDisposedException(s.GetType().FullName); } - ThreadPoolBoundHandle boundHandle = s.SafeHandle.GetOrAllocateThreadPoolBoundHandle(); + ThreadPoolBoundHandle boundHandle = s.GetOrAllocateThreadPoolBoundHandle(); unsafe { NativeOverlapped* overlapped = boundHandle.AllocateNativeOverlapped(s_ioCallback, this, objectsToPin); _nativeOverlapped = new SafeNativeOverlapped(s.SafeHandle, overlapped); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"{boundHandle}::AllocateNativeOverlapped. return={_nativeOverlapped}"); } + + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"{boundHandle}::AllocateNativeOverlapped. return={_nativeOverlapped}"); } private static unsafe void CompletionPortCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) @@ -78,8 +71,6 @@ private static unsafe void CompletionPortCallback(uint errorCode, uint numBytes, #endif BaseOverlappedAsyncResult asyncResult = (BaseOverlappedAsyncResult)ThreadPoolBoundHandle.GetNativeOverlappedState(nativeOverlapped); - object returnObject = null; - if (asyncResult.InternalPeekCompleted) { NetEventSource.Fail(null, $"asyncResult.IsCompleted: {asyncResult}"); @@ -115,7 +106,7 @@ private static unsafe void CompletionPortCallback(uint errorCode, uint numBytes, SocketFlags ignore; bool success = Interop.Winsock.WSAGetOverlappedResult( socket.SafeHandle, - asyncResult.NativeOverlapped, + asyncResult._nativeOverlapped, out numBytes, false, out ignore); @@ -135,15 +126,23 @@ private static unsafe void CompletionPortCallback(uint errorCode, uint numBytes, } } } - asyncResult.ErrorCode = (int)socketError; - returnObject = asyncResult.PostCompletion((int)numBytes); - asyncResult.ReleaseUnmanagedStructures(); - asyncResult.InvokeCallback(returnObject); + + // Set results and invoke callback + asyncResult.CompletionCallback((int)numBytes, socketError); #if DEBUG } #endif } + // Called either synchronously from SocketPal async routines or asynchronously via CompletionPortCallback above. + private void CompletionCallback(int numBytes, SocketError socketError) + { + ReleaseUnmanagedStructures(); + + ErrorCode = (int)socketError; + InvokeCallback(PostCompletion(numBytes)); + } + // The following property returns the Win32 unsafe pointer to // whichever Overlapped structure we're using for IO. internal SafeHandle OverlappedHandle @@ -157,7 +156,46 @@ internal SafeHandle OverlappedHandle } } - private void ReleaseUnmanagedStructures() + // Check the result of the overlapped operation. + // Handle synchronous success by completing the asyncResult here. + // Handle synchronous failure by cleaning up and returning a SocketError. + internal SocketError ProcessOverlappedResult(bool success, int bytesTransferred) + { + if (success) + { + // Synchronous success. + Socket socket = (Socket)AsyncObject; + if (socket.SafeHandle.SkipCompletionPortOnSuccess) + { + // The socket handle is configured to skip completion on success, + // so we can complete this asyncResult right now. + CompletionCallback(bytesTransferred, SocketError.Success); + return SocketError.Success; + } + + // Socket handle is going to post a completion to the completion port (may have done so already). + // Return pending and we will continue in the completion port callback. + return SocketError.IOPending; + } + + // Get the socket error (which may be IOPending) + SocketError errorCode = SocketPal.GetLastSocketError(); + + if (errorCode == SocketError.IOPending) + { + // Operation is pending. + // We will continue when the completion arrives (may have already at this point). + return SocketError.IOPending; + } + + // Synchronous failure. + // Release overlapped and pinned structures. + ReleaseUnmanagedStructures(); + + return errorCode; + } + + internal void ReleaseUnmanagedStructures() { if (Interlocked.Decrement(ref _cleanupCount) == 0) { diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs b/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs index 1248e7ce730d..1a6a243329ab 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/BaseOverlappedAsyncResult.cs @@ -36,30 +36,5 @@ internal int InternalWaitForCompletionInt32Result() base.InternalWaitForCompletion(); return _numBytes; } - - // This method is called after an asynchronous call is made for the user. - // It checks and acts accordingly if the IO: - // 1) completed synchronously. - // 2) was pended. - // 3) failed. - internal unsafe SocketError CheckAsyncCallOverlappedResult(SocketError errorCode) - { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, errorCode); - - if (errorCode == SocketError.Success || errorCode == SocketError.IOPending) - { - // Ignore cases in which a completion packet will be queued: - // we'll deal with this IO in the callback. - return SocketError.Success; - } - - // In the remaining cases a completion packet will NOT be queued: - // we have to call the callback explicitly signaling an error. - ErrorCode = (int)errorCode; - Result = -1; - - ReleaseUnmanagedStructures(); // Additional release for the completion that won't happen. - return errorCode; - } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs index 015c0b8b3a9b..bde3c459d106 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs @@ -6,6 +6,7 @@ using System.Collections; using System.IO; using System.Runtime.InteropServices; +using System.Threading; namespace System.Net.Sockets { @@ -246,12 +247,9 @@ private IAsyncResult BeginSendFileInternal(string fileName, byte[] preBuffer, by SocketError errorCode = SocketPal.SendFileAsync(_handle, fileStream, preBuffer, postBuffer, flags, asyncResult); // Check for synchronous exception - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + throw new SocketException((int)errorCode); } asyncResult.FinishPostingAsyncOp(ref Caches.SendClosureCache); @@ -292,5 +290,14 @@ private void EndSendFileInternal(IAsyncResult asyncResult) } } + + internal ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle() + { + // There is a known bug that exists through Windows 7 with UDP and + // SetFileCompletionNotificationModes. + // So, don't try to enable skipping the completion port on success in this case. + bool trySkipCompletionPortOnSuccess = !(CompletionPortHelper.PlatformHasUdpIssue && _protocolType == ProtocolType.Udp); + return _handle.GetOrAllocateThreadPoolBoundHandle(trySkipCompletionPortOnSuccess); + } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index f3fc7915b688..697943480dd2 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -2415,17 +2415,10 @@ private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}"); - // if the asynchronous native call fails synchronously - // we'll throw a SocketException - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - - if (errorCode != SocketError.Success) + // If the call failed, update our status and throw + if (!CheckErrorAndUpdateStatus(errorCode)) { - // update our internal state after this socket error and throw - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + throw new SocketException((int)errorCode); } if (NetEventSource.IsEnabled) NetEventSource.Exit(this, asyncResult); @@ -2702,32 +2695,16 @@ private SocketError DoBeginSend(byte[] buffer, int offset, int size, SocketFlags { if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"SRC:{LocalEndPoint} DST:{RemoteEndPoint} size:{size}"); - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try - { - // Get the Send going. - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"asyncResult:{asyncResult} size:{size}"); + // Get the Send going. + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"asyncResult:{asyncResult} size:{size}"); - errorCode = SocketPal.SendAsync(_handle, buffer, offset, size, socketFlags, asyncResult); + SocketError errorCode = SocketPal.SendAsync(_handle, buffer, offset, size, socketFlags, asyncResult); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSASend returns:{errorCode} size:{size} returning AsyncResult:{asyncResult}"); - } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSASend returns:{errorCode} size:{size} returning AsyncResult:{asyncResult}"); + + // If the call failed, update our status + CheckErrorAndUpdateStatus(errorCode); - // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) - { - UpdateStatusAfterSocketError(errorCode); - if (NetEventSource.IsEnabled) - { - if (NetEventSource.IsEnabled) NetEventSource.Error(this, new SocketException((int)errorCode)); - } - } return errorCode; } @@ -2784,29 +2761,15 @@ public IAsyncResult BeginSend(IList> buffers, SocketFlags soc private SocketError DoBeginSend(IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"SRC:{LocalEndPoint} DST:{RemoteEndPoint} buffers:{buffers}"); + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"asyncResult:{asyncResult}"); - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try - { - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"asyncResult:{asyncResult}"); + SocketError errorCode = SocketPal.SendAsync(_handle, buffers, socketFlags, asyncResult); - errorCode = SocketPal.SendAsync(_handle, buffers, socketFlags, asyncResult); + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSASend returns:{errorCode} returning AsyncResult:{asyncResult}"); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSASend returns:{errorCode} returning AsyncResult:{asyncResult}"); - } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } + // If the call failed, update our status + CheckErrorAndUpdateStatus(errorCode); - // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) - { - UpdateStatusAfterSocketError(errorCode); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, new SocketException((int)errorCode)); - } return errorCode; } @@ -3027,20 +2990,14 @@ private void DoBeginSendTo(byte[] buffer, int offset, int size, SocketFlags sock _rightEndPoint = oldEndPoint; throw; } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + + throw new SocketException((int)errorCode); } if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"size:{size} returning AsyncResult:{asyncResult}"); @@ -3216,28 +3173,13 @@ private SocketError DoBeginReceive(byte[] buffer, int offset, int size, SocketFl #if DEBUG IntPtr lastHandle = _handle.DangerousGetHandle(); #endif - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try - { - errorCode = SocketPal.ReceiveAsync(_handle, buffer, offset, size, socketFlags, asyncResult); + SocketError errorCode = SocketPal.ReceiveAsync(_handle, buffer, offset, size, socketFlags, asyncResult); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSARecv returns:{errorCode} returning AsyncResult:{asyncResult}"); - } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSARecv returns:{errorCode} returning AsyncResult:{asyncResult}"); // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { - // Update the internal state of this socket according to the error before throwing. - UpdateStatusAfterSocketError(errorCode); - var socketException = new SocketException((int)errorCode); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - asyncResult.InvokeCallback(new SocketException((int)errorCode)); } #if DEBUG else @@ -3309,28 +3251,12 @@ private SocketError DoBeginReceive(IList> buffers, SocketFlag #if DEBUG IntPtr lastHandle = _handle.DangerousGetHandle(); #endif - // Guarantee to call CheckAsyncCallOverlappedResult if we call SetUnamangedStructures with a cache in order to - // avoid a Socket leak in case of error. - SocketError errorCode = SocketError.SocketError; - try - { - errorCode = SocketPal.ReceiveAsync(_handle, buffers, socketFlags, asyncResult); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSARecv returns:{errorCode} returning AsyncResult:{asyncResult}"); - } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } + SocketError errorCode = SocketPal.ReceiveAsync(_handle, buffers, socketFlags, asyncResult); - // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.WSARecv returns:{errorCode} returning AsyncResult:{asyncResult}"); + + if (!CheckErrorAndUpdateStatus(errorCode)) { - // Update the internal state of this socket according to the error before throwing. - UpdateStatusAfterSocketError(errorCode); - if (NetEventSource.IsEnabled) - { - if (NetEventSource.IsEnabled) NetEventSource.Error(this, new SocketException((int)errorCode)); - } } #if DEBUG else @@ -3530,20 +3456,14 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size, _rightEndPoint = oldEndPoint; throw; } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + + throw new SocketException((int)errorCode); } // Capture the context, maybe call the callback, and return. @@ -3768,20 +3688,14 @@ private void DoBeginReceiveFrom(byte[] buffer, int offset, int size, SocketFlags _rightEndPoint = oldEndPoint; throw; } - finally - { - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - } // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + + throw new SocketException((int)errorCode); } if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"size:{size} return AsyncResult:{asyncResult}"); @@ -3964,17 +3878,12 @@ private void DoBeginAccept(Socket acceptSocket, int receiveSize, AcceptOverlappe int socketAddressSize = _rightEndPoint.Serialize().Size; SocketError errorCode = SocketPal.AcceptAsync(this, _handle, acceptHandle, receiveSize, socketAddressSize, asyncResult); - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.AcceptEx returns:{errorCode} {asyncResult}"); // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + throw new SocketException((int)errorCode); } } @@ -4109,7 +4018,6 @@ public void Shutdown(SocketShutdown how) public bool AcceptAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; if (CleanedUp) { @@ -4142,13 +4050,12 @@ public bool AcceptAsync(SocketAsyncEventArgs e) e.StartOperationAccept(); // Local variables for sync completion. - int bytesTransferred; SocketError socketError = SocketError.Success; // Make the native call. try { - socketError = e.DoOperationAccept(this, _handle, acceptHandle, out bytesTransferred); + socketError = e.DoOperationAccept(this, _handle, acceptHandle); } catch { @@ -4157,25 +4064,15 @@ public bool AcceptAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } - - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + bool pending = (socketError == SocketError.IOPending); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public bool ConnectAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; + bool pending; if (CleanedUp) { @@ -4219,7 +4116,7 @@ public bool ConnectAsync(SocketAsyncEventArgs e) e.StartOperationCommon(this); e.StartOperationWrapperConnect(multipleConnectAsync); - retval = multipleConnectAsync.StartConnectAsync(e, dnsEP); + pending = multipleConnectAsync.StartConnectAsync(e, dnsEP); } else { @@ -4256,11 +4153,10 @@ public bool ConnectAsync(SocketAsyncEventArgs e) e.StartOperationConnect(); // Make the native call. - int bytesTransferred; SocketError socketError = SocketError.Success; try { - socketError = e.DoOperationConnect(this, _handle, out bytesTransferred); + socketError = e.DoOperationConnect(this, _handle); } catch { @@ -4271,26 +4167,17 @@ public bool ConnectAsync(SocketAsyncEventArgs e) throw; } - // Handle failure where completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } + pending = (socketError == SocketError.IOPending); } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(null); - bool retval; + bool pending; if (e == null) { @@ -4336,16 +4223,16 @@ public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType e.StartOperationCommon(attemptSocket); e.StartOperationWrapperConnect(multipleConnectAsync); - retval = multipleConnectAsync.StartConnectAsync(e, dnsEP); + pending = multipleConnectAsync.StartConnectAsync(e, dnsEP); } else { Socket attemptSocket = new Socket(endPointSnapshot.AddressFamily, socketType, protocolType); - retval = attemptSocket.ConnectAsync(e); + pending = attemptSocket.ConnectAsync(e); } - if (NetEventSource.IsEnabled) NetEventSource.Exit(null, retval); - return retval; + if (NetEventSource.IsEnabled) NetEventSource.Exit(null, pending); + return pending; } public static void CancelConnectAsync(SocketAsyncEventArgs e) @@ -4360,7 +4247,6 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e) public bool DisconnectAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this); - bool retval; // Throw if socket disposed if (CleanedUp) @@ -4384,17 +4270,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, 0, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } - + bool retval = (socketError == SocketError.IOPending); if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); return retval; } @@ -4402,7 +4278,6 @@ public bool DisconnectAsync(SocketAsyncEventArgs e) public bool ReceiveAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; if (CleanedUp) { @@ -4420,13 +4295,12 @@ public bool ReceiveAsync(SocketAsyncEventArgs e) // Local vars for sync completion of native call. SocketFlags flags; - int bytesTransferred; SocketError socketError; // Wrap native methods with try/catch so event args object can be cleaned up. try { - socketError = e.DoOperationReceive(_handle, out flags, out bytesTransferred); + socketError = e.DoOperationReceive(_handle, out flags); } catch { @@ -4435,25 +4309,14 @@ public bool ReceiveAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, flags); - retval = false; - } - else - { - retval = true; - } - - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + bool pending = (socketError == SocketError.IOPending); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public bool ReceiveFromAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; if (CleanedUp) { @@ -4491,12 +4354,11 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e) // Make the native call. SocketFlags flags; - int bytesTransferred; SocketError socketError; try { - socketError = e.DoOperationReceiveFrom(_handle, out flags, out bytesTransferred); + socketError = e.DoOperationReceiveFrom(_handle, out flags); } catch { @@ -4505,25 +4367,14 @@ public bool ReceiveFromAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, flags); - retval = false; - } - else - { - retval = true; - } - - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + bool pending = (socketError == SocketError.IOPending); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; if (CleanedUp) { @@ -4562,12 +4413,11 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) e.StartOperationReceiveMessageFrom(); // Make the native call. - int bytesTransferred; SocketError socketError; try { - socketError = e.DoOperationReceiveMessageFrom(this, _handle, out bytesTransferred); + socketError = e.DoOperationReceiveMessageFrom(this, _handle); } catch { @@ -4576,25 +4426,14 @@ public bool ReceiveMessageFromAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } - - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + bool pending = (socketError == SocketError.IOPending); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public bool SendAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; if (CleanedUp) { @@ -4611,13 +4450,12 @@ public bool SendAsync(SocketAsyncEventArgs e) e.StartOperationSend(); // Local vars for sync completion of native call. - int bytesTransferred; SocketError socketError; // Wrap native methods with try/catch so event args object can be cleaned up. try { - socketError = e.DoOperationSend(_handle, out bytesTransferred); + socketError = e.DoOperationSend(_handle); } catch { @@ -4626,25 +4464,14 @@ public bool SendAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } - - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + bool pending = (socketError == SocketError.IOPending); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public bool SendPacketsAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; // Throw if socket disposed if (CleanedUp) @@ -4686,33 +4513,22 @@ public bool SendPacketsAsync(SocketAsyncEventArgs e) e.Complete(); throw; } - - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, 0, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } } else { // No buffers or files to send. - e.FinishOperationSuccess(SocketError.Success, 0, SocketFlags.None); - retval = false; + e.FinishOperationSyncSuccess(0, SocketFlags.None); + socketError = SocketError.Success; } - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); - return retval; + bool pending = (socketError == SocketError.IOPending); + if (NetEventSource.IsEnabled) NetEventSource.Exit(this, pending); + return pending; } public bool SendToAsync(SocketAsyncEventArgs e) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this, e); - bool retval; if (CleanedUp) { @@ -4737,13 +4553,12 @@ public bool SendToAsync(SocketAsyncEventArgs e) e.StartOperationSendTo(); // Make the native call. - int bytesTransferred; SocketError socketError; // Wrap native methods with try/catch so event args object can be cleaned up. try { - socketError = e.DoOperationSendTo(_handle, out bytesTransferred); + socketError = e.DoOperationSendTo(_handle); } catch { @@ -4752,17 +4567,7 @@ public bool SendToAsync(SocketAsyncEventArgs e) throw; } - // Handle completion when completion port is not posted. - if (socketError != SocketError.Success && socketError != SocketError.IOPending) - { - e.FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); - retval = false; - } - else - { - retval = true; - } - + bool retval = (socketError == SocketError.IOPending); if (NetEventSource.IsEnabled) NetEventSource.Exit(this, retval); return retval; } @@ -5431,32 +5236,25 @@ private IAsyncResult BeginConnectEx(EndPoint remoteEP, bool flowContext, AsyncCa } catch { - // If ConnectEx throws we need to unpin the socketAddress buffer. // _rightEndPoint will always equal oldEndPoint. - asyncResult.InternalCleanup(); _rightEndPoint = oldEndPoint; throw; } + if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.connect returns:{errorCode}"); if (errorCode == SocketError.Success) { + // Synchronous success. Indicate that we're connected. SetToConnected(); } - if (NetEventSource.IsEnabled) NetEventSource.Info(this, $"Interop.Winsock.connect returns:{errorCode}"); - - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - - // Throw an appropriate SocketException if the native call fails synchronously. - if (errorCode != SocketError.Success) + if (!CheckErrorAndUpdateStatus(errorCode)) { // Update the internal state of this socket according to the error before throwing. _rightEndPoint = oldEndPoint; - SocketException socketException = new SocketException((int)errorCode); - UpdateStatusAfterSocketError(socketException); - if (NetEventSource.IsEnabled) NetEventSource.Error(this, socketException); - throw socketException; + + throw new SocketException((int)errorCode); } // We didn't throw, so indicate that we're returning this result to the user. This may call the callback. @@ -5810,6 +5608,17 @@ internal void UpdateStatusAfterSocketError(SocketError errorCode) } } + private bool CheckErrorAndUpdateStatus(SocketError errorCode) + { + if (errorCode == SocketError.Success || errorCode == SocketError.IOPending) + { + return true; + } + + UpdateStatusAfterSocketError(errorCode); + return false; + } + // ValidateBlockingMode - called before synchronous calls to validate // the fact that we are in blocking mode (not in non-blocking mode) so the // call will actually be synchronous. diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 36a95649dcab..066e5f3fae89 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -11,27 +11,25 @@ namespace System.Net.Sockets { - // - // NOTE: the publicly-exposed asynchronous methods should match the behavior of - // Winsock overlapped sockets as closely as possible. Especially important are - // completion semantics, as the consuming code relies on the Winsock behavior. - // - // Winsock queues a completion callback for an overlapped operation in two cases: - // 1. the operation successfully completes synchronously, or - // 2. the operation completes asynchronously, successfully or otherwise. - // In other words, a completion callback is queued iff an operation does not - // fail synchronously. The asynchronous methods below (e.g. ReceiveAsync) may - // fail synchronously for either of the following reasons: - // 1. an underlying system call fails synchronously, or - // 2. an underlying system call returns EAGAIN, but the socket is closed before - // the method is able to enqueue its corresponding operation. - // In the first case, the async method should return the SocketError that - // corresponds to the native error code; in the second, the method should return - // SocketError.OperationAborted (which matches what Winsock would return in this - // case). The publicly-exposed synchronous methods may also encounter the second - // case. In this situation these methods should return SocketError.Interrupted - // (which again matches Winsock). - // + // Note on asynchronous behavior here: + + // The asynchronous socket operations here generally do the following: + // (1) If the operation queue is empty, try to perform the operation immediately, non-blocking. + // If this completes (i.e. does not return EWOULDBLOCK), then we return the results immediately + // for both success (SocketError.Success) or failure. + // No callback will happen; callers are expected to handle these synchronous completions themselves. + // (2) If EWOULDBLOCK is returned, or the queue is not empty, then we enqueue an operation to the + // appropriate queue and return SocketError.IOPending. + // Enqueuing itself may fail because the socket is closed before the operation can be enqueued; + // in this case, we return SocketError.OperationAborted (which matches what Winsock would return in this case). + // (3) When the queue completes the operation, it will post a work item to the threadpool + // to call the callback with results (either success or failure). + + // Synchronous operations generally do the same, except that instead of returning IOPending, + // they block on an event handle until the operation is processed by the queue. + // Also, synchronous methods return SocketError.Interrupted when enqueuing fails + // (which again matches Winsock behavior). + internal sealed class SocketAsyncContext { private abstract class AsyncOperation @@ -662,7 +660,7 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, int ti } } - public SocketError AcceptAsync(byte[] socketAddress, int socketAddressLen, Action callback) + public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action callback) { Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); @@ -670,20 +668,11 @@ public SocketError AcceptAsync(byte[] socketAddress, int socketAddressLen, Actio SetNonBlocking(); - IntPtr acceptedFd; SocketError errorCode; if (SocketPal.TryCompleteAccept(_socket, socketAddress, ref socketAddressLen, out acceptedFd, out errorCode)) { Debug.Assert(errorCode == SocketError.Success || acceptedFd == (IntPtr)(-1), $"Unexpected values: errorCode={errorCode}, acceptedFd={acceptedFd}"); - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var tup = (Tuple, IntPtr, byte[], int>)args; - tup.Item1(tup.Item2, tup.Item3, tup.Item4, SocketError.Success); - }, Tuple.Create(callback, acceptedFd, socketAddress, socketAddressLen)); - } return errorCode; } @@ -782,11 +771,6 @@ public SocketError ConnectAsync(byte[] socketAddress, int socketAddressLen, Acti { RegisterConnectResult(errorCode); - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(arg => ((Action)arg)(SocketError.Success), callback); - } - return errorCode; } @@ -827,9 +811,10 @@ public SocketError Receive(byte[] buffer, int offset, int count, ref SocketFlags return ReceiveFrom(buffer, offset, count, ref flags, null, ref socketAddressLen, timeout, out bytesReceived); } - public SocketError ReceiveAsync(byte[] buffer, int offset, int count, SocketFlags flags, Action callback) + public SocketError ReceiveAsync(byte[] buffer, int offset, int count, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action callback) { - return ReceiveFromAsync(buffer, offset, count, flags, null, 0, callback); + int socketAddressLen = 0; + return ReceiveFromAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback); } public SocketError ReceiveFrom(byte[] buffer, int offset, int count, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, int timeout, out int bytesReceived) @@ -897,30 +882,24 @@ public SocketError ReceiveFrom(byte[] buffer, int offset, int count, ref SocketF } } - public SocketError ReceiveFromAsync(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, int socketAddressLen, Action callback) + public SocketError ReceiveFromAsync(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action callback) { SetNonBlocking(); lock (_receiveLock) { - int bytesReceived; - SocketFlags receivedFlags; SocketError errorCode; if (_receiveQueue.IsEmpty && SocketPal.TryCompleteReceiveFrom(_socket, buffer, offset, count, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var tup = (Tuple, int, byte[], int, SocketFlags>)args; - tup.Item1(tup.Item2, tup.Item3, tup.Item4, tup.Item5, SocketError.Success); - }, Tuple.Create(callback, bytesReceived, socketAddress, socketAddressLen, receivedFlags)); - } + // Synchronous success or failure return errorCode; } + bytesReceived = 0; + receivedFlags = SocketFlags.None; + var operation = new ReceiveOperation { Callback = callback, @@ -955,9 +934,10 @@ public SocketError Receive(IList> buffers, ref SocketFlags fl return ReceiveFrom(buffers, ref flags, null, 0, timeout, out bytesReceived); } - public SocketError ReceiveAsync(IList> buffers, SocketFlags flags, Action callback) + public SocketError ReceiveAsync(IList> buffers, SocketFlags flags, out int bytesReceived, out SocketFlags receivedFlags, Action callback) { - return ReceiveFromAsync(buffers, flags, null, 0, callback); + int socketAddressLen = 0; + return ReceiveFromAsync(buffers, flags, null, ref socketAddressLen, out bytesReceived, out receivedFlags, callback); } public SocketError ReceiveFrom(IList> buffers, ref SocketFlags flags, byte[] socketAddress, int socketAddressLen, int timeout, out int bytesReceived) @@ -1024,7 +1004,7 @@ public SocketError ReceiveFrom(IList> buffers, ref SocketFlag } } - public SocketError ReceiveFromAsync(IList> buffers, SocketFlags flags, byte[] socketAddress, int socketAddressLen, Action callback) + public SocketError ReceiveFromAsync(IList> buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesReceived, out SocketFlags receivedFlags, Action callback) { SetNonBlocking(); @@ -1032,23 +1012,17 @@ public SocketError ReceiveFromAsync(IList> buffers, SocketFla lock (_receiveLock) { - int bytesReceived; - SocketFlags receivedFlags; SocketError errorCode; if (_receiveQueue.IsEmpty && SocketPal.TryCompleteReceiveFrom(_socket, buffers, flags, socketAddress, ref socketAddressLen, out bytesReceived, out receivedFlags, out errorCode)) { - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var tup = (Tuple, int, byte[], int, SocketFlags>)args; - tup.Item1(tup.Item2, tup.Item3, tup.Item4, tup.Item5, SocketError.Success); - }, Tuple.Create(callback, bytesReceived, socketAddress, socketAddressLen, receivedFlags)); - } + // Synchronous success or failure return errorCode; } + bytesReceived = 0; + receivedFlags = SocketFlags.None; + operation = new ReceiveOperation { Callback = callback, @@ -1147,31 +1121,25 @@ public SocketError ReceiveMessageFrom(byte[] buffer, int offset, int count, ref } } - public SocketError ReceiveMessageFromAsync(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, int socketAddressLen, bool isIPv4, bool isIPv6, Action callback) + public SocketError ReceiveMessageFromAsync(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action callback) { SetNonBlocking(); lock (_receiveLock) { - int bytesReceived; - SocketFlags receivedFlags; - IPPacketInformation ipPacketInformation; + ipPacketInformation = default(IPPacketInformation); SocketError errorCode; if (_receiveQueue.IsEmpty && SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, offset, count, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode)) { - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var tup = (Tuple, int, byte[], int, SocketFlags, IPPacketInformation>)args; - tup.Item1(tup.Item2, tup.Item3, tup.Item4, tup.Item5, tup.Item6, SocketError.Success); - }, Tuple.Create(callback, bytesReceived, socketAddress, socketAddressLen, receivedFlags, ipPacketInformation)); - } + // Synchronous success or failure return errorCode; } + bytesReceived = 0; + receivedFlags = SocketFlags.None; + var operation = new ReceiveMessageFromOperation { Callback = callback, @@ -1208,9 +1176,10 @@ public SocketError Send(byte[] buffer, int offset, int count, SocketFlags flags, return SendTo(buffer, offset, count, flags, null, 0, timeout, out bytesSent); } - public SocketError SendAsync(byte[] buffer, int offset, int count, SocketFlags flags, Action callback) + public SocketError SendAsync(byte[] buffer, int offset, int count, SocketFlags flags, out int bytesSent, Action callback) { - return SendToAsync(buffer, offset, count, flags, null, 0, callback); + int socketAddressLen = 0; + return SendToAsync(buffer, offset, count, flags, null, ref socketAddressLen, out bytesSent, callback); } public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, int socketAddressLen, int timeout, out int bytesSent) @@ -1274,26 +1243,19 @@ public SocketError SendTo(byte[] buffer, int offset, int count, SocketFlags flag } } - public SocketError SendToAsync(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, int socketAddressLen, Action callback) + public SocketError SendToAsync(byte[] buffer, int offset, int count, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesSent, Action callback) { SetNonBlocking(); lock (_sendAcceptConnectLock) { - int bytesSent = 0; + bytesSent = 0; SocketError errorCode; if (_sendQueue.IsEmpty && SocketPal.TryCompleteSendTo(_socket, buffer, ref offset, ref count, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode)) { - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var tup = (Tuple, int, byte[], int>)args; - tup.Item1(tup.Item2, tup.Item3, tup.Item4, 0, SocketError.Success); - }, Tuple.Create(callback, bytesSent, socketAddress, socketAddressLen)); - } + // Synchronous success or failure return errorCode; } @@ -1332,9 +1294,10 @@ public SocketError Send(IList> buffers, SocketFlags flags, in return SendTo(buffers, flags, null, 0, timeout, out bytesSent); } - public SocketError SendAsync(IList> buffers, SocketFlags flags, Action callback) + public SocketError SendAsync(IList> buffers, SocketFlags flags, out int bytesSent, Action callback) { - return SendToAsync(buffers, flags, null, 0, callback); + int socketAddressLen = 0; + return SendToAsync(buffers, flags, null, ref socketAddressLen, out bytesSent, callback); } public SocketError SendTo(IList> buffers, SocketFlags flags, byte[] socketAddress, int socketAddressLen, int timeout, out int bytesSent) @@ -1400,28 +1363,23 @@ public SocketError SendTo(IList> buffers, SocketFlags flags, } } - public SocketError SendToAsync(IList> buffers, SocketFlags flags, byte[] socketAddress, int socketAddressLen, Action callback) + + + public SocketError SendToAsync(IList> buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, out int bytesSent, Action callback) { SetNonBlocking(); lock (_sendAcceptConnectLock) { + bytesSent = 0; int bufferIndex = 0; int offset = 0; - int bytesSent = 0; SocketError errorCode; if (_sendQueue.IsEmpty && SocketPal.TryCompleteSendTo(_socket, buffers, ref bufferIndex, ref offset, flags, socketAddress, socketAddressLen, ref bytesSent, out errorCode)) { - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var tup = (Tuple, int, byte[], int>)args; - tup.Item1(tup.Item2, tup.Item3, tup.Item4, SocketFlags.None, SocketError.Success); - }, Tuple.Create(callback, bytesSent, socketAddress, socketAddressLen)); - } + // Synchronous success or failure return errorCode; } @@ -1512,26 +1470,19 @@ public SocketError SendFile(SafeFileHandle fileHandle, long offset, long count, } } - public SocketError SendFileAsync(SafeFileHandle fileHandle, long offset, long count, Action callback) + public SocketError SendFileAsync(SafeFileHandle fileHandle, long offset, long count, out long bytesSent, Action callback) { SetNonBlocking(); lock (_sendAcceptConnectLock) { - long bytesSent = 0; + bytesSent = 0; SocketError errorCode; if (_sendQueue.IsEmpty && SocketPal.TryCompleteSendFile(_socket, fileHandle, ref offset, ref count, ref bytesSent, out errorCode)) { - if (errorCode == SocketError.Success) - { - ThreadPool.QueueUserWorkItem(args => - { - var c = (Action)args; - c(bytesSent, SocketError.Success); - }, callback); - } + // Synchronous success or failure return errorCode; } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index e93e2bd5d024..f71663b54ba8 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -51,15 +51,20 @@ private void InnerStartOperationAccept(bool userSuppliedBuffer) } private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError) + { + CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketAddressSize, socketError); + + CompletionCallback(0, SocketFlags.None, socketError); + } + + private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError) { _acceptedFileDescriptor = acceptedFileDescriptor; Debug.Assert(socketAddress == null || socketAddress == _acceptBuffer, $"Unexpected socketAddress: {socketAddress}"); _acceptAddressBufferCount = socketAddressSize; - - CompletionCallback(0, socketError); } - internal unsafe SocketError DoOperationAccept(Socket socket, SafeCloseSocket handle, SafeCloseSocket acceptHandle, out int bytesTransferred) + internal unsafe SocketError DoOperationAccept(Socket socket, SafeCloseSocket handle, SafeCloseSocket acceptHandle) { if (_buffer != null) { @@ -68,9 +73,17 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeCloseSocket han Debug.Assert(acceptHandle == null, $"Unexpected acceptHandle: {acceptHandle}"); - bytesTransferred = 0; + IntPtr acceptedFd; + int socketAddressLen = _acceptAddressBufferCount / 2; + SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback); - return handle.AsyncContext.AcceptAsync(_acceptBuffer, _acceptAddressBufferCount / 2, AcceptCompletionCallback); + if (socketError != SocketError.IOPending) + { + CompleteAcceptOperation(acceptedFd, _acceptBuffer, socketAddressLen, socketError); + FinishOperationSync(socketError, 0, SocketFlags.None); + } + + return socketError; } private void InnerStartOperationConnect() @@ -80,14 +93,17 @@ private void InnerStartOperationConnect() private void ConnectCompletionCallback(SocketError socketError) { - CompletionCallback(0, socketError); + CompletionCallback(0, SocketFlags.None, socketError); } - internal unsafe SocketError DoOperationConnect(Socket socket, SafeCloseSocket handle, out int bytesTransferred) + internal unsafe SocketError DoOperationConnect(Socket socket, SafeCloseSocket handle) { - bytesTransferred = 0; - - return handle.AsyncContext.ConnectAsync(_socketAddress.Buffer, _socketAddress.Size, ConnectCompletionCallback); + SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress.Buffer, _socketAddress.Size, ConnectCompletionCallback); + if (socketError != SocketError.IOPending) + { + FinishOperationSync(socketError, 0, SocketFlags.None); + } + return socketError; } internal SocketError DoOperationDisconnect(Socket socket, SafeCloseSocket handle) @@ -104,12 +120,17 @@ private void InnerStartOperationDisconnect() _transferCompletionCallback ?? (_transferCompletionCallback = TransferCompletionCallbackCore); private void TransferCompletionCallbackCore(int bytesTransferred, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, SocketError socketError) + { + CompleteTransferOperation(bytesTransferred, socketAddress, socketAddressSize, receivedFlags, socketError); + + CompletionCallback(bytesTransferred, receivedFlags, socketError); + } + + private void CompleteTransferOperation(int bytesTransferred, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, SocketError socketError) { Debug.Assert(socketAddress == null || socketAddress == _socketAddress.Buffer, $"Unexpected socketAddress: {socketAddress}"); _socketAddressSize = socketAddressSize; _receivedFlags = receivedFlags; - - CompletionCallback(bytesTransferred, socketError); } private void InnerStartOperationReceive() @@ -118,20 +139,25 @@ private void InnerStartOperationReceive() _socketAddressSize = 0; } - internal unsafe SocketError DoOperationReceive(SafeCloseSocket handle, out SocketFlags flags, out int bytesTransferred) + internal unsafe SocketError DoOperationReceive(SafeCloseSocket handle, out SocketFlags flags) { + int bytesReceived; SocketError errorCode; if (_buffer != null) { - errorCode = handle.AsyncContext.ReceiveAsync(_buffer, _offset, _count, _socketFlags, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveAsync(_buffer, _offset, _count, _socketFlags, out bytesReceived, out flags, TransferCompletionCallback); } else { - errorCode = handle.AsyncContext.ReceiveAsync(_bufferList, _socketFlags, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveAsync(_bufferList, _socketFlags, out bytesReceived, out flags, TransferCompletionCallback); + } + + if (errorCode != SocketError.IOPending) + { + CompleteTransferOperation(bytesReceived, null, 0, flags, errorCode); + FinishOperationSync(errorCode, bytesReceived, flags); } - flags = _socketFlags; - bytesTransferred = 0; return errorCode; } @@ -141,20 +167,26 @@ private void InnerStartOperationReceiveFrom() _socketAddressSize = 0; } - internal unsafe SocketError DoOperationReceiveFrom(SafeCloseSocket handle, out SocketFlags flags, out int bytesTransferred) + internal unsafe SocketError DoOperationReceiveFrom(SafeCloseSocket handle, out SocketFlags flags) { SocketError errorCode; + int bytesReceived = 0; + int socketAddressLen = _socketAddress.Size; if (_buffer != null) { - errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, _socketAddress.Size, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); } else { - errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferList, _socketFlags, _socketAddress.Buffer, _socketAddress.Size, TransferCompletionCallback); + errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferList, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback); + } + + if (errorCode != SocketError.IOPending) + { + CompleteTransferOperation(bytesReceived, _socketAddress.Buffer, socketAddressLen, flags, errorCode); + FinishOperationSync(errorCode, bytesReceived, flags); } - flags = _socketFlags; - bytesTransferred = 0; return errorCode; } @@ -166,6 +198,13 @@ private void InnerStartOperationReceiveMessageFrom() } private void ReceiveMessageFromCompletionCallback(int bytesTransferred, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) + { + CompleteReceiveMessageFromOperation(bytesTransferred, socketAddress, socketAddressSize, receivedFlags, ipPacketInformation, errorCode); + + CompletionCallback(bytesTransferred, receivedFlags, errorCode); + } + + private void CompleteReceiveMessageFromOperation(int bytesTransferred, byte[] socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode) { Debug.Assert(_socketAddress != null, "Expected non-null _socketAddress"); Debug.Assert(socketAddress == null || _socketAddress.Buffer == socketAddress, $"Unexpected socketAddress: {socketAddress}"); @@ -173,17 +212,24 @@ private void ReceiveMessageFromCompletionCallback(int bytesTransferred, byte[] s _socketAddressSize = socketAddressSize; _receivedFlags = receivedFlags; _receiveMessageFromPacketInfo = ipPacketInformation; - - CompletionCallback(bytesTransferred, errorCode); } - internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeCloseSocket handle, out int bytesTransferred) + internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeCloseSocket handle) { bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(socket.AddressFamily, _socketAddress, out isIPv4, out isIPv6); - bytesTransferred = 0; - return handle.AsyncContext.ReceiveMessageFromAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, _socketAddress.Size, isIPv4, isIPv6, ReceiveMessageFromCompletionCallback); + int socketAddressSize = _socketAddress.Size; + int bytesReceived; + SocketFlags receivedFlags; + IPPacketInformation ipPacketInformation; + SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback); + if (socketError != SocketError.IOPending) + { + CompleteReceiveMessageFromOperation(bytesReceived, _socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation, socketError); + FinishOperationSync(socketError, bytesReceived, receivedFlags); + } + return socketError; } private void InnerStartOperationSend() @@ -192,19 +238,25 @@ private void InnerStartOperationSend() _socketAddressSize = 0; } - internal unsafe SocketError DoOperationSend(SafeCloseSocket handle, out int bytesTransferred) + internal unsafe SocketError DoOperationSend(SafeCloseSocket handle) { + int bytesSent; SocketError errorCode; if (_buffer != null) { - errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, out bytesSent, TransferCompletionCallback); } else { - errorCode = handle.AsyncContext.SendAsync(_bufferList, _socketFlags, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendAsync(_bufferList, _socketFlags, out bytesSent, TransferCompletionCallback); + } + + if (errorCode != SocketError.IOPending) + { + CompleteTransferOperation(bytesSent, null, 0, SocketFlags.None, errorCode); + FinishOperationSync(errorCode, bytesSent, SocketFlags.None); } - bytesTransferred = 0; return errorCode; } @@ -224,19 +276,26 @@ private void InnerStartOperationSendTo() _socketAddressSize = 0; } - internal SocketError DoOperationSendTo(SafeCloseSocket handle, out int bytesTransferred) + internal SocketError DoOperationSendTo(SafeCloseSocket handle) { + int bytesSent; + int socketAddressLen = _socketAddress.Size; SocketError errorCode; if (_buffer != null) { - errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, _socketAddress.Size, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback); } else { - errorCode = handle.AsyncContext.SendToAsync(_bufferList, _socketFlags, _socketAddress.Buffer, _socketAddress.Size, TransferCompletionCallback); + errorCode = handle.AsyncContext.SendToAsync(_bufferList, _socketFlags, _socketAddress.Buffer, ref socketAddressLen, out bytesSent, TransferCompletionCallback); + } + + if (errorCode != SocketError.IOPending) + { + CompleteTransferOperation(bytesSent, _socketAddress.Buffer, socketAddressLen, SocketFlags.None, errorCode); + FinishOperationSync(errorCode, bytesSent, SocketFlags.None); } - bytesTransferred = 0; return errorCode; } @@ -289,11 +348,11 @@ private void FinishOperationSendPackets() throw new PlatformNotSupportedException(); } - private void CompletionCallback(int bytesTransferred, SocketError socketError) + private void CompletionCallback(int bytesTransferred, SocketFlags flags, SocketError socketError) { if (socketError == SocketError.Success) { - FinishOperationSuccess(socketError, bytesTransferred, _receivedFlags); + FinishOperationAsyncSuccess(bytesTransferred, flags); } else { @@ -302,7 +361,7 @@ private void CompletionCallback(int bytesTransferred, SocketError socketError) socketError = SocketError.OperationAborted; } - FinishOperationAsyncFailure(socketError, bytesTransferred, _receivedFlags); + FinishOperationAsyncFailure(socketError, bytesTransferred, flags); } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 53d6c029c727..934942018bb9 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -127,7 +127,7 @@ private unsafe void PrepareIOCPOperation() Debug.Assert(_currentSocket.SafeHandle != null, "_currentSocket.SafeHandle is null"); Debug.Assert(!_currentSocket.SafeHandle.IsInvalid, "_currentSocket.SafeHandle is invalid"); - ThreadPoolBoundHandle boundHandle = _currentSocket.SafeHandle.GetOrAllocateThreadPoolBoundHandle(); + ThreadPoolBoundHandle boundHandle = _currentSocket.GetOrAllocateThreadPoolBoundHandle(); NativeOverlapped* overlapped = null; if (_preAllocatedOverlapped != null) @@ -157,15 +157,40 @@ private unsafe void PrepareIOCPOperation() } } - private void CompleteIOCPOperation() + private SocketError ProcessIOCPResult(bool success, int bytesTransferred) { - // TODO #4900: Optimization to remove callbacks if the operations are completed synchronously: - // Use SetFileCompletionNotificationModes(FILE_SKIP_COMPLETION_PORT_ON_SUCCESS). + if (success) + { + // Synchronous success. + if (_currentSocket.SafeHandle.SkipCompletionPortOnSuccess) + { + // The socket handle is configured to skip completion on success, + // so we can set the results right now. + FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None); + return SocketError.Success; + } + + // Socket handle is going to post a completion to the completion port (may have done so already). + // Return pending and we will continue in the completion port callback. + return SocketError.IOPending; + } - // If SetFileCompletionNotificationModes(FILE_SKIP_COMPLETION_PORT_ON_SUCCESS) is not set on this handle - // it is guaranteed that the IOCP operation will be completed in the callback even if Socket.Success was - // returned by the Win32 API. + // Get the socket error (which may be IOPending) + SocketError errorCode = SocketPal.GetLastSocketError(); + + if (errorCode == SocketError.IOPending) + { + return errorCode; + } + + FinishOperationSyncFailure(errorCode, bytesTransferred, SocketFlags.None); + + // Note, the overlapped will be release in CompleteIOCPOperation below, for either success or failure + return errorCode; + } + private void CompleteIOCPOperation() + { // Required to allow another IOCP operation for the same handle. We release the native overlapped // in the safe handle, but keep the safe handle object around so as to be able to reuse it // for other operations. @@ -180,13 +205,12 @@ private void InnerStartOperationAccept(bool userSuppliedBuffer) } } - internal unsafe SocketError DoOperationAccept(Socket socket, SafeCloseSocket handle, SafeCloseSocket acceptHandle, out int bytesTransferred) + internal unsafe SocketError DoOperationAccept(Socket socket, SafeCloseSocket handle, SafeCloseSocket acceptHandle) { PrepareIOCPOperation(); - SocketError socketError = SocketError.Success; - - if (!socket.AcceptEx( + int bytesTransferred; + bool success = socket.AcceptEx( handle, acceptHandle, (_ptrSingleBuffer != IntPtr.Zero) ? _ptrSingleBuffer : _ptrAcceptBuffer, @@ -194,12 +218,9 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeCloseSocket han _acceptAddressBufferCount / 2, _acceptAddressBufferCount / 2, out bytesTransferred, - _ptrNativeOverlapped)) - { - socketError = SocketPal.GetLastSocketError(); - } + _ptrNativeOverlapped); - return socketError; + return ProcessIOCPResult(success, bytesTransferred); } private void InnerStartOperationConnect() @@ -213,25 +234,21 @@ private void InnerStartOperationConnect() CheckPinNoBuffer(); } - internal unsafe SocketError DoOperationConnect(Socket socket, SafeCloseSocket handle, out int bytesTransferred) + internal unsafe SocketError DoOperationConnect(Socket socket, SafeCloseSocket handle) { PrepareIOCPOperation(); - SocketError socketError = SocketError.Success; - - if (!socket.ConnectEx( + int bytesTransferred; + bool success = socket.ConnectEx( handle, _ptrSocketAddressBuffer, _socketAddress.Size, _ptrSingleBuffer, Count, out bytesTransferred, - _ptrNativeOverlapped)) - { - socketError = SocketPal.GetLastSocketError(); - } + _ptrNativeOverlapped); - return socketError; + return ProcessIOCPResult(success, bytesTransferred); } private void InnerStartOperationDisconnect() @@ -243,18 +260,13 @@ internal SocketError DoOperationDisconnect(Socket socket, SafeCloseSocket handle { PrepareIOCPOperation(); - SocketError socketError = SocketError.Success; - - if (!socket.DisconnectEx( + bool success = socket.DisconnectEx( handle, _ptrNativeOverlapped, (int)(DisconnectReuseSocket ? TransmitFileOptions.ReuseSocket : 0), - 0)) - { - socketError = SocketPal.GetLastSocketError(); - } + 0); - return socketError; + return ProcessIOCPResult(success, 0); } private void InnerStartOperationReceive() @@ -274,12 +286,13 @@ private void InnerStartOperationReceive() // An array of WSABuffer descriptors is allocated. } - internal unsafe SocketError DoOperationReceive(SafeCloseSocket handle, out SocketFlags flags, out int bytesTransferred) + internal unsafe SocketError DoOperationReceive(SafeCloseSocket handle, out SocketFlags flags) { PrepareIOCPOperation(); flags = _socketFlags; + int bytesTransferred; SocketError socketError; if (_buffer != null) { @@ -306,12 +319,7 @@ internal unsafe SocketError DoOperationReceive(SafeCloseSocket handle, out Socke IntPtr.Zero); } - if (socketError == SocketError.SocketError) - { - socketError = SocketPal.GetLastSocketError(); - } - - return socketError; + return ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred); } private void InnerStartOperationReceiveFrom() @@ -334,12 +342,13 @@ private void InnerStartOperationReceiveFrom() PinSocketAddressBuffer(); } - internal unsafe SocketError DoOperationReceiveFrom(SafeCloseSocket handle, out SocketFlags flags, out int bytesTransferred) + internal unsafe SocketError DoOperationReceiveFrom(SafeCloseSocket handle, out SocketFlags flags) { PrepareIOCPOperation(); flags = _socketFlags; + int bytesTransferred; SocketError socketError; if (_buffer != null) { @@ -368,12 +377,7 @@ internal unsafe SocketError DoOperationReceiveFrom(SafeCloseSocket handle, out S IntPtr.Zero); } - if (socketError == SocketError.SocketError) - { - socketError = SocketPal.GetLastSocketError(); - } - - return socketError; + return ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred); } private void InnerStartOperationReceiveMessageFrom() @@ -465,10 +469,11 @@ private void InnerStartOperationReceiveMessageFrom() } } - internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeCloseSocket handle, out int bytesTransferred) + internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeCloseSocket handle) { PrepareIOCPOperation(); + int bytesTransferred; SocketError socketError = socket.WSARecvMsg( handle, _ptrWSAMessageBuffer, @@ -476,12 +481,7 @@ internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeClo _ptrNativeOverlapped, IntPtr.Zero); - if (socketError == SocketError.SocketError) - { - socketError = SocketPal.GetLastSocketError(); - } - - return socketError; + return ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred); } private void InnerStartOperationSend() @@ -501,11 +501,12 @@ private void InnerStartOperationSend() // An array of WSABuffer descriptors is allocated. } - internal unsafe SocketError DoOperationSend(SafeCloseSocket handle, out int bytesTransferred) + internal unsafe SocketError DoOperationSend(SafeCloseSocket handle) { PrepareIOCPOperation(); SocketError socketError; + int bytesTransferred; if (_buffer != null) { // Single buffer case. @@ -531,12 +532,7 @@ internal unsafe SocketError DoOperationSend(SafeCloseSocket handle, out int byte IntPtr.Zero); } - if (socketError == SocketError.SocketError) - { - socketError = SocketPal.GetLastSocketError(); - } - - return socketError; + return ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred); } private void InnerStartOperationSendPackets() @@ -640,7 +636,7 @@ internal SocketError DoOperationSendPackets(Socket socket, SafeCloseSocket handl _ptrNativeOverlapped, _sendPacketsFlags); - return result ? SocketError.Success : SocketPal.GetLastSocketError(); + return ProcessIOCPResult(result, 0); } private void InnerStartOperationSendTo() @@ -663,10 +659,11 @@ private void InnerStartOperationSendTo() PinSocketAddressBuffer(); } - internal SocketError DoOperationSendTo(SafeCloseSocket handle, out int bytesTransferred) + internal SocketError DoOperationSendTo(SafeCloseSocket handle) { PrepareIOCPOperation(); + int bytesTransferred; SocketError socketError; if (_buffer != null) { @@ -696,12 +693,7 @@ internal SocketError DoOperationSendTo(SafeCloseSocket handle, out int bytesTran IntPtr.Zero); } - if (socketError == SocketError.SocketError) - { - socketError = SocketPal.GetLastSocketError(); - } - - return socketError; + return ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred); } // Ensures Overlapped object exists for operations that need no data buffer. @@ -1218,7 +1210,7 @@ private unsafe void CompletionPortCallback(uint errorCode, uint numBytes, Native if (socketError == SocketError.Success) { - FinishOperationSuccess(socketError, (int)numBytes, socketFlags); + FinishOperationAsyncSuccess((int)numBytes, SocketFlags.None); } else { @@ -1249,6 +1241,7 @@ private unsafe void CompletionPortCallback(uint errorCode, uint numBytes, Native } } } + FinishOperationAsyncFailure(socketError, (int)numBytes, socketFlags); } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index 4b01cf738bef..fcc13d76f5bb 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -585,6 +585,20 @@ internal void UpdatePerfCounters(int size, bool sendOp) } } + internal void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags) + { + Debug.Assert(socketError != SocketError.IOPending); + + if (socketError == SocketError.Success) + { + FinishOperationSyncSuccess(bytesTransferred, flags); + } + else + { + FinishOperationSyncFailure(socketError, bytesTransferred, flags); + } + } + internal void FinishOperationSyncFailure(SocketError socketError, int bytesTransferred, SocketFlags flags) { SetResults(socketError, bytesTransferred, flags); @@ -671,10 +685,11 @@ internal void FinishWrapperConnectSuccess(Socket connectSocket, int bytesTransfe } } - internal void FinishOperationSuccess(SocketError socketError, int bytesTransferred, SocketFlags flags) + internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags) { - SetResults(socketError, bytesTransferred, flags); + SetResults(SocketError.Success, bytesTransferred, flags); + SocketError socketError = SocketError.Success; switch (_completedOperation) { case SocketAsyncOperation.Accept: @@ -871,8 +886,15 @@ internal void FinishOperationSuccess(SocketError socketError, int bytesTransferr _currentSocket.UpdateStatusAfterSocketError(socketError); } - // Complete the operation and raise completion event. + // Complete the operation. Complete(); + } + + internal void FinishOperationAsyncSuccess(int bytesTransferred, SocketFlags flags) + { + FinishOperationSyncSuccess(bytesTransferred, flags); + + // Raise completion event. if (_context == null) { OnCompleted(this); diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index dc53587964fa..67aa36c33d79 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1341,46 +1341,98 @@ public static SocketError Shutdown(SafeCloseSocket handle, bool isConnected, boo public static SocketError ConnectAsync(Socket socket, SafeCloseSocket handle, byte[] socketAddress, int socketAddressLen, ConnectOverlappedAsyncResult asyncResult) { - return handle.AsyncContext.ConnectAsync(socketAddress, socketAddressLen, asyncResult.CompletionCallback); + SocketError socketError = handle.AsyncContext.ConnectAsync(socketAddress, socketAddressLen, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(SocketError.Success); + } + return socketError; } public static SocketError SendAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { - return handle.AsyncContext.SendAsync(buffer, offset, count, socketFlags, asyncResult.CompletionCallback); + int bytesSent; + SocketError socketError = handle.AsyncContext.SendAsync(buffer, offset, count, socketFlags, out bytesSent, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesSent, null, 0, SocketFlags.None, SocketError.Success); + } + return socketError; } public static SocketError SendAsync(SafeCloseSocket handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { - return handle.AsyncContext.SendAsync(buffers, socketFlags, asyncResult.CompletionCallback); + int bytesSent; + SocketError socketError = handle.AsyncContext.SendAsync(buffers, socketFlags, out bytesSent, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesSent, null, 0, SocketFlags.None, SocketError.Success); + } + return socketError; } public static SocketError SendFileAsync(SafeCloseSocket handle, FileStream fileStream, Action callback) { - return handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, 0, (int)fileStream.Length, callback); + long bytesSent; + SocketError socketError = handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, 0, (int)fileStream.Length, out bytesSent, callback); + if (socketError == SocketError.Success) + { + callback(bytesSent, SocketError.Success); + } + return socketError; } public static SocketError SendToAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) { asyncResult.SocketAddress = socketAddress; - return handle.AsyncContext.SendToAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, socketAddress.Size, asyncResult.CompletionCallback); + int bytesSent; + int socketAddressLen = socketAddress.Size; + SocketError socketError = handle.AsyncContext.SendToAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, ref socketAddressLen, out bytesSent, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesSent, socketAddress.Buffer, socketAddressLen, SocketFlags.None, SocketError.Success); + } + return socketError; } public static SocketError ReceiveAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { - return handle.AsyncContext.ReceiveAsync(buffer, offset, count, socketFlags, asyncResult.CompletionCallback); + int bytesReceived; + SocketFlags receivedFlags; + SocketError socketError = handle.AsyncContext.ReceiveAsync(buffer, offset, count, socketFlags, out bytesReceived, out receivedFlags, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesReceived, null, 0, receivedFlags, SocketError.Success); + } + return socketError; } public static SocketError ReceiveAsync(SafeCloseSocket handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { - return handle.AsyncContext.ReceiveAsync(buffers, socketFlags, asyncResult.CompletionCallback); + int bytesReceived; + SocketFlags receivedFlags; + SocketError socketError = handle.AsyncContext.ReceiveAsync(buffers, socketFlags, out bytesReceived, out receivedFlags, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesReceived, null, 0, receivedFlags, SocketError.Success); + } + return socketError; } public static SocketError ReceiveFromAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) { asyncResult.SocketAddress = socketAddress; - return handle.AsyncContext.ReceiveFromAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, socketAddress.InternalSize, asyncResult.CompletionCallback); + int socketAddressSize = socketAddress.InternalSize; + int bytesReceived; + SocketFlags receivedFlags; + SocketError socketError = handle.AsyncContext.ReceiveFromAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, ref socketAddressSize, out bytesReceived, out receivedFlags, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesReceived, socketAddress.Buffer, socketAddressSize, receivedFlags, SocketError.Success); + } + return socketError; } public static SocketError ReceiveMessageFromAsync(Socket socket, SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, ReceiveMessageOverlappedAsyncResult asyncResult) @@ -1390,7 +1442,16 @@ public static SocketError ReceiveMessageFromAsync(Socket socket, SafeCloseSocket bool isIPv4, isIPv6; Socket.GetIPProtocolInformation(((Socket)asyncResult.AsyncObject).AddressFamily, socketAddress, out isIPv4, out isIPv6); - return handle.AsyncContext.ReceiveMessageFromAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, socketAddress.InternalSize, isIPv4, isIPv6, asyncResult.CompletionCallback); + int socketAddressSize = socketAddress.InternalSize; + int bytesReceived; + SocketFlags receivedFlags; + IPPacketInformation ipPacketInformation; + SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(buffer, offset, count, socketFlags, socketAddress.Buffer, ref socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(bytesReceived, socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation, SocketError.Success); + } + return socketError; } public static SocketError AcceptAsync(Socket socket, SafeCloseSocket handle, SafeCloseSocket acceptHandle, int receiveSize, int socketAddressSize, AcceptOverlappedAsyncResult asyncResult) @@ -1400,7 +1461,14 @@ public static SocketError AcceptAsync(Socket socket, SafeCloseSocket handle, Saf byte[] socketAddressBuffer = new byte[socketAddressSize]; - return handle.AsyncContext.AcceptAsync(socketAddressBuffer, socketAddressSize, asyncResult.CompletionCallback); + IntPtr acceptedFd; + SocketError socketError = handle.AsyncContext.AcceptAsync(socketAddressBuffer, ref socketAddressSize, out acceptedFd, asyncResult.CompletionCallback); + if (socketError == SocketError.Success) + { + asyncResult.CompletionCallback(acceptedFd, socketAddressBuffer, socketAddressSize, SocketError.Success); + } + + return socketError; } internal static SocketError DisconnectAsync(Socket socket, SafeCloseSocket handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult) diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 009a541edf42..2b216eb6bd90 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -214,7 +214,8 @@ public static unsafe SocketError SendFile(SafeCloseSocket handle, SafeFileHandle fixed (byte* prePinnedBuffer = preBuffer) fixed (byte* postPinnedBuffer = postBuffer) { - return TransmitFileHelper(handle, fileHandle, SafeNativeOverlapped.Zero, preBuffer, postBuffer, flags); + bool success = TransmitFileHelper(handle, fileHandle, SafeNativeOverlapped.Zero, preBuffer, postBuffer, flags); + return (success ? SocketError.Success : SocketPal.GetLastSocketError()); } } @@ -802,76 +803,82 @@ public static unsafe SocketError ConnectAsync(Socket socket, SafeCloseSocket han { // This will pin the socketAddress buffer. asyncResult.SetUnmanagedStructures(socketAddress); + try + { + int ignoreBytesSent; + bool success = socket.ConnectEx( + handle, + Marshal.UnsafeAddrOfPinnedArrayElement(socketAddress, 0), + socketAddressLen, + IntPtr.Zero, + 0, + out ignoreBytesSent, + asyncResult.OverlappedHandle); - int ignoreBytesSent; - if (!socket.ConnectEx( - handle, - Marshal.UnsafeAddrOfPinnedArrayElement(socketAddress, 0), - socketAddressLen, - IntPtr.Zero, - 0, - out ignoreBytesSent, - asyncResult.OverlappedHandle)) + return asyncResult.ProcessOverlappedResult(success, 0); + } + catch { - return GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return SocketError.Success; } public static unsafe SocketError SendAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { - // Set up asyncResult for overlapped WSASend. - // This call will use completion ports. + // Set up unmanaged structures for overlapped WSASend. asyncResult.SetUnmanagedStructures(buffer, offset, count, null, false /*don't pin null remoteEP*/); + try + { + // This can throw ObjectDisposedException. + int bytesTransferred; + SocketError errorCode = Interop.Winsock.WSASend( + handle, + ref asyncResult._singleBuffer, + 1, // There is only ever 1 buffer being sent. + out bytesTransferred, + socketFlags, + asyncResult.OverlappedHandle, + IntPtr.Zero); - // This can throw ObjectDisposedException. - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSASend( - handle, - ref asyncResult._singleBuffer, - 1, // There is only ever 1 buffer being sent. - out bytesTransferred, - socketFlags, - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static unsafe SocketError SendAsync(SafeCloseSocket handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { // Set up asyncResult for overlapped WSASend. - // This call will use completion ports. asyncResult.SetUnmanagedStructures(buffers); + try + { + // This can throw ObjectDisposedException. + int bytesTransferred; + SocketError errorCode = Interop.Winsock.WSASend( + handle, + asyncResult._wsaBuffers, + asyncResult._wsaBuffers.Length, + out bytesTransferred, + socketFlags, + asyncResult.OverlappedHandle, + IntPtr.Zero); - // This can throw ObjectDisposedException. - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSASend( - handle, - asyncResult._wsaBuffers, - asyncResult._wsaBuffers.Length, - out bytesTransferred, - socketFlags, - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } // This assumes preBuffer/postBuffer are pinned already - private static unsafe SocketError TransmitFileHelper( + private static unsafe bool TransmitFileHelper( SafeHandle socket, SafeHandle fileHandle, SafeHandle overlapped, @@ -899,141 +906,157 @@ private static unsafe SocketError TransmitFileHelper( bool success = Interop.Mswsock.TransmitFile(socket, fileHandle, 0, 0, overlapped, needTransmitFileBuffers ? &transmitFileBuffers : null, flags); - return success ? SocketError.Success : GetLastSocketError(); + return success; } public static unsafe SocketError SendFileAsync(SafeCloseSocket handle, FileStream fileStream, byte[] preBuffer, byte[] postBuffer, TransmitFileOptions flags, TransmitFileAsyncResult asyncResult) { asyncResult.SetUnmanagedStructures(fileStream, preBuffer, postBuffer, (flags & (TransmitFileOptions.Disconnect | TransmitFileOptions.ReuseSocket)) != 0); - - SocketError errorCode = TransmitFileHelper(handle, fileStream?.SafeFileHandle, asyncResult.OverlappedHandle, preBuffer, postBuffer, flags); - - // This will release resources if necessary - errorCode = asyncResult.CheckAsyncCallOverlappedResult(errorCode); - - return errorCode; + try + { + bool success = TransmitFileHelper( + handle, + fileStream?.SafeFileHandle, + asyncResult.OverlappedHandle, + preBuffer, + postBuffer, + flags); + + return asyncResult.ProcessOverlappedResult(success, 0); + } + catch + { + asyncResult.ReleaseUnmanagedStructures(); + throw; + } } public static unsafe SocketError SendToAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) { // Set up asyncResult for overlapped WSASendTo. - // This call will use completion ports. asyncResult.SetUnmanagedStructures(buffer, offset, count, socketAddress, false /* don't pin RemoteEP*/); + try + { + int bytesTransferred; + SocketError errorCode = Interop.Winsock.WSASendTo( + handle, + ref asyncResult._singleBuffer, + 1, // There is only ever 1 buffer being sent. + out bytesTransferred, + socketFlags, + asyncResult.GetSocketAddressPtr(), + asyncResult.SocketAddress.Size, + asyncResult.OverlappedHandle, + IntPtr.Zero); - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSASendTo( - handle, - ref asyncResult._singleBuffer, - 1, // There is only ever 1 buffer being sent. - out bytesTransferred, - socketFlags, - asyncResult.GetSocketAddressPtr(), - asyncResult.SocketAddress.Size, - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static unsafe SocketError ReceiveAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { // Set up asyncResult for overlapped WSARecv. - // This call will use completion ports. asyncResult.SetUnmanagedStructures(buffer, offset, count, null, false /* don't pin null RemoteEP*/); + try + { + // This can throw ObjectDisposedException. + int bytesTransferred; + SocketError errorCode = Interop.Winsock.WSARecv( + handle, + ref asyncResult._singleBuffer, + 1, + out bytesTransferred, + ref socketFlags, + asyncResult.OverlappedHandle, + IntPtr.Zero); - // This can throw ObjectDisposedException. - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSARecv( - handle, - ref asyncResult._singleBuffer, - 1, - out bytesTransferred, - ref socketFlags, - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static unsafe SocketError ReceiveAsync(SafeCloseSocket handle, IList> buffers, SocketFlags socketFlags, OverlappedAsyncResult asyncResult) { // Set up asyncResult for overlapped WSASend. - // This call will use completion ports. asyncResult.SetUnmanagedStructures(buffers); + try + { + // This can throw ObjectDisposedException. + int bytesTransferred; + SocketError errorCode = Interop.Winsock.WSARecv( + handle, + asyncResult._wsaBuffers, + asyncResult._wsaBuffers.Length, + out bytesTransferred, + ref socketFlags, + asyncResult.OverlappedHandle, + IntPtr.Zero); - // This can throw ObjectDisposedException. - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSARecv( - handle, - asyncResult._wsaBuffers, - asyncResult._wsaBuffers.Length, - out bytesTransferred, - ref socketFlags, - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static unsafe SocketError ReceiveFromAsync(SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, OverlappedAsyncResult asyncResult) { // Set up asyncResult for overlapped WSARecvFrom. - // This call will use completion ports on WinNT and Overlapped IO on Win9x. asyncResult.SetUnmanagedStructures(buffer, offset, count, socketAddress, true); + try + { + int bytesTransferred; + SocketError errorCode = Interop.Winsock.WSARecvFrom( + handle, + ref asyncResult._singleBuffer, + 1, + out bytesTransferred, + ref socketFlags, + asyncResult.GetSocketAddressPtr(), + asyncResult.GetSocketAddressSizePtr(), + asyncResult.OverlappedHandle, + IntPtr.Zero); - int bytesTransferred; - SocketError errorCode = Interop.Winsock.WSARecvFrom( - handle, - ref asyncResult._singleBuffer, - 1, - out bytesTransferred, - ref socketFlags, - asyncResult.GetSocketAddressPtr(), - asyncResult.GetSocketAddressSizePtr(), - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransferred); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static unsafe SocketError ReceiveMessageFromAsync(Socket socket, SafeCloseSocket handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, Internals.SocketAddress socketAddress, ReceiveMessageOverlappedAsyncResult asyncResult) { asyncResult.SetUnmanagedStructures(buffer, offset, count, socketAddress, socketFlags); + try + { + int bytesTransfered; + SocketError errorCode = (SocketError)socket.WSARecvMsg( + handle, + Marshal.UnsafeAddrOfPinnedArrayElement(asyncResult._messageBuffer, 0), + out bytesTransfered, + asyncResult.OverlappedHandle, + IntPtr.Zero); - int bytesTransfered; - SocketError errorCode = (SocketError)socket.WSARecvMsg( - handle, - Marshal.UnsafeAddrOfPinnedArrayElement(asyncResult._messageBuffer, 0), - out bytesTransfered, - asyncResult.OverlappedHandle, - IntPtr.Zero); - - if (errorCode != SocketError.Success) + return asyncResult.ProcessOverlappedResult(errorCode == SocketError.Success, bytesTransfered); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static unsafe SocketError AcceptAsync(Socket socket, SafeCloseSocket handle, SafeCloseSocket acceptHandle, int receiveSize, int socketAddressSize, AcceptOverlappedAsyncResult asyncResult) @@ -1046,24 +1069,27 @@ public static unsafe SocketError AcceptAsync(Socket socket, SafeCloseSocket hand // Set up asyncResult for overlapped AcceptEx. // This call will use completion ports on WinNT. asyncResult.SetUnmanagedStructures(buffer, addressBufferSize); + try + { + // This can throw ObjectDisposedException. + int bytesTransferred; + bool success = socket.AcceptEx( + handle, + acceptHandle, + Marshal.UnsafeAddrOfPinnedArrayElement(asyncResult.Buffer, 0), + receiveSize, + addressBufferSize, + addressBufferSize, + out bytesTransferred, + asyncResult.OverlappedHandle); - // This can throw ObjectDisposedException. - int bytesTransferred; - SocketError errorCode = SocketError.Success; - if (!socket.AcceptEx( - handle, - acceptHandle, - Marshal.UnsafeAddrOfPinnedArrayElement(asyncResult.Buffer, 0), - receiveSize, - addressBufferSize, - addressBufferSize, - out bytesTransferred, - asyncResult.OverlappedHandle)) + return asyncResult.ProcessOverlappedResult(success, 0); + } + catch { - errorCode = GetLastSocketError(); + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } public static void CheckDualModeReceiveSupport(Socket socket) @@ -1074,15 +1100,22 @@ public static void CheckDualModeReceiveSupport(Socket socket) internal static SocketError DisconnectAsync(Socket socket, SafeCloseSocket handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult) { asyncResult.SetUnmanagedStructures(null); - - // This can throw ObjectDisposedException - SocketError errorCode = SocketError.Success; - if (!socket.DisconnectEx(handle, asyncResult.OverlappedHandle, (int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0), 0)) + try { - errorCode = GetLastSocketError(); + // This can throw ObjectDisposedException + bool success = socket.DisconnectEx( + handle, + asyncResult.OverlappedHandle, + (int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0), + 0); + + return asyncResult.ProcessOverlappedResult(success, 0); + } + catch + { + asyncResult.ReleaseUnmanagedStructures(); + throw; } - - return errorCode; } internal static SocketError Disconnect(Socket socket, SafeCloseSocket handle, bool reuseSocket) diff --git a/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs b/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs index 1542628ba1e7..c3a03890f741 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFromAsync.cs @@ -44,7 +44,11 @@ public void Success() args.Completed += OnCompleted; args.UserToken = completed; - Assert.True(receiver.ReceiveMessageFromAsync(args)); + bool pending = receiver.ReceiveMessageFromAsync(args); + if (!pending) + { + OnCompleted(null, args); + } Assert.True(completed.WaitOne(TestSettings.PassingTestTimeout), "Timeout while waiting for connection"); @@ -84,7 +88,11 @@ public void Success_IPv6() args.Completed += OnCompleted; args.UserToken = completed; - Assert.True(receiver.ReceiveMessageFromAsync(args)); + bool pending = receiver.ReceiveMessageFromAsync(args); + if (!pending) + { + OnCompleted(null, args); + } Assert.True(completed.WaitOne(TestSettings.PassingTestTimeout), "Timeout while waiting for connection"); diff --git a/src/System.Net.Sockets/tests/FunctionalTests/Shutdown.cs b/src/System.Net.Sockets/tests/FunctionalTests/Shutdown.cs index c83c827fd493..1e46e00d05df 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/Shutdown.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/Shutdown.cs @@ -30,7 +30,11 @@ private static void OnOperationCompleted(object sender, SocketAsyncEventArgs arg args.SetBuffer(new byte[1], 0, 1); args.UserToken = client; - Assert.True(client.ReceiveAsync(args)); + bool pending = client.ReceiveAsync(args); + if (!pending) + { + OnOperationCompleted(null, args); + } break; } @@ -43,7 +47,11 @@ private static void OnOperationCompleted(object sender, SocketAsyncEventArgs arg break; } - Assert.True(client.SendAsync(args)); + bool pending = client.SendAsync(args); + if (!pending) + { + OnOperationCompleted(null, args); + } break; } @@ -52,7 +60,12 @@ private static void OnOperationCompleted(object sender, SocketAsyncEventArgs arg var client = (Socket)args.UserToken; Assert.True(args.BytesTransferred == args.Buffer.Length); - Assert.True(client.ReceiveAsync(args)); + + bool pending = client.ReceiveAsync(args); + if (!pending) + { + OnOperationCompleted(null, args); + } break; } } diff --git a/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs index 82fb41fe281a..fe363e15d8c3 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs @@ -70,9 +70,11 @@ public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_Success() using (Socket sock = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified)) { - Assert.True(sock.ConnectAsync(args)); - - await complete.Task; + bool willRaiseEvent = sock.ConnectAsync(args); + if (willRaiseEvent) + { + await complete.Task; + } Assert.Equal(SocketError.Success, args.SocketError); Assert.Null(args.ConnectByNameError);