diff --git a/.editorconfig b/.editorconfig index 72707109516..57d2f6c6c3e 100644 --- a/.editorconfig +++ b/.editorconfig @@ -126,6 +126,10 @@ dotnet_code_quality_unused_parameters = non_public:suggestion # https://learn.microsoft.com/en-gb/dotnet/fundamentals/code-analysis/quality-rules/ca1859 dotnet_diagnostic.CA1859.severity = suggestion +# Disable SA1600 (ElementsMustBeDocumented) for test directory only +[test/**/*.cs] +dotnet_diagnostic.SA1600.severity = none + # CSharp code style settings: [*.cs] diff --git a/.github/workflows/linux-ci.yml b/.github/workflows/linux-ci.yml index 6a30bfcba22..d4d2c14f8ee 100644 --- a/.github/workflows/linux-ci.yml +++ b/.github/workflows/linux-ci.yml @@ -21,6 +21,7 @@ on: - master - release/** - github-mirror + - "*-feature" # Path filters for PRs need to go into the changes job concurrency: diff --git a/.github/workflows/macos-ci.yml b/.github/workflows/macos-ci.yml index 8e5f1620bb5..dc1c38a162d 100644 --- a/.github/workflows/macos-ci.yml +++ b/.github/workflows/macos-ci.yml @@ -19,6 +19,7 @@ on: - master - release/** - github-mirror + - "*-feature" # Path filters for PRs need to go into the changes job concurrency: diff --git a/.github/workflows/windows-ci.yml b/.github/workflows/windows-ci.yml index 1bc6ebe0a1f..18b426aa191 100644 --- a/.github/workflows/windows-ci.yml +++ b/.github/workflows/windows-ci.yml @@ -18,6 +18,7 @@ on: - master - release/** - github-mirror + - "*-feature" # Path filters for PRs need to go into the changes job diff --git a/build.psm1 b/build.psm1 index d76dd7f63ff..ffbeddb1c21 100644 --- a/build.psm1 +++ b/build.psm1 @@ -1984,7 +1984,9 @@ function Test-PSPesterResults function Start-PSxUnit { [CmdletBinding()]param( - [string] $xUnitTestResultsFile = "xUnitResults.xml" + [string] $xUnitTestResultsFile = "xUnitResults.xml", + [switch] $DebugLogging, + [string] $Filter ) # Add .NET CLI tools to PATH @@ -2042,9 +2044,28 @@ function Start-PSxUnit { # We run the xUnit tests sequentially to avoid race conditions caused by manipulating the config.json file. # xUnit tests run in parallel by default. To make them run sequentially, we need to define the 'xunit.runner.json' file. - dotnet test --configuration $Options.configuration --test-adapter-path:. "--logger:xunit;LogFilePath=$xUnitTestResultsFile" + $extraParams = @() + if($Filter) { + $extraParams += @( + '--filter' + $Filter + ) + } + + if($DebugLogging) { + $extraParams += @( + "--logger:console;verbosity=detailed" + ) + } else { + $extraParams += @( + "--logger:xunit;LogFilePath=$xUnitTestResultsFile" + ) + } + dotnet test @extraParams --configuration $Options.configuration --test-adapter-path:. - Publish-TestResults -Path $xUnitTestResultsFile -Type 'XUnit' -Title 'Xunit Sequential' + if(!$DebugLogging){ + Publish-TestResults -Path $xUnitTestResultsFile -Type 'XUnit' -Title 'Xunit Sequential' + } } finally { $env:DOTNET_ROOT = $originalDOTNET_ROOT diff --git a/src/Microsoft.PowerShell.ConsoleHost/host/msh/CommandLineParameterParser.cs b/src/Microsoft.PowerShell.ConsoleHost/host/msh/CommandLineParameterParser.cs index 50d2bd77d0f..0fb85e740f4 100644 --- a/src/Microsoft.PowerShell.ConsoleHost/host/msh/CommandLineParameterParser.cs +++ b/src/Microsoft.PowerShell.ConsoleHost/host/msh/CommandLineParameterParser.cs @@ -196,6 +196,7 @@ internal static int MaxNameLength() "workingdirectory" }; +#pragma warning disable SA1025 // CodeMustNotContainMultipleWhitespaceInARow /// /// These represent the parameters that are used when starting pwsh. /// We can query in our telemetry to determine how pwsh was invoked. @@ -203,35 +204,36 @@ internal static int MaxNameLength() [Flags] internal enum ParameterBitmap : long { - Command = 0x00000001, // -Command | -c - ConfigurationName = 0x00000002, // -ConfigurationName | -config - CustomPipeName = 0x00000004, // -CustomPipeName - EncodedCommand = 0x00000008, // -EncodedCommand | -e | -ec - EncodedArgument = 0x00000010, // -EncodedArgument - ExecutionPolicy = 0x00000020, // -ExecutionPolicy | -ex | -ep - File = 0x00000040, // -File | -f - Help = 0x00000080, // -Help, -?, /? - InputFormat = 0x00000100, // -InputFormat | -inp | -if - Interactive = 0x00000200, // -Interactive | -i - Login = 0x00000400, // -Login | -l - MTA = 0x00000800, // -MTA - NoExit = 0x00001000, // -NoExit | -noe - NoLogo = 0x00002000, // -NoLogo | -nol - NonInteractive = 0x00004000, // -NonInteractive | -noni - NoProfile = 0x00008000, // -NoProfile | -nop - OutputFormat = 0x00010000, // -OutputFormat | -o | -of - SettingsFile = 0x00020000, // -SettingsFile | -settings - SSHServerMode = 0x00040000, // -SSHServerMode | -sshs - SocketServerMode = 0x00080000, // -SocketServerMode | -sockets - ServerMode = 0x00100000, // -ServerMode | -server - NamedPipeServerMode = 0x00200000, // -NamedPipeServerMode | -namedpipes - STA = 0x00400000, // -STA - Version = 0x00800000, // -Version | -v - WindowStyle = 0x01000000, // -WindowStyle | -w - WorkingDirectory = 0x02000000, // -WorkingDirectory | -wd - ConfigurationFile = 0x04000000, // -ConfigurationFile - NoProfileLoadTime = 0x08000000, // -NoProfileLoadTime - CommandWithArgs = 0x10000000, // -CommandWithArgs | -cwa + Command = 0x0000000000000001, // -Command | -c + ConfigurationName = 0x0000000000000002, // -ConfigurationName | -config + CustomPipeName = 0x0000000000000004, // -CustomPipeName + EncodedCommand = 0x0000000000000008, // -EncodedCommand | -e | -ec + EncodedArgument = 0x0000000000000010, // -EncodedArgument + ExecutionPolicy = 0x0000000000000020, // -ExecutionPolicy | -ex | -ep + File = 0x0000000000000040, // -File | -f + Help = 0x0000000000000080, // -Help, -?, /? + InputFormat = 0x0000000000000100, // -InputFormat | -inp | -if + Interactive = 0x0000000000000200, // -Interactive | -i + Login = 0x0000000000000400, // -Login | -l + MTA = 0x0000000000000800, // -MTA + NoExit = 0x0000000000001000, // -NoExit | -noe + NoLogo = 0x0000000000002000, // -NoLogo | -nol + NonInteractive = 0x0000000000004000, // -NonInteractive | -noni + NoProfile = 0x0000000000008000, // -NoProfile | -nop + OutputFormat = 0x0000000000010000, // -OutputFormat | -o | -of + SettingsFile = 0x0000000000020000, // -SettingsFile | -settings + SSHServerMode = 0x0000000000040000, // -SSHServerMode | -sshs + SocketServerMode = 0x0000000000080000, // -SocketServerMode | -sockets + ServerMode = 0x0000000000100000, // -ServerMode | -server + NamedPipeServerMode = 0x0000000000200000, // -NamedPipeServerMode | -namedpipes + STA = 0x0000000000400000, // -STA + Version = 0x0000000000800000, // -Version | -v + WindowStyle = 0x0000000001000000, // -WindowStyle | -w + WorkingDirectory = 0x0000000002000000, // -WorkingDirectory | -wd + ConfigurationFile = 0x0000000004000000, // -ConfigurationFile + NoProfileLoadTime = 0x0000000008000000, // -NoProfileLoadTime + CommandWithArgs = 0x0000000010000000, // -CommandWithArgs | -cwa + // Enum values for specified ExecutionPolicy EPUnrestricted = 0x0000000100000000, // ExecutionPolicy unrestricted EPRemoteSigned = 0x0000000200000000, // ExecutionPolicy remote signed @@ -241,7 +243,11 @@ internal enum ParameterBitmap : long EPBypass = 0x0000002000000000, // ExecutionPolicy bypass EPUndefined = 0x0000004000000000, // ExecutionPolicy undefined EPIncorrect = 0x0000008000000000, // ExecutionPolicy incorrect + + // V2 Socket Server Mode + V2SocketServerMode = 0x0000100000000000, // -V2SocketServerMode | -v2so } +#pragma warning restore SA1025 // CodeMustNotContainMultipleWhitespaceInARow internal ParameterBitmap ParametersUsed = 0; @@ -597,6 +603,33 @@ internal bool RemoveWorkingDirectoryTrailingCharacter return _removeWorkingDirectoryTrailingCharacter; } } + + internal DateTimeOffset? UTCTimestamp + { + get + { + AssertArgumentsParsed(); + return _utcTimestamp; + } + } + + internal string? Token + { + get + { + AssertArgumentsParsed(); + return _token; + } + } + + internal bool V2SocketServerMode + { + get + { + AssertArgumentsParsed(); + return _v2SocketServerMode; + } + } #endif #endregion Internal properties @@ -916,6 +949,14 @@ private void ParseHelper(string[] args) _showBanner = false; ParametersUsed |= ParameterBitmap.SocketServerMode; } +#if !UNIX + else if (MatchSwitch(switchKey, "v2socketservermode", "v2so")) + { + _v2SocketServerMode = true; + _showBanner = false; + ParametersUsed |= ParameterBitmap.V2SocketServerMode; + } +#endif else if (MatchSwitch(switchKey, "servermode", "s")) { _serverMode = true; @@ -1176,6 +1217,37 @@ private void ParseHelper(string[] args) { _removeWorkingDirectoryTrailingCharacter = true; } + else if (MatchSwitch(switchKey, "token", "to")) + { + ++i; + if (i >= args.Length) + { + SetCommandLineError( + string.Format(CultureInfo.CurrentCulture, CommandLineParameterParserStrings.MissingMandatoryArgument, "-Token")); + break; + } + + _token = args[i]; + + // Not adding anything to ParametersUsed, because it is required with V2 socket server mode + // So, we can assume it based on that bit + } + else if (MatchSwitch(switchKey, "utctimestamp", "utc")) + { + ++i; + if (i >= args.Length) + { + SetCommandLineError( + string.Format(CultureInfo.CurrentCulture, CommandLineParameterParserStrings.MissingMandatoryArgument, "-UTCTimestamp")); + break; + } + + // Parse as iso8601UtcString + _utcTimestamp = DateTimeOffset.ParseExact(args[i], "yyyy-MM-dd'T'HH:mm:ssK", CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind); + + // Not adding anything to ParametersUsed, because it is required with V2 socket server mode + // So, we can assume it based on that bit + } #endif else { @@ -1530,6 +1602,9 @@ private bool CollectArgs(string[] args, ref int i) } private bool _socketServerMode; +#if !UNIX + private bool _v2SocketServerMode; +#endif private bool _serverMode; private bool _namedPipeServerMode; private bool _sshServerMode; @@ -1562,6 +1637,10 @@ private bool CollectArgs(string[] args, ref int i) private string? _executionPolicy; private string? _settingsFile; private string? _workingDirectory; +#if !UNIX + private string? _token; + private DateTimeOffset? _utcTimestamp; +#endif #if !UNIX private ProcessWindowStyle? _windowStyle; diff --git a/src/Microsoft.PowerShell.ConsoleHost/host/msh/ConsoleHost.cs b/src/Microsoft.PowerShell.ConsoleHost/host/msh/ConsoleHost.cs index ab9bdab568d..8a38b1904cb 100644 --- a/src/Microsoft.PowerShell.ConsoleHost/host/msh/ConsoleHost.cs +++ b/src/Microsoft.PowerShell.ConsoleHost/host/msh/ConsoleHost.cs @@ -198,7 +198,26 @@ internal static int Start( } // Servermode parameter validation check. - if ((s_cpp.ServerMode && s_cpp.NamedPipeServerMode) || (s_cpp.ServerMode && s_cpp.SocketServerMode) || (s_cpp.NamedPipeServerMode && s_cpp.SocketServerMode)) + int serverModeCount = 0; + if (s_cpp.ServerMode) + { + serverModeCount++; + } + if (s_cpp.NamedPipeServerMode) + { + serverModeCount++; + } + if (s_cpp.SocketServerMode) + { + serverModeCount++; + } +#if !UNIX + if (s_cpp.V2SocketServerMode) + { + serverModeCount++; + } +#endif + if (serverModeCount > 1) { s_tracer.TraceError("Conflicting server mode parameters, parameters must be used exclusively."); s_theConsoleHost?.ui.WriteErrorLine(ConsoleHostStrings.ConflictingServerModeParameters); @@ -242,6 +261,34 @@ internal static int Start( configurationName: s_cpp.ConfigurationName); exitCode = 0; } +#if !UNIX + else if (s_cpp.V2SocketServerMode) + { + if (s_cpp.Token == null) + { + s_tracer.TraceError("Token is required for V2SocketServerMode."); + s_theConsoleHost?.ui.WriteErrorLine(string.Format(CultureInfo.CurrentCulture, ConsoleHostStrings.MissingMandatoryParameter, "-Token", "-V2SocketServerMode")); + return ExitCodeBadCommandLineParameter; + } + + if (s_cpp.UTCTimestamp == null) + { + s_tracer.TraceError("UTCTimestamp is required for V2SocketServerMode."); + s_theConsoleHost?.ui.WriteErrorLine(string.Format(CultureInfo.CurrentCulture, ConsoleHostStrings.MissingMandatoryParameter, "-UTCTimestamp", "-v2socketservermode")); + return ExitCodeBadCommandLineParameter; + } + + ApplicationInsightsTelemetry.SendPSCoreStartupTelemetry("V2SocketServerMode", s_cpp.ParametersUsedAsDouble); + ProfileOptimization.StartProfile("StartupProfileData-V2SocketServerMode"); + HyperVSocketMediator.Run( + initialCommand: s_cpp.InitialCommand, + configurationName: s_cpp.ConfigurationName, + token: s_cpp.Token, + tokenCreationTime: s_cpp.UTCTimestamp.Value); + + exitCode = 0; + } +#endif else if (s_cpp.SocketServerMode) { ApplicationInsightsTelemetry.SendPSCoreStartupTelemetry("SocketServerMode", s_cpp.ParametersUsedAsDouble); @@ -1879,6 +1926,7 @@ private void DoRunspaceInitialization(RunspaceCreationEventArgs args) { s_theConsoleHost.UI.WriteLine(ManagedEntranceStrings.ShellBannerCLAuditMode); } + break; case PSLanguageMode.NoLanguage: @@ -2745,6 +2793,7 @@ e is RemoteException || #endif } } + // NTRAID#Windows Out Of Band Releases-915506-2005/09/09 // Removed HandleUnexpectedExceptions infrastructure finally diff --git a/src/Microsoft.PowerShell.ConsoleHost/resources/CommandLineParameterParserStrings.resx b/src/Microsoft.PowerShell.ConsoleHost/resources/CommandLineParameterParserStrings.resx index 34bb696c33c..33445ceebd2 100644 --- a/src/Microsoft.PowerShell.ConsoleHost/resources/CommandLineParameterParserStrings.resx +++ b/src/Microsoft.PowerShell.ConsoleHost/resources/CommandLineParameterParserStrings.resx @@ -225,4 +225,7 @@ Valid formats are: Invalid ExecutionPolicy value '{0}'. + + An argument is required to be supplied to the '{0}' parameter. + diff --git a/src/Microsoft.PowerShell.ConsoleHost/resources/ConsoleHostStrings.resx b/src/Microsoft.PowerShell.ConsoleHost/resources/ConsoleHostStrings.resx index 80b3d4aafe0..9bc06e0d42f 100644 --- a/src/Microsoft.PowerShell.ConsoleHost/resources/ConsoleHostStrings.resx +++ b/src/Microsoft.PowerShell.ConsoleHost/resources/ConsoleHostStrings.resx @@ -185,4 +185,7 @@ The current session does not support debugging; execution will continue. PushRunspace can only push a remote runspace. + + The '{0}' parameter is mandatory and must be specified when using the '{1}' parameter. + diff --git a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs index 94ce5af0208..02cf5b29697 100644 --- a/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs +++ b/src/System.Management.Automation/engine/remoting/common/RemoteSessionHyperVSocket.cs @@ -7,8 +7,10 @@ using System.Net.Sockets; using System.Text; using System.Threading; +using System.Buffers; using Dbg = System.Diagnostics.Debug; +using SMA = System.Management.Automation; namespace System.Management.Automation.Remoting { @@ -140,6 +142,10 @@ internal sealed class RemoteSessionHyperVSocketServer : IDisposable private readonly object _syncObject; private readonly PowerShellTraceSource _tracer = PowerShellTraceSourceFactory.GetTraceSource(); + // This is to prevent persistent replay attacks. + // it is not meant to ensure all replay attacks are impossible. + private const int MAX_TOKEN_LIFE_MINUTES = 10; + #endregion #region Properties @@ -175,64 +181,74 @@ internal sealed class RemoteSessionHyperVSocketServer : IDisposable public RemoteSessionHyperVSocketServer(bool LoopbackMode) { - // TODO: uncomment below code when .NET supports Hyper-V socket duplication - /* - NamedPipeClientStream clientPipeStream; - byte[] buffer = new byte[1000]; - int bytesRead; - */ _syncObject = new object(); Exception ex = null; try { - // TODO: uncomment below code when .NET supports Hyper-V socket duplication - /* - if (!LoopbackMode) - { - // - // Create named pipe client. - // - using (clientPipeStream = new NamedPipeClientStream(".", - "PS_VMSession", - PipeDirection.InOut, - PipeOptions.None, - TokenImpersonationLevel.None)) - { - // - // Connect to named pipe server. - // - clientPipeStream.Connect(10*1000); - - // - // Read LPWSAPROTOCOL_INFO. - // - bytesRead = clientPipeStream.Read(buffer, 0, 1000); - } - } + Guid serviceId = new Guid("a5201c21-2770-4c11-a68e-f182edb29220"); // HV_GUID_VM_SESSION_SERVICE_ID_2 + Guid loopbackId = new Guid("e0e16197-dd56-4a10-9195-5ee7a155a838"); // HV_GUID_LOOPBACK + Guid parentId = new Guid("a42e7cda-d03f-480c-9cc2-a4de20abb878"); // HV_GUID_PARENT + Guid vmId = LoopbackMode ? loopbackId : parentId; + HyperVSocketEndPoint endpoint = new HyperVSocketEndPoint(HyperVSocketEndPoint.AF_HYPERV, vmId, serviceId); + + Socket listenSocket = new Socket(endpoint.AddressFamily, SocketType.Stream, (System.Net.Sockets.ProtocolType)1); + listenSocket.Bind(endpoint); + + listenSocket.Listen(1); + HyperVSocket = listenSocket.Accept(); + + Stream = new NetworkStream(HyperVSocket, true); + + // Create reader/writer streams. + TextReader = new StreamReader(Stream); + TextWriter = new StreamWriter(Stream); + TextWriter.AutoFlush = true; // - // Create duplicate socket. + // listenSocket is not closed when it goes out of scope here. Sometimes it is + // closed later in this thread, while other times it is not closed at all. This will + // cause problem when we set up a second PowerShell Direct session. Let's + // explicitly close listenSocket here for safe. // - byte[] protocolInfo = new byte[bytesRead]; - Array.Copy(buffer, protocolInfo, bytesRead); + if (listenSocket != null) + { + try { listenSocket.Dispose(); } + catch (ObjectDisposedException) { } + } + } + catch (Exception e) + { + ex = e; + } - SocketInformation sockInfo = new SocketInformation(); - sockInfo.ProtocolInformation = protocolInfo; - sockInfo.Options = SocketInformationOptions.Connected; + if (ex != null) + { + Dbg.Fail("Unexpected error in RemoteSessionHyperVSocketServer."); - socket = new Socket(sockInfo); - if (socket == null) - { - Dbg.Assert(false, "Unexpected error in RemoteSessionHyperVSocketServer."); + // Unexpected error. + string errorMessage = !string.IsNullOrEmpty(ex.Message) ? ex.Message : string.Empty; + _tracer.WriteMessage("RemoteSessionHyperVSocketServer", "RemoteSessionHyperVSocketServer", Guid.Empty, + "Unexpected error in constructor: {0}", errorMessage); - tracer.WriteMessage("RemoteSessionHyperVSocketServer", "RemoteSessionHyperVSocketServer", Guid.Empty, - "Unexpected error in constructor: {0}", "socket duplication failure"); - } - */ + throw new PSInvalidOperationException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.RemoteSessionHyperVSocketServerConstructorFailure), + ex, + nameof(PSRemotingErrorId.RemoteSessionHyperVSocketServerConstructorFailure), + ErrorCategory.InvalidOperation, + null); + } + } + + public RemoteSessionHyperVSocketServer(bool LoopbackMode, string token, DateTimeOffset tokenCreationTime) + { + _syncObject = new object(); - // TODO: remove below 6 lines of code when .NET supports Hyper-V socket duplication + Exception ex = null; + + try + { Guid serviceId = new Guid("a5201c21-2770-4c11-a68e-f182edb29220"); // HV_GUID_VM_SESSION_SERVICE_ID_2 HyperVSocketEndPoint endpoint = new HyperVSocketEndPoint(HyperVSocketEndPoint.AF_HYPERV, Guid.Empty, serviceId); @@ -242,6 +258,31 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode) listenSocket.Listen(1); HyperVSocket = listenSocket.Accept(); + TimeSpan timeout = TimeSpan.FromMinutes(MAX_TOKEN_LIFE_MINUTES); + DateTimeOffset timeoutExpiry = tokenCreationTime.Add(timeout); + DateTimeOffset now = DateTimeOffset.UtcNow; + + // Calculate remaining time and create cancellation token + TimeSpan remainingTime = timeoutExpiry - now; + + // Check if the token has already expired + if (remainingTime <= TimeSpan.Zero) + { + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential, "Token has expired")); + } + + // Set socket timeout for receive operations to prevent indefinite blocking + int timeoutMs = (int)remainingTime.TotalMilliseconds; + HyperVSocket.ReceiveTimeout = timeoutMs; + HyperVSocket.SendTimeout = timeoutMs; + + // Create a cancellation token that will be cancelled when the timeout expires + using var cancellationTokenSource = new CancellationTokenSource(remainingTime); + CancellationToken cancellationToken = cancellationTokenSource.Token; + + ValidateToken(HyperVSocket, token, cancellationToken); + Stream = new NetworkStream(HyperVSocket, true); // Create reader/writer streams. @@ -257,8 +298,13 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode) // if (listenSocket != null) { - try { listenSocket.Dispose(); } - catch (ObjectDisposedException) { } + try + { + listenSocket.Dispose(); + } + catch (ObjectDisposedException) + { + } } } catch (Exception e) @@ -272,8 +318,12 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode) // Unexpected error. string errorMessage = !string.IsNullOrEmpty(ex.Message) ? ex.Message : string.Empty; - _tracer.WriteMessage("RemoteSessionHyperVSocketServer", "RemoteSessionHyperVSocketServer", Guid.Empty, - "Unexpected error in constructor: {0}", errorMessage); + _tracer.WriteMessage( + "RemoteSessionHyperVSocketServer", + "RemoteSessionHyperVSocketServer", + Guid.Empty, + "Unexpected error in constructor: {0}", + errorMessage); throw new PSInvalidOperationException( PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.RemoteSessionHyperVSocketServerConstructorFailure), @@ -283,7 +333,6 @@ public RemoteSessionHyperVSocketServer(bool LoopbackMode) null); } } - #endregion #region IDisposable @@ -333,6 +382,79 @@ public void Dispose() } #endregion + + /// + /// Validates the token received from the client over the HyperVSocket. + /// Throws PSDirectException if the token is invalid or not received in time. + /// + /// The connected HyperVSocket. + /// The expected token string. + /// Cancellation token for timeout handling. + internal static void ValidateToken(Socket socket, string token, CancellationToken cancellationToken = default) + { + // Check for cancellation before starting validation + cancellationToken.ThrowIfCancellationRequested(); + + // We should move to this pattern and + // in the tests I found I needed to get a bigger buffer than the token length + // and test length of the received data similar to this pattern. + string responseString = RemoteSessionHyperVSocketClient.ReceiveResponse(socket, RemoteSessionHyperVSocketClient.VERSION_REQUEST.Length + 4); + if (string.IsNullOrEmpty(responseString) || responseString.Length != RemoteSessionHyperVSocketClient.VERSION_REQUEST.Length) + { + socket.Send("FAIL"u8); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Client", "Version Request: " + responseString)); + } + + cancellationToken.ThrowIfCancellationRequested(); + + socket.Send(Encoding.UTF8.GetBytes(RemoteSessionHyperVSocketClient.CLIENT_VERSION)); + responseString = RemoteSessionHyperVSocketClient.ReceiveResponse(socket, RemoteSessionHyperVSocketClient.CLIENT_VERSION.Length + 4); + + // In the future we may need to handle different versions, differently. + // For now, we are just checking that we exchanged versions correctly. + if (string.IsNullOrEmpty(responseString) || !responseString.StartsWith(RemoteSessionHyperVSocketClient.VERSION_PREFIX, StringComparison.Ordinal)) + { + socket.Send("FAIL"u8); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Client", "Version Response: " + responseString)); + } + + cancellationToken.ThrowIfCancellationRequested(); + + socket.Send("PASS"u8); + + // The client should send the token in the format TOKEN + // the token should be up to 256 bits, which is less than 50 characters. + // I'll double that to 100 characters to be safe, plus the "TOKEN " prefix. + // So we expect a response of length 6 + 100 = 106 characters. + responseString = RemoteSessionHyperVSocketClient.ReceiveResponse(socket, 110); + + cancellationToken.ThrowIfCancellationRequested(); + + if (string.IsNullOrEmpty(responseString) || !responseString.StartsWith("TOKEN ", StringComparison.Ordinal)) + { + socket.Send("FAIL"u8); + // If the response is not in the expected format, we throw an exception. + // This is a failure to authenticate the client. + // don't send this response for risk of information disclosure. + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Client", "Token Response")); + } + + // Extract the token from the response. + string responseToken = responseString.Substring(6).Trim(); + + if (!string.Equals(responseToken, token, StringComparison.Ordinal)) + { + socket.Send("FAIL"u8); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential)); + } + + // Acknowledge the token is valid with "PASS". + socket.Send("PASS"u8); + } } internal sealed class RemoteSessionHyperVSocketClient : IDisposable @@ -340,7 +462,15 @@ internal sealed class RemoteSessionHyperVSocketClient : IDisposable #region Members private readonly object _syncObject; - private readonly PowerShellTraceSource _tracer = PowerShellTraceSourceFactory.GetTraceSource(); + + #region tracer + /// + /// An instance of the PSTraceSource class used for trace output. + /// + [SMA.TraceSource("RemoteSessionHyperVSocketClient", "Class that has PowerShell Direct Client implementation")] + private static readonly PSTraceSource s_tracer = PSTraceSource.GetTracer("RemoteSessionHyperVSocketClient", "Class that has PowerShell Direct Client implementation"); + + #endregion tracer private static readonly ManualResetEvent s_connectDone = new ManualResetEvent(false); @@ -354,6 +484,14 @@ internal sealed class RemoteSessionHyperVSocketClient : IDisposable #endregion + #region version constants + + internal const string VERSION_REQUEST = "VERSION"; + internal const string CLIENT_VERSION = "VERSION_2"; + internal const string VERSION_PREFIX = "VERSION_"; + + #endregion + #region Properties /// @@ -364,7 +502,7 @@ internal sealed class RemoteSessionHyperVSocketClient : IDisposable /// /// Returns the Hyper-V socket object. /// - public Socket HyperVSocket { get; } + public Socket HyperVSocket { get; private set; } /// /// Returns the network stream object. @@ -381,6 +519,37 @@ internal sealed class RemoteSessionHyperVSocketClient : IDisposable /// public StreamWriter TextWriter { get; private set; } + /// + /// True if the client is a Hyper-V container. + /// + public bool IsContainer { get; } + + /// + /// True if the client is using backwards compatible mode. + /// This is used to determine if the client should use + /// the backwards compatible or not. + /// In modern mode, the vmicvmsession service will + /// hand off the socket to the PowerShell process + /// inside the VM automatically. + /// In backwards compatible mode, the vmicvmsession + /// service create a new socket to the PowerShell process + /// inside the VM. + /// + public bool UseBackwardsCompatibleMode { get; private set; } + + /// + /// The authentication token used for the session. + /// This token is provided by the broker and provided to the server to authenticate the server session. + /// This protocol uses two connections: + /// 1. The first is to the broker or vmicvmsession service to exchange credentials and configuration. + /// The broker will respond with an authentication token. The broker also launches a PowerShell + /// server process with the authentication token. + /// 2. The second is to the server process, that was launched by the broker, + /// inside the VM, which uses the authentication token to verify that the client is the same client + /// that connected to the broker. + /// + public string AuthenticationToken { get; private set; } + /// /// Returns true if object is currently disposed. /// @@ -393,7 +562,9 @@ internal sealed class RemoteSessionHyperVSocketClient : IDisposable internal RemoteSessionHyperVSocketClient( Guid vmId, bool isFirstConnection, - bool isContainer = false) + bool useBackwardsCompatibleMode = false, + bool isContainer = false, + string authenticationToken = null) { Guid serviceId; @@ -412,28 +583,16 @@ internal RemoteSessionHyperVSocketClient( EndPoint = new HyperVSocketEndPoint(HyperVSocketEndPoint.AF_HYPERV, vmId, serviceId); - HyperVSocket = new Socket(EndPoint.AddressFamily, SocketType.Stream, (System.Net.Sockets.ProtocolType)1); + IsContainer = isContainer; - // - // We need to call SetSocketOption() in order to set up Hyper-V socket connection between container host and Hyper-V container. - // Here is the scenario: the Hyper-V container is inside a utility vm, which is inside the container host - // - if (isContainer) - { - var value = new byte[sizeof(uint)]; - value[0] = 1; + UseBackwardsCompatibleMode = useBackwardsCompatibleMode; - try - { - HyperVSocket.SetSocketOption((System.Net.Sockets.SocketOptionLevel)HV_PROTOCOL_RAW, - (System.Net.Sockets.SocketOptionName)HVSOCKET_CONTAINER_PASSTHRU, - (byte[])value); - } - catch - { - throw new PSDirectException( - PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.RemoteSessionHyperVSocketClientConstructorSetSocketOptionFailure)); - } + if (!isFirstConnection && !useBackwardsCompatibleMode && !string.IsNullOrEmpty(authenticationToken)) + { + // If this is not the first connection and we are using backwards compatible mode, + // we should not set the authentication token here. + // The authentication token will be set during the Connect method. + AuthenticationToken = authenticationToken; } } @@ -489,6 +648,81 @@ public void Dispose() #region Public Methods + private void ShutdownSocket() + { + if (HyperVSocket != null) + { + // Ensure the socket is disposed properly. + try + { + s_tracer.WriteLine("ShutdownSocket: Disposing of the HyperVSocket."); + HyperVSocket.Dispose(); + } + catch (Exception ex) + { + s_tracer.WriteLine("ShutdownSocket: Exception while disposing the socket: {0}", ex.Message); + } + } + + // Dispose of the existing stream if it exists. + if (Stream != null) + { + try + { + Stream.Dispose(); + } + catch (Exception ex) + { + s_tracer.WriteLine("ShutdownSocket: Exception while disposing the stream: {0}", ex.Message); + } + } + } + + /// + /// Recreates the HyperVSocket and connects it to the endpoint, updating the Stream if successful. + /// + private bool ConnectSocket() + { + HyperVSocket = new Socket(EndPoint.AddressFamily, SocketType.Stream, (System.Net.Sockets.ProtocolType)1); + + // + // We need to call SetSocketOption() in order to set up Hyper-V socket connection between container host and Hyper-V container. + // Here is the scenario: the Hyper-V container is inside a utility vm, which is inside the container host + // + if (IsContainer) + { + var value = new byte[sizeof(uint)]; + value[0] = 1; + + try + { + HyperVSocket.SetSocketOption( + (System.Net.Sockets.SocketOptionLevel)HV_PROTOCOL_RAW, + (System.Net.Sockets.SocketOptionName)HVSOCKET_CONTAINER_PASSTHRU, + value); + } + catch + { + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.RemoteSessionHyperVSocketClientConstructorSetSocketOptionFailure)); + } + } + + s_tracer.WriteLine("Connect: Client connecting, to {0}; isContainer: {1}.", EndPoint.ServiceId.ToString(), IsContainer); + HyperVSocket.Connect(EndPoint); + + // Check if the socket is connected. + // If it is connected, create a NetworkStream. + if (HyperVSocket.Connected) + { + s_tracer.WriteLine("Connect: Client connected, to {0}; isContainer: {1}.", EndPoint.ServiceId.ToString(), IsContainer); + Stream = new NetworkStream(HyperVSocket, true); + return true; + } + + return false; + } + /// /// Connect to Hyper-V socket server. This is a blocking call until a /// connection occurs or the timeout time has elapsed. @@ -516,100 +750,51 @@ public bool Connect( } } - HyperVSocket.Connect(EndPoint); - - if (HyperVSocket.Connected) + if (ConnectSocket()) { - _tracer.WriteMessage("RemoteSessionHyperVSocketClient", "Connect", Guid.Empty, - "Client connected."); - - Stream = new NetworkStream(HyperVSocket, true); - if (isFirstConnection) { - if (string.IsNullOrEmpty(networkCredential.Domain)) + var exchangeResult = ExchangeCredentialsAndConfiguration(networkCredential, configurationName, HyperVSocket, this.UseBackwardsCompatibleMode); + if (!exchangeResult.success) { - networkCredential.Domain = "localhost"; - } + // We will not block here for a container because a container does not have a broker. + if (IsRequirePsDirectAuthenticationEnabled(@"SOFTWARE\\Microsoft\\PowerShell", Microsoft.Win32.RegistryHive.LocalMachine)) + { + s_tracer.WriteLine("ExchangeCredentialsAndConfiguration: RequirePsDirectAuthentication is enabled, requiring latest transport version."); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVNegotiationFailed)); + } - bool emptyPassword = string.IsNullOrEmpty(networkCredential.Password); - bool emptyConfiguration = string.IsNullOrEmpty(configurationName); - - byte[] domain = Encoding.Unicode.GetBytes(networkCredential.Domain); - byte[] userName = Encoding.Unicode.GetBytes(networkCredential.UserName); - byte[] password = Encoding.Unicode.GetBytes(networkCredential.Password); - byte[] response = new byte[4]; // either "PASS" or "FAIL" - string responseString; - - // - // Send credential to VM so that PowerShell process inside VM can be - // created under the correct security context. - // - HyperVSocket.Send(domain); - HyperVSocket.Receive(response); - - HyperVSocket.Send(userName); - HyperVSocket.Receive(response); - - // - // We cannot simply send password because if it is empty, - // the vmicvmsession service in VM will block in recv method. - // - if (emptyPassword) - { - HyperVSocket.Send("EMPTYPW"u8); - HyperVSocket.Receive(response); - responseString = Encoding.ASCII.GetString(response); - } - else - { - HyperVSocket.Send("NONEMPTYPW"u8); - HyperVSocket.Receive(response); + this.UseBackwardsCompatibleMode = true; + s_tracer.WriteLine("ExchangeCredentialsAndConfiguration: Using backwards compatible mode."); - HyperVSocket.Send(password); - HyperVSocket.Receive(response); - responseString = Encoding.ASCII.GetString(response); + // If the first connection fails in modern mode, fall back to backwards compatible mode. + ShutdownSocket(); // will terminate the broker + ConnectSocket(); // restart the broker + exchangeResult = ExchangeCredentialsAndConfiguration(networkCredential, configurationName, HyperVSocket, this.UseBackwardsCompatibleMode); + if (!exchangeResult.success) + { + s_tracer.WriteLine("ExchangeCredentialsAndConfiguration: Failed to exchange credentials and configuration in backwards compatible mode."); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Broker", "Credential")); + } } - - // - // There are 3 cases for the responseString received above. - // - "FAIL": credential is invalid - // - "PASS": credential is valid, but PowerShell Direct in VM does not support configuration (Server 2016 TP4 and before) - // - "CONF": credential is valid, and PowerShell Direct in VM supports configuration (Server 2016 TP5 and later) - // - - // - // Credential is invalid. - // - if (string.Equals(responseString, "FAIL", StringComparison.Ordinal)) + else { - HyperVSocket.Send(response); - - throw new PSDirectException( - PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential)); + this.AuthenticationToken = exchangeResult.authenticationToken; } + } - // - // If PowerShell Direct in VM supports configuration, send configuration name. - // - if (string.Equals(responseString, "CONF", StringComparison.Ordinal)) + if (!isFirstConnection) + { + if (!this.UseBackwardsCompatibleMode) { - if (emptyConfiguration) - { - HyperVSocket.Send("EMPTYCF"u8); - } - else - { - HyperVSocket.Send("NONEMPTYCF"u8); - HyperVSocket.Receive(response); - - byte[] configName = Encoding.Unicode.GetBytes(configurationName); - HyperVSocket.Send(configName); - } + s_tracer.WriteLine("Connect-Server: Performing transport version and token exchange for Hyper-V socket. isFirstConnection: {0}, UseBackwardsCompatibleMode: {1}", isFirstConnection, this.UseBackwardsCompatibleMode); + RemoteSessionHyperVSocketClient.PerformTransportVersionAndTokenExchange(HyperVSocket, this.AuthenticationToken); } else { - HyperVSocket.Send(response); + s_tracer.WriteLine("Connect-Server: Skipping transport version and token exchange for backwards compatible mode."); } } @@ -621,8 +806,7 @@ public bool Connect( } else { - _tracer.WriteMessage("RemoteSessionHyperVSocketClient", "Connect", Guid.Empty, - "Client unable to connect."); + s_tracer.WriteLine("Connect: Client unable to connect."); result = false; } @@ -630,12 +814,318 @@ public bool Connect( return result; } + /// + /// Performs the transport version and token exchange sequence for the Hyper-V socket connection. + /// Throws PSDirectException on failure. + /// + /// The socket to use for communication. + /// The authentication token to send. + public static void PerformTransportVersionAndTokenExchange(Socket socket, string authenticationToken) + { + if (string.IsNullOrEmpty(authenticationToken)) + { + s_tracer.WriteLine("PerformTransportVersionAndTokenExchange: Authentication token is null or empty. Aborting transport version and token exchange."); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential)); + } + + socket.Send(Encoding.UTF8.GetBytes(VERSION_REQUEST)); + string responseStr = ReceiveResponse(socket, 16); + + // Check if the response starts with the expected version prefix. + // We will rely on the broker to determine if the two can communicate. + // At least, for now. + if (!responseStr.StartsWith(VERSION_PREFIX, StringComparison.Ordinal)) + { + s_tracer.WriteLine("PerformTransportVersionAndTokenExchange: Server responded with an invalid response of {0}. Notifying the transport manager to downgrade if allowed.", responseStr); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Server", "TransportVersion")); + } + + socket.Send(Encoding.UTF8.GetBytes(CLIENT_VERSION)); + string response = ReceiveResponse(socket, 4); // either "PASS" or "FAIL" + + if (!string.Equals(response, "PASS", StringComparison.Ordinal)) + { + s_tracer.WriteLine( + "PerformTransportVersionAndTokenExchange: Transport version negotiation with server failed. Response: {0}", response); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Server", "TransportVersion")); + } + + byte[] tokenBytes = Encoding.UTF8.GetBytes("TOKEN " + authenticationToken); + socket.Send(tokenBytes); + + // This is the opportunity for the server to tell the client to go away. + string tokenResponse = ReceiveResponse(socket, 256); // either "PASS" or "FAIL", but get a little more buffer to allow for better error in the future + if (!string.Equals(tokenResponse, "PASS", StringComparison.Ordinal)) + { + s_tracer.WriteLine( + "PerformTransportVersionAndTokenExchange: Server Authentication Token exchange failed. Response: {0}", tokenResponse); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential)); + } + } + + /// + /// Checks if the registry key RequirePsDirectAuthentication is set to 1. + /// Returns true if fallback should be aborted. + /// Uses the 64-bit registry view on 64-bit systems to ensure consistent behavior regardless of process architecture. + /// On 32-bit systems, uses the default registry view since there is no WOW64 redirection. + /// + internal static bool IsRequirePsDirectAuthenticationEnabled(string keyPath, Microsoft.Win32.RegistryHive registryHive) + { + const string regValueName = "RequirePsDirectAuthentication"; + + try + { + Microsoft.Win32.RegistryView registryView = Environment.Is64BitOperatingSystem + ? Microsoft.Win32.RegistryView.Registry64 + : Microsoft.Win32.RegistryView.Default; + + using (Microsoft.Win32.RegistryKey baseKey = Microsoft.Win32.RegistryKey.OpenBaseKey( + registryHive, + registryView)) + { + using (Microsoft.Win32.RegistryKey key = baseKey.OpenSubKey(keyPath)) + { + if (key != null) + { + var value = key.GetValue(regValueName); + if (value is int intValue && intValue != 0) + { + return true; + } + } + + return false; + } + } + } + catch (Exception regEx) + { + s_tracer.WriteLine("IsRequirePsDirectAuthenticationEnabled: Exception while checking registry key: {0}", regEx.Message); + return false; // If we cannot read the registry, assume the feature is not enabled. + } + } + + /// + /// Handles credential and configuration exchange with the VM for the first connection. + /// + public static (bool success, string authenticationToken) ExchangeCredentialsAndConfiguration(NetworkCredential networkCredential, string configurationName, Socket HyperVSocket, bool useBackwardsCompatibleMode) + { + // Encoding for the Hyper-V socket communication + // To send the domain, username, password, and configuration name, use UTF-16 (Encoding.Unicode) + // All other sends use UTF-8 (Encoding.UTF8) + // Receiving uses ASCII encoding + // NOT CONFUSING AT ALL + + if (!useBackwardsCompatibleMode) + { + HyperVSocket.Send(Encoding.UTF8.GetBytes(VERSION_REQUEST)); + // vmicvmsession service in VM will respond with "VERSION_2" or newer + // Version 1 protocol will respond with "PASS" or "FAIL" + // Receive the response and check for VERSION_2 or newer + string responseStr = ReceiveResponse(HyperVSocket, 16); + if (!responseStr.StartsWith(VERSION_PREFIX, StringComparison.Ordinal)) + { + s_tracer.WriteLine("When asking for version the server responded with an invalid response of {0}.", responseStr); + s_tracer.WriteLine("Session is invalid, continuing session with a fake user to close the session with the broker for stability."); + // If not the new protocol, finish the conversation + // Send a fake user + // Use ? <> that are illegal in user names so no one can create the user + string probeUserName = "?"; // must be less than or equal to 20 characters for Windows Server 2016 + s_tracer.WriteLine("probeUserName (static): length: {0}", probeUserName.Length); + SendUserData(probeUserName, HyperVSocket); + responseStr = ReceiveResponse(HyperVSocket, 4); // either "PASS" or "FAIL" + s_tracer.WriteLine("When sending user {0}.", responseStr); + + // Send that the password is empty + HyperVSocket.Send("EMPTYPW"u8); + responseStr = ReceiveResponse(HyperVSocket, 4); // either "CONF", "PASS" or "FAIL" + s_tracer.WriteLine("When sending EMPTYPW: {0}.", responseStr); // server responds with FAIL so we respond with FAIL and the conversation is done + HyperVSocket.Send("FAIL"u8); + + s_tracer.WriteLine("Notifying the transport manager to downgrade if allowed."); + // end new code + return (false, null); + } + + HyperVSocket.Send(Encoding.UTF8.GetBytes(CLIENT_VERSION)); + ReceiveResponse(HyperVSocket, 4); // either "PASS" or "FAIL" + } + + if (string.IsNullOrEmpty(networkCredential.Domain)) + { + networkCredential.Domain = "localhost"; + } + + System.Security.SecureString securePassword = networkCredential.SecurePassword; + int passwordLength = securePassword.Length; + bool emptyPassword = (passwordLength <= 0); + bool emptyConfiguration = string.IsNullOrEmpty(configurationName); + + string responseString; + + // Send credential to VM so that PowerShell process inside VM can be + // created under the correct security context. + SendUserData(networkCredential.Domain, HyperVSocket); + ReceiveResponse(HyperVSocket, 4); // only "PASS" is expected + + SendUserData(networkCredential.UserName, HyperVSocket); + ReceiveResponse(HyperVSocket, 4); // only "PASS" is expected + + // We cannot simply send password because if it is empty, + // the vmicvmsession service in VM will block in recv method. + if (emptyPassword) + { + HyperVSocket.Send("EMPTYPW"u8); + responseString = ReceiveResponse(HyperVSocket, 4); // either "CONF", "PASS" or "FAIL" (note, "PASS" is not used in VERSION_2 or newer mode) + } + else + { + HyperVSocket.Send("NONEMPTYPW"u8); + ReceiveResponse(HyperVSocket, 4); // only "PASS" is expected + + // Get the password bytes from the SecureString, send them, and then zero out the byte array. + byte[] passwordBytes = Microsoft.PowerShell.SecureStringHelper.GetData(securePassword); + try + { + HyperVSocket.Send(passwordBytes); + } + finally + { + // Zero out the byte array for security + Array.Clear(passwordBytes); + } + + responseString = ReceiveResponse(HyperVSocket, 4); // either "CONF", "PASS" or "FAIL" (note, "PASS" is not used in VERSION_2 or newer mode) + } + + // Check for invalid response from server + if (!string.Equals(responseString, "FAIL", StringComparison.Ordinal) && + !string.Equals(responseString, "PASS", StringComparison.Ordinal) && + !string.Equals(responseString, "CONF", StringComparison.Ordinal)) + { + s_tracer.WriteLine("ExchangeCredentialsAndConfiguration: Server responded with an invalid response of {0} for credentials.", responseString); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Broker", "Credential")); + } + + // Credential is invalid. + if (string.Equals(responseString, "FAIL", StringComparison.Ordinal)) + { + HyperVSocket.Send("FAIL"u8); + // should we be doing this? Disabling the test for now + // HyperVSocket.Shutdown(SocketShutdown.Both); + s_tracer.WriteLine("ExchangeCredentialsAndConfiguration: Server responded with FAIL for credentials."); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.InvalidCredential)); + } + + // If PowerShell Direct in VM supports configuration, send configuration name. + if (string.Equals(responseString, "CONF", StringComparison.Ordinal)) + { + if (emptyConfiguration) + { + HyperVSocket.Send("EMPTYCF"u8); + } + else + { + HyperVSocket.Send("NONEMPTYCF"u8); + ReceiveResponse(HyperVSocket, 4); // only "PASS" is expected + + SendUserData(configurationName, HyperVSocket); + } + } + else + { + HyperVSocket.Send("PASS"u8); + } + + if (!useBackwardsCompatibleMode) + { + // Receive the token from the server + // Getting 1024 bytes because it is well above the expected token size + // The expected size at the time of writing this would be about 50 based64 characters, + // plus the 6 characters for the "TOKEN " prefix. + // The 50 character size is designed to last 10 years of cryptographic changes. + // Since the broker completely controls the cryptographic portion here, + // allowing a significant larger size, allows the broker to make almost arbitrary changes, + // without breaking the client. + string token = ReceiveResponse(HyperVSocket, 1024); // either "PASS" or "FAIL" + if (token == null || !token.StartsWith("TOKEN ", StringComparison.Ordinal)) + { + s_tracer.WriteLine("ExchangeCredentialsAndConfiguration: Server did not respond with a valid token. Response: {0}", token); + throw new PSDirectException( + PSRemotingErrorInvariants.FormatResourceString(RemotingErrorIdStrings.HyperVInvalidResponse, "Broker", "Token " + token)); + } + + token = token.Substring(6); // remove "TOKEN " prefix + + HyperVSocket.Send("PASS"u8); // acknowledge the token + return (true, token); + } + + return (true, null); + } + public void Close() { Stream.Dispose(); HyperVSocket.Dispose(); } + /// + /// Receives a response from the socket and decodes it. + /// + /// The socket to receive from. + /// The size of the buffer to use for receiving data. + /// The decoded response string. + internal static string ReceiveResponse(Socket socket, int bufferSize) + { + System.Buffers.ArrayPool pool = System.Buffers.ArrayPool.Shared; + byte[] responseBuffer = pool.Rent(bufferSize); + int bytesReceived = 0; + try + { + bytesReceived = socket.Receive(responseBuffer); + if (bytesReceived == 0) + { + return null; + } + + string response = Encoding.ASCII.GetString(responseBuffer, 0, bytesReceived); + + // Handle null terminators and log if found + if (response.EndsWith('\0')) + { + int originalLength = response.Length; + response = response.TrimEnd('\0'); + // Cannot log actual response, because we don't know if it is sensitive + s_tracer.WriteLine( + "ReceiveResponse: Removed null terminator(s). Original length: {0}, New length: {1}", + originalLength, + response.Length); + } + + return response; + } + finally + { + pool.Return(responseBuffer); + } + } + + /// + /// Sends user data (domain, username, etc.) over the HyperVSocket using Unicode encoding. + /// + private static void SendUserData(string data, Socket socket) + { + // this encodes the data in UTF-16 (Unicode) + byte[] buffer = Encoding.Unicode.GetBytes(data); + socket.Send(buffer); + } #endregion } } diff --git a/src/System.Management.Automation/engine/remoting/fanin/OutOfProcTransportManager.cs b/src/System.Management.Automation/engine/remoting/fanin/OutOfProcTransportManager.cs index 96a8b833885..d9532c8691a 100644 --- a/src/System.Management.Automation/engine/remoting/fanin/OutOfProcTransportManager.cs +++ b/src/System.Management.Automation/engine/remoting/fanin/OutOfProcTransportManager.cs @@ -1014,7 +1014,7 @@ internal void OnCloseTimeOutTimerElapsed(object source) } #endregion - + #region Protected Methods /// @@ -1544,8 +1544,9 @@ internal VMHyperVSocketClientSessionTransportManager( /// public override void CreateAsync() { - _client = new RemoteSessionHyperVSocketClient(_vmGuid, true); - if (!_client.Connect(_networkCredential, _configurationName, true)) + // isFirstConnection: true - specifies to use VM_SESSION_SERVICE_ID socket. + _client = new RemoteSessionHyperVSocketClient(_vmGuid, useBackwardsCompatibleMode: false, isFirstConnection: true); + if (!_client.Connect(_networkCredential, _configurationName, isFirstConnection: true)) { _client.Dispose(); throw new PSInvalidOperationException( @@ -1555,11 +1556,14 @@ public override void CreateAsync() ErrorCategory.InvalidOperation, null); } + bool useBackwardsCompatibleMode = _client.UseBackwardsCompatibleMode; + string token = _client.AuthenticationToken; - // TODO: remove below 3 lines when Hyper-V socket duplication is supported in .NET framework. _client.Dispose(); - _client = new RemoteSessionHyperVSocketClient(_vmGuid, false); - if (!_client.Connect(_networkCredential, _configurationName, false)) + + // isFirstConnection: false - specifies to use the SESSION_SERVICE_ID_2 socket. + _client = new RemoteSessionHyperVSocketClient(_vmGuid, useBackwardsCompatibleMode: useBackwardsCompatibleMode, isFirstConnection: false, authenticationToken: token); + if (!_client.Connect(_networkCredential, _configurationName, isFirstConnection: false)) { _client.Dispose(); throw new PSInvalidOperationException( @@ -1617,7 +1621,9 @@ internal ContainerHyperVSocketClientSessionTransportManager( /// public override void CreateAsync() { - _client = new RemoteSessionHyperVSocketClient(_targetGuid, false, true); + // Container scenario is not working. + // When we fix it we need to setup the token in ContainerConnectionInfo and use it here. + _client = new RemoteSessionHyperVSocketClient(_targetGuid, isFirstConnection: false, useBackwardsCompatibleMode: false, isContainer: true); if (!_client.Connect(null, string.Empty, false)) { _client.Dispose(); @@ -1716,7 +1722,7 @@ public override void CreateAsync() // Start connection timeout timer if requested. // Timer callback occurs only once after timeout time. _connectionTimer = new Timer( - callback: (_) => + callback: (_) => { if (_connectionEstablished) { @@ -2505,7 +2511,7 @@ internal OutOfProcessServerSessionTransportManager(OutOfProcessTextWriter outWri _stdErrWriter = errWriter; _cmdTransportManagers = new Dictionary(); - this.WSManTransportErrorOccured += (object sender, TransportErrorOccuredEventArgs e) => + this.WSManTransportErrorOccured += (object sender, TransportErrorOccuredEventArgs e) => { string msg = e.Exception.TransportMessage ?? e.Exception.InnerException?.Message ?? string.Empty; _stdErrWriter.WriteLine(StringUtil.Format(RemotingErrorIdStrings.RemoteTransportError, msg)); diff --git a/src/System.Management.Automation/engine/remoting/server/OutOfProcServerMediator.cs b/src/System.Management.Automation/engine/remoting/server/OutOfProcServerMediator.cs index 14b0240858b..6c794e21b24 100644 --- a/src/System.Management.Automation/engine/remoting/server/OutOfProcServerMediator.cs +++ b/src/System.Management.Automation/engine/remoting/server/OutOfProcServerMediator.cs @@ -635,6 +635,16 @@ private HyperVSocketMediator() originalStdErr = new HyperVSocketErrorTextWriter(_hypervSocketServer.TextWriter); } + private HyperVSocketMediator(string token, + DateTimeOffset tokenCreationTime) + : base(false) + { + _hypervSocketServer = new RemoteSessionHyperVSocketServer(false, token: token, tokenCreationTime: tokenCreationTime); + + originalStdIn = _hypervSocketServer.TextReader; + originalStdOut = new OutOfProcessTextWriter(_hypervSocketServer.TextWriter); + originalStdErr = new HyperVSocketErrorTextWriter(_hypervSocketServer.TextWriter); + } #endregion #region Static Methods @@ -656,6 +666,24 @@ internal static void Run( configurationFile: null); } + internal static void Run( + string initialCommand, + string configurationName, + string token, + DateTimeOffset tokenCreationTime) + { + lock (SyncObject) + { + s_instance = new HyperVSocketMediator(token, tokenCreationTime); + } + + s_instance.Start( + initialCommand: initialCommand, + cryptoHelper: new PSRemotingCryptoHelperServer(), + workingDirectory: null, + configurationName: configurationName, + configurationFile: null); + } #endregion } diff --git a/src/System.Management.Automation/resources/RemotingErrorIdStrings.resx b/src/System.Management.Automation/resources/RemotingErrorIdStrings.resx index 9e572797bd2..a9dc46bdec5 100644 --- a/src/System.Management.Automation/resources/RemotingErrorIdStrings.resx +++ b/src/System.Management.Automation/resources/RemotingErrorIdStrings.resx @@ -1726,4 +1726,10 @@ SSH client process terminated before connection could be established. Failed to get Hyper-V VM State. The value was of the type {0} but was expected to be Microsoft.HyperV.PowerShell.VMState or System.String. + + Hyper-V {0} sent an invalid {1} response during the connection negotiation. + + + Negotiating a secure connection to Hyper-V failed. Make sure the Host and Guest are updated with all relevant Microsoft Updates. + diff --git a/test/xUnit/csharp/test_CommandLineParser.cs b/test/xUnit/csharp/test_CommandLineParser.cs index 5025584d6ac..01f572d230d 100644 --- a/test/xUnit/csharp/test_CommandLineParser.cs +++ b/test/xUnit/csharp/test_CommandLineParser.cs @@ -48,6 +48,9 @@ public static void TestDefaults() Assert.False(cpp.ShowVersion); Assert.False(cpp.SkipProfiles); Assert.False(cpp.SocketServerMode); +#if !UNIX + Assert.False(cpp.V2SocketServerMode); +#endif Assert.False(cpp.SSHServerMode); if (Platform.IsWindows) { @@ -336,6 +339,25 @@ public static void TestParameter_SocketServerMode(params string[] commandLine) Assert.Null(cpp.ErrorMessage); } +#if !UNIX + [Theory] + [InlineData("-v2socketservermode", "-token", "natoheusatoehusnatoeu", "-utctimestamp", "2023-10-01T12:00:00Z")] + [InlineData("-v2so", "-token", "asentuhasoneuthsaoe", "-utctimestamp", "2025-06-09T12:00:00Z")] + public static void TestParameter_V2SocketServerMode(params string[] commandLine) + { + var cpp = new CommandLineParameterParser(); + + cpp.Parse(commandLine); + + Assert.False(cpp.AbortStartup); + Assert.True(cpp.NoExit); + Assert.False(cpp.ShowShortHelp); + Assert.False(cpp.ShowBanner); + Assert.True(cpp.V2SocketServerMode); + Assert.Null(cpp.ErrorMessage); + } +#endif + [Theory] [InlineData("-servermode")] [InlineData("-s")] diff --git a/test/xUnit/csharp/test_RemoteHyperV.cs b/test/xUnit/csharp/test_RemoteHyperV.cs new file mode 100644 index 00000000000..f694f6894df --- /dev/null +++ b/test/xUnit/csharp/test_RemoteHyperV.cs @@ -0,0 +1,661 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Management.Automation.Language; +using System.Management.Automation.Subsystem; +using System.Management.Automation.Subsystem.Prediction; +using System.Threading; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Reflection; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace PSTests.Sequential +{ + public class RemoteHyperVTests + { + private static ITestOutputHelper _output; + private static TimeSpan timeout = TimeSpan.FromSeconds(15); + + public RemoteHyperVTests(ITestOutputHelper output) + { + if (!System.Management.Automation.Platform.IsWindows) + { + throw new SkipException("RemoteHyperVTests are only supported on Windows."); + } + + _output = output; + } + + // Helper method to connect with retries + private static void ConnectWithRetry(Socket client, IPAddress address, int port, ITestOutputHelper output, int maxRetries = 10) + { + int retryDelayMs = 500; + int attempt = 0; + bool connected = false; + while (attempt < maxRetries && !connected) + { + try + { + client.Connect(address, port); + connected = true; + } + catch (SocketException) + { + attempt++; + if (attempt < maxRetries) + { + output?.WriteLine($"Connect attempt {attempt} failed, retrying in {retryDelayMs}ms..."); + Thread.Sleep(retryDelayMs); + retryDelayMs *= 2; + } + else + { + output?.WriteLine($"Failed to connect after {maxRetries} attempts. This is most likely an intermittent failure due to environmental issues."); + throw; + } + } + } + } + + private static void StartHandshakeServer( + string name, + int port, + IEnumerable<(string message, + Encoding encoding)> expectedClientSends, + IEnumerable<(string message, Encoding encoding)> serverResponses, + bool verifyConnectionClosed, + CancellationToken cancellationToken, + bool sendFirst = false) + { + var expectedMessages = new Queue<(string message, byte[] bytes, Encoding encoding)>(); + foreach (var item in expectedClientSends) + { + var itemBytes = item.encoding.GetBytes(item.message); + expectedMessages.Enqueue((message: item.message, bytes: itemBytes, encoding: item.encoding)); + } + + var serverResponseBytes = new Queue(); + foreach (var item in serverResponses) + { + serverResponseBytes.Enqueue(item.encoding.GetBytes(item.message)); + } + + StartHandshakeServer(name, port, expectedMessages, serverResponseBytes, verifyConnectionClosed, cancellationToken, sendFirst); + } + + private static void StartHandshakeServer(string name, int port, Queue<(string message, byte[] bytes, Encoding encoding)> expectedClientSends, Queue serverResponses, bool verifyConnectionClosed, CancellationToken cancellationToken, bool sendFirst = false) + { + var buffer = new byte[1024]; + var listener = new TcpListener(IPAddress.Loopback, port); + listener.Start(); + try + { + using (var client = listener.AcceptSocket()) + { + if (sendFirst) + { + // Send the first message from the serverResponses queue + if (serverResponses.Count > 0) + { + var resp = serverResponses.Dequeue(); + client.Send(resp, resp.Length, SocketFlags.None); + _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); + } + } + + while (expectedClientSends.Count > 0) + { + client.ReceiveTimeout = 2 * 1000; // 2 seconds timeout for receiving data + cancellationToken.ThrowIfCancellationRequested(); + var expectedMessage = expectedClientSends.Dequeue(); + var expected = expectedMessage.bytes; + Array.Clear(buffer, 0, buffer.Length); + int received = client.Receive(buffer); + // Optionally validate received data matches expected + string expectedString = expectedMessage.message; + string bufferString = expectedMessage.encoding.GetString(buffer, 0, received); + string alternativeEncodedString = string.Empty; + if (expectedMessage.encoding == Encoding.Unicode) + { + alternativeEncodedString = Encoding.UTF8.GetString(buffer, 0, received); + } + else if (expectedMessage.encoding == Encoding.UTF8) + { + alternativeEncodedString = Encoding.Unicode.GetString(buffer, 0, received); + } + + if (received != expected.Length) + { + string errorMessage = $"Mock {name} - Expected {expected.Length} bytes, but received {received} bytes: `{bufferString}`(alt encoding: `{alternativeEncodedString}`); expected: {expectedString}"; + _output.WriteLine(errorMessage); + throw new Exception(errorMessage); + } + if (!string.Equals(bufferString, expectedString, StringComparison.OrdinalIgnoreCase)) + { + string errorMessage = $"Mock {name} - Expected `{expectedString}`; length {expected.Length}, but received; length {received}; `{bufferString}`(alt encoding: `{alternativeEncodedString}`) instead."; + _output.WriteLine(errorMessage); + throw new Exception(errorMessage); + } + _output.WriteLine($"Mock {name} - received expected message: " + expectedString); + if (serverResponses.Count > 0) + { + var resp = serverResponses.Dequeue(); + client.Send(resp, resp.Length, SocketFlags.None); + _output.WriteLine($"Mock {name} - sent response: " + Encoding.ASCII.GetString(resp)); + } + } + + if (verifyConnectionClosed) + { + _output.WriteLine($"Mock {name} - verifying client connection is closed."); + // Wait for the client to close the connection synchronously (no timeout) + try + { + while (true) + { + int bytesRead = client.Receive(buffer, SocketFlags.None); + if (bytesRead == 0) + { + break; + } + + // If we receive any data, log and throw (assume UTF8 encoding) + string unexpectedData = Encoding.UTF8.GetString(buffer, 0, bytesRead); + _output.WriteLine($"Mock {name} - received unexpected data after handshake: {unexpectedData}"); + throw new Exception($"Mock {name} - received unexpected data after handshake: {unexpectedData}"); + } + _output.WriteLine($"Mock {name} - client closed the connection."); + } + catch (SocketException ex) + { + _output.WriteLine($"Mock {name} - socket exception while waiting for client close: {ex.Message} {ex.GetType().FullName}"); + } + catch (ObjectDisposedException) + { + // Socket already closed + } + } + } + + _output.WriteLine($"Mock {name} - on port {port} completed successfully."); + } + finally + { + listener.Stop(); + } + } + + // Helper function to create a random 4-character ASCII response + private static string CreateRandomAsciiResponse() + { + var rand = new Random(); + // Randomly return either "PASS" or "FAIL" + return rand.Next(0, 2) == 0 ? "PASS" : "FAIL"; + } + + // Helper method to create test data + private static (List<(string, Encoding)> expectedClientSends, List<(string, Encoding)> serverResponses) CreateHandshakeTestData(NetworkCredential cred) + { + var expectedClientSends = new List<(string message, Encoding encoding)> + { + (message: cred.Domain, encoding: Encoding.Unicode), + (message: cred.UserName, encoding: Encoding.Unicode), + (message: "NONEMPTYPW", encoding: Encoding.ASCII), + (message: cred.Password, encoding: Encoding.Unicode) + }; + + var serverResponses = new List<(string message, Encoding encoding)> + { + (message: CreateRandomAsciiResponse(), encoding: Encoding.ASCII), // Response to domain + (message: CreateRandomAsciiResponse(), encoding: Encoding.ASCII), // Response to username + (message: CreateRandomAsciiResponse(), encoding: Encoding.ASCII) // Response to non-empty password + }; + + return (expectedClientSends, serverResponses); + } + + private static List<(string message, Encoding encoding)> CreateVersionNegotiationClientSends() + { + return new List<(string message, Encoding encoding)> + { + (message: "VERSION", encoding: Encoding.UTF8), + (message: "VERSION_2", encoding: Encoding.UTF8), + }; + } + + private static List<(string, Encoding)> CreateV2Sends(NetworkCredential cred, string configurationName) + { + var sends = CreateVersionNegotiationClientSends(); + var password = cred.Password; + var emptyPassword = string.IsNullOrEmpty(password); + + sends.AddRange(new List<(string message, Encoding encoding)> + { + (message: cred.Domain, encoding: Encoding.Unicode), + (message: cred.UserName, encoding: Encoding.Unicode) + }); + + if (!emptyPassword) + { + sends.AddRange(new List<(string message, Encoding encoding)> + { + (message: "NONEMPTYPW", encoding: Encoding.UTF8), + (message: cred.Password, encoding: Encoding.Unicode) + }); + } + else + { + sends.Add((message: "EMPTYPW", encoding: Encoding.UTF8)); // Empty password and we don't expect a response + } + + if (!string.IsNullOrEmpty(configurationName)) + { + sends.Add((message: "NONEMPTYCF", encoding: Encoding.UTF8)); + sends.Add((message: configurationName, encoding: Encoding.Unicode)); // Configuration string and we don't expect a response + } + else + { + sends.Add((message: "EMPTYCF", encoding: Encoding.UTF8)); // Configuration string and we don't expect a response + } + + sends.Add((message: "PASS", encoding: Encoding.ASCII)); // Response to TOKEN + + return sends; + } + + private static List<(string, Encoding)> CreateV2Responses(string version = "VERSION_2", bool emptyConfig = false, string token = "FakeToken0+/=", bool emptyPassword = false) + { + var responses = new List<(string message, Encoding encoding)> + { + (message: version, encoding: Encoding.ASCII), // Response to VERSION + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to domain + (message: "PASS", encoding: Encoding.ASCII), // Response to username + }; + + if (!emptyPassword) + { + responses.Add((message: "PASS", encoding: Encoding.ASCII)); // Response to non-empty password + } + + responses.Add((message: "CONF", encoding: Encoding.ASCII)); // Response to configuration + + if (!emptyConfig) + { + responses.Add((message: "PASS", encoding: Encoding.ASCII)); // Response to non-empty configuration + } + responses.Add((message: "TOKEN " + token, encoding: Encoding.ASCII)); // Response to with a token than uses each class of character in base 64 encoding + + return responses; + } + + // Helper method to create test data + private static (List<(string, Encoding)> expectedClientSends, List<(string, Encoding)> serverResponses) + CreateHandshakeTestDataV2(NetworkCredential cred, string version, string configurationName, string token) + { + bool emptyConfig = string.IsNullOrEmpty(configurationName); + bool emptyPassword = string.IsNullOrEmpty(cred.Password); + return (CreateV2Sends(cred, configurationName), CreateV2Responses(version, emptyConfig, token, emptyPassword)); + } + + // Helper method to create test data + private static (List<(string, Encoding)> expectedClientSends, List<(string, Encoding)> serverResponses) CreateHandshakeTestDataForFallback(NetworkCredential cred) + { + var expectedClientSends = new List<(string message, Encoding encoding)> + { + (message: "VERSION", encoding: Encoding.UTF8), + (message: @"?", encoding: Encoding.Unicode), + (message: "EMPTYPW", encoding: Encoding.UTF8), // Response to domain + (message: "FAIL", encoding: Encoding.UTF8), // Response to domain + }; + + List<(string message, Encoding encoding)> serverResponses = new List<(string message, Encoding encoding)> + { + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION but v1 server expects domain so it says "PASS" + (message: "PASS", encoding: Encoding.ASCII), // Response to username + (message: "FAIL", encoding: Encoding.ASCII) // Response to EMPTYPW + }; + + return (expectedClientSends, serverResponses); + } + + // Helper to create a password with at least one non-ASCII Unicode character + public static string CreateRandomUnicodePassword(string prefix) + { + var rand = new Random(); + var asciiPart = new char[6 + prefix.Length]; + // Copy prefix into asciiPart + Array.Copy(prefix.ToCharArray(), 0, asciiPart, 0, prefix.Length); + for (int i = prefix.Length; i < asciiPart.Length; i++) + { + asciiPart[i] = (char)rand.Next(33, 127); // ASCII printable + } + // Add a random Unicode character outside ASCII range (e.g., U+0100 to U+017F) + char unicodeChar = (char)rand.Next(0x0100, 0x017F); + // Insert the unicode character at a random position + int insertPos = rand.Next(0, asciiPart.Length + 1); + var passwordChars = new List(asciiPart); + passwordChars.Insert(insertPos, unicodeChar); + return new string(passwordChars.ToArray()); + } + + public static NetworkCredential CreateTestCredential() + { + return new NetworkCredential(CreateRandomUnicodePassword("username"), CreateRandomUnicodePassword("password"), CreateRandomUnicodePassword("domain")); + } + + [SkippableFact] + public async Task PerformCredentialAndConfigurationHandshake_V1_Pass() + { + // Arrange + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + var cred = CreateTestCredential(); + string configurationName = CreateRandomUnicodePassword("config"); + + var (expectedClientSends, serverResponses) = CreateHandshakeTestData(cred); + expectedClientSends.Add(("PASS", Encoding.ASCII)); + serverResponses.Add(("PASS", Encoding.ASCII)); + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Broker", port, expectedClientSends, serverResponses, verifyConnectionClosed: false, cts.Token), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + var exchangeResult = System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.ExchangeCredentialsAndConfiguration(cred, configurationName, client, true); + var result = exchangeResult.success; + _output.WriteLine($"Exchange result: {result}, Token: {exchangeResult.authenticationToken}"); + System.Threading.Thread.Sleep(100); // Allow time for server to process + Assert.True(result, $"Expected Exchange to pass"); + } + + await serverTask; + } + + [SkippableTheory] + [InlineData("VERSION_2", "configurationname1", "FakeTokenaaaaaaaaaAAAAAAAAAAAAAAAAAAAAAA0FakeTokenaaaaaaaaaAAAAAAAAAAAAAAAAAAAAAA0+/==")] // a fake base64 token about 512 bits long (double the size when this was spec'ed) + [InlineData("VERSION_10", null, "FakeTokenaaaaaaaaaAAAAAAAAAAAAAAAAAAAAAA0+/=")] // a fake base64 token about 256 bits Long (the size when this was spec'ed) + public async Task PerformCredentialAndConfigurationHandshake_V2_Pass(string versionResponse, string configurationName, string token) + { + // Arrange + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + var cred = CreateTestCredential(); + + var (expectedClientSends, serverResponses) = CreateHandshakeTestDataV2(cred, versionResponse, configurationName, token); + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Broker", port, expectedClientSends, serverResponses, verifyConnectionClosed: true, cts.Token), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + client.Connect(IPAddress.Loopback, port); + var exchangeResult = System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.ExchangeCredentialsAndConfiguration(cred, configurationName, client, false); + var result = exchangeResult.success; + System.Threading.Thread.Sleep(100); // Allow time for server to process + Assert.True(result, $"Expected Exchange to pass for version response '{versionResponse}'"); + Assert.Equal(token, exchangeResult.authenticationToken); + } + + await serverTask; + } + + [SkippableFact] + public async Task PerformCredentialAndConfigurationHandshake_V1_Fallback() + { + // Arrange + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + var cred = CreateTestCredential(); + string configurationName = CreateRandomUnicodePassword("config"); + + var (expectedClientSends, serverResponses) = CreateHandshakeTestDataForFallback(cred); + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Broker", port, expectedClientSends, serverResponses, verifyConnectionClosed: false, cts.Token), cts.Token); + + bool isFallback = false; + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + _output.WriteLine("Starting handshake with V2 protocol."); + client.Connect(IPAddress.Loopback, port); + var exchangeResult = System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.ExchangeCredentialsAndConfiguration(cred, configurationName, client, false); + isFallback = !exchangeResult.success; + + System.Threading.Thread.Sleep(100); // Allow time for server to process + _output.WriteLine("Handshake indicated fallback to V1."); + Assert.True(isFallback, "Expected fallback to V1."); + } + _output.WriteLine("Handshake completed successfully with fallback to V1."); + + await serverTask; + } + + [SkippableFact] + public async Task PerformCredentialAndConfigurationHandshake_V2_InvalidResponse() + { + // Arrange + int port = 51000 + (int)(DateTime.Now.Ticks % 10000); + var cred = CreateTestCredential(); + + var (expectedClientSends, serverResponses) = CreateHandshakeTestData(cred); + //expectedClientSends.Add("FAI1"); + serverResponses.Add(("FAI1", Encoding.ASCII)); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + //cts.Token.Register(() => throw new OperationCanceledException("Test timed out.")); + + var serverTask = Task.Run(() => StartHandshakeServer("Broker", port, expectedClientSends, serverResponses, verifyConnectionClosed: false, cts.Token), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + _output.WriteLine("connecting on port " + port); + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + + var ex = Record.Exception(() => System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.ExchangeCredentialsAndConfiguration(cred, "config", client, true)); + + try + { + await serverTask; + } + catch (AggregateException exAgg) + { + Assert.Null(exAgg.Flatten().InnerExceptions[1].Message); + } + cts.Token.ThrowIfCancellationRequested(); + + Assert.NotNull(ex); + Assert.NotNull(ex.Message); + Assert.Contains("Hyper-V Broker sent an invalid Credential response", ex.Message); + } + } + + [SkippableFact] + public async Task PerformCredentialAndConfigurationHandshake_V1_Fail() + { + // Arrange + int port = 51000 + (int)(DateTime.Now.Ticks % 10000); + var cred = CreateTestCredential(); + + var (expectedClientSends, serverResponses) = CreateHandshakeTestData(cred); + expectedClientSends.Add(("FAIL", Encoding.ASCII)); + serverResponses.Add(("FAIL", Encoding.ASCII)); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(15)); + + // This scenario does not close the connection in a timely manner, so we set verifyConnectionClosed to false + var serverTask = Task.Run(() => StartHandshakeServer("Broker", port, expectedClientSends, serverResponses, verifyConnectionClosed: false, cts.Token), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + client.Connect(IPAddress.Loopback, port); + + var ex = Record.Exception(() => System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.ExchangeCredentialsAndConfiguration(cred, "config", client, true)); + + try + { + await serverTask; + } + catch (AggregateException exAgg) + { + Assert.Null(exAgg.Flatten().InnerExceptions[1].Message); + } + + cts.Token.ThrowIfCancellationRequested(); + + Assert.NotNull(ex); + Assert.NotNull(ex.Message); + Assert.Contains("The credential is invalid.", ex.Message); + } + } + + [SkippableTheory] + [InlineData("VERSION_2", "FakeTokenaaaaaaaaaAAAAAAAAAAAAAAAAAAAAAA0FakeTokenaaaaaaaaaAAAAAAAAAAAAAAAAAAAAAA0+/==")] // a fake base64 token about 512 bits long (double the size when this was spec'ed) + [InlineData("VERSION_10", "FakeTokenaaaaaaaaaAAAAAAAAAAAAAAAAAAAAAA0+/=")] // a fake base64 token about 256 bits Long (the size when this was spec'ed) + public async Task PerformTransportVersionAndTokenExchange_Pass(string version, string token) + { + // Arrange + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + var cred = CreateTestCredential(); + + var expectedClientSends = CreateVersionNegotiationClientSends(); + expectedClientSends.Add((message: "TOKEN " + token, encoding: Encoding.ASCII)); + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: version, encoding: Encoding.ASCII), // Response to VERSION + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII) // Response to token + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Server", port, expectedClientSends, serverResponses, verifyConnectionClosed: true, cts.Token), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.PerformTransportVersionAndTokenExchange(client, token); + System.Threading.Thread.Sleep(100); // Allow time for server to process + } + + await serverTask; + } + + [SkippableTheory] + [InlineData(1, true)] + [InlineData(2, true)] + [InlineData(0, false)] + [InlineData(null, false)] + [System.Runtime.Versioning.SupportedOSPlatform("windows")] + public void IsRequirePsDirectAuthenticationEnabled(int? regValue, bool expected) + { + const string testKeyPath = @"SOFTWARE\Microsoft\TestRequirePsDirectAuthentication"; + const string valueName = "RequirePsDirectAuthentication"; + if (!System.Management.Automation.Platform.IsWindows) + { + throw new SkipException("RemoteHyperVTests are only supported on Windows."); + } + + // Clean up any previous test key + var regHive = Microsoft.Win32.RegistryHive.CurrentUser; + var baseKey = Microsoft.Win32.RegistryKey.OpenBaseKey(regHive, Microsoft.Win32.RegistryView.Registry64); + baseKey.DeleteSubKeyTree(testKeyPath, false); + + bool? result = null; + + // Create the test key + using (var key = baseKey.CreateSubKey(testKeyPath)) + { + if (regValue.HasValue) + { + key.SetValue(valueName, regValue.Value, Microsoft.Win32.RegistryValueKind.DWord); + } + else + { + // Ensure the value does not exist + key.DeleteValue(valueName, false); + } + + result = System.Management.Automation.Remoting.RemoteSessionHyperVSocketClient.IsRequirePsDirectAuthenticationEnabled(testKeyPath, regHive); + } + + Assert.True(result.HasValue, "IsRequirePsDirectAuthenticationEnabled should return a value."); + Assert.True(expected == result.Value, + $"Expected IsRequirePsDirectAuthenticationEnabled to return {expected} when registry value is {(regValue.HasValue ? regValue.ToString() : "not set")}."); + + return; + } + + [SkippableTheory] + [InlineData("testToken", "testToken")] + [InlineData("testToken\0", "testToken")] + public async Task ValidatePassesWhenTokensMatch(string token, string expectedToken) + { + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + + var expectedClientSends = new List<(string message, Encoding encoding)>{ + (message: "VERSION", encoding: Encoding.ASCII), // Response to VERSION + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: $"TOKEN {token}", encoding: Encoding.ASCII) + }; + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "PASS", encoding: Encoding.ASCII) // Response to token + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: true, cts.Token, sendFirst: true), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken); + System.Threading.Thread.Sleep(100); // Allow time for server to process + } + + await serverTask; + } + + [SkippableTheory] + [InlineData("abc", "xyz")] + [InlineData("abc", "abcdef")] + [InlineData("abcdef", "abc")] + [InlineData("abc\0def", "abc")] + public async Task ValidateFailsWhenTokensMismatch(string token, string expectedToken) + { + int port = 50000 + (int)(DateTime.Now.Ticks % 10000); + + var expectedClientSends = new List<(string message, Encoding encoding)>{ + (message: "VERSION", encoding: Encoding.ASCII), // Initial request + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: $"TOKEN {token}", encoding: Encoding.ASCII) + }; + + var serverResponses = new List<(string message, Encoding encoding)>{ + (message: "VERSION_2", encoding: Encoding.ASCII), // Response to VERSION + (message: "PASS", encoding: Encoding.ASCII), // Response to VERSION_2 + (message: "FAIL", encoding: Encoding.ASCII) // Response to token + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + var serverTask = Task.Run(() => StartHandshakeServer("Client", port, serverResponses, expectedClientSends, verifyConnectionClosed: true, cts.Token, sendFirst: true), cts.Token); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + ConnectWithRetry(client, IPAddress.Loopback, port, _output); + var exception = Assert.Throws( + () => System.Management.Automation.Remoting.RemoteSessionHyperVSocketServer.ValidateToken(client, expectedToken)); + System.Threading.Thread.Sleep(100); // Allow time for server to process + Assert.Contains("The credential is invalid.", exception.Message); + } + + await serverTask; + } + } +}