diff --git a/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs
new file mode 100644
index 00000000..50601f66
--- /dev/null
+++ b/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs
@@ -0,0 +1,143 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Protocol;
+using System.Net;
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.Client;
+
+///
+/// A transport that automatically detects whether to use Streamable HTTP or SSE transport
+/// by trying Streamable HTTP first and falling back to SSE if that fails.
+///
+internal sealed partial class AutoDetectingClientSessionTransport : ITransport
+{
+ private readonly SseClientTransportOptions _options;
+ private readonly HttpClient _httpClient;
+ private readonly ILoggerFactory? _loggerFactory;
+ private readonly ILogger _logger;
+ private readonly string _name;
+ private readonly Channel _messageChannel;
+
+ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
+ {
+ Throw.IfNull(transportOptions);
+ Throw.IfNull(httpClient);
+
+ _options = transportOptions;
+ _httpClient = httpClient;
+ _loggerFactory = loggerFactory;
+ _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
+ _name = endpointName;
+
+ // Same as TransportBase.cs.
+ _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ {
+ SingleReader = true,
+ SingleWriter = false,
+ });
+ }
+
+ ///
+ /// Returns the active transport (either StreamableHttp or SSE)
+ ///
+ internal ITransport? ActiveTransport { get; private set; }
+
+ public ChannelReader MessageReader => _messageChannel.Reader;
+
+ ///
+ public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (ActiveTransport is null)
+ {
+ return InitializeAsync(message, cancellationToken);
+ }
+
+ return ActiveTransport.SendMessageAsync(message, cancellationToken);
+ }
+
+ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken cancellationToken)
+ {
+ // Try StreamableHttp first
+ var streamableHttpTransport = new StreamableHttpClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory);
+
+ try
+ {
+ LogAttemptingStreamableHttp(_name);
+ using var response = await streamableHttpTransport.SendHttpRequestAsync(message, cancellationToken).ConfigureAwait(false);
+
+ if (response.IsSuccessStatusCode)
+ {
+ LogUsingStreamableHttp(_name);
+ ActiveTransport = streamableHttpTransport;
+ }
+ else
+ {
+ // If the status code is not success, fall back to SSE
+ LogStreamableHttpFailed(_name, response.StatusCode);
+
+ await streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
+ await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ catch
+ {
+ // If nothing threw inside the try block, we've either set streamableHttpTransport as the
+ // ActiveTransport, or else we will have disposed it in the !IsSuccessStatusCode else block.
+ await streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
+ throw;
+ }
+ }
+
+ private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
+ {
+ var sseTransport = new SseClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory);
+
+ try
+ {
+ LogAttemptingSSE(_name);
+ await sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
+ await sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+
+ LogUsingSSE(_name);
+ ActiveTransport = sseTransport;
+ }
+ catch
+ {
+ await sseTransport.DisposeAsync().ConfigureAwait(false);
+ throw;
+ }
+ }
+
+ public async ValueTask DisposeAsync()
+ {
+ try
+ {
+ if (ActiveTransport is not null)
+ {
+ await ActiveTransport.DisposeAsync().ConfigureAwait(false);
+ }
+ }
+ finally
+ {
+ // In the majority of cases, either the Streamable HTTP transport or SSE transport has completed the channel by now.
+ // However, this may not be the case if HttpClient throws during the initial request due to misconfiguration.
+ _messageChannel.Writer.TryComplete();
+ }
+ }
+
+ [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} attempting to connect using Streamable HTTP transport.")]
+ private partial void LogAttemptingStreamableHttp(string endpointName);
+
+ [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport.")]
+ private partial void LogStreamableHttpFailed(string endpointName, HttpStatusCode statusCode);
+
+ [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} using Streamable HTTP transport.")]
+ private partial void LogUsingStreamableHttp(string endpointName);
+
+ [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} attempting to connect using SSE transport.")]
+ private partial void LogAttemptingSSE(string endpointName);
+
+ [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} using SSE transport.")]
+ private partial void LogUsingSSE(string endpointName);
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Client/HttpTransportMode.cs b/src/ModelContextProtocol/Client/HttpTransportMode.cs
new file mode 100644
index 00000000..f2d46c30
--- /dev/null
+++ b/src/ModelContextProtocol/Client/HttpTransportMode.cs
@@ -0,0 +1,23 @@
+namespace ModelContextProtocol.Client;
+
+///
+/// Specifies the transport mode for HTTP client connections.
+///
+public enum HttpTransportMode
+{
+ ///
+ /// Automatically detect the appropriate transport by trying Streamable HTTP first, then falling back to SSE if that fails.
+ /// This is the recommended mode for maximum compatibility.
+ ///
+ AutoDetect,
+
+ ///
+ /// Use only the Streamable HTTP transport.
+ ///
+ StreamableHttp,
+
+ ///
+ /// Use only the HTTP with SSE transport.
+ ///
+ Sse
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol/Client/SseClientSessionTransport.cs
index b29306e1..fd2466ea 100644
--- a/src/ModelContextProtocol/Client/SseClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Client/SseClientSessionTransport.cs
@@ -6,6 +6,7 @@
using System.Net.ServerSentEvents;
using System.Text;
using System.Text.Json;
+using System.Threading.Channels;
namespace ModelContextProtocol.Client;
@@ -24,15 +25,16 @@ internal sealed partial class SseClientSessionTransport : TransportBase
private readonly TaskCompletionSource _connectionEstablished;
///
- /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server.
+ /// SSE transport for a single session. Unlike stdio it does not launch a process, but connects to an existing server.
/// The HTTP server can be local or remote, and must support the SSE protocol.
///
- /// Configuration options for the transport.
- /// The HTTP client instance used for requests.
- /// Logger factory for creating loggers.
- /// The endpoint name used for logging purposes.
- public SseClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
- : base(endpointName, loggerFactory)
+ public SseClientSessionTransport(
+ string endpointName,
+ SseClientTransportOptions transportOptions,
+ HttpClient httpClient,
+ Channel? messageChannel,
+ ILoggerFactory? loggerFactory)
+ : base(endpointName, messageChannel, loggerFactory)
{
Throw.IfNull(transportOptions);
Throw.IfNull(httpClient);
diff --git a/src/ModelContextProtocol/Client/SseClientTransport.cs b/src/ModelContextProtocol/Client/SseClientTransport.cs
index df1cdac6..57789c1c 100644
--- a/src/ModelContextProtocol/Client/SseClientTransport.cs
+++ b/src/ModelContextProtocol/Client/SseClientTransport.cs
@@ -4,11 +4,11 @@
namespace ModelContextProtocol.Client;
///
-/// Provides an over HTTP using the Server-Sent Events (SSE) protocol.
+/// Provides an over HTTP using the Server-Sent Events (SSE) or Streamable HTTP protocol.
///
///
-/// This transport connects to an MCP server over HTTP using SSE,
-/// allowing for real-time server-to-client communication with a standard HTTP request.
+/// This transport connects to an MCP server over HTTP using SSE or Streamable HTTP,
+/// allowing for real-time server-to-client communication with a standard HTTP requests.
/// Unlike the , this transport connects to an existing server
/// rather than launching a new process.
///
@@ -36,7 +36,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFac
/// The HTTP client instance used for requests.
/// Logger factory for creating loggers used for diagnostic output during transport operations.
///
- /// to dispose of when the transport is disposed;
+ /// to dispose of when the transport is disposed;
/// if the caller is retaining ownership of the 's lifetime.
///
public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false)
@@ -57,12 +57,22 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient
///
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
- if (_options.UseStreamableHttp)
+ switch (_options.TransportMode)
{
- return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name);
+ case HttpTransportMode.AutoDetect:
+ return new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name);
+ case HttpTransportMode.StreamableHttp:
+ return new StreamableHttpClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory);
+ case HttpTransportMode.Sse:
+ return await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false);
+ default:
+ throw new InvalidOperationException($"Unsupported transport mode: {_options.TransportMode}");
}
+ }
- var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name);
+ private async Task ConnectSseTransportAsync(CancellationToken cancellationToken)
+ {
+ var sessionTransport = new SseClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory);
try
{
diff --git a/src/ModelContextProtocol/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol/Client/SseClientTransportOptions.cs
index f67f6f07..8843fca8 100644
--- a/src/ModelContextProtocol/Client/SseClientTransportOptions.cs
+++ b/src/ModelContextProtocol/Client/SseClientTransportOptions.cs
@@ -31,11 +31,19 @@ public required Uri Endpoint
}
///
- /// Gets or sets a value indicating whether to use "Streamable HTTP" for the transport rather than "HTTP with SSE". Defaults to false.
+ /// Gets or sets the transport mode to use for the connection. Defaults to .
+ ///
+ ///
+ ///
+ /// When set to (the default), the client will first attempt to use
+ /// Streamable HTTP transport and automatically fall back to SSE transport if the server doesn't support it.
+ ///
+ ///
/// Streamable HTTP transport specification.
/// HTTP with SSE transport specification.
- ///
- public bool UseStreamableHttp { get; init; }
+ ///
+ ///
+ public HttpTransportMode TransportMode { get; init; } = HttpTransportMode.AutoDetect;
///
/// Gets a transport identifier used for logging purposes.
diff --git a/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs
index 3330f4de..e35e2b18 100644
--- a/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs
@@ -17,7 +17,7 @@ internal class StreamClientSessionTransport : TransportBase
/// Initializes a new instance of the class.
///
///
- /// The text writer connected to the server's input stream.
+ /// The text writer connected to the server's input stream.
/// Messages written to this writer will be sent to the server.
///
///
@@ -41,17 +41,17 @@ public StreamClientSessionTransport(
_serverOutput = serverOutput;
_serverInput = serverInput;
+ SetConnected();
+
// Start reading messages in the background. We use the rarer pattern of new Task + Start
// in order to ensure that the body of the task will always see _readTask initialized.
// It is then able to reliably null it out on completion.
var readTask = new Task(
- thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token),
+ thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token),
this,
TaskCreationOptions.DenyChildAttach);
_readTask = readTask.Unwrap();
readTask.Start();
-
- SetConnected();
}
///
@@ -80,7 +80,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
}
///
- public override ValueTask DisposeAsync() =>
+ public override ValueTask DisposeAsync() =>
CleanupAsync(cancellationToken: CancellationToken.None);
private async Task ReadMessagesAsync(CancellationToken cancellationToken)
diff --git a/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs
index 55ecb963..78f99e20 100644
--- a/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs
@@ -4,6 +4,8 @@
using System.Net.ServerSentEvents;
using System.Text.Json;
using ModelContextProtocol.Protocol;
+using System.Threading.Channels;
+
#if NET
using System.Net.Http.Json;
#else
@@ -28,8 +30,13 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
private string? _mcpSessionId;
private Task? _getReceiveTask;
- public StreamableHttpClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
- : base(endpointName, loggerFactory)
+ public StreamableHttpClientSessionTransport(
+ string endpointName,
+ SseClientTransportOptions transportOptions,
+ HttpClient httpClient,
+ Channel? messageChannel,
+ ILoggerFactory? loggerFactory)
+ : base(endpointName, messageChannel, loggerFactory)
{
Throw.IfNull(transportOptions);
Throw.IfNull(httpClient);
@@ -46,9 +53,15 @@ public StreamableHttpClientSessionTransport(SseClientTransportOptions transportO
}
///
- public override async Task SendMessageAsync(
- JsonRpcMessage message,
- CancellationToken cancellationToken = default)
+ public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ // Immediately dispose the response. SendHttpRequestAsync only returns the response so the auto transport can look at it.
+ using var response = await SendHttpRequestAsync(message, cancellationToken).ConfigureAwait(false);
+ response.EnsureSuccessStatusCode();
+ }
+
+ // This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception.
+ internal async Task SendHttpRequestAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
cancellationToken = sendCts.Token;
@@ -73,9 +86,14 @@ public override async Task SendMessageAsync(
};
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId);
- using var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
- response.EnsureSuccessStatusCode();
+ var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
+
+ // We'll let the caller decide whether to throw or fall back given an unsuccessful response.
+ if (!response.IsSuccessStatusCode)
+ {
+ return response;
+ }
var rpcRequest = message as JsonRpcRequest;
JsonRpcMessage? rpcResponseCandidate = null;
@@ -93,7 +111,7 @@ public override async Task SendMessageAsync(
if (rpcRequest is null)
{
- return;
+ return response;
}
if (rpcResponseCandidate is not JsonRpcMessageWithId messageWithId || messageWithId.Id != rpcRequest.Id)
@@ -111,6 +129,8 @@ public override async Task SendMessageAsync(
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
}
+
+ return response;
}
public override async ValueTask DisposeAsync()
@@ -136,7 +156,12 @@ public override async ValueTask DisposeAsync()
}
finally
{
- SetDisconnected();
+ // If we're auto-detecting the transport and failed to connect, leave the message Channel open for the SSE transport.
+ // This class isn't directly exposed to public callers, so we don't have to worry about changing the _state in this case.
+ if (_options.TransportMode is not HttpTransportMode.AutoDetect || _getReceiveTask is not null)
+ {
+ SetDisconnected();
+ }
}
}
diff --git a/src/ModelContextProtocol/Protocol/TransportBase.cs b/src/ModelContextProtocol/Protocol/TransportBase.cs
index 31b3b146..9be9c6fa 100644
--- a/src/ModelContextProtocol/Protocol/TransportBase.cs
+++ b/src/ModelContextProtocol/Protocol/TransportBase.cs
@@ -36,12 +36,20 @@ public abstract partial class TransportBase : ITransport
/// Initializes a new instance of the class.
///
protected TransportBase(string name, ILoggerFactory? loggerFactory)
+ : this(name, null, loggerFactory)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of the class with a specified channel to back .
+ ///
+ internal TransportBase(string name, Channel? messageChannel, ILoggerFactory? loggerFactory)
{
Name = name;
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
- // Unbounded channel to prevent blocking on writes
- _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ // Unbounded channel to prevent blocking on writes. Ensure AutoDetectingClientSessionTransport matches this.
+ _messageChannel = messageChannel ?? Channel.CreateUnbounded(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false,
@@ -112,7 +120,7 @@ protected void SetConnected()
case StateConnected:
return;
-
+
case StateDisconnected:
throw new IOException("Transport is already disconnected and can't be reconnected.");
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs
index ee717530..30187faa 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs
@@ -56,7 +56,7 @@ public async Task Connect_TestServer_ShouldProvideServerFields()
[Fact]
public async Task ListTools_Sse_TestServer()
- {
+ {
// arrange
// act
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs
index c987bca9..be8763ae 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs
@@ -1,5 +1,6 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
+using ModelContextProtocol.Client;
namespace ModelContextProtocol.AspNetCore.Tests;
@@ -34,4 +35,112 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat
Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name);
}
+
+ [Fact]
+ public async Task StreamableHttpMode_Works_WithRootEndpoint()
+ {
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "StreamableHttpTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(ConfigureStateless);
+ await using var app = Builder.Build();
+
+ app.MapMcp();
+
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var mcpClient = await ConnectAsync("/", new()
+ {
+ Endpoint = new Uri("http://localhost/"),
+ TransportMode = HttpTransportMode.AutoDetect
+ });
+
+ Assert.Equal("StreamableHttpTestServer", mcpClient.ServerInfo.Name);
+ }
+
+ [Fact]
+ public async Task AutoDetectMode_Works_WithRootEndpoint()
+ {
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "AutoDetectTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(ConfigureStateless);
+ await using var app = Builder.Build();
+
+ app.MapMcp();
+
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var mcpClient = await ConnectAsync("/", new()
+ {
+ Endpoint = new Uri("http://localhost/"),
+ TransportMode = HttpTransportMode.AutoDetect
+ });
+
+ Assert.Equal("AutoDetectTestServer", mcpClient.ServerInfo.Name);
+ }
+
+ [Fact]
+ public async Task AutoDetectMode_Works_WithSseEndpoint()
+ {
+ Assert.SkipWhen(Stateless, "SSE endpoint is disabled in stateless mode.");
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "AutoDetectSseTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(ConfigureStateless);
+ await using var app = Builder.Build();
+
+ app.MapMcp();
+
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var mcpClient = await ConnectAsync("/sse", new()
+ {
+ Endpoint = new Uri("http://localhost/sse"),
+ TransportMode = HttpTransportMode.AutoDetect
+ });
+
+ Assert.Equal("AutoDetectSseTestServer", mcpClient.ServerInfo.Name);
+ }
+
+ [Fact]
+ public async Task SseMode_Works_WithSseEndpoint()
+ {
+ Assert.SkipWhen(Stateless, "SSE endpoint is disabled in stateless mode.");
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "SseTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(ConfigureStateless);
+ await using var app = Builder.Build();
+
+ app.MapMcp();
+
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var mcpClient = await ConnectAsync(options: new()
+ {
+ Endpoint = new Uri("http://localhost/sse"),
+ TransportMode = HttpTransportMode.Sse
+ });
+
+ Assert.Equal("SseTestServer", mcpClient.ServerInfo.Name);
+ }
}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
index cf49fee1..6d153220 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
@@ -20,16 +20,17 @@ protected void ConfigureStateless(HttpServerTransportOptions options)
options.Stateless = Stateless;
}
- protected async Task ConnectAsync(string? path = null)
+ protected async Task ConnectAsync(string? path = null, SseClientTransportOptions? options = null)
{
+ // Default behavior when no options are provided
path ??= UseStreamableHttp ? "/" : "/sse";
- var sseClientTransportOptions = new SseClientTransportOptions()
+ await using var transport = new SseClientTransport(options ?? new SseClientTransportOptions()
{
Endpoint = new Uri($"http://localhost{path}"),
- UseStreamableHttp = UseStreamableHttp,
- };
- await using var transport = new SseClientTransport(sseClientTransportOptions, HttpClient, LoggerFactory);
+ TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse,
+ }, HttpClient, LoggerFactory);
+
return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken);
}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs
index b1b61805..a9e2e5f5 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs
@@ -9,6 +9,6 @@ public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fix
{
Endpoint = new Uri("http://localhost/stateless"),
Name = "In-memory Streamable HTTP Client",
- UseStreamableHttp = true,
+ TransportMode = HttpTransportMode.StreamableHttp,
};
}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs
index 2f364be0..acfc744b 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs
@@ -18,7 +18,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem
{
Endpoint = new Uri("http://localhost/"),
Name = "In-memory Streamable HTTP Client",
- UseStreamableHttp = true,
+ TransportMode = HttpTransportMode.StreamableHttp,
};
private async Task StartAsync()
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs
index 94540f8c..d7f8433b 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs
@@ -98,7 +98,7 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer()
await using var transport = new SseClientTransport(new()
{
Endpoint = new("http://localhost/mcp"),
- UseStreamableHttp = true,
+ TransportMode = HttpTransportMode.StreamableHttp,
}, HttpClient, LoggerFactory);
await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken);
@@ -118,7 +118,7 @@ public async Task CanCallToolConcurrently()
await using var transport = new SseClientTransport(new()
{
Endpoint = new("http://localhost/mcp"),
- UseStreamableHttp = true,
+ TransportMode = HttpTransportMode.StreamableHttp,
}, HttpClient, LoggerFactory);
await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken);
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs
index 64505b3d..7c4366f1 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs
@@ -15,7 +15,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur
{
Endpoint = new Uri("http://localhost/"),
Name = "In-memory Streamable HTTP Client",
- UseStreamableHttp = true,
+ TransportMode = HttpTransportMode.StreamableHttp,
};
[Fact]
diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs
new file mode 100644
index 00000000..8f6fbff2
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs
@@ -0,0 +1,109 @@
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Tests.Utils;
+using System.Net;
+
+namespace ModelContextProtocol.Tests.Transport;
+
+public class SseClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
+{
+ [Fact]
+ public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt()
+ {
+ var options = new SseClientTransportOptions
+ {
+ Endpoint = new Uri("http://localhost"),
+ TransportMode = HttpTransportMode.AutoDetect,
+ Name = "AutoDetect test client"
+ };
+
+ using var mockHttpHandler = new MockHttpHandler();
+ using var httpClient = new HttpClient(mockHttpHandler);
+ await using var transport = new SseClientTransport(options, httpClient, LoggerFactory);
+
+ // Simulate successful Streamable HTTP response for initialize
+ mockHttpHandler.RequestHandler = (request) =>
+ {
+ if (request.Method == HttpMethod.Post)
+ {
+ return Task.FromResult(new HttpResponseMessage
+ {
+ StatusCode = HttpStatusCode.OK,
+ Content = new StringContent("{\"jsonrpc\":\"2.0\",\"id\":\"init-id\",\"result\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{\"tools\":{}}}}"),
+ Headers =
+ {
+ { "Content-Type", "application/json" },
+ { "mcp-session-id", "test-session" }
+ }
+ });
+ }
+
+ // Shouldn't reach here for successful Streamable HTTP
+ throw new InvalidOperationException("Unexpected request");
+ };
+
+ await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken);
+
+ // The auto-detecting transport should be returned
+ Assert.NotNull(session);
+ }
+
+ [Fact]
+ public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails()
+ {
+ var options = new SseClientTransportOptions
+ {
+ Endpoint = new Uri("http://localhost"),
+ TransportMode = HttpTransportMode.AutoDetect,
+ Name = "AutoDetect test client"
+ };
+
+ using var mockHttpHandler = new MockHttpHandler();
+ using var httpClient = new HttpClient(mockHttpHandler);
+ await using var transport = new SseClientTransport(options, httpClient, LoggerFactory);
+
+ var requestCount = 0;
+
+ mockHttpHandler.RequestHandler = (request) =>
+ {
+ requestCount++;
+
+ if (request.Method == HttpMethod.Post && requestCount == 1)
+ {
+ // First POST (Streamable HTTP) fails
+ return Task.FromResult(new HttpResponseMessage
+ {
+ StatusCode = HttpStatusCode.NotFound,
+ Content = new StringContent("Streamable HTTP not supported")
+ });
+ }
+
+ if (request.Method == HttpMethod.Get)
+ {
+ // SSE connection request
+ return Task.FromResult(new HttpResponseMessage
+ {
+ StatusCode = HttpStatusCode.OK,
+ Content = new StringContent("event: endpoint\r\ndata: /sse-endpoint\r\n\r\n"),
+ Headers = { { "Content-Type", "text/event-stream" } }
+ });
+ }
+
+ if (request.Method == HttpMethod.Post && requestCount > 1)
+ {
+ // Subsequent POST to SSE endpoint succeeds
+ return Task.FromResult(new HttpResponseMessage
+ {
+ StatusCode = HttpStatusCode.OK,
+ Content = new StringContent("accepted")
+ });
+ }
+
+ throw new InvalidOperationException($"Unexpected request: {request.Method}, count: {requestCount}");
+ };
+
+ await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken);
+
+ // The auto-detecting transport should be returned
+ Assert.NotNull(session);
+ }
+}
\ No newline at end of file
diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs
index 857e496a..ae449ac9 100644
--- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs
+++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs
@@ -17,6 +17,7 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper)
Endpoint = new Uri("http://localhost:8080"),
ConnectionTimeout = TimeSpan.FromSeconds(2),
Name = "Test Server",
+ TransportMode = HttpTransportMode.Sse,
AdditionalHeaders = new Dictionary
{
["test"] = "header"
diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs
index 7a7f39c2..b8d8d714 100644
--- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs
+++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs
@@ -1,9 +1,10 @@
using ModelContextProtocol.Client;
+using ModelContextProtocol.Tests.Utils;
using System.Runtime.InteropServices;
namespace ModelContextProtocol.Tests.Transport;
-public class StdioClientTransportTests
+public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
{
[Fact]
public async Task CreateAsync_ValidProcessInvalidServer_Throws()
@@ -11,10 +12,10 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws()
string id = Guid.NewGuid().ToString("N");
StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
- new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }) :
- new(new() { Command = "ls", Arguments = [id] });
+ new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }, LoggerFactory) :
+ new(new() { Command = "ls", Arguments = [id] }, LoggerFactory);
- IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken));
+ IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken));
Assert.Contains(id, e.ToString());
}
}