diff --git a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs index a9de12c3931..d62805d7e89 100644 --- a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs +++ b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs @@ -258,30 +258,7 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode, string token, DateTime listenSocket.Listen(1); HyperVSocket = listenSocket.Accept(); - TimeSpan timeout = TimeSpan.FromMinutes(MAX_TOKEN_LIFE_MINUTES); - DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout); - DateTimeOffset now = DateTimeOffset.UtcNow; - - // Calculate remaining time and create cancellation token - TimeSpan remainingTime = timeoutExpiry - now; - - // Check if the token has already expired - if (remainingTime <= TimeSpan.Zero) - { - throw new PSDirectException( - PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired")); - } - - // Set socket timeout for receive operations to prevent indefinite blocking - int timeoutMs = (int)remainingTime.TotalMilliseconds; - HyperVSocket.ReceiveTimeout = timeoutMs; - HyperVSocket.SendTimeout = timeoutMs; - - // Create a cancellation token that will be cancelled when the timeout expires - using var cancellationTokenSource = new CancellationTokenSource(remainingTime); - CancellationToken cancellationToken = cancellationTokenSource.Token; - - ValidateToken(HyperVSocket, token, cancellationToken); + ValidateToken(HyperVSocket, token, tokenCreationTime, MAX_TOKEN_LIFE_MINUTES * 60); Stream = new NetworkStream(HyperVSocket, true); @@ -389,9 +366,33 @@ public void Dispose() /// /// The connected HyperVSocket. /// The expected token string. - /// Cancellation token for timeout handling. - internal static void ValidateToken(Socket socket, string token, CancellationToken cancellationToken = default) + /// The creation time of the token. + /// The maximum lifetime of the token in seconds. + internal static void ValidateToken(Socket socket, string token, DateTimeOffset tokenCreationTime, int maxTokenLifeSeconds) { + TimeSpan timeout = TimeSpan.FromSeconds(maxTokenLifeSeconds); + DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout); + DateTimeOffset now = DateTimeOffset.UtcNow; + + // Calculate remaining time and create cancellation token + TimeSpan remainingTime = timeoutExpiry - now; + + // Check if the token has already expired + if (remainingTime <= TimeSpan.Zero) + { + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired")); + } + + // Create a cancellation token that will be cancelled when the timeout expires + using var cancellationTokenSource = new CancellationTokenSource(remainingTime); + CancellationToken cancellationToken = cancellationTokenSource.Token; + + // Set socket timeout for receive operations to prevent indefinite blocking + int timeoutMs = (int)remainingTime.TotalMilliseconds; + socket.ReceiveTimeout = timeoutMs; + socket.SendTimeout = timeoutMs; + // Check for cancellation before starting validation cancellationToken.ThrowIfCancellationRequested(); @@ -430,6 +431,7 @@ internal static void ValidateToken(Socket socket, string token, CancellationToke // So we expect a response of length 6 + 100 = 106 characters. responseString = RemoteSessionHyperVSocketClient.ReceiveResponse(socket, 110); + // Final check if we got the token before the timeout cancellationToken.ThrowIfCancellationRequested(); if (string.IsNullOrEmpty(responseString) || !responseString.StartsWith("TOKEN ", StringComparison.Ordinal)) @@ -454,6 +456,9 @@ internal static void ValidateToken(Socket socket, string token, CancellationToke // Acknowledge the token is valid with "PASS". socket.Send("PASS"u8); + + socket.ReceiveTimeout = 0; // Disable the timeout after successful validation + socket.SendTimeout = 0; } } diff --git a/test/xUnit/csharp/test_RemoteHyperV.cs b/test/xUnit/csharp/test_RemoteHyperV.cs index f694f6894df..c75159ac279 100644 --- a/test/xUnit/csharp/test_RemoteHyperV.cs +++ b/test/xUnit/csharp/test_RemoteHyperV.cs @@ -20,7 +20,7 @@ namespace PSTests.Sequential public class RemoteHyperVTests { private static ITestOutputHelper _output; - private static TimeSpan timeout = TimeSpan.FromSeconds(15); + private static readonly TimeSpan timeout = TimeSpan.FromSeconds(15); public RemoteHyperVTests(ITestOutputHelper output) { @@ -63,15 +63,51 @@ private static void ConnectWithRetry(Socket client, IPAddress address, int port, } } + private static void SendResponse(string name, Socket client, Queue<(byte[] bytes, int delayMs)> serverResponses) + { + if (serverResponses.Count > 0) + { + _output.WriteLine($"Mock {name} ----------------------------------------------------"); + var respTuple = serverResponses.Dequeue(); + var resp = respTuple.bytes; + + if (respTuple.delayMs > 0) + { + _output.WriteLine($"Mock {name} - delaying response by {respTuple.delayMs} ms"); + Thread.Sleep(respTuple.delayMs); + } + if (resp.Length > 0) { + client.Send(resp, resp.Length, SocketFlags.None); + _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); + } + } + } + private static void StartHandshakeServer( string name, int port, - IEnumerable<(string message, - Encoding encoding)> expectedClientSends, + IEnumerable<(string message, Encoding encoding)> expectedClientSends, IEnumerable<(string message, Encoding encoding)> serverResponses, bool verifyConnectionClosed, CancellationToken cancellationToken, bool sendFirst = false) + { + IEnumerable<(string message, Encoding encoding, int delayMs)> serverResponsesWithDelay = new List<(string message, Encoding encoding, int delayMs)>(); + foreach (var item in serverResponses) + { + ((List<(string message, Encoding encoding, int delayMs)>)serverResponsesWithDelay).Add((item.message, item.encoding, 1)); + } + StartHandshakeServer(name, port, expectedClientSends, serverResponsesWithDelay, verifyConnectionClosed, cancellationToken, sendFirst); + } + + private static void StartHandshakeServer( + string name, + int port, + IEnumerable<(string message, Encoding encoding)> expectedClientSends, + IEnumerable<(string message, Encoding encoding, int delayMs)> serverResponses, + bool verifyConnectionClosed, + CancellationToken cancellationToken, + bool sendFirst = false) { var expectedMessages = new Queue<(string message, byte[] bytes, Encoding encoding)>(); foreach (var item in expectedClientSends) @@ -80,17 +116,27 @@ private static void StartHandshakeServer( expectedMessages.Enqueue((message: item.message, bytes: itemBytes, encoding: item.encoding)); } - var serverResponseBytes = new Queue(); + var serverResponseBytes = new Queue<(byte[] bytes, int delayMs)>(); foreach (var item in serverResponses) { - serverResponseBytes.Enqueue(item.encoding.GetBytes(item.message)); + (byte[] bytes, int delayMs) queueItem = (item.encoding.GetBytes(item.message), item.delayMs); + serverResponseBytes.Enqueue(queueItem); } - StartHandshakeServer(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst); + _output.WriteLine($"Mock {name} - starting listener on port {port} with {expectedMessages.Count} expected messages and {serverResponseBytes.Count} responses."); + StartHandshakeServerImplementation(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst); } - private static void StartHandshakeServer(string name, int port, Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends, Queue serverResponses, bool verifyConnectionClosed, CancellationToken cancellationToken, bool sendFirst = false) + private static void StartHandshakeServerImplementation( + string name, + int port, + Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends, + Queue<(byte[] bytes, int delayMs)> serverResponses, + bool verifyConnectionClosed, + CancellationToken cancellationToken, + bool sendFirst = false) { + DateTime startTime = DateTime.UtcNow; var buffer = new byte[1024]; var listener = new TcpListener(IPAddress.Loopback, port); listener.Start(); @@ -101,19 +147,16 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me if (sendFirst) { // Send the first message from the serverResponses queue - if (serverResponses.Count > 0) - { - var resp = serverResponses.Dequeue(); - client.Send(resp, resp.Length, SocketFlags.None); - _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); - } + SendResponse(name, client, serverResponses); } while (expectedClientSends.Count > 0) { + _output.WriteLine($"Mock {name} - time elapsed: {(DateTime.UtcNow - startTime).TotalMilliseconds} milliseconds"); client.ReceiveTimeout = 2 * 1000; // 2 seconds timeout for receiving data cancellationToken.ThrowIfCancellationRequested(); var expectedMessage = expectedClientSends.Dequeue(); + _output.WriteLine($"Mock {name} - remaining expected messages: {expectedClientSends.Count}"); var expected = expectedMessage.bytes; Array.Clear(buffer, 0, buffer.Length); int received = client.Receive(buffer); @@ -143,12 +186,7 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me throw new Exception(errorMessage); } _output.WriteLine($"Mock {name} - received expected message: " + expectedString); - if (serverResponses.Count > 0) - { - var resp = serverResponses.Dequeue(); - client.Send(resp, resp.Length, SocketFlags.None); - _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); - } + SendResponse(name, client, serverResponses); } if (verifyConnectionClosed) @@ -178,6 +216,7 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me } catch (ObjectDisposedException) { + _output.WriteLine($"Mock {name} - socket already closed."); // Socket already closed } } @@ -185,8 +224,16 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me _output.WriteLine($"Mock {name} - on port {port} completed successfully."); } + catch (Exception ex) + { + _output.WriteLine($"Mock {name} - Exception: {ex.Message} {ex.GetType().FullName}"); + _output.WriteLine(ex.StackTrace); + throw; + } finally { + _output.WriteLine($"Mock {name} - remaining expected messages: {expectedClientSends.Count}"); + _output.WriteLine($"Mock {name} - stopping listener on port {port}."); listener.Stop(); } } @@ -615,13 +662,111 @@ public async Task ValidatePassesWhenTokensMatch(string token, string expectedTok using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { ConnectWithRetry(client, IPAddress.Loopback, port, _output); - System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken); + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 1); System.Threading.Thread.Sleep(100); // Allow time for server to process } await serverTask; } + [SkippableTheory] + [InlineData(5500, "A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond.", "SocketException")] // test the socket timeout + [InlineData(3200, "canceled", "System.OperationCanceledException")] // test the cancellation token + [InlineData(10, "", "")] + public async Task ValidateTokenTimeoutFails(int timeoutMs, string expectedMessage, string expectedExceptionType = "SocketException") + { + string token = "testToken"; + string expectedToken = token; + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + + var expectedClientSends = new List<(string message, Encoding encoding, int delayMs)>{ + (message: "VERSION", encoding: Encoding.ASCII, delayMs: timeoutMs), // Response to VERSION + (message: "VERSION_2", encoding: Encoding.ASCII, delayMs: timeoutMs), // Response to VERSION_2 + (message: $"TOKEN {token}", encoding: Encoding.ASCII, delayMs: 1) + }; + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII) // Response to token + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: true, cts.Token, sendFirst: true), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + if (expectedMessage.Length > 0) + { + var exception = Record.Exception( + () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5)); // set the timeout to 5 seconds or 5000 ms + Assert.NotNull(exception); + string exceptionType = exception.GetType().FullName; + _output.WriteLine($"Caught exception of type {exceptionType} with message: {exception.Message}"); + Assert.Contains(expectedExceptionType, exceptionType, StringComparison.OrdinalIgnoreCase); + Assert.Contains(expectedMessage, exception.Message, StringComparison.OrdinalIgnoreCase); + } + else + { + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5); + } + System.Threading.Thread.Sleep(100); // Allow time for server to process + } + + if (expectedMessage.Length == 0) + { + await serverTask; + } + } + + [SkippableFact] + public async Task ValidateTokenTimeoutDoesAffectSession() + { + string token = "testToken"; + string expectedToken = token; + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + + var expectedClientSends = new List<(string message, Encoding encoding, int delayMs)>{ + (message: "VERSION", encoding: Encoding.ASCII, delayMs: 1), // Response to VERSION + (message: "VERSION_2", encoding: Encoding.ASCII, delayMs: 1), // Response to VERSION_2 + (message: $"TOKEN {token}", encoding: Encoding.ASCII, delayMs: 1), + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 99), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 100), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 101), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 102), // Send some data after the handshake + (message: string.Empty, encoding: Encoding.ASCII, delayMs: 103) // Send some data after the handshake + }; + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to token + (message: "PSRP-Message0", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message1", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message2", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message3", encoding: Encoding.ASCII), // Indicate server is ready to receive data + (message: "PSRP-Message4", encoding: Encoding.ASCII) // + + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(2)); + var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: false, cts.Token, sendFirst: true), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5); + for (int i = 0; i < 5; i++) + { + System.Threading.Thread.Sleep(1500); + client.Send(Encoding.ASCII.GetBytes($"PSRP-Message{i}")); // Send some data after the handshake + } + } + + await serverTask; + } + [SkippableTheory] [InlineData("abc", "xyz")] [InlineData("abc", "abcdef")] @@ -649,8 +794,9 @@ public async Task ValidateFailsWhenTokensMismatch(string token, string expectedT using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { ConnectWithRetry(client, IPAddress.Loopback, port, _output); + DateTimeOffset tokenCreationTime = DateTimeOffset.UtcNow; // Token created 10 minutes ago var exception = Assert.Throws( - () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken)); + () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, tokenCreationTime, 5)); System.Threading.Thread.Sleep(100); // Allow time for server to process Assert.Contains("The credential is invalid.", exception.Message); }