diff --git a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs
index a9de12c3931..d62805d7e89 100644
--- a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs
+++ b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs
@@ -258,30 +258,7 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode, string token, DateTime
listenSocket.Listen(1);
HyperVSocket = listenSocket.Accept();
- TimeSpan timeout = TimeSpan.FromMinutes(MAX_TOKEN_LIFE_MINUTES);
- DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout);
- DateTimeOffset now = DateTimeOffset.UtcNow;
-
- // Calculate remaining time and create cancellation token
- TimeSpan remainingTime = timeoutExpiry - now;
-
- // Check if the token has already expired
- if (remainingTime <= TimeSpan.Zero)
- {
- throw new PSDirectException(
- PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired"));
- }
-
- // Set socket timeout for receive operations to prevent indefinite blocking
- int timeoutMs = (int)remainingTime.TotalMilliseconds;
- HyperVSocket.ReceiveTimeout = timeoutMs;
- HyperVSocket.SendTimeout = timeoutMs;
-
- // Create a cancellation token that will be cancelled when the timeout expires
- using var cancellationTokenSource = new CancellationTokenSource(remainingTime);
- CancellationToken cancellationToken = cancellationTokenSource.Token;
-
- ValidateToken(HyperVSocket, token, cancellationToken);
+ ValidateToken(HyperVSocket, token, tokenCreationTime, MAX_TOKEN_LIFE_MINUTES * 60);
Stream = new NetworkStream(HyperVSocket, true);
@@ -389,9 +366,33 @@ public void Dispose()
///
/// The connected HyperVSocket.
/// The expected token string.
- /// Cancellation token for timeout handling.
- internal static void ValidateToken(Socket socket, string token, CancellationToken cancellationToken = default)
+ /// The creation time of the token.
+ /// The maximum lifetime of the token in seconds.
+ internal static void ValidateToken(Socket socket, string token, DateTimeOffset tokenCreationTime, int maxTokenLifeSeconds)
{
+ TimeSpan timeout = TimeSpan.FromSeconds(maxTokenLifeSeconds);
+ DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout);
+ DateTimeOffset now = DateTimeOffset.UtcNow;
+
+ // Calculate remaining time and create cancellation token
+ TimeSpan remainingTime = timeoutExpiry - now;
+
+ // Check if the token has already expired
+ if (remainingTime <= TimeSpan.Zero)
+ {
+ throw new PSDirectException(
+ PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired"));
+ }
+
+ // Create a cancellation token that will be cancelled when the timeout expires
+ using var cancellationTokenSource = new CancellationTokenSource(remainingTime);
+ CancellationToken cancellationToken = cancellationTokenSource.Token;
+
+ // Set socket timeout for receive operations to prevent indefinite blocking
+ int timeoutMs = (int)remainingTime.TotalMilliseconds;
+ socket.ReceiveTimeout = timeoutMs;
+ socket.SendTimeout = timeoutMs;
+
// Check for cancellation before starting validation
cancellationToken.ThrowIfCancellationRequested();
@@ -430,6 +431,7 @@ internal static void ValidateToken(Socket socket, string token, CancellationToke
// So we expect a response of length 6 + 100 = 106 characters.
responseString = RemoteSessionHyperVSocketClient.ReceiveResponse(socket, 110);
+ // Final check if we got the token before the timeout
cancellationToken.ThrowIfCancellationRequested();
if (string.IsNullOrEmpty(responseString) || !responseString.StartsWith("TOKEN ", StringComparison.Ordinal))
@@ -454,6 +456,9 @@ internal static void ValidateToken(Socket socket, string token, CancellationToke
// Acknowledge the token is valid with "PASS".
socket.Send("PASS"u8);
+
+ socket.ReceiveTimeout = 0; // Disable the timeout after successful validation
+ socket.SendTimeout = 0;
}
}
diff --git a/test/xUnit/csharp/test_RemoteHyperV.cs b/test/xUnit/csharp/test_RemoteHyperV.cs
index f694f6894df..c75159ac279 100644
--- a/test/xUnit/csharp/test_RemoteHyperV.cs
+++ b/test/xUnit/csharp/test_RemoteHyperV.cs
@@ -20,7 +20,7 @@ namespace PSTests.Sequential
public class RemoteHyperVTests
{
private static ITestOutputHelper _output;
- private static TimeSpan timeout = TimeSpan.FromSeconds(15);
+ private static readonly TimeSpan timeout = TimeSpan.FromSeconds(15);
public RemoteHyperVTests(ITestOutputHelper output)
{
@@ -63,15 +63,51 @@ private static void ConnectWithRetry(Socket client, IPAddress address, int port,
}
}
+ private static void SendResponse(string name, Socket client, Queue<(byte[] bytes, int delayMs)> serverResponses)
+ {
+ if (serverResponses.Count > 0)
+ {
+ _output.WriteLine($"Mock {name} ----------------------------------------------------");
+ var respTuple = serverResponses.Dequeue();
+ var resp = respTuple.bytes;
+
+ if (respTuple.delayMs > 0)
+ {
+ _output.WriteLine($"Mock {name} - delaying response by {respTuple.delayMs} ms");
+ Thread.Sleep(respTuple.delayMs);
+ }
+ if (resp.Length > 0) {
+ client.Send(resp, resp.Length, SocketFlags.None);
+ _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp));
+ }
+ }
+ }
+
private static void StartHandshakeServer(
string name,
int port,
- IEnumerable<(string message,
- Encoding encoding)> expectedClientSends,
+ IEnumerable<(string message, Encoding encoding)> expectedClientSends,
IEnumerable<(string message, Encoding encoding)> serverResponses,
bool verifyConnectionClosed,
CancellationToken cancellationToken,
bool sendFirst = false)
+ {
+ IEnumerable<(string message, Encoding encoding, int delayMs)> serverResponsesWithDelay = new List<(string message, Encoding encoding, int delayMs)>();
+ foreach (var item in serverResponses)
+ {
+ ((List<(string message, Encoding encoding, int delayMs)>)serverResponsesWithDelay).Add((item.message, item.encoding, 1));
+ }
+ StartHandshakeServer(name, port, expectedClientSends, serverResponsesWithDelay, verifyConnectionClosed, cancellationToken, sendFirst);
+ }
+
+ private static void StartHandshakeServer(
+ string name,
+ int port,
+ IEnumerable<(string message, Encoding encoding)> expectedClientSends,
+ IEnumerable<(string message, Encoding encoding, int delayMs)> serverResponses,
+ bool verifyConnectionClosed,
+ CancellationToken cancellationToken,
+ bool sendFirst = false)
{
var expectedMessages = new Queue<(string message, byte[] bytes, Encoding encoding)>();
foreach (var item in expectedClientSends)
@@ -80,17 +116,27 @@ private static void StartHandshakeServer(
expectedMessages.Enqueue((message: item.message, bytes: itemBytes, encoding: item.encoding));
}
- var serverResponseBytes = new Queue();
+ var serverResponseBytes = new Queue<(byte[] bytes, int delayMs)>();
foreach (var item in serverResponses)
{
- serverResponseBytes.Enqueue(item.encoding.GetBytes(item.message));
+ (byte[] bytes, int delayMs) queueItem = (item.encoding.GetBytes(item.message), item.delayMs);
+ serverResponseBytes.Enqueue(queueItem);
}
- StartHandshakeServer(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst);
+ _output.WriteLine($"Mock {name} - starting listener on port {port} with {expectedMessages.Count} expected messages and {serverResponseBytes.Count} responses.");
+ StartHandshakeServerImplementation(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst);
}
- private static void StartHandshakeServer(string name, int port, Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends, Queue serverResponses, bool verifyConnectionClosed, CancellationToken cancellationToken, bool sendFirst = false)
+ private static void StartHandshakeServerImplementation(
+ string name,
+ int port,
+ Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends,
+ Queue<(byte[] bytes, int delayMs)> serverResponses,
+ bool verifyConnectionClosed,
+ CancellationToken cancellationToken,
+ bool sendFirst = false)
{
+ DateTime startTime = DateTime.UtcNow;
var buffer = new byte[1024];
var listener = new TcpListener(IPAddress.Loopback, port);
listener.Start();
@@ -101,19 +147,16 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me
if (sendFirst)
{
// Send the first message from the serverResponses queue
- if (serverResponses.Count > 0)
- {
- var resp = serverResponses.Dequeue();
- client.Send(resp, resp.Length, SocketFlags.None);
- _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp));
- }
+ SendResponse(name, client, serverResponses);
}
while (expectedClientSends.Count > 0)
{
+ _output.WriteLine($"Mock {name} - time elapsed: {(DateTime.UtcNow - startTime).TotalMilliseconds} milliseconds");
client.ReceiveTimeout = 2 * 1000; // 2 seconds timeout for receiving data
cancellationToken.ThrowIfCancellationRequested();
var expectedMessage = expectedClientSends.Dequeue();
+ _output.WriteLine($"Mock {name} - remaining expected messages: {expectedClientSends.Count}");
var expected = expectedMessage.bytes;
Array.Clear(buffer, 0, buffer.Length);
int received = client.Receive(buffer);
@@ -143,12 +186,7 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me
throw new Exception(errorMessage);
}
_output.WriteLine($"Mock {name} - received expected message: " + expectedString);
- if (serverResponses.Count > 0)
- {
- var resp = serverResponses.Dequeue();
- client.Send(resp, resp.Length, SocketFlags.None);
- _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp));
- }
+ SendResponse(name, client, serverResponses);
}
if (verifyConnectionClosed)
@@ -178,6 +216,7 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me
}
catch (ObjectDisposedException)
{
+ _output.WriteLine($"Mock {name} - socket already closed.");
// Socket already closed
}
}
@@ -185,8 +224,16 @@ private static void StartHandshakeServer(string name, int port, Queue<(string me
_output.WriteLine($"Mock {name} - on port {port} completed successfully.");
}
+ catch (Exception ex)
+ {
+ _output.WriteLine($"Mock {name} - Exception: {ex.Message} {ex.GetType().FullName}");
+ _output.WriteLine(ex.StackTrace);
+ throw;
+ }
finally
{
+ _output.WriteLine($"Mock {name} - remaining expected messages: {expectedClientSends.Count}");
+ _output.WriteLine($"Mock {name} - stopping listener on port {port}.");
listener.Stop();
}
}
@@ -615,13 +662,111 @@ public async Task ValidatePassesWhenTokensMatch(string token, string expectedTok
using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
ConnectWithRetry(client, IPAddress.Loopback, port, _output);
- System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken);
+ System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 1);
System.Threading.Thread.Sleep(100); // Allow time for server to process
}
await serverTask;
}
+ [SkippableTheory]
+ [InlineData(5500, "A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond.", "SocketException")] // test the socket timeout
+ [InlineData(3200, "canceled", "System.OperationCanceledException")] // test the cancellation token
+ [InlineData(10, "", "")]
+ public async Task ValidateTokenTimeoutFails(int timeoutMs, string expectedMessage, string expectedExceptionType = "SocketException")
+ {
+ string token = "testToken";
+ string expectedToken = token;
+ int port = 50000 + (int)(DateTime.Now.Ticks % 10000);
+
+ var expectedClientSends = new List<(string message, Encoding encoding, int delayMs)>{
+ (message: "VERSION", encoding: Encoding.ASCII, delayMs: timeoutMs), // Response to VERSION
+ (message: "VERSION_2", encoding: Encoding.ASCII, delayMs: timeoutMs), // Response to VERSION_2
+ (message: $"TOKEN {token}", encoding: Encoding.ASCII, delayMs: 1)
+ };
+
+ var serverResponses = new List<(string message, Encoding encoding)>{
+ (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2
+ (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2
+ (message: "PASS", encoding: Encoding.ASCII) // Response to token
+ };
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1));
+ var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: true, cts.Token, sendFirst: true), cts.Token);
+
+ using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ ConnectWithRetry(client, IPAddress.Loopback, port, _output);
+ if (expectedMessage.Length > 0)
+ {
+ var exception = Record.Exception(
+ () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5)); // set the timeout to 5 seconds or 5000 ms
+ Assert.NotNull(exception);
+ string exceptionType = exception.GetType().FullName;
+ _output.WriteLine($"Caught exception of type {exceptionType} with message: {exception.Message}");
+ Assert.Contains(expectedExceptionType, exceptionType, StringComparison.OrdinalIgnoreCase);
+ Assert.Contains(expectedMessage, exception.Message, StringComparison.OrdinalIgnoreCase);
+ }
+ else
+ {
+ System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5);
+ }
+ System.Threading.Thread.Sleep(100); // Allow time for server to process
+ }
+
+ if (expectedMessage.Length == 0)
+ {
+ await serverTask;
+ }
+ }
+
+ [SkippableFact]
+ public async Task ValidateTokenTimeoutDoesAffectSession()
+ {
+ string token = "testToken";
+ string expectedToken = token;
+ int port = 50000 + (int)(DateTime.Now.Ticks % 10000);
+
+ var expectedClientSends = new List<(string message, Encoding encoding, int delayMs)>{
+ (message: "VERSION", encoding: Encoding.ASCII, delayMs: 1), // Response to VERSION
+ (message: "VERSION_2", encoding: Encoding.ASCII, delayMs: 1), // Response to VERSION_2
+ (message: $"TOKEN {token}", encoding: Encoding.ASCII, delayMs: 1),
+ (message: string.Empty, encoding: Encoding.ASCII, delayMs: 99), // Send some data after the handshake
+ (message: string.Empty, encoding: Encoding.ASCII, delayMs: 100), // Send some data after the handshake
+ (message: string.Empty, encoding: Encoding.ASCII, delayMs: 101), // Send some data after the handshake
+ (message: string.Empty, encoding: Encoding.ASCII, delayMs: 102), // Send some data after the handshake
+ (message: string.Empty, encoding: Encoding.ASCII, delayMs: 103) // Send some data after the handshake
+ };
+
+ var serverResponses = new List<(string message, Encoding encoding)>{
+ (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2
+ (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2
+ (message: "PASS", encoding: Encoding.ASCII), // Response to token
+ (message: "PSRP-Message0", encoding: Encoding.ASCII), // Indicate server is ready to receive data
+ (message: "PSRP-Message1", encoding: Encoding.ASCII), // Indicate server is ready to receive data
+ (message: "PSRP-Message2", encoding: Encoding.ASCII), // Indicate server is ready to receive data
+ (message: "PSRP-Message3", encoding: Encoding.ASCII), // Indicate server is ready to receive data
+ (message: "PSRP-Message4", encoding: Encoding.ASCII) //
+
+ };
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(2));
+ var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: false, cts.Token, sendFirst: true), cts.Token);
+
+ using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ ConnectWithRetry(client, IPAddress.Loopback, port, _output);
+ System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, DateTimeOffset.UtcNow, 5);
+ for (int i = 0; i < 5; i++)
+ {
+ System.Threading.Thread.Sleep(1500);
+ client.Send(Encoding.ASCII.GetBytes($"PSRP-Message{i}")); // Send some data after the handshake
+ }
+ }
+
+ await serverTask;
+ }
+
[SkippableTheory]
[InlineData("abc", "xyz")]
[InlineData("abc", "abcdef")]
@@ -649,8 +794,9 @@ public async Task ValidateFailsWhenTokensMismatch(string token, string expectedT
using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
ConnectWithRetry(client, IPAddress.Loopback, port, _output);
+ DateTimeOffset tokenCreationTime = DateTimeOffset.UtcNow; // Token created 10 minutes ago
var exception = Assert.Throws(
- () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken));
+ () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken, tokenCreationTime, 5));
System.Threading.Thread.Sleep(100); // Allow time for server to process
Assert.Contains("The credential is invalid.", exception.Message);
}