From 65ccd22177e05f56e41273daf22b923940284e64 Mon Sep 17 00:00:00 2001 From: Travis Plunk Date: Mon, 22 Sep 2025 10:56:19 -0700 Subject: [PATCH 1/2] Ensure that socket timeouts are set only during the token validation (#26066) The main goal is to ensure that socket timeouts are set only during the token validation phase and are properly reset afterward, improving reliability and preventing unintended blocking or premature timeouts in subsequent operations. --- .../common/RemoteSessionHyperVSocket.cs | 57 +++--- test/xUnit/csharp/test_RemoteHyperV.cs | 186 ++++++++++++++++-- 2 files changed, 197 insertions(+), 46 deletions(-) diff --git a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs index a9de12c3931..d62805d7e89 100644 --- a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs +++ b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs @@ -258,30 +258,7 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode, string token, DateTime listenSocket.Listen(1); HyperVSocket = listenSocket.Accept(); - TimeSpan timeout = TimeSpan.FromMinutes(MAX_TOKEN_LIFE_MINUTES); - DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout); - DateTimeOffset now = DateTimeOffset.UtcNow; - - // Calculate remaining time and create cancellation token - TimeSpan remainingTime = timeoutExpiry - now; - - // Check if the token has already expired - if (remainingTime <= TimeSpan.Zero) - { - throw new PSDirectException( - PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired")); - } - - // Set socket timeout for receive operations to prevent indefinite blocking - int timeoutMs = (int)remainingTime.TotalMilliseconds; - HyperVSocket.ReceiveTimeout = timeoutMs; - HyperVSocket.SendTimeout = timeoutMs; - - // Create a cancellation token that will be cancelled when the timeout expires - using var cancellationTokenSource = new CancellationTokenSource(remainingTime); - CancellationToken cancellationToken = cancellationTokenSource.Token; - - ValidateToken(HyperVSocket, token, cancellationToken); + ValidateToken(HyperVSocket, token, tokenCreationTime, MAX_TOKEN_LIFE_MINUTES * 60); Stream = new NetworkStream(HyperVSocket, true); @@ -389,9 +366,33 @@ public void Dispose() /// /// The connected HyperVSocket. /// The expected token string. - /// Cancellation token for timeout handling. - internal static void ValidateToken(Socket socket, string token, CancellationToken cancellationToken = default) + /// The creation time of the token. + /// The maximum lifetime of the token in seconds. + internal static void ValidateToken(Socket socket, string token, DateTimeOffset tokenCreationTime, int maxTokenLifeSeconds) { + TimeSpan timeout = TimeSpan.FromSeconds(maxTokenLifeSeconds); + DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout); + DateTimeOffset now = DateTimeOffset.UtcNow; + + // Calculate remaining time and create cancellation token + TimeSpan remainingTime = timeoutExpiry - now; + + // Check if the token has already expired + if (remainingTime <= TimeSpan.Zero) + { + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired")); + } + + // Create a cancellation token that will be cancelled when the timeout expires + using var cancellationTokenSource = new CancellationTokenSource(remainingTime); + CancellationToken cancellationToken = cancellationTokenSource.Token; + + // Set socket timeout for receive operations to prevent indefinite blocking + int timeoutMs = (int)remainingTime.TotalMilliseconds; + socket.ReceiveTimeout = timeoutMs; + socket.SendTimeout = timeoutMs; + // Check for cancellation before starting validation cancellationToken.ThrowIfCancellationRequested(); @@ -430,6 +431,7 @@ internal static void ValidateToken(Socket socket, string token, CancellationToke // So we expect a response of length 6 + 100 = 106 characters. responseString = RemoteSessionHyperVSocketClient.ReceiveResponse(socket, 110); + // Final check if we got the token before the timeout cancellationToken.ThrowIfCancellationRequested(); if (string.IsNullOrEmpty(responseString) || !responseString.StartsWith("TOKEN ", StringComparison.Ordinal)) @@ -454,6 +456,9 @@ internal static void ValidateToken(Socket socket, string token, CancellationToke // Acknowledge the token is valid with "PASS". socket.Send("PASS"u8); + + socket.ReceiveTimeout = 0; // Disable the timeout after successful validation + socket.SendTimeout = 0; } } diff --git a/test/xUnit/csharp/test_RemoteHyperV.cs b/test/xUnit/csharp/test_RemoteHyperV.cs index f694f6894df..27f7fb17375 100644 --- a/test/xUnit/csharp/test_RemoteHyperV.cs +++ b/test/xUnit/csharp/test_RemoteHyperV.cs @@ -63,15 +63,51 @@ private static void ConnectWithRetry(Socket client, IPAddress address, int port, } } + private static void SendResponse(string name, Socket client, Queue<(byte[] bytes, int delayMs)> serverResponses) + { + if (serverResponses.Count > 0) + { + _output.WriteLine($"Mock {name} ----------------------------------------------------"); + var respTuple = serverResponses.Dequeue(); + var resp = respTuple.bytes; + + if (respTuple.delayMs > 0) + { + _output.WriteLine($"Mock {name} - delaying response by {respTuple.delayMs} ms"); + Thread.Sleep(respTuple.delayMs); + } + if (resp.Length > 0) { + client.Send(resp, resp.Length, SocketFlags.None); + _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); + } + } + } + private static void StartHandshakeServer( string name, int port, - IEnumerable<(string message, - Encoding encoding)> expectedClientSends, + IEnumerable<(string message, Encoding encoding)> expectedClientSends, IEnumerable<(string message, Encoding encoding)> serverResponses, bool verifyConnectionClosed, CancellationToken cancellationToken, bool sendFirst = false) + { + IEnumerable<(string message, Encoding encoding, int delayMs)> serverResponsesWithDelay = new List<(string message, Encoding encoding, int delayMs)>(); + foreach (var item in serverResponses) + { + ((List<(string message, Encoding encoding, int delayMs)>)serverResponsesWithDelay).Add((item.message, item.encoding, 1)); + } + StartHandshakeServer(name, port, expectedClientSends, serverResponsesWithDelay, verifyConnectionClosed, cancellationToken, sendFirst); + } + + private static void StartHandshakeServer( + string name, + int port, + IEnumerable<(string message, Encoding encoding)> expectedClientSends, + IEnumerable<(string message, Encoding encoding, int delayMs)> serverResponses, + bool verifyConnectionClosed, + CancellationToken cancellationToken, + bool sendFirst = false) { var expectedMessages = new Queue<(string message, byte[] bytes, Encoding encoding)>(); foreach (var item in expectedClientSends) @@ -80,17 +116,27 @@ private static void StartHandshakeServer( expectedMessages.Enqueue((message: item.message, bytes: itemBytes, encoding: item.encoding)); } - var serverResponseBytes = new Queue(); + var serverResponseBytes = new Queue<(byte[] bytes, int delayMs)>(); foreach (var item in serverResponses) { - serverResponseBytes.Enqueue(item.encoding.GetBytes(item.message)); + (byte[] bytes, int delayMs) queueItem = (item.encoding.GetBytes(item.message), item.delayMs); + serverResponseBytes.Enqueue(queueItem); } - StartHandshakeServer(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst); + _output.WriteLine($"Mock {name} - starting listener on port {port} with {expectedMessages.Count} expected messages and {serverResponseBytes.Count} responses."); + StartHandshakeServerImplementation(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst); } - private static void StartHandshakeServer(string name, int port, Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends, Queue serverResponses, bool verifyConnectionClosed, CancellationToken cancellationToken, bool sendFirst = false) + private static void StartHandshakeServerImplementation( + string name, + int port, + Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends, + Queue<(byte[] bytes, int delayMs)> serverResponses, + bool verifyConnectionClosed, + CancellationToken cancellationToken, + bool sendFirst = false) { + DateTime startTime = DateTime.UtcNow; var buffer = new byte[1024]; var listener = new TcpListener(IPAddress.Loopback, port); listener.Start(); @@ -101,19 +147,16 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me if (sendFirst) { // Send the first message from the serverResponses queue - if (serverResponses.Count > 0) - { - var resp = serverResponses.Dequeue(); - client.Send(resp, resp.Length, SocketFlags.None); - _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); - } + SendResponse(name, client, serverResponses); } while (expectedClientSends.Count > 0) { + _output.WriteLine($"Mock {name} - time elapsed: {(DateTime.UtcNow - startTime).TotalMilliseconds} milliseconds"); client.ReceiveTimeout = 2 * 1000; // 2 seconds timeout for receiving data cancellationToken.ThrowIfCancellationRequested(); var expectedMessage = expectedClientSends.Dequeue(); + _output.WriteLine($"Mock {name} - remaining expected messages: {expectedClientSends.Count}"); var expected = expectedMessage.bytes; Array.Clear(buffer, 0, buffer.Length); int received = client.Receive(buffer); @@ -143,12 +186,7 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me throw new Exception(errorMessage); } _output.WriteLine($"Mock {name} - received expected message: " + expectedString); - if (serverResponses.Count > 0) - { - var resp = serverResponses.Dequeue(); - client.Send(resp, resp.Length, SocketFlags.None); - _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); - } + SendResponse(name, client, serverResponses); } if (verifyConnectionClosed) @@ -178,6 +216,7 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me } catch (ObjectDisposedException) { + _output.WriteLine($"Mock {name} - socket already closed."); // Socket already closed } } @@ -185,8 +224,16 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me _output.WriteLine($"Mock {name} - on port {port} completed successfully."); } + catch (Exception ex) + { + _output.WriteLine($"Mock {name} - Exception: {ex.Message} {ex.GetType().FullName}"); + _output.WriteLine(ex.StackTrace); + throw; + } finally { + _output.WriteLine($"Mock {name} - remaining expected messages: {expectedClientSends.Count}"); + _output.WriteLine($"Mock {name} - stopping listener on port {port}."); listener.Stop(); } } @@ -615,13 +662,111 @@ public async Task ValidatePassesWhenTokensMatch(string token, string expectedTok using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { ConnectWithRetry(client, IPAddress.Loopback, port, _output); - System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken); + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 1); System.Threading.Thread.Sleep(100); // Allow time for server to process } await serverTask; } + [SkippableTheory] + [InlineData(5500, "A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond.", "SocketException")] // test the socket timeout + [InlineData(3200, "canceled", "System.OperationCanceledException")] // test the cancellation token + [InlineData(10, "", "")] + public async Task ValidateTokenTimeoutFails(int timeoutMs, string expectedMessage, string expectedExceptionType = "SocketException") + { + string token = "testToken"; + string expectedToken = token; + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + + var expectedClientSends = new List<(string message, Encoding encoding, int delayMs)>{ + (message: "VERSION", encoding: Encoding.ASCII, delayMs: timeoutMs), // Response to VERSION + (message: "VERSION_2", encoding: Encoding.ASCII, delayMs: timeoutMs), // Response to VERSION_2 + (message: $"TOKEN {token}", encoding: Encoding.ASCII, delayMs: 1) + }; + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII) // Response to token + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: true, cts.Token, sendFirst: true), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + if (expectedMessage.Length > 0) + { + var exception = Record.Exception( + () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5)); // set the timeout to 5 seconds or 5000 ms + Assert.NotNull(exception); + string exceptionType = exception.GetType().FullName; + _output.WriteLine($"Caught exception of type {exceptionType} with message: {exception.Message}"); + Assert.Contains(expectedExceptionType, exceptionType, StringComparison.OrdinalIgnoreCase); + Assert.Contains(expectedMessage, exception.Message, StringComparison.OrdinalIgnoreCase); + } + else + { + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5); + } + System.Threading.Thread.Sleep(100); // Allow time for server to process + } + + if (expectedMessage.Length == 0) + { + await serverTask; + } + } + + [SkippableFact] + public async Task ValidateTokenTimeoutDoesAffectSession() + { + string token = "testToken"; + string expectedToken = token; + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + + var expectedClientSends = new List<(string message, Encoding encoding, int delayMs)>{ + (message: "VERSION", encoding: Encoding.ASCII, delayMs: 1), // Response to VERSION + (message: "VERSION_2", encoding: Encoding.ASCII, delayMs: 1), // Response to VERSION_2 + (message: $"TOKEN {token}", encoding: Encoding.ASCII, delayMs: 1), + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 99), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 100), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 101), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 102), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 103) // Send some data after the handshake + }; + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to token + (message: "PSRP-Message0", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message1", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message2", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message3", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message4", encoding: Encoding.ASCII) // + + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(2)); + var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: false, cts.Token, sendFirst: true), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5); + for (int i = 0; i < 5; i++) + { + System.Threading.Thread.Sleep(1500); + client.Send(Encoding.ASCII.GetBytes($"PSRP-Message{i}")); // Send some data after the handshake + } + } + + await serverTask; + } + [SkippableTheory] [InlineData("abc", "xyz")] [InlineData("abc", "abcdef")] @@ -649,8 +794,9 @@ public async Task ValidateFailsWhenTokensMismatch(string token, string expectedT using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { ConnectWithRetry(client, IPAddress.Loopback, port, _output); + DateTimeOffset tokenCreationTime = DateTimeOffset.UtcNow; // Token created 10 minutes ago var exception = Assert.Throws( - () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken)); + () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, tokenCreationTime, 5)); System.Threading.Thread.Sleep(100); // Allow time for server to process Assert.Contains("The credential is invalid.", exception.Message); } From 96b2040c8bac603e11743b3ca1f1446d6e0d9293 Mon Sep 17 00:00:00 2001 From: Travis Plunk Date: Mon, 22 Sep 2025 13:51:26 -0700 Subject: [PATCH 2/2] Change timeout variable to readonly --- test/xUnit/csharp/test_RemoteHyperV.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/xUnit/csharp/test_RemoteHyperV.cs b/test/xUnit/csharp/test_RemoteHyperV.cs index 27f7fb17375..c75159ac279 100644 --- a/test/xUnit/csharp/test_RemoteHyperV.cs +++ b/test/xUnit/csharp/test_RemoteHyperV.cs @@ -20,7 +20,7 @@ namespace PSTests.Sequential public class RemoteHyperVTests { private static ITestOutputHelper _output; - private static TimeSpan timeout = TimeSpan.FromSeconds(15); + private static readonly TimeSpan timeout = TimeSpan.FromSeconds(15); public RemoteHyperVTests(ITestOutputHelper output) {