diff --git a/docs/docfx/articles/transforms.md b/docs/docfx/articles/transforms.md index 6e92fbec7..4fc4288f4 100644 --- a/docs/docfx/articles/transforms.md +++ b/docs/docfx/articles/transforms.md @@ -959,6 +959,8 @@ Only header1 and header2 are copied from the proxy response. All request transforms must derive from the abstract base class [RequestTransform](xref:Yarp.ReverseProxy.Transforms.RequestTransform). These can freely modify the proxy `HttpRequestMessage`. Avoid reading or modifying the request body as this may disrupt the proxying flow. Consider also adding a parametrized extension method on `TransformBuilderContext` for discoverability and easy of use. +A request transform may conditionally produce an immediate response such as for error conditions. This prevents any remaining transforms from running and the request from being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`, or writing to the `HttpResponse.Body` or `BodyWriter`. + ### ResponseTransform All response transforms must derive from the abstract base class [ResponseTransform](xref:Yarp.ReverseProxy.Transforms.ResponseTransform). These can freely modify the client `HttpResponse`. Avoid reading or modifying the response body as this may disrupt the proxying flow. Consider also adding a parametrized extension method on `TransformBuilderContext` for discoverability and easy of use. diff --git a/samples/ReverseProxy.Auth.Sample/Startup.cs b/samples/ReverseProxy.Auth.Sample/Startup.cs index 34dfa4056..c848930a8 100644 --- a/samples/ReverseProxy.Auth.Sample/Startup.cs +++ b/samples/ReverseProxy.Auth.Sample/Startup.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Net.Http.Headers; +using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; +using Yarp.ReverseProxy.Transforms; namespace Yarp.Sample { @@ -31,8 +34,37 @@ public void ConfigureServices(IServiceCollection services) // Required to supply the authentication UI in Views/* services.AddRazorPages(); + services.AddSingleton(); + services.AddReverseProxy() - .LoadFromConfig(_configuration.GetSection("ReverseProxy")); + .LoadFromConfig(_configuration.GetSection("ReverseProxy")) + .AddTransforms(transformBuilderContext => // Add transforms inline + { + // For each route+cluster pair decide if we want to add transforms, and if so, which? + // This logic is re-run each time a route is rebuilt. + + // Only do this for routes that require auth. + if (string.Equals("myPolicy", transformBuilderContext.Route.AuthorizationPolicy)) + { + transformBuilderContext.AddRequestTransform(async transformContext => + { + // AuthN and AuthZ will have already been completed after request routing. + var ticket = await transformContext.HttpContext.AuthenticateAsync(CookieAuthenticationDefaults.AuthenticationScheme); + var tokenService = transformContext.HttpContext.RequestServices.GetRequiredService(); + var token = await tokenService.GetAuthTokenAsync(ticket.Principal); + + // Reject invalid requests + if (string.IsNullOrEmpty(token)) + { + var response = transformContext.HttpContext.Response; + response.StatusCode = 401; + return; + } + + transformContext.ProxyRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); + }); + } + }); ; services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme) .AddCookie(); diff --git a/samples/ReverseProxy.Auth.Sample/TokenService.cs b/samples/ReverseProxy.Auth.Sample/TokenService.cs new file mode 100644 index 000000000..fa7db8c95 --- /dev/null +++ b/samples/ReverseProxy.Auth.Sample/TokenService.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Security.Claims; +using System.Threading.Tasks; + +namespace Yarp.Sample +{ + internal class TokenService + { + internal Task GetAuthTokenAsync(ClaimsPrincipal user) + { + // we only have tokens for bob + if (string.Equals("Bob", user.Identity.Name)) + { + return Task.FromResult(Guid.NewGuid().ToString()); + } + return Task.FromResult(null); + } + } +} diff --git a/samples/ReverseProxy.Auth.Sample/Views/Account/Login.cshtml b/samples/ReverseProxy.Auth.Sample/Views/Account/Login.cshtml index 0c2a01f9d..a4b290ca4 100644 --- a/samples/ReverseProxy.Auth.Sample/Views/Account/Login.cshtml +++ b/samples/ReverseProxy.Auth.Sample/Views/Account/Login.cshtml @@ -6,7 +6,7 @@

-
+

Note:The authorization policy will check for the value of "green", other values should pass authentication, but not authorize for specific routes
diff --git a/samples/ReverseProxy.Transforms.Sample/MyTransformProvider.cs b/samples/ReverseProxy.Transforms.Sample/MyTransformProvider.cs index e2fc45f53..b585c0b3f 100644 --- a/samples/ReverseProxy.Transforms.Sample/MyTransformProvider.cs +++ b/samples/ReverseProxy.Transforms.Sample/MyTransformProvider.cs @@ -25,8 +25,7 @@ public void ValidateRoute(TransformRouteValidationContext context) public void ValidateCluster(TransformClusterValidationContext context) { // Check all clusters for a custom property and validate the associated transform data. - string value = null; - if (context.Cluster.Metadata?.TryGetValue("CustomMetadata", out value) ?? false) + if (context.Cluster.Metadata?.TryGetValue("CustomMetadata", out var value) ?? false) { if (string.IsNullOrEmpty(value)) { diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index 53d75c542..a7f150bfd 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -94,6 +94,11 @@ public async ValueTask SendAsync( _ = requestConfig ?? throw new ArgumentNullException(nameof(requestConfig)); _ = transformer ?? throw new ArgumentNullException(nameof(transformer)); + if (RequestUtilities.IsResponseSet(context.Response)) + { + throw new InvalidOperationException("The request cannot be forwarded, the response has already started"); + } + // HttpClient overload for SendAsync changes response behavior to fully buffered which impacts performance // See discussion in https://github.com/microsoft/reverse-proxy/issues/458 if (httpClient is HttpClient) @@ -116,6 +121,15 @@ public async ValueTask SendAsync( var (destinationRequest, requestContent) = await CreateRequestMessageAsync( context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource); + // Transforms generated a response, do not proxy. + if (RequestUtilities.IsResponseSet(context.Response)) + { + Log.NotProxying(_logger, context.Response.StatusCode); + return ForwarderError.None; + } + + Log.Proxying(_logger, destinationRequest, isStreamingRequest); + // :: Step 4: Send the outgoing request using HttpClient HttpResponseMessage destinationResponse; try @@ -282,6 +296,12 @@ public async ValueTask SendAsync( // :: Step 3: Copy request headers Client --► Proxy --► Destination await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix); + // The transformer generated a response, do not forward. + if (RequestUtilities.IsResponseSet(context.Response)) + { + return (destinationRequest, requestContent); + } + if (isUpgradeRequest) { RestoreUpgradeHeaders(context, destinationRequest); @@ -291,8 +311,6 @@ public async ValueTask SendAsync( var request = context.Request; destinationRequest.RequestUri ??= RequestUtilities.MakeDestinationAddress(destinationPrefix, request.Path, request.QueryString); - Log.Proxying(_logger, destinationRequest, isStreamingRequest); - if (requestConfig?.AllowResponseBuffering != true) { context.Features.Get()?.DisableBuffering(); @@ -765,6 +783,11 @@ private static class Log EventIds.ForwardingError, "{error}: {message}"); + private static readonly Action _notProxying = LoggerMessage.Define( + LogLevel.Information, + EventIds.NotForwarding, + "Not Proxying, a {statusCode} response was set by the transforms."); + public static void ResponseReceived(ILogger logger, HttpResponseMessage msg) { _responseReceived(logger, msg.Version, (int)msg.StatusCode, null); @@ -782,6 +805,11 @@ public static void Proxying(ILogger logger, HttpRequestMessage msg, bool isStrea } } + public static void NotProxying(ILogger logger, int statusCode) + { + _notProxying(logger, statusCode, null); + } + public static void ErrorProxying(ILogger logger, ForwarderError error, Exception ex) { _proxyError(logger, error, GetMessage(error), ex); diff --git a/src/ReverseProxy/Forwarder/HttpTransformer.cs b/src/ReverseProxy/Forwarder/HttpTransformer.cs index 0e34b8255..a3a847385 100644 --- a/src/ReverseProxy/Forwarder/HttpTransformer.cs +++ b/src/ReverseProxy/Forwarder/HttpTransformer.cs @@ -59,6 +59,9 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) => /// See for constructing a custom request Uri. /// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri. /// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority"). + /// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from + /// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`, + /// or writing to the `HttpResponse.Body` or `BodyWriter`. /// /// The incoming request. /// The outgoing proxy request. diff --git a/src/ReverseProxy/Forwarder/RequestUtilities.cs b/src/ReverseProxy/Forwarder/RequestUtilities.cs index 530342971..25f335603 100644 --- a/src/ReverseProxy/Forwarder/RequestUtilities.cs +++ b/src/ReverseProxy/Forwarder/RequestUtilities.cs @@ -404,4 +404,10 @@ static StringValues ToArray(in HeaderStringValues values) values = default; return false; } + + internal static bool IsResponseSet(HttpResponse response) + { + return response.StatusCode != StatusCodes.Status200OK + || response.HasStarted; + } } diff --git a/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs b/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs index ed2814f17..145d9f477 100644 --- a/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs +++ b/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs @@ -88,6 +88,12 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H foreach (var requestTransform in RequestTransforms) { await requestTransform.ApplyAsync(transformContext); + + // The transform generated a response, do not apply further transforms and do not forward. + if (RequestUtilities.IsResponseSet(httpContext.Response)) + { + return; + } } // Allow a transform to directly set a custom RequestUri. diff --git a/src/ReverseProxy/Utilities/EventIds.cs b/src/ReverseProxy/Utilities/EventIds.cs index c548b73e9..22a1cc58e 100644 --- a/src/ReverseProxy/Utilities/EventIds.cs +++ b/src/ReverseProxy/Utilities/EventIds.cs @@ -62,4 +62,5 @@ internal static class EventIds public static readonly EventId ResponseReceived = new EventId(56, "ResponseReceived"); public static readonly EventId DelegationQueueReset = new EventId(57, "DelegationQueueReset"); public static readonly EventId Http10RequestVersionDetected = new EventId(58, "Http10RequestVersionDetected"); + public static readonly EventId NotForwarding = new EventId(59, "NotForwarding"); } diff --git a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs index 6d6b9ff17..16f0a44b2 100644 --- a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs @@ -358,6 +358,125 @@ public async Task TransformRequestAsync_ReplaceBody() events.AssertContainProxyStages(); } + [Fact] + public async Task TransformRequestAsync_SetsStatus_ShortCircuits() + { + var events = TestEventListener.Collect(); + + var httpContext = new DefaultHttpContext(); + httpContext.Request.Method = "POST"; + httpContext.Request.Protocol = "HTTP/2"; + + var destinationPrefix = "https://localhost/"; + + var transforms = new DelegateHttpTransforms() + { + CopyRequestHeaders = true, + OnRequest = (context, request, destination) => + { + context.Response.StatusCode = 401; + return Task.CompletedTask; + } + }; + + var sut = CreateProxy(); + var client = MockHttpHandler.CreateClient( + (HttpRequestMessage request, CancellationToken cancellationToken) => + { + throw new NotImplementedException(); + }); + + var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms); + + Assert.Equal(ForwarderError.None, proxyError); + Assert.Equal(StatusCodes.Status401Unauthorized, httpContext.Response.StatusCode); + + AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode); + events.AssertContainProxyStages(new ForwarderStage[0]); + } + + [Fact] + public async Task TransformRequestAsync_StartsResponse_ShortCircuits() + { + var events = TestEventListener.Collect(); + + var httpContext = new DefaultHttpContext(); + var responseBody = new TestResponseBody(); + httpContext.Features.Set(responseBody); + httpContext.Features.Set(responseBody); + httpContext.Request.Method = "POST"; + httpContext.Request.Protocol = "HTTP/2"; + + var destinationPrefix = "https://localhost/"; + + var transforms = new DelegateHttpTransforms() + { + CopyRequestHeaders = true, + OnRequest = (context, request, destination) => + { + return context.Response.StartAsync(); + } + }; + + var sut = CreateProxy(); + var client = MockHttpHandler.CreateClient( + (HttpRequestMessage request, CancellationToken cancellationToken) => + { + throw new NotImplementedException(); + }); + + var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms); + + Assert.Equal(ForwarderError.None, proxyError); + Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode); + Assert.True(httpContext.Response.HasStarted); + + AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode); + events.AssertContainProxyStages(new ForwarderStage[0]); + } + + [Fact] + public async Task TransformRequestAsync_WritesToResponse_ShortCircuits() + { + var events = TestEventListener.Collect(); + + var httpContext = new DefaultHttpContext(); + var resultStream = new MemoryStream(); + var responseBody = new TestResponseBody(resultStream); + httpContext.Features.Set(responseBody); + httpContext.Features.Set(responseBody); + httpContext.Request.Method = "POST"; + httpContext.Request.Protocol = "HTTP/2"; + + var destinationPrefix = "https://localhost/"; + + var transforms = new DelegateHttpTransforms() + { + CopyRequestHeaders = true, + OnRequest = (context, request, destination) => + { + return context.Response.Body.WriteAsync(Encoding.UTF8.GetBytes("Hello World")).AsTask(); + } + }; + + var sut = CreateProxy(); + var client = MockHttpHandler.CreateClient( + (HttpRequestMessage request, CancellationToken cancellationToken) => + { + throw new NotImplementedException(); + }); + + var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms); + + Assert.Equal(ForwarderError.None, proxyError); + Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode); + Assert.True(httpContext.Response.HasStarted); + Assert.Equal("Hello World", Encoding.UTF8.GetString(resultStream.ToArray())); + + AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode); + events.AssertContainProxyStages(new ForwarderStage[0]); + } + // Tests proxying an upgradeable request. [Theory] [InlineData("WebSocket")] @@ -1887,11 +2006,10 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() var httpContext = new DefaultHttpContext(); httpContext.Request.Method = "GET"; httpContext.Request.Host = new HostString("example.com:3456"); - var responseBody = new TestResponseBody() { HasStarted = true }; + var responseBody = new TestResponseBody(); httpContext.Features.Set(responseBody); httpContext.Features.Set(responseBody); httpContext.Features.Set(responseBody); - httpContext.RequestAborted = new CancellationToken(canceled: true); var destinationPrefix = "https://localhost:123/"; var sut = CreateProxy(); @@ -1900,7 +2018,11 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() { var message = new HttpResponseMessage() { - Content = new StreamContent(new MemoryStream(new byte[1])) + Content = new StreamContent(new CallbackReadStream((_, _) => + { + responseBody.HasStarted = true; + throw new TaskCanceledException(); + })) }; message.Headers.AcceptRanges.Add("bytes"); return Task.FromResult(message); @@ -2828,7 +2950,8 @@ public Task SendFileAsync(string path, long offset, long? count, CancellationTok public Task StartAsync(CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + OnStart(); + return Task.CompletedTask; } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default)