diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..2746f0762 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +{ + "name": "C# (.NET SDK)", + "image": "mcr.microsoft.com/devcontainers/dotnet:1-8.0-jammy", + "features": { + "ghcr.io/devcontainers/features/dotnet:2": { + "version": "9.0" + }, + "ghcr.io/devcontainers/features/node:1": {} + }, + "customizations": { + "vscode": { + "extensions": [ + "ms-dotnettools.csharp", + "ms-dotnettools.csdevkit" + ], + "settings": { + "dotnet.defaultSolution": "ModelContextProtocol.slnx" + } + } + }, + "postCreateCommand": "dotnet --list-sdks && echo 'Available .NET SDKs installed successfully!'" +} \ No newline at end of file diff --git a/.editorconfig b/.editorconfig index 3ce6343ba..99b4fbaf1 100644 --- a/.editorconfig +++ b/.editorconfig @@ -3,6 +3,7 @@ root = true # C# files [*.cs] +csharp_style_namespace_declarations=file_scoped:warning # Compiler dotnet_diagnostic.CS1998.severity = suggestion # CS1998: Missing awaits diff --git a/.github/workflows/ci-build-test.yml b/.github/workflows/ci-build-test.yml index 9a66ccd07..b29bccd53 100644 --- a/.github/workflows/ci-build-test.yml +++ b/.github/workflows/ci-build-test.yml @@ -18,6 +18,7 @@ on: - "src/**" - "tests/**" - "samples/**" + - "docs/**" permissions: contents: read @@ -50,6 +51,10 @@ jobs: if: runner.os == 'Linux' run: sudo apt-get install -y mono-devel + - name: Setup Mono on macOS + if: runner.os == 'macOS' + run: brew install mono + - name: Set up Node.js uses: actions/setup-node@cdca7365b2dadb8aad0a33bc7601856ffabcc48e # v4.3.0 with: diff --git a/.github/workflows/markdown-link-check.yml b/.github/workflows/markdown-link-check.yml index 6a49bec6a..b69bbc440 100644 --- a/.github/workflows/markdown-link-check.yml +++ b/.github/workflows/markdown-link-check.yml @@ -21,5 +21,5 @@ jobs: - name: Markup Link Checker (mlc) uses: becheran/mlc@c925f90a9a25e16e4c4bfa29058f6f9ffa9f0d8c # v0.21.0 with: - # Ignore external links that result in 403 errors during CI. Do not warn for redirects where we want to keep the vanity URL in the markdown or for GitHub links that redirect to the login. - args: --ignore-links "https://www.anthropic.com/*,https://hackerone.com/anthropic-vdp/*" --do-not-warn-for-redirect-to "https://modelcontextprotocol.io/*,https://github.com/login?*" ./ + # Ignore external links that result in 403 errors during CI. Do not warn for redirects where we want to keep the vanity URL in the markdown or for GitHub links that redirect to the login, and DocFX snippet links. + args: --ignore-links "https://www.anthropic.com/*,https://hackerone.com/anthropic-vdp/*" --do-not-warn-for-redirect-to "https://modelcontextprotocol.io/*,https://github.com/login?*" --ignore-links "*samples/*?name=snippet_*" ./docs diff --git a/Directory.Build.props b/Directory.Build.props index bca375922..b0cf7f215 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -34,4 +34,9 @@ true + + + + true + diff --git a/Directory.Packages.props b/Directory.Packages.props index 6da9521f7..1ae45da31 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,7 +3,7 @@ true 9.0.5 10.0.0-preview.4.25258.110 - 9.7.1 + 9.9.1 @@ -47,13 +47,13 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive all - + @@ -66,8 +66,8 @@ - - + + @@ -75,9 +75,9 @@ - - + + - + \ No newline at end of file diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index 5ed8ba0d6..a70e3e310 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -8,12 +8,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + - - + + + @@ -24,6 +56,7 @@ + diff --git a/README.md b/README.md index 163d57f8a..4c87ba9bd 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ dotnet add package ModelContextProtocol --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClient.CreateAsync` method is used to instantiate and connect an `McpClient` +to a server. Once you have an `McpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions @@ -48,7 +48,7 @@ var clientTransport = new StdioClientTransport(new StdioClientTransportOptions Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); -var client = await McpClientFactory.CreateAsync(clientTransport); +var client = await McpClient.CreateAsync(clientTransport); // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) @@ -88,7 +88,7 @@ var response = await chatClient.GetResponseAsync( Here is an example of how to create an MCP server and register all tools from the current application. It includes a simple echo tool as an example (this is included in the same file here for easy of copy and paste, but it needn't be in the same file... the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the -`McpTool` attribute as tools.) +`McpServerTool` attribute as tools.) ``` dotnet add package ModelContextProtocol --prerelease @@ -122,14 +122,14 @@ public static class EchoTool } ``` -Tools can have the `IMcpServer` representing the server injected via a parameter to the method, and can use that for interaction with +Tools can have the `McpServer` representing the server injected via a parameter to the method, and can use that for interaction with the connected client. Similarly, arguments may be injected via dependency injection. For example, this tool will use the supplied -`IMcpServer` to make sampling requests back to the client in order to summarize content it downloads from the specified url via +`McpServer` to make sampling requests back to the client in order to summarize content it downloads from the specified url via an `HttpClient` injected via dependency injection. ```csharp [McpServerTool(Name = "SummarizeContentFromUrl"), Description("Summarizes content downloaded from a specific URI")] public static async Task SummarizeDownloadedContent( - IMcpServer thisServer, + McpServer thisServer, HttpClient httpClient, [Description("The url from which to download the content to summarize")] string url, CancellationToken cancellationToken) @@ -174,57 +174,54 @@ using System.Text.Json; McpServerOptions options = new() { ServerInfo = new Implementation { Name = "MyServer", Version = "1.0.0" }, - Capabilities = new ServerCapabilities + Handlers = new McpServerHandlers() { - Tools = new ToolsCapability - { - ListToolsHandler = (request, cancellationToken) => - ValueTask.FromResult(new ListToolsResult - { - Tools = - [ - new Tool - { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "The input to echo back" - } - }, - "required": ["message"] - } - """), - } - ] - }), - - CallToolHandler = (request, cancellationToken) => + ListToolsHandler = (request, cancellationToken) => + ValueTask.FromResult(new ListToolsResult { - if (request.Params?.Name == "echo") - { - if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) + Tools = + [ + new Tool { - throw new McpException("Missing required argument 'message'"); + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back" + } + }, + "required": ["message"] + } + """), } + ] + }), - return ValueTask.FromResult(new CallToolResult - { - Content = [new TextContentBlock { Text = $"Echo: {message}", Type = "text" }] - }); + CallToolHandler = (request, cancellationToken) => + { + if (request.Params?.Name == "echo") + { + if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) + { + throw new McpException("Missing required argument 'message'"); } - throw new McpException($"Unknown tool: '{request.Params?.Name}'"); - }, + return ValueTask.FromResult(new CallToolResult + { + Content = [new TextContentBlock { Text = $"Echo: {message}", Type = "text" }] + }); + } + + throw new McpException($"Unknown tool: '{request.Params?.Name}'"); } - }, + } }; -await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); +await using McpServer server = McpServer.Create(new StdioServerTransport("MyServer"), options); await server.RunAsync(); ``` diff --git a/docs/concepts/elicitation/elicitation.md b/docs/concepts/elicitation/elicitation.md new file mode 100644 index 000000000..ebda0979a --- /dev/null +++ b/docs/concepts/elicitation/elicitation.md @@ -0,0 +1,53 @@ +--- +title: Elicitation +author: mikekistler +description: Enable interactive AI experiences by requesting user input during tool execution. +uid: elicitation +--- + +## Elicitation + +The **elicitation** feature allows servers to request additional information from users during interactions. This enables more dynamic and interactive AI experiences, making it easier to gather necessary context before executing tasks. + +### Server Support for Elicitation + +Servers request structured data from users with the [ElicitAsync] extension method on [IMcpServer]. +The C# SDK registers an instance of [IMcpServer] with the dependency injection container, +so tools can simply add a parameter of type [IMcpServer] to their method signature to access it. + +[ElicitAsync]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Server.McpServerExtensions.html#ModelContextProtocol_Server_McpServerExtensions_ElicitAsync_ModelContextProtocol_Server_IMcpServer_ModelContextProtocol_Protocol_ElicitRequestParams_System_Threading_CancellationToken_ +[IMcpServer]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Server.IMcpServer.html + +The MCP Server must specify the schema of each input value it is requesting from the user. +Only primitive types (string, number, boolean) are supported for elicitation requests. +The schema may include a description to help the user understand what is being requested. + +The server can request a single input or multiple inputs at once. +To help distinguish multiple inputs, each input has a unique name. + +The following example demonstrates how a server could request a boolean response from the user. + +[!code-csharp[](samples/server/Tools/InteractiveTools.cs?name=snippet_GuessTheNumber)] + +### Client Support for Elicitation + +Elicitation is an optional feature so clients declare their support for it in their capabilities as part of the `initialize` request. In the MCP C# SDK, this is done by configuring an [ElicitationHandler] in the [McpClientOptions]: + +[ElicitationHandler]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Protocol.ElicitationCapability.html#ModelContextProtocol_Protocol_ElicitationCapability_ElicitationHandler +[McpClientOptions]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Client.McpClientOptions.html + +[!code-csharp[](samples/client/Program.cs?name=snippet_McpInitialize)] + +The ElicitationHandler is an asynchronous method that will be called when the server requests additional information. +The ElicitationHandler must request input from the user and return the data in a format that matches the requested schema. +This will be highly dependent on the client application and how it interacts with the user. + +If the user provides the requested information, the ElicitationHandler should return an [ElicitResult] with the action set to "accept" and the content containing the user's input. +If the user does not provide the requested information, the ElicitationHandler should return an [ElicitResult] with the action set to "reject" and no content. + +[ElicitResult]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Protocol.ElicitResult.html + +Below is an example of how a console application might handle elicitation requests. +Here's an example implementation: + +[!code-csharp[](samples/client/Program.cs?name=snippet_ElicitationHandler)] diff --git a/docs/concepts/elicitation/samples/client/ElicitationClient.csproj b/docs/concepts/elicitation/samples/client/ElicitationClient.csproj new file mode 100644 index 000000000..4be0d6ec7 --- /dev/null +++ b/docs/concepts/elicitation/samples/client/ElicitationClient.csproj @@ -0,0 +1,16 @@ + + + + Exe + net9.0 + enable + enable + false + false + + + + + + + diff --git a/docs/concepts/elicitation/samples/client/Program.cs b/docs/concepts/elicitation/samples/client/Program.cs new file mode 100644 index 000000000..b2a91ca4b --- /dev/null +++ b/docs/concepts/elicitation/samples/client/Program.cs @@ -0,0 +1,118 @@ +using System.Text.Json; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; + +var endpoint = Environment.GetEnvironmentVariable("ENDPOINT") ?? "http://localhost:3001"; + +var clientTransport = new HttpClientTransport(new() +{ + Endpoint = new Uri(endpoint), + TransportMode = HttpTransportMode.StreamableHttp, +}); + +// +McpClientOptions options = new() +{ + ClientInfo = new() + { + Name = "ElicitationClient", + Version = "1.0.0" + }, + Handlers = new() + { + ElicitationHandler = HandleElicitationAsync + } +}; + +await using var mcpClient = await McpClient.CreateAsync(clientTransport, options); +// + +var tools = await mcpClient.ListToolsAsync(); +foreach (var tool in tools) +{ + Console.WriteLine($"Connected to server with tools: {tool.Name}"); +} + +Console.WriteLine($"Calling tool: {tools.First().Name}"); + +var result = await mcpClient.CallToolAsync(toolName: tools.First().Name); + +foreach (var block in result.Content) +{ + if (block is TextContentBlock textBlock) + { + Console.WriteLine(textBlock.Text); + } + else + { + Console.WriteLine($"Received unexpected result content of type {block.GetType()}"); + } +} + +// +async ValueTask HandleElicitationAsync(ElicitRequestParams? requestParams, CancellationToken token) +{ + // Bail out if the requestParams is null or if the requested schema has no properties + if (requestParams is null || requestParams.RequestedSchema?.Properties is null) + { + return new ElicitResult(); + } + + // Process the elicitation request + if (requestParams.Message is not null) + { + Console.WriteLine(requestParams.Message); + } + + var content = new Dictionary(); + + // Loop through requestParams.requestSchema.Properties dictionary requesting values for each property + foreach (var property in requestParams.RequestedSchema.Properties) + { + if (property.Value is ElicitRequestParams.BooleanSchema booleanSchema) + { + Console.Write($"{booleanSchema.Description}: "); + var clientInput = Console.ReadLine(); + bool parsedBool; + + // Try standard boolean parsing first + if (bool.TryParse(clientInput, out parsedBool)) + { + content[property.Key] = JsonSerializer.Deserialize(JsonSerializer.Serialize(parsedBool)); + } + // Also accept "yes"/"no" as valid boolean inputs + else if (string.Equals(clientInput?.Trim(), "yes", StringComparison.OrdinalIgnoreCase)) + { + content[property.Key] = JsonSerializer.Deserialize(JsonSerializer.Serialize(true)); + } + else if (string.Equals(clientInput?.Trim(), "no", StringComparison.OrdinalIgnoreCase)) + { + content[property.Key] = JsonSerializer.Deserialize(JsonSerializer.Serialize(false)); + } + } + else if (property.Value is ElicitRequestParams.NumberSchema numberSchema) + { + Console.Write($"{numberSchema.Description}: "); + var clientInput = Console.ReadLine(); + double parsedNumber; + if (double.TryParse(clientInput, out parsedNumber)) + { + content[property.Key] = JsonSerializer.Deserialize(JsonSerializer.Serialize(parsedNumber)); + } + } + else if (property.Value is ElicitRequestParams.StringSchema stringSchema) + { + Console.Write($"{stringSchema.Description}: "); + var clientInput = Console.ReadLine(); + content[property.Key] = JsonSerializer.Deserialize(JsonSerializer.Serialize(clientInput)); + } + } + + // Return the user's input + return new ElicitResult + { + Action = "accept", + Content = content + }; +} +// diff --git a/docs/concepts/elicitation/samples/server/Elicitation.csproj b/docs/concepts/elicitation/samples/server/Elicitation.csproj new file mode 100644 index 000000000..f4998aa12 --- /dev/null +++ b/docs/concepts/elicitation/samples/server/Elicitation.csproj @@ -0,0 +1,15 @@ + + + + net9.0 + enable + enable + false + false + + + + + + + diff --git a/docs/concepts/elicitation/samples/server/Elicitation.http b/docs/concepts/elicitation/samples/server/Elicitation.http new file mode 100644 index 000000000..04dcdb343 --- /dev/null +++ b/docs/concepts/elicitation/samples/server/Elicitation.http @@ -0,0 +1,58 @@ +@HostAddress = http://localhost:3001 + +# No session ID, so elicitation capabilities not declared. + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json +MCP-Protocol-Version: 2025-06-18 + +{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "guess_the_number" + } +} + +### + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json + +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "clientInfo": { + "name": "RestClient", + "version": "0.1.0" + }, + "capabilities": { + "elicitation": {} + }, + "protocolVersion": "2025-06-18" + } +} + +### + +@SessionId = lgEu87uKTy8kLffZayO5rQ + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json +Mcp-Session-Id: {{SessionId}} +MCP-Protocol-Version: 2025-06-18 + +{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "guess_the_number" + } +} diff --git a/docs/concepts/elicitation/samples/server/Program.cs b/docs/concepts/elicitation/samples/server/Program.cs new file mode 100644 index 000000000..8c6862464 --- /dev/null +++ b/docs/concepts/elicitation/samples/server/Program.cs @@ -0,0 +1,24 @@ +using Elicitation.Tools; + +var builder = WebApplication.CreateBuilder(args); + +// Add services to the container. + +builder.Services.AddMcpServer() + .WithHttpTransport(options => + options.IdleTimeout = Timeout.InfiniteTimeSpan // Never timeout + ) + .WithTools(); + +builder.Logging.AddConsole(options => +{ + options.LogToStandardErrorThreshold = LogLevel.Information; +}); + +var app = builder.Build(); + +app.UseHttpsRedirection(); + +app.MapMcp(); + +app.Run(); diff --git a/samples/AspNetCoreSseServer/Properties/launchSettings.json b/docs/concepts/elicitation/samples/server/Properties/launchSettings.json similarity index 85% rename from samples/AspNetCoreSseServer/Properties/launchSettings.json rename to docs/concepts/elicitation/samples/server/Properties/launchSettings.json index c789fb474..74cf457ef 100644 --- a/samples/AspNetCoreSseServer/Properties/launchSettings.json +++ b/docs/concepts/elicitation/samples/server/Properties/launchSettings.json @@ -1,4 +1,4 @@ -{ +{ "$schema": "https://json.schemastore.org/launchsettings.json", "profiles": { "http": { @@ -7,7 +7,6 @@ "applicationUrl": "http://localhost:3001", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development", - "OTEL_SERVICE_NAME": "sse-server", } }, "https": { @@ -16,8 +15,7 @@ "applicationUrl": "https://localhost:7133;http://localhost:3001", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development", - "OTEL_SERVICE_NAME": "sse-server", } } } -} +} \ No newline at end of file diff --git a/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs b/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs new file mode 100644 index 000000000..b907a805d --- /dev/null +++ b/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs @@ -0,0 +1,126 @@ +using System.ComponentModel; +using System.Text.Json; +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using static ModelContextProtocol.Protocol.ElicitRequestParams; + +namespace Elicitation.Tools; + +[McpServerToolType] +public sealed class InteractiveTools +{ + // + [McpServerTool, Description("A simple game where the user has to guess a number between 1 and 10.")] + public async Task GuessTheNumber( + McpServer server, // Get the McpServer from DI container + CancellationToken token + ) + { + // Check if the client supports elicitation + if (server.ClientCapabilities?.Elicitation == null) + { + // fail the tool call + throw new McpException("Client does not support elicitation"); + } + + // First ask the user if they want to play + var playSchema = new RequestSchema + { + Properties = + { + ["Answer"] = new BooleanSchema() + } + }; + + var playResponse = await server.ElicitAsync(new ElicitRequestParams + { + Message = "Do you want to play a game?", + RequestedSchema = playSchema + }, token); + + // Check if user wants to play + if (playResponse.Action != "accept" || playResponse.Content?["Answer"].ValueKind != JsonValueKind.True) + { + return "Maybe next time!"; + } + // + + // Now ask the user to enter their name + var nameSchema = new RequestSchema + { + Properties = + { + ["Name"] = new StringSchema() + { + Description = "Name of the player", + MinLength = 2, + MaxLength = 50, + } + } + }; + + var nameResponse = await server.ElicitAsync(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = nameSchema + }, token); + + if (nameResponse.Action != "accept") + { + return "Maybe next time!"; + } + string? playerName = nameResponse.Content?["Name"].GetString(); + + // Generate a random number between 1 and 10 + Random random = new Random(); + int targetNumber = random.Next(1, 11); // 1 to 10 inclusive + int attempts = 0; + + var message = "Guess a number between 1 and 10"; + + while (true) + { + attempts++; + + var guessSchema = new RequestSchema + { + Properties = + { + ["Guess"] = new NumberSchema() + { + Type = "integer", + Minimum = 1, + Maximum = 10, + } + } + }; + + var guessResponse = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = guessSchema + }, token); + + if (guessResponse.Action != "accept") + { + return "Maybe next time!"; + } + int guess = (int)(guessResponse.Content?["Guess"].GetInt32())!; + + // Check if the guess is correct + if (guess == targetNumber) + { + return $"Congratulations {playerName}! You guessed the number {targetNumber} in {attempts} attempts!"; + } + else if (guess < targetNumber) + { + message = $"Your guess is too low! Try again (Attempt #{attempts}):"; + } + else + { + message = $"Your guess is too high! Try again (Attempt #{attempts}):"; + } + } + } +} \ No newline at end of file diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md new file mode 100644 index 000000000..3e081d123 --- /dev/null +++ b/docs/concepts/filters.md @@ -0,0 +1,317 @@ +--- +title: Filters +author: halter73 +description: MCP Server Handler Filters +uid: filters +--- + +# MCP Server Handler Filters + +For each handler type in the MCP Server, there are corresponding `AddXXXFilter` methods in `McpServerBuilderExtensions.cs` that allow you to add filters to the handler pipeline. The filters are stored in `McpServerOptions.Filters` and applied during server configuration. + +## Available Filter Methods + +The following filter methods are available: + +- `AddListResourceTemplatesFilter` - Filter for list resource templates handlers +- `AddListToolsFilter` - Filter for list tools handlers +- `AddCallToolFilter` - Filter for call tool handlers +- `AddListPromptsFilter` - Filter for list prompts handlers +- `AddGetPromptFilter` - Filter for get prompt handlers +- `AddListResourcesFilter` - Filter for list resources handlers +- `AddReadResourceFilter` - Filter for read resource handlers +- `AddCompleteFilter` - Filter for completion handlers +- `AddSubscribeToResourcesFilter` - Filter for resource subscription handlers +- `AddUnsubscribeFromResourcesFilter` - Filter for resource unsubscription handlers +- `AddSetLoggingLevelFilter` - Filter for logging level handlers + +## Usage + +Filters are functions that take a handler and return a new handler, allowing you to wrap the original handler with additional functionality: + +```csharp +services.AddMcpServer() + .WithListToolsHandler(async (context, cancellationToken) => + { + // Your base handler logic + return new ListToolsResult { Tools = GetTools() }; + }) + .AddListToolsFilter(next => async (context, cancellationToken) => + { + var logger = context.Services?.GetService>(); + + // Pre-processing logic + logger?.LogInformation("Before handler execution"); + + var result = await next(context, cancellationToken); + + // Post-processing logic + logger?.LogInformation("After handler execution"); + return result; + }); +``` + +## Filter Execution Order + +```csharp +services.AddMcpServer() + .WithListToolsHandler(baseHandler) + .AddListToolsFilter(filter1) // Executes first (outermost) + .AddListToolsFilter(filter2) // Executes second + .AddListToolsFilter(filter3); // Executes third (closest to handler) +``` + +Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filter2 -> filter1` + +## Common Use Cases + +### Logging +```csharp +.AddListToolsFilter(next => async (context, cancellationToken) => +{ + var logger = context.Services?.GetService>(); + + logger?.LogInformation($"Processing request from {context.Meta.ProgressToken}"); + var result = await next(context, cancellationToken); + logger?.LogInformation($"Returning {result.Tools?.Count ?? 0} tools"); + return result; +}); +``` + +### Error Handling +```csharp +.AddCallToolFilter(next => async (context, cancellationToken) => +{ + try + { + return await next(context, cancellationToken); + } + catch (Exception ex) + { + return new CallToolResult + { + Content = new[] { new TextContent { Type = "text", Text = $"Error: {ex.Message}" } }, + IsError = true + }; + } +}); +``` + +### Performance Monitoring +```csharp +.AddListToolsFilter(next => async (context, cancellationToken) => +{ + var logger = context.Services?.GetService>(); + + var stopwatch = Stopwatch.StartNew(); + var result = await next(context, cancellationToken); + stopwatch.Stop(); + logger?.LogInformation($"Handler took {stopwatch.ElapsedMilliseconds}ms"); + return result; +}); +``` + +### Caching +```csharp +.AddListResourcesFilter(next => async (context, cancellationToken) => +{ + var cache = context.Services!.GetRequiredService(); + + var cacheKey = $"resources:{context.Params.Cursor}"; + if (cache.TryGetValue(cacheKey, out var cached)) + { + return (ListResourcesResult)cached; + } + + var result = await next(context, cancellationToken); + cache.Set(cacheKey, result, TimeSpan.FromMinutes(5)); + return result; +}); +``` + +## Built-in Authorization Filters + +When using the ASP.NET Core integration (`ModelContextProtocol.AspNetCore`), you can add authorization filters to support `[Authorize]` and `[AllowAnonymous]` attributes on MCP server tools, prompts, and resources by calling `AddAuthorizationFilters()` on your MCP server builder. + +### Enabling Authorization Filters + +To enable authorization support, call `AddAuthorizationFilters()` when configuring your MCP server: + +```csharp +services.AddMcpServer() + .WithHttpTransport() + .AddAuthorizationFilters() // Enable authorization filter support + .WithTools(); +``` + +**Important**: You should always call `AddAuthorizationFilters()` when using ASP.NET Core integration if you want to use authorization attributes like `[Authorize]` on your MCP server tools, prompts, or resources. + +### Authorization Attributes Support + +The MCP server automatically respects the following authorization attributes: + +- **`[Authorize]`** - Requires authentication for access +- **`[Authorize(Roles = "RoleName")]`** - Requires specific roles +- **`[Authorize(Policy = "PolicyName")]`** - Requires specific authorization policies +- **`[AllowAnonymous]`** - Explicitly allows anonymous access (overrides `[Authorize]`) + +### Tool Authorization + +Tools can be decorated with authorization attributes to control access: + +```csharp +[McpServerToolType] +public class WeatherTools +{ + [McpServerTool, Description("Gets public weather data")] + public static string GetWeather(string location) + { + return $"Weather for {location}: Sunny, 25°C"; + } + + [McpServerTool, Description("Gets detailed weather forecast")] + [Authorize] // Requires authentication + public static string GetDetailedForecast(string location) + { + return $"Detailed forecast for {location}: ..."; + } + + [McpServerTool, Description("Manages weather alerts")] + [Authorize(Roles = "Admin")] // Requires Admin role + public static string ManageWeatherAlerts(string alertType) + { + return $"Managing alert: {alertType}"; + } +} +``` + +### Class-Level Authorization + +You can apply authorization at the class level, which affects all tools in the class: + +```csharp +[McpServerToolType] +[Authorize] // All tools require authentication +public class RestrictedTools +{ + [McpServerTool, Description("Restricted tool accessible to authenticated users")] + public static string RestrictedOperation() + { + return "Restricted operation completed"; + } + + [McpServerTool, Description("Public tool accessible to anonymous users")] + [AllowAnonymous] // Overrides class-level [Authorize] + public static string PublicOperation() + { + return "Public operation completed"; + } +} +``` + +### How Authorization Filters Work + +The authorization filters work differently for list operations versus individual operations: + +#### List Operations (ListTools, ListPrompts, ListResources) +For list operations, the filters automatically remove unauthorized items from the results. Users only see tools, prompts, or resources they have permission to access. + +#### Individual Operations (CallTool, GetPrompt, ReadResource) +For individual operations, the filters throw an `McpException` with "Access forbidden" message. These get turned into JSON-RPC errors if uncaught by middleware. + +### Filter Execution Order and Authorization + +Authorization filters are applied automatically when you call `AddAuthorizationFilters()`. These filters run at a specific point in the filter pipeline, which means: + +**Filters added before authorization filters** can see: +- Unauthorized requests for operations before they are rejected by the authorization filters +- Complete listings for unauthorized primitives before they are filtered out by the authorization filters + +**Filters added after authorization filters** will only see: +- Authorized requests that passed authorization checks +- Filtered listings containing only authorized primitives + +This allows you to implement logging, metrics, or other cross-cutting concerns that need to see all requests, while still maintaining proper authorization: + +```csharp +services.AddMcpServer() + .WithHttpTransport() + .AddListToolsFilter(next => async (context, cancellationToken) => + { + var logger = context.Services?.GetService>(); + + // This filter runs BEFORE authorization - sees all tools + logger?.LogInformation("Request for tools list - will see all tools"); + var result = await next(context, cancellationToken); + logger?.LogInformation($"Returning {result.Tools?.Count ?? 0} tools after authorization"); + return result; + }) + .AddAuthorizationFilters() // Authorization filtering happens here + .AddListToolsFilter(next => async (context, cancellationToken) => + { + var logger = context.Services?.GetService>(); + + // This filter runs AFTER authorization - only sees authorized tools + var result = await next(context, cancellationToken); + logger?.LogInformation($"Post-auth filter sees {result.Tools?.Count ?? 0} authorized tools"); + return result; + }) + .WithTools(); +``` + +### Setup Requirements + +To use authorization features, you must configure authentication and authorization in your ASP.NET Core application and call `AddAuthorizationFilters()`: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddAuthentication("Bearer") + .AddJwtBearer(options => { /* JWT configuration */ }) + .AddMcp(options => { /* Resource metadata configuration */ }); +builder.Services.AddAuthorization(); + +builder.Services.AddMcpServer() + .WithHttpTransport() + .AddAuthorizationFilters() // Required for authorization support + .WithTools() + .AddCallToolFilter(next => async (context, cancellationToken) => + { + // Custom call tool logic + return await next(context, cancellationToken); + }); + +var app = builder.Build(); + +app.MapMcp(); +app.Run(); +``` + +### Custom Authorization Filters + +You can also create custom authorization filters using the filter methods: + +```csharp +.AddCallToolFilter(next => async (context, cancellationToken) => +{ + // Custom authorization logic + if (context.User?.Identity?.IsAuthenticated != true) + { + return new CallToolResult + { + Content = [new TextContent { Text = "Custom: Authentication required" }], + IsError = true + }; + } + + return await next(context, cancellationToken); +}); +``` + +### RequestContext + +Within filters, you have access to: + +- `context.User` - The current user's `ClaimsPrincipal` +- `context.Services` - The request's service provider for resolving authorization services +- `context.MatchedPrimitive` - The matched tool/prompt/resource with its metadata including authorization attributes via `context.MatchedPrimitive.Metadata` diff --git a/docs/concepts/httpcontext/httpcontext.md b/docs/concepts/httpcontext/httpcontext.md new file mode 100644 index 000000000..51bbd050a --- /dev/null +++ b/docs/concepts/httpcontext/httpcontext.md @@ -0,0 +1,31 @@ +--- +title: HTTP Context +author: mikekistler +description: How to access the HttpContext in the MCP C# SDK. +uid: httpcontext +--- + +## HTTP Context + +When using the Streamable HTTP transport, an MCP server may need to access the underlying [HttpContext] for a request. +The [HttpContext] contains request metadata such as the HTTP headers, authorization context, and the actual path and query string for the request. + +To access the [HttpContext], the MCP server should add the [IHttpContextAccessor] service to the application service collection (typically in Program.cs). +Then any classes, e.g. a class containing MCP tools, should accept an [IHttpContextAccessor] in their constructor and store this for use by its methods. +Methods then use the [HttpContext property][IHttpContextAccessor.HttpContext] of the accessor to get the current context. + +[HttpContext]: https://learn.microsoft.com/dotnet/api/microsoft.aspnetcore.http.httpcontext +[IHttpContextAccessor]: https://learn.microsoft.com/dotnet/api/microsoft.aspnetcore.http.ihttpcontextaccessor +[IHttpContextAccessor.HttpContext]: https://learn.microsoft.com/dotnet/api/microsoft.aspnetcore.http.ihttpcontextaccessor.httpcontext + +The following code snippet illustrates how to add the [IHttpContextAccessor] service to the application service collection: + +[!code-csharp[](samples/Program.cs?name=snippet_AddHttpContextAccessor)] + +Any class that needs access to the [HttpContext] can accept an [IHttpContextAccessor] in its constructor and store it for later use. +Methods of the class can then access the current [HttpContext] using the stored accessor. + +The following code snippet shows the `ContextTools` class accepting an [IHttpContextAccessor] in its primary constructor +and the `GetHttpHeaders` method accessing the current [HttpContext] to retrieve the HTTP headers from the current request. + +[!code-csharp[](samples/Tools/ContextTools.cs?name=snippet_AccessHttpContext)] diff --git a/docs/concepts/httpcontext/samples/HttpContext.csproj b/docs/concepts/httpcontext/samples/HttpContext.csproj new file mode 100644 index 000000000..2982d8f87 --- /dev/null +++ b/docs/concepts/httpcontext/samples/HttpContext.csproj @@ -0,0 +1,18 @@ + + + + net9.0 + enable + enable + + + + false + false + + + + + + + diff --git a/docs/concepts/httpcontext/samples/HttpContext.http b/docs/concepts/httpcontext/samples/HttpContext.http new file mode 100644 index 000000000..838457e9b --- /dev/null +++ b/docs/concepts/httpcontext/samples/HttpContext.http @@ -0,0 +1,15 @@ +@HostAddress = http://localhost:3001 + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json +MCP-Protocol-Version: 2025-06-18 + +{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "get_http_headers" + } +} diff --git a/docs/concepts/httpcontext/samples/Program.cs b/docs/concepts/httpcontext/samples/Program.cs new file mode 100644 index 000000000..043e6069d --- /dev/null +++ b/docs/concepts/httpcontext/samples/Program.cs @@ -0,0 +1,26 @@ +using HttpContext.Tools; + +var builder = WebApplication.CreateBuilder(args); + +// Add services to the container. + +builder.Services.AddMcpServer() + .WithHttpTransport() + .WithTools(); + +// +builder.Services.AddHttpContextAccessor(); +// + +builder.Logging.AddConsole(options => +{ + options.LogToStandardErrorThreshold = LogLevel.Information; +}); + +var app = builder.Build(); + +app.UseHttpsRedirection(); + +app.MapMcp(); + +app.Run(); diff --git a/docs/concepts/httpcontext/samples/Properties/launchSettings.json b/docs/concepts/httpcontext/samples/Properties/launchSettings.json new file mode 100644 index 000000000..c6eb0fa56 --- /dev/null +++ b/docs/concepts/httpcontext/samples/Properties/launchSettings.json @@ -0,0 +1,23 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "https://localhost:7191;http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/docs/concepts/httpcontext/samples/Tools/ContextTools.cs b/docs/concepts/httpcontext/samples/Tools/ContextTools.cs new file mode 100644 index 000000000..222130bb9 --- /dev/null +++ b/docs/concepts/httpcontext/samples/Tools/ContextTools.cs @@ -0,0 +1,28 @@ +using ModelContextProtocol.Server; +using System.ComponentModel; + +namespace HttpContext.Tools; + +// +public class ContextTools(IHttpContextAccessor httpContextAccessor) +{ + [McpServerTool(UseStructuredContent = true)] + [Description("Retrieves the HTTP headers from the current request and returns them as a JSON object.")] + public object GetHttpHeaders() + { + var context = httpContextAccessor.HttpContext; + if (context == null) + { + return "No HTTP context available"; + } + + var headers = new Dictionary(); + foreach (var header in context.Request.Headers) + { + headers[header.Key] = string.Join(", ", header.Value.ToArray()); + } + + return headers; + } +// +} diff --git a/docs/concepts/httpcontext/samples/TryItOut.ipynb b/docs/concepts/httpcontext/samples/TryItOut.ipynb new file mode 100644 index 000000000..95ebfafec --- /dev/null +++ b/docs/concepts/httpcontext/samples/TryItOut.ipynb @@ -0,0 +1,112 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "37868ee2", + "metadata": { + "language_info": { + "name": "polyglot-notebook" + }, + "polyglot_notebook": { + "kernelName": "csharp" + } + }, + "source": [ + "## HTTP Context\n", + "\n", + "This project illustrates how to access the HttpContext from tool calls. See the [README](../README.md) for more details.\n" + ] + }, + { + "cell_type": "markdown", + "id": "093a7d4f", + "metadata": {}, + "source": [ + "### Examples\n", + "\n", + "The following request illustrates a tool call that retrieves the HTTP headers from the [HttpContext] using the [IHttpContextAccessor].\n", + "\n", + "The server implements two other tools, `get_request_info` and `get_user_claims`. You can modify the code below to call these tools as well,\n", + "which illustrate how to access other parts of the [HttpContext].\n", + "\n", + "\n", + "[HttpContext]: https://docs.microsoft.com/dotnet/api/microsoft.aspnetcore.http.httpcontext\n", + "[IHttpContextAccessor]: https://docs.microsoft.com/dotnet/api/microsoft.aspnetcore.http.ihttpcontextaccessor" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3324ec56", + "metadata": { + "language_info": { + "name": "polyglot-notebook" + }, + "polyglot_notebook": { + "kernelName": "pwsh" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"result\": {\n", + " \"content\": [\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": \"[]\"\n", + " }\n", + " ],\n", + " \"structuredContent\": {\n", + " \"result\": []\n", + " }\n", + " },\n", + " \"id\": 1,\n", + " \"jsonrpc\": \"2.0\"\n", + "}\n" + ] + } + ], + "source": [ + "curl -s -X POST http://localhost:3001 `\n", + "-H \"ProtocolVersion: 2025-06-18\" `\n", + "-H \"Accept: application/json, text/event-stream\" `\n", + "-H \"Content-Type: application/json\" `\n", + "-d '{\n", + " \"jsonrpc\": \"2.0\",\n", + " \"id\": 1,\n", + " \"method\": \"tools/call\",\n", + " \"params\": {\n", + " \"name\": \"get_user_claims\"\n", + " }\n", + "}' | Where-Object { $_ -like \"data:*\" } | ForEach-Object { $_.Substring(5) } | jq" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".NET (C#)", + "language": "C#", + "name": ".net-csharp" + }, + "language_info": { + "name": "polyglot-notebook" + }, + "polyglot_notebook": { + "kernelInfo": { + "defaultKernelName": "csharp", + "items": [ + { + "aliases": [], + "languageName": "csharp", + "name": "csharp" + } + ] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/samples/AspNetCoreSseServer/appsettings.Development.json b/docs/concepts/httpcontext/samples/appsettings.Development.json similarity index 100% rename from samples/AspNetCoreSseServer/appsettings.Development.json rename to docs/concepts/httpcontext/samples/appsettings.Development.json diff --git a/samples/AspNetCoreSseServer/appsettings.json b/docs/concepts/httpcontext/samples/appsettings.json similarity index 100% rename from samples/AspNetCoreSseServer/appsettings.json rename to docs/concepts/httpcontext/samples/appsettings.json diff --git a/docs/concepts/index.md b/docs/concepts/index.md new file mode 100644 index 000000000..e038c8996 --- /dev/null +++ b/docs/concepts/index.md @@ -0,0 +1,2 @@ + +Welcome to the conceptual documentation for the Model Context Protocol SDK. Here you'll find high-level overviews, explanations, and guides to help you understand how the SDK implements the Model Context Protocol. diff --git a/docs/concepts/logging/logging.md b/docs/concepts/logging/logging.md new file mode 100644 index 000000000..411a61b1c --- /dev/null +++ b/docs/concepts/logging/logging.md @@ -0,0 +1,101 @@ +--- +title: Logging +author: mikekistler +description: How to use the logging feature in the MCP C# SDK. +uid: logging +--- + +## Logging + +MCP servers may expose log messages to clients through the [Logging utility]. + +[Logging utility]: https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/logging + +This document describes how to implement logging in MCP servers and how clients can consume log messages. + +### Logging Levels + +MCP uses the logging levels defined in [RFC 5424](https://tools.ietf.org/html/rfc5424). + +The MCP C# SDK uses the standard .NET [ILogger] and [ILoggerProvider] abstractions, which support a slightly +different set of logging levels. Here's the levels and how they map to standard .NET logging levels. + +| Level | .NET | Description | Example Use Case | +|-----------|------|-----------------------------------|------------------------------| +| debug | ✓ | Detailed debugging information | Function entry/exit points | +| info | ✓ | General informational messages | Operation progress updates | +| notice | | Normal but significant events | Configuration changes | +| warning | ✓ | Warning conditions | Deprecated feature usage | +| error | ✓ | Error conditions | Operation failures | +| critical | ✓ | Critical conditions | System component failures | +| alert | | Action must be taken immediately | Data corruption detected | +| emergency | | System is unusable | | + +**Note:** .NET's [ILogger] also supports a `Trace` level (more verbose than Debug) log level. +As there is no equivalent level in the MCP logging levels, Trace level logs messages are silently +dropped when sending messages to the client. + +[ILogger]: https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.ilogger +[ILoggerProvider]: https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.iloggerprovider + +### Server configuration and logging + +MCP servers that implement the Logging utility must declare this in the capabilities sent in the +[Initialization] phase at the beginning of the MCP session. + +[Initialization]: https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization + +Servers built with the C# SDK always declare the logging capability. Doing so does not obligate the server +to send log messages -- only allows it. Note that stateless MCP servers may not be capable of sending log +messages as there may not be an open connection to the client on which the log messages could be sent. + +The C# SDK provides an extension method [WithSetLoggingLevelHandler] on [IMcpServerBuilder] to allow the +server to perform any special logic it wants to perform when a client sets the logging level. However, the +SDK already takes care of setting the [LoggingLevel] in the [IMcpServer], so most servers will not need to +implement this. + +[IMcpServer]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Server.IMcpServer.html +[IMcpServerBuilder]: https://modelcontextprotocol.github.io/csharp-sdk/api/Microsoft.Extensions.DependencyInjection.IMcpServerBuilder.html +[WithSetLoggingLevelHandler]: https://modelcontextprotocol.github.io/csharp-sdk/api/Microsoft.Extensions.DependencyInjection.McpServerBuilderExtensions.html#Microsoft_Extensions_DependencyInjection_McpServerBuilderExtensions_WithSetLoggingLevelHandler_Microsoft_Extensions_DependencyInjection_IMcpServerBuilder_System_Func_ModelContextProtocol_Server_RequestContext_ModelContextProtocol_Protocol_SetLevelRequestParams__System_Threading_CancellationToken_System_Threading_Tasks_ValueTask_ModelContextProtocol_Protocol_EmptyResult___ +[LoggingLevel]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Server.IMcpServer.html#ModelContextProtocol_Server_IMcpServer_LoggingLevel + +MCP Servers using the MCP C# SDK can obtain an [ILoggerProvider] from the IMcpServer [AsClientLoggerProvider] extension method, +and from that can create an [ILogger] instance for logging messages that should be sent to the MCP client. + +[!code-csharp[](samples/server/Tools/LoggingTools.cs?name=snippet_LoggingConfiguration)] + +[ILoggerProvider]: https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.iloggerprovider +[AsClientLoggerProvider]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Server.McpServerExtensions.html#ModelContextProtocol_Server_McpServerExtensions_AsClientLoggerProvider_ModelContextProtocol_Server_IMcpServer_ +[ILogger]: https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.ilogger + +### Client support for logging + +When the server indicates that it supports logging, clients should configure +the logging level to specify which messages the server should send to the client. + +Clients should check if the server supports logging by checking the [Logging] property of the [ServerCapabilities] field of [IMcpClient]. + +[IMcpClient]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Client.IMcpClient.html +[ServerCapabilities]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Client.IMcpClient.html#ModelContextProtocol_Client_IMcpClient_ServerCapabilities +[Logging]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Protocol.ServerCapabilities.html#ModelContextProtocol_Protocol_ServerCapabilities_Logging + +[!code-csharp[](samples/client/Program.cs?name=snippet_LoggingCapabilities)] + +If the server supports logging, the client should set the level of log messages it wishes to receive with +the [SetLoggingLevel] method on [IMcpClient]. If the client does not set a logging level, the server might choose +to send all log messages or none -- this is not specified in the protocol -- so it is important that the client +sets a logging level to ensure it receives the desired log messages and only those messages. + +The `loggingLevel` set by the client is an MCP logging level. +See the [Logging Levels](#logging-levels) section above for the mapping between MCP and .NET logging levels. + +[SetLoggingLevel]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Client.McpClientExtensions.html#ModelContextProtocol_Client_McpClientExtensions_SetLoggingLevel_ModelContextProtocol_Client_IMcpClient_Microsoft_Extensions_Logging_LogLevel_System_Threading_CancellationToken_ + +[!code-csharp[](samples/client/Program.cs?name=snippet_LoggingLevel)] + +Lastly, the client must configure a notification handler for [NotificationMethods.LoggingMessageNotification] notifications. +The following example simply writes the log messages to the console. + +[NotificationMethods.LoggingMessageNotification]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Protocol.NotificationMethods.html#ModelContextProtocol_Protocol_NotificationMethods_LoggingMessageNotification + +[!code-csharp[](samples/client/Program.cs?name=snippet_LoggingHandler)] diff --git a/docs/concepts/logging/samples/client/LoggingClient.csproj b/docs/concepts/logging/samples/client/LoggingClient.csproj new file mode 100644 index 000000000..9f020005d --- /dev/null +++ b/docs/concepts/logging/samples/client/LoggingClient.csproj @@ -0,0 +1,16 @@ + + + + Exe + net9.0 + enable + enable + false + false + + + + + + + diff --git a/docs/concepts/logging/samples/client/Program.cs b/docs/concepts/logging/samples/client/Program.cs new file mode 100644 index 000000000..29a15726a --- /dev/null +++ b/docs/concepts/logging/samples/client/Program.cs @@ -0,0 +1,67 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Client; +using System.Text.Json; + +var endpoint = Environment.GetEnvironmentVariable("ENDPOINT") ?? "http://localhost:3001"; + +var clientTransport = new HttpClientTransport(new() +{ + Endpoint = new Uri(endpoint), + TransportMode = HttpTransportMode.StreamableHttp, +}); + +await using var mcpClient = await McpClient.CreateAsync(clientTransport); + +// +// Verify that the server supports logging +if (mcpClient.ServerCapabilities.Logging is null) +{ + Console.WriteLine("Server does not support logging."); + return; +} +// + +// Get the first argument if it was specified +var firstArgument = args.Length > 0 ? args[0] : null; + +if (firstArgument is not null) +{ + // Set the logging level to the value from the first argument + if (Enum.TryParse(firstArgument, true, out var loggingLevel)) + { + // + await mcpClient.SetLoggingLevel(loggingLevel); + // + } + else + { + Console.WriteLine($"Invalid logging level: {firstArgument}"); + // Print a list of valid logging levels + Console.WriteLine("Valid logging levels are:"); + foreach (var level in Enum.GetValues()) + { + Console.WriteLine($" - {level}"); + } + } +} + +// +mcpClient.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification, + (notification, ct) => + { + if (JsonSerializer.Deserialize(notification.Params) is { } ln) + { + Console.WriteLine($"[{ln.Level}] {ln.Logger} {ln.Data}"); + } + else + { + Console.WriteLine($"Received unexpected logging notification: {notification.Params}"); + } + + return default; + }); +// + +// Now call the "logging_tool" tool +await mcpClient.CallToolAsync("logging_tool"); + diff --git a/docs/concepts/logging/samples/server/Logging.csproj b/docs/concepts/logging/samples/server/Logging.csproj new file mode 100644 index 000000000..f4998aa12 --- /dev/null +++ b/docs/concepts/logging/samples/server/Logging.csproj @@ -0,0 +1,15 @@ + + + + net9.0 + enable + enable + false + false + + + + + + + diff --git a/docs/concepts/logging/samples/server/Logging.http b/docs/concepts/logging/samples/server/Logging.http new file mode 100644 index 000000000..3f0f028b7 --- /dev/null +++ b/docs/concepts/logging/samples/server/Logging.http @@ -0,0 +1,40 @@ +@HostAddress = http://localhost:3001 + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json + +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "clientInfo": { + "name": "RestClient", + "version": "0.1.0" + }, + "capabilities": {}, + "protocolVersion": "2025-06-18" + } +} + +### + +@SessionId = JCo3W4Q7KA_evyWoFE5qwA + +### + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json +MCP-Protocol-Version: 2025-06-18 +Mcp-Session-Id: {{SessionId}} + +{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "logging_tool" + } +} \ No newline at end of file diff --git a/docs/concepts/logging/samples/server/Program.cs b/docs/concepts/logging/samples/server/Program.cs new file mode 100644 index 000000000..7de039e09 --- /dev/null +++ b/docs/concepts/logging/samples/server/Program.cs @@ -0,0 +1,20 @@ +using Logging.Tools; + +var builder = WebApplication.CreateBuilder(args); + +// Add services to the container. + +builder.Services.AddMcpServer() + .WithHttpTransport(options => + options.IdleTimeout = Timeout.InfiniteTimeSpan // Never timeout + ) + .WithTools(); + // .WithSetLoggingLevelHandler(async (ctx, ct) => new EmptyResult()); + +var app = builder.Build(); + +app.UseHttpsRedirection(); + +app.MapMcp(); + +app.Run(); diff --git a/docs/concepts/logging/samples/server/Properties/launchSettings.json b/docs/concepts/logging/samples/server/Properties/launchSettings.json new file mode 100644 index 000000000..c09325b27 --- /dev/null +++ b/docs/concepts/logging/samples/server/Properties/launchSettings.json @@ -0,0 +1,23 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "https://localhost:7207;http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/docs/concepts/logging/samples/server/Tools/LoggingTools.cs b/docs/concepts/logging/samples/server/Tools/LoggingTools.cs new file mode 100644 index 000000000..33fa3c040 --- /dev/null +++ b/docs/concepts/logging/samples/server/Tools/LoggingTools.cs @@ -0,0 +1,45 @@ +using System.ComponentModel; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace Logging.Tools; + +[McpServerToolType] +public class LoggingTools +{ + [McpServerTool, Description("Demonstrates a tool that produces log messages")] + public static async Task LoggingTool( + RequestContext context, + int duration = 10, + int steps = 10) + { + var progressToken = context.Params?.ProgressToken; + var stepDuration = duration / steps; + + // + ILoggerProvider loggerProvider = context.Server.AsClientLoggerProvider(); + ILogger logger = loggerProvider.CreateLogger("LoggingTools"); + // + + for (int i = 1; i <= steps; i++) + { + await Task.Delay(stepDuration * 1000); + + try + { + logger.LogCritical("A critical log message"); + logger.LogError("An error log message"); + logger.LogWarning("A warning log message"); + logger.LogInformation("An informational log message"); + logger.LogDebug("A debug log message"); + logger.LogTrace("A trace log message"); + } + catch (Exception ex) + { + logger.LogError(ex, "An error occurred while logging messages"); + } + } + + return $"Long running tool completed. Duration: {duration} seconds. Steps: {steps}."; + } +} diff --git a/docs/concepts/logging/samples/server/appsettings.Development.json b/docs/concepts/logging/samples/server/appsettings.Development.json new file mode 100644 index 000000000..0c208ae91 --- /dev/null +++ b/docs/concepts/logging/samples/server/appsettings.Development.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/docs/concepts/logging/samples/server/appsettings.json b/docs/concepts/logging/samples/server/appsettings.json new file mode 100644 index 000000000..10f68b8c8 --- /dev/null +++ b/docs/concepts/logging/samples/server/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +} diff --git a/docs/concepts/progress/progress.md b/docs/concepts/progress/progress.md new file mode 100644 index 000000000..ccdf9f19c --- /dev/null +++ b/docs/concepts/progress/progress.md @@ -0,0 +1,69 @@ +--- +title: Progress +author: mikekistler +description: +uid: progress +--- + +## Progress + +The Model Context Protocol (MCP) supports [progress tracking] for long-running operations through notification messages. + +[progress tracking]: https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress + +Typically progress tracking is supported by server tools that perform operations that take a significant amount of time to complete, such as image generation or complex calculations. +However, progress tracking is defined in the MCP specification as a general feature that can be implemented for any request that is handled by either a server or a client. +This project illustrates the common case of a server tool that performs a long-running operation and sends progress updates to the client. + +### Server Implementation + +When processing a request, the server can use the [sendNotificationAsync] extension method of [IMcpServer] to send progress updates, +specifying `"notifications/progress"` as the notification method name. +The C# SDK registers an instance of [IMcpServer] with the dependency injection container, +so tools can simply add a parameter of type [IMcpServer] to their method signature to access it. +The parameters passed to [sendNotificationAsync] should be an instance of [ProgressNotificationParams], which includes the current progress, total steps, and an optional message. + +[sendNotificationAsync]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.McpEndpointExtensions.html#ModelContextProtocol_McpEndpointExtensions_SendNotificationAsync_ModelContextProtocol_IMcpEndpoint_System_String_System_Threading_CancellationToken_ +[IMcpServer]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Server.IMcpServer.html +[ProgressNotificationParams]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Protocol.ProgressNotificationParams.html + +The server must verify that the caller provided a `progressToken` in the request and include it in the call to [sendNotificationAsync]. The following example demonstrates how a server can send a progress notification: + +[!code-csharp[](samples/server/Tools/LongRunningTools.cs?name=snippet_SendProgress)] + +### Client Implementation + +Clients request progress updates by including a `progressToken` in the parameters of a request. +Note that servers are not required to support progress tracking, so clients should not depend on receiving progress updates. + +In the MCP C# SDK, clients can specify a `progressToken` in the request parameters when calling a tool method. +The client should also provide a notification handler to process "notifications/progress" notifications. +There are two way to do this. The first is to register a notification handler using the [RegisterNotificationHandler] method on the [IMcpClient] instance. A handler registered this way will receive all progress notifications sent by the server. + +[IMcpClient]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.Client.IMcpClient.html +[RegisterNotificationHandler]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.IMcpEndpoint.html#ModelContextProtocol_IMcpEndpoint_RegisterNotificationHandler_System_String_System_Func_ModelContextProtocol_Protocol_JsonRpcNotification_System_Threading_CancellationToken_System_Threading_Tasks_ValueTask__ + +```csharp +mcpClient.RegisterNotificationHandler(NotificationMethods.ProgressNotification, + (notification, cancellationToken) => + { + if (JsonSerializer.Deserialize(notification.Params) is { } pn && + pn.ProgressToken == progressToken) + { + // progress.Report(pn.Progress); + Console.WriteLine($"Tool progress: {pn.Progress.Progress} of {pn.Progress.Total} - {pn.Progress.Message}"); + } + return ValueTask.CompletedTask; + }).ConfigureAwait(false); +``` + +The second way is to pass a [Progress``] instance to the tool method. [Progress``] is a standard .NET type that provides a way to receive progress updates. +For the purposes of MCP progress notifications, `T` should be [ProgressNotificationValue]. +The MCP C# SDK will automatically handle progress notifications and report them through the [Progress``] instance. +This notification handler will only receive progress updates for the specific request that was made, +rather than all progress notifications from the server. + +[Progress``]: https://learn.microsoft.com/en-us/dotnet/api/system.progress-1 +[ProgressNotificationValue]: https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.ProgressNotificationValue.html + +[!code-csharp[](samples/client/Program.cs?name=snippet_ProgressHandler)] diff --git a/docs/concepts/progress/samples/client/Program.cs b/docs/concepts/progress/samples/client/Program.cs new file mode 100644 index 000000000..2a5f589de --- /dev/null +++ b/docs/concepts/progress/samples/client/Program.cs @@ -0,0 +1,52 @@ +using System.Text.Json; +using ModelContextProtocol; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; + +var endpoint = Environment.GetEnvironmentVariable("ENDPOINT") ?? "http://localhost:3001"; + +var clientTransport = new HttpClientTransport(new() +{ + Endpoint = new Uri(endpoint), + TransportMode = HttpTransportMode.StreamableHttp, +}); + +McpClientOptions options = new() +{ + ClientInfo = new() + { + Name = "ProgressClient", + Version = "1.0.0" + } +}; + +await using var mcpClient = await McpClient.CreateAsync(clientTransport, options); + +var tools = await mcpClient.ListToolsAsync(); +foreach (var tool in tools) +{ + Console.WriteLine($"Connected to server with tools: {tool.Name}"); +} + +Console.WriteLine($"Calling tool: {tools.First().Name}"); + +// +var progressHandler = new Progress(value => +{ + Console.WriteLine($"Tool progress: {value.Progress} of {value.Total} - {value.Message}"); +}); + +var result = await mcpClient.CallToolAsync(toolName: tools.First().Name, progress: progressHandler); +// + +foreach (var block in result.Content) +{ + if (block is TextContentBlock textBlock) + { + Console.WriteLine(textBlock.Text); + } + else + { + Console.WriteLine($"Received unexpected result content of type {block.GetType()}"); + } +} diff --git a/docs/concepts/progress/samples/client/ProgressClient.csproj b/docs/concepts/progress/samples/client/ProgressClient.csproj new file mode 100644 index 000000000..9f020005d --- /dev/null +++ b/docs/concepts/progress/samples/client/ProgressClient.csproj @@ -0,0 +1,16 @@ + + + + Exe + net9.0 + enable + enable + false + false + + + + + + + diff --git a/docs/concepts/progress/samples/server/Program.cs b/docs/concepts/progress/samples/server/Program.cs new file mode 100644 index 000000000..7216b2fe1 --- /dev/null +++ b/docs/concepts/progress/samples/server/Program.cs @@ -0,0 +1,22 @@ +using Progress.Tools; + +var builder = WebApplication.CreateBuilder(args); + +// Add services to the container. + +builder.Services.AddMcpServer() + .WithHttpTransport() + .WithTools(); + +builder.Logging.AddConsole(options => +{ + options.LogToStandardErrorThreshold = LogLevel.Information; +}); + +var app = builder.Build(); + +app.UseHttpsRedirection(); + +app.MapMcp(); + +app.Run(); diff --git a/docs/concepts/progress/samples/server/Progress.csproj b/docs/concepts/progress/samples/server/Progress.csproj new file mode 100644 index 000000000..f4998aa12 --- /dev/null +++ b/docs/concepts/progress/samples/server/Progress.csproj @@ -0,0 +1,15 @@ + + + + net9.0 + enable + enable + false + false + + + + + + + diff --git a/docs/concepts/progress/samples/server/Progress.http b/docs/concepts/progress/samples/server/Progress.http new file mode 100644 index 000000000..3b40db854 --- /dev/null +++ b/docs/concepts/progress/samples/server/Progress.http @@ -0,0 +1,18 @@ +@HostAddress = http://localhost:3001 + +POST {{HostAddress}}/ +Accept: application/json, text/event-stream +Content-Type: application/json +MCP-Protocol-Version: 2025-06-18 + +{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "_meta": { + "progressToken": "abc123" + }, + "name": "long_running_tool" + } +} diff --git a/docs/concepts/progress/samples/server/Properties/launchSettings.json b/docs/concepts/progress/samples/server/Properties/launchSettings.json new file mode 100644 index 000000000..f5b342d69 --- /dev/null +++ b/docs/concepts/progress/samples/server/Properties/launchSettings.json @@ -0,0 +1,23 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "https://localhost:7175;http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs b/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs new file mode 100644 index 000000000..7fcd1244a --- /dev/null +++ b/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs @@ -0,0 +1,44 @@ +using System.ComponentModel; +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace Progress.Tools; + +[McpServerToolType] +public class LongRunningTools +{ + [McpServerTool, Description("Demonstrates a long running tool with progress updates")] + public static async Task LongRunningTool( + McpServer server, + RequestContext context, + int duration = 10, + int steps = 5) + { + var progressToken = context.Params?.ProgressToken; + var stepDuration = duration / steps; + + for (int i = 1; i <= steps; i++) + { + await Task.Delay(stepDuration * 1000); + + // + if (progressToken is not null) + { + await server.SendNotificationAsync("notifications/progress", new ProgressNotificationParams + { + ProgressToken = progressToken.Value, + Progress = new ProgressNotificationValue + { + Progress = i, + Total = steps, + Message = $"Step {i} of {steps} completed.", + }, + }); + } + // + } + + return $"Long running tool completed. Duration: {duration} seconds. Steps: {steps}."; + } +} \ No newline at end of file diff --git a/docs/concepts/progress/samples/server/appsettings.Development.json b/docs/concepts/progress/samples/server/appsettings.Development.json new file mode 100644 index 000000000..0c208ae91 --- /dev/null +++ b/docs/concepts/progress/samples/server/appsettings.Development.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/docs/concepts/progress/samples/server/appsettings.json b/docs/concepts/progress/samples/server/appsettings.json new file mode 100644 index 000000000..10f68b8c8 --- /dev/null +++ b/docs/concepts/progress/samples/server/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +} diff --git a/docs/concepts/toc.yml b/docs/concepts/toc.yml new file mode 100644 index 000000000..97ffbd16e --- /dev/null +++ b/docs/concepts/toc.yml @@ -0,0 +1,19 @@ +items: +- name: Overview + href: index.md +- name: Base Protocol + items: + - name: Progress + uid: progress +- name: Client Features + items: + - name: Elicitation + uid: elicitation +- name: Server Features + items: + - name: Logging + uid: logging + - name: HTTP Context + uid: httpcontext + - name: Filters + uid: filters \ No newline at end of file diff --git a/docs/docfx.json b/docs/docfx.json index 6b4feb833..fe8a18d95 100644 --- a/docs/docfx.json +++ b/docs/docfx.json @@ -42,6 +42,7 @@ "_appLogoPath": "images/mcp.svg", "_appFaviconPath": "images/favicon.ico", "_enableSearch": true, + "_disableNextArticle": true, "pdf": false } } diff --git a/docs/toc.yml b/docs/toc.yml index f63a01348..350a2ae3b 100644 --- a/docs/toc.yml +++ b/docs/toc.yml @@ -1,5 +1,7 @@ items: -- name: API Docs +- name: Documentation + href: concepts/index.md +- name: API Reference href: api/ModelContextProtocol.yml - name: Github href: https://github.com/ModelContextProtocol/csharp-sdk \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/AspNetCoreMcpPerSessionTools.csproj b/samples/AspNetCoreMcpPerSessionTools/AspNetCoreMcpPerSessionTools.csproj new file mode 100644 index 000000000..23e95062c --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/AspNetCoreMcpPerSessionTools.csproj @@ -0,0 +1,20 @@ + + + + net9.0 + enable + enable + true + + + + + + + + + + + + + \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/Program.cs b/samples/AspNetCoreMcpPerSessionTools/Program.cs new file mode 100644 index 000000000..3484978ec --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/Program.cs @@ -0,0 +1,86 @@ +using AspNetCoreMcpPerSessionTools.Tools; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Reflection; + +var builder = WebApplication.CreateBuilder(args); + +// Create and populate the tool dictionary at startup +var toolDictionary = new ConcurrentDictionary(); +PopulateToolDictionary(toolDictionary); + +// Register all MCP server tools - they will be filtered per session based on route +builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + // Configure per-session options to filter tools based on route category + options.ConfigureSessionOptions = async (httpContext, mcpOptions, cancellationToken) => + { + // Determine tool category from route parameters + var toolCategory = httpContext.Request.RouteValues["toolCategory"]?.ToString()?.ToLower() ?? "all"; + + // Get pre-populated tools for the requested category + if (toolDictionary.TryGetValue(toolCategory, out var tools)) + { + mcpOptions.Capabilities = new(); + mcpOptions.Capabilities.Tools = new(); + var toolCollection = mcpOptions.ToolCollection = new(); + + foreach (var tool in tools) + { + toolCollection.Add(tool); + } + } + }; + }); + +var app = builder.Build(); + +// Map MCP with route parameter for tool category filtering +app.MapMcp("/{toolCategory?}"); + +app.Run(); + +// Helper method to populate the tool dictionary at startup +static void PopulateToolDictionary(ConcurrentDictionary toolDictionary) +{ + // Get tools for each category + var clockTools = GetToolsForType(); + var calculatorTools = GetToolsForType(); + var userInfoTools = GetToolsForType(); + McpServerTool[] allTools = [.. clockTools, + .. calculatorTools, + .. userInfoTools]; + + // Populate the dictionary with tools for each category + toolDictionary.TryAdd("clock", clockTools); + toolDictionary.TryAdd("calculator", calculatorTools); + toolDictionary.TryAdd("userinfo", userInfoTools); + toolDictionary.TryAdd("all", allTools); +} + +// Helper method to get tools for a specific type using reflection +static McpServerTool[] GetToolsForType<[System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers( + System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicMethods)] T>() +{ + var tools = new List(); + var toolType = typeof(T); + var methods = toolType.GetMethods(BindingFlags.Public | BindingFlags.Static) + .Where(m => m.GetCustomAttributes(typeof(McpServerToolAttribute), false).Any()); + + foreach (var method in methods) + { + try + { + var tool = McpServerTool.Create(method, target: null, new McpServerToolCreateOptions()); + tools.Add(tool); + } + catch (Exception ex) + { + // Log error but continue with other tools + Console.WriteLine($"Failed to add tool {toolType.Name}.{method.Name}: {ex.Message}"); + } + } + + return [.. tools]; +} \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/Properties/launchSettings.json b/samples/AspNetCoreMcpPerSessionTools/Properties/launchSettings.json new file mode 100644 index 000000000..da8208a11 --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/Properties/launchSettings.json @@ -0,0 +1,13 @@ +{ + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/README.md b/samples/AspNetCoreMcpPerSessionTools/README.md new file mode 100644 index 000000000..8e3665100 --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/README.md @@ -0,0 +1,113 @@ +# ASP.NET Core MCP Server with Per-Session Tool Filtering + +This sample demonstrates how to create an MCP (Model Context Protocol) server that provides different sets of tools based on route-based session configuration. This showcases the technique of using `ConfigureSessionOptions` to dynamically modify the `ToolCollection` based on route parameters for each MCP session. + +## Overview + +The sample demonstrates route-based tool filtering using the SDK's `ConfigureSessionOptions` callback. You could use any mechanism, routing is just one way to achieve this. The point of the sample is to show how an MCP server can dynamically adjust the available tools for each session based on arbitrary criteria, in this case, the URL route. + +## Route-Based Configuration + +The server uses route parameters to determine which tools to make available: + +- `GET /clock` - MCP server with only clock/time tools +- `GET /calculator` - MCP server with only calculation tools +- `GET /userinfo` - MCP server with only session/system info tools +- `GET /all` or `GET /` - MCP server with all tools (default) + +## Running the Sample + +1. Navigate to the sample directory: + ```bash + cd samples/AspNetCoreMcpPerSessionTools + ``` + +2. Run the server: + ```bash + dotnet run + ``` + +3. The server will start on `https://localhost:5001` (or the port shown in the console) + +## Testing Tool Categories + +### Testing Clock Tools +Connect your MCP client to: `https://localhost:5001/clock` +- Available tools: GetTime, GetDate, ConvertTimeZone + +### Testing Calculator Tools +Connect your MCP client to: `https://localhost:5001/calculator` +- Available tools: Calculate, CalculatePercentage, SquareRoot + +### Testing UserInfo Tools +Connect your MCP client to: `https://localhost:5001/userinfo` +- Available tools: GetUserInfo + +### Testing All Tools +Connect your MCP client to: `https://localhost:5001/all` or `https://localhost:5001/` +- Available tools: All tools from all categories + +## How It Works + +### 1. Tool Registration +All tools are registered during startup using the normal MCP tool registration: + +```csharp +builder.Services.AddMcpServer() + .WithTools() + .WithTools() + .WithTools(); +``` + +### 2. Route-Based Session Filtering +The key technique is using `ConfigureSessionOptions` to modify the tool collection per session based on the route: + +```csharp +.WithHttpTransport(options => +{ + options.ConfigureSessionOptions = async (httpContext, mcpOptions, cancellationToken) => + { + var toolCategory = GetToolCategoryFromRoute(httpContext); + var toolCollection = mcpOptions.Capabilities?.Tools?.ToolCollection; + + if (toolCollection != null) + { + // Clear all tools and add back only those for this category + toolCollection.Clear(); + + switch (toolCategory?.ToLower()) + { + case "clock": + AddToolsForType(toolCollection); + break; + case "calculator": + AddToolsForType(toolCollection); + break; + case "userinfo": + AddToolsForType(toolCollection); + break; + default: + // All tools for default/all category + AddToolsForType(toolCollection); + AddToolsForType(toolCollection); + AddToolsForType(toolCollection); + break; + } + } + }; +}) +``` + +### 3. Route Parameter Detection +The `GetToolCategoryFromRoute` method extracts the tool category from the URL route: + +```csharp +static string? GetToolCategoryFromRoute(HttpContext context) +{ + if (context.Request.RouteValues.TryGetValue("toolCategory", out var categoryObj) && categoryObj is string category) + { + return string.IsNullOrEmpty(category) ? "all" : category; + } + return "all"; // Default +} +``` diff --git a/samples/AspNetCoreMcpPerSessionTools/Tools/CalculatorTool.cs b/samples/AspNetCoreMcpPerSessionTools/Tools/CalculatorTool.cs new file mode 100644 index 000000000..c6d9f6216 --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/Tools/CalculatorTool.cs @@ -0,0 +1,81 @@ +using ModelContextProtocol.Server; +using System.ComponentModel; + +namespace AspNetCoreMcpPerSessionTools.Tools; + +/// +/// Calculator tools for mathematical operations +/// +[McpServerToolType] +public sealed class CalculatorTool +{ + [McpServerTool, Description("Performs basic arithmetic calculations (addition, subtraction, multiplication, division).")] + public static string Calculate([Description("Mathematical expression to evaluate (e.g., '5 + 3', '10 - 2', '4 * 6', '15 / 3')")] string expression) + { + try + { + // Simple calculator for demo purposes - supports basic operations + expression = expression.Trim(); + + if (expression.Contains("+")) + { + var parts = expression.Split('+'); + if (parts.Length == 2 && double.TryParse(parts[0].Trim(), out var a) && double.TryParse(parts[1].Trim(), out var b)) + { + return $"{expression} = {a + b}"; + } + } + else if (expression.Contains("-")) + { + var parts = expression.Split('-'); + if (parts.Length == 2 && double.TryParse(parts[0].Trim(), out var a) && double.TryParse(parts[1].Trim(), out var b)) + { + return $"{expression} = {a - b}"; + } + } + else if (expression.Contains("*")) + { + var parts = expression.Split('*'); + if (parts.Length == 2 && double.TryParse(parts[0].Trim(), out var a) && double.TryParse(parts[1].Trim(), out var b)) + { + return $"{expression} = {a * b}"; + } + } + else if (expression.Contains("/")) + { + var parts = expression.Split('/'); + if (parts.Length == 2 && double.TryParse(parts[0].Trim(), out var a) && double.TryParse(parts[1].Trim(), out var b)) + { + if (b == 0) + return "Error: Division by zero"; + return $"{expression} = {a / b}"; + } + } + + return $"Cannot evaluate expression: {expression}. Supported operations: +, -, *, / (e.g., '5 + 3')"; + } + catch (Exception ex) + { + return $"Error evaluating '{expression}': {ex.Message}"; + } + } + + [McpServerTool, Description("Calculates percentage of a number.")] + public static string CalculatePercentage( + [Description("The number to calculate percentage of")] double number, + [Description("The percentage value")] double percentage) + { + var result = (number * percentage) / 100; + return $"{percentage}% of {number} = {result}"; + } + + [McpServerTool, Description("Calculates the square root of a number.")] + public static string SquareRoot([Description("The number to find square root of")] double number) + { + if (number < 0) + return "Error: Cannot calculate square root of negative number"; + + var result = Math.Sqrt(number); + return $"√{number} = {result}"; + } +} \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/Tools/ClockTool.cs b/samples/AspNetCoreMcpPerSessionTools/Tools/ClockTool.cs new file mode 100644 index 000000000..d112de0f9 --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/Tools/ClockTool.cs @@ -0,0 +1,40 @@ +using ModelContextProtocol.Server; +using System.ComponentModel; + +namespace AspNetCoreMcpPerSessionTools.Tools; + +/// +/// Clock-related tools for time and date operations +/// +[McpServerToolType] +public sealed class ClockTool +{ + [McpServerTool, Description("Gets the current server time in various formats.")] + public static string GetTime() + { + return $"Current server time: {DateTime.Now:yyyy-MM-dd HH:mm:ss} UTC"; + } + + [McpServerTool, Description("Gets the current date in a specific format.")] + public static string GetDate([Description("Date format (e.g., 'yyyy-MM-dd', 'MM/dd/yyyy')")] string format = "yyyy-MM-dd") + { + try + { + return $"Current date: {DateTime.Now.ToString(format)}"; + } + catch (FormatException) + { + return $"Invalid format '{format}'. Using default: {DateTime.Now:yyyy-MM-dd}"; + } + } + + [McpServerTool, Description("Converts time between timezones.")] + public static string ConvertTimeZone( + [Description("Source timezone (e.g., 'UTC', 'EST')")] string fromTimeZone = "UTC", + [Description("Target timezone (e.g., 'PST', 'GMT')")] string toTimeZone = "PST") + { + // Simplified timezone conversion for demo purposes + var now = DateTime.Now; + return $"Time conversion from {fromTimeZone} to {toTimeZone}: {now:HH:mm:ss} (simulated)"; + } +} \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/Tools/UserInfoTool.cs b/samples/AspNetCoreMcpPerSessionTools/Tools/UserInfoTool.cs new file mode 100644 index 000000000..1fec18733 --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/Tools/UserInfoTool.cs @@ -0,0 +1,23 @@ +using ModelContextProtocol.Server; +using System.ComponentModel; + +namespace AspNetCoreMcpPerSessionTools.Tools; + +/// +/// User information tools +/// +[McpServerToolType] +public sealed class UserInfoTool +{ + [McpServerTool, Description("Gets information about the current user in the MCP session.")] + public static string GetUserInfo() + { + // Dummy user information for demonstration purposes + return $"User Information:\n" + + $"- User ID: {Guid.NewGuid():N}[..8] (simulated)\n" + + $"- Username: User{new Random().Next(1, 1000)}\n" + + $"- Roles: User, Guest\n" + + $"- Last Login: {DateTime.Now.AddMinutes(-new Random().Next(1, 60)):HH:mm:ss}\n" + + $"- Account Status: Active"; + } +} \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/appsettings.Development.json b/samples/AspNetCoreMcpPerSessionTools/appsettings.Development.json new file mode 100644 index 000000000..f999bc20e --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/appsettings.Development.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Debug", + "System": "Information", + "Microsoft": "Information" + } + } +} \ No newline at end of file diff --git a/samples/AspNetCoreMcpPerSessionTools/appsettings.json b/samples/AspNetCoreMcpPerSessionTools/appsettings.json new file mode 100644 index 000000000..88c89fa7d --- /dev/null +++ b/samples/AspNetCoreMcpPerSessionTools/appsettings.json @@ -0,0 +1,10 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning", + "AspNetCoreMcpPerSessionTools": "Debug" + } + }, + "AllowedHosts": "*" +} \ No newline at end of file diff --git a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj b/samples/AspNetCoreMcpServer/AspNetCoreMcpServer.csproj similarity index 100% rename from samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj rename to samples/AspNetCoreMcpServer/AspNetCoreMcpServer.csproj diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreMcpServer/Program.cs similarity index 61% rename from samples/AspNetCoreSseServer/Program.cs rename to samples/AspNetCoreMcpServer/Program.cs index c21b328f6..96f89bffa 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreMcpServer/Program.cs @@ -1,14 +1,16 @@ using OpenTelemetry; using OpenTelemetry.Metrics; using OpenTelemetry.Trace; -using TestServerWithHosting.Tools; -using TestServerWithHosting.Resources; +using AspNetCoreMcpServer.Tools; +using AspNetCoreMcpServer.Resources; +using System.Net.Http.Headers; var builder = WebApplication.CreateBuilder(args); builder.Services.AddMcpServer() .WithHttpTransport() .WithTools() .WithTools() + .WithTools() .WithResources(); builder.Services.AddOpenTelemetry() @@ -21,6 +23,13 @@ .WithLogging() .UseOtlpExporter(); +// Configure HttpClientFactory for weather.gov API +builder.Services.AddHttpClient("WeatherApi", client => +{ + client.BaseAddress = new Uri("https://api.weather.gov"); + client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); +}); + var app = builder.Build(); app.MapMcp(); diff --git a/samples/AspNetCoreMcpServer/Properties/launchSettings.json b/samples/AspNetCoreMcpServer/Properties/launchSettings.json new file mode 100644 index 000000000..6670029e1 --- /dev/null +++ b/samples/AspNetCoreMcpServer/Properties/launchSettings.json @@ -0,0 +1,23 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "applicationUrl": "http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "OTEL_SERVICE_NAME": "aspnetcore-mcp-server" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "applicationUrl": "https://localhost:7133;http://localhost:3001", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development", + "OTEL_SERVICE_NAME": "aspnetcore-mcp-server" + } + } + } +} diff --git a/samples/AspNetCoreSseServer/Resources/SimpleResourceType.cs b/samples/AspNetCoreMcpServer/Resources/SimpleResourceType.cs similarity index 76% rename from samples/AspNetCoreSseServer/Resources/SimpleResourceType.cs rename to samples/AspNetCoreMcpServer/Resources/SimpleResourceType.cs index e73ce133c..aaf6d11a5 100644 --- a/samples/AspNetCoreSseServer/Resources/SimpleResourceType.cs +++ b/samples/AspNetCoreMcpServer/Resources/SimpleResourceType.cs @@ -1,8 +1,7 @@ -using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.ComponentModel; -namespace TestServerWithHosting.Resources; +namespace AspNetCoreMcpServer.Resources; [McpServerResourceType] public class SimpleResourceType diff --git a/samples/AspNetCoreSseServer/Tools/EchoTool.cs b/samples/AspNetCoreMcpServer/Tools/EchoTool.cs similarity index 88% rename from samples/AspNetCoreSseServer/Tools/EchoTool.cs rename to samples/AspNetCoreMcpServer/Tools/EchoTool.cs index 7913b73e4..a9dc0a665 100644 --- a/samples/AspNetCoreSseServer/Tools/EchoTool.cs +++ b/samples/AspNetCoreMcpServer/Tools/EchoTool.cs @@ -1,7 +1,7 @@ using ModelContextProtocol.Server; using System.ComponentModel; -namespace TestServerWithHosting.Tools; +namespace AspNetCoreMcpServer.Tools; [McpServerToolType] public sealed class EchoTool diff --git a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs similarity index 93% rename from samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs rename to samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs index 247619dbb..e69477452 100644 --- a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs +++ b/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs @@ -2,7 +2,7 @@ using ModelContextProtocol.Server; using System.ComponentModel; -namespace TestServerWithHosting.Tools; +namespace AspNetCoreMcpServer.Tools; /// /// This tool uses dependency injection and async method @@ -12,7 +12,7 @@ public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer thisServer, + McpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/samples/AspNetCoreMcpServer/Tools/WeatherTools.cs b/samples/AspNetCoreMcpServer/Tools/WeatherTools.cs new file mode 100644 index 000000000..b4e3a7414 --- /dev/null +++ b/samples/AspNetCoreMcpServer/Tools/WeatherTools.cs @@ -0,0 +1,73 @@ +using ModelContextProtocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Globalization; +using System.Text.Json; + +namespace AspNetCoreMcpServer.Tools; + +[McpServerToolType] +public sealed class WeatherTools +{ + private readonly IHttpClientFactory _httpClientFactory; + + public WeatherTools(IHttpClientFactory httpClientFactory) + { + _httpClientFactory = httpClientFactory; + } + + [McpServerTool, Description("Get weather alerts for a US state.")] + public async Task GetAlerts( + [Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + using var responseStream = await client.GetStreamAsync($"/alerts/active/area/{state}"); + using var jsonDocument = await JsonDocument.ParseAsync(responseStream) + ?? throw new McpException("No JSON returned from alerts endpoint"); + + var alerts = jsonDocument.RootElement.GetProperty("features").EnumerateArray(); + + if (!alerts.Any()) + { + return "No active alerts for this state."; + } + + return string.Join("\n--\n", alerts.Select(alert => + { + JsonElement properties = alert.GetProperty("properties"); + return $""" + Event: {properties.GetProperty("event").GetString()} + Area: {properties.GetProperty("areaDesc").GetString()} + Severity: {properties.GetProperty("severity").GetString()} + Description: {properties.GetProperty("description").GetString()} + Instruction: {properties.GetProperty("instruction").GetString()} + """; + })); + } + + [McpServerTool, Description("Get weather forecast for a location.")] + public async Task GetForecast( + [Description("Latitude of the location.")] double latitude, + [Description("Longitude of the location.")] double longitude) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); + + using var locationResponseStream = await client.GetStreamAsync(pointUrl); + using var locationDocument = await JsonDocument.ParseAsync(locationResponseStream); + var forecastUrl = locationDocument?.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + + using var forecastResponseStream = await client.GetStreamAsync(forecastUrl); + using var forecastDocument = await JsonDocument.ParseAsync(forecastResponseStream); + var periods = forecastDocument?.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray() + ?? throw new McpException("No JSON returned from forecast endpoint"); + + return string.Join("\n---\n", periods.Select(period => $""" + {period.GetProperty("name").GetString()} + Temperature: {period.GetProperty("temperature").GetInt32()}°F + Wind: {period.GetProperty("windSpeed").GetString()} {period.GetProperty("windDirection").GetString()} + Forecast: {period.GetProperty("detailedForecast").GetString()} + """)); + } +} diff --git a/samples/AspNetCoreMcpServer/appsettings.Development.json b/samples/AspNetCoreMcpServer/appsettings.Development.json new file mode 100644 index 000000000..0c208ae91 --- /dev/null +++ b/samples/AspNetCoreMcpServer/appsettings.Development.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/samples/AspNetCoreMcpServer/appsettings.json b/samples/AspNetCoreMcpServer/appsettings.json new file mode 100644 index 000000000..10f68b8c8 --- /dev/null +++ b/samples/AspNetCoreMcpServer/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +} diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index ba597ae8a..c6fca0493 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -32,7 +32,7 @@ .UseOpenTelemetry(loggerFactory: loggerFactory, configure: o => o.EnableSensitiveData = true) .Build(); -var mcpClient = await McpClientFactory.CreateAsync( +var mcpClient = await McpClient.CreateAsync( new StdioClientTransport(new() { Command = "npx", @@ -41,7 +41,10 @@ }), clientOptions: new() { - Capabilities = new() { Sampling = new() { SamplingHandler = samplingClient.CreateSamplingHandler() } }, + Handlers = new() + { + SamplingHandler = samplingClient.CreateSamplingHandler() + } }, loggerFactory: loggerFactory); diff --git a/samples/EverythingServer/LoggingUpdateMessageSender.cs b/samples/EverythingServer/LoggingUpdateMessageSender.cs index 844aa70d8..5f524ad8a 100644 --- a/samples/EverythingServer/LoggingUpdateMessageSender.cs +++ b/samples/EverythingServer/LoggingUpdateMessageSender.cs @@ -5,7 +5,7 @@ namespace EverythingServer; -public class LoggingUpdateMessageSender(IMcpServer server, Func getMinLevel) : BackgroundService +public class LoggingUpdateMessageSender(McpServer server, Func getMinLevel) : BackgroundService { readonly Dictionary _loggingLevelMap = new() { diff --git a/samples/EverythingServer/SubscriptionMessageSender.cs b/samples/EverythingServer/SubscriptionMessageSender.cs index 774d98523..b071965dc 100644 --- a/samples/EverythingServer/SubscriptionMessageSender.cs +++ b/samples/EverythingServer/SubscriptionMessageSender.cs @@ -2,7 +2,7 @@ using ModelContextProtocol; using ModelContextProtocol.Server; -internal class SubscriptionMessageSender(IMcpServer server, HashSet subscriptions) : BackgroundService +internal class SubscriptionMessageSender(McpServer server, HashSet subscriptions) : BackgroundService { protected override async Task ExecuteAsync(CancellationToken stoppingToken) { diff --git a/samples/EverythingServer/Tools/LongRunningTool.cs b/samples/EverythingServer/Tools/LongRunningTool.cs index 27f6ac20f..405b5e823 100644 --- a/samples/EverythingServer/Tools/LongRunningTool.cs +++ b/samples/EverythingServer/Tools/LongRunningTool.cs @@ -10,7 +10,7 @@ public class LongRunningTool { [McpServerTool(Name = "longRunningOperation"), Description("Demonstrates a long running operation with progress updates")] public static async Task LongRunningOperation( - IMcpServer server, + McpServer server, RequestContext context, int duration = 10, int steps = 5) diff --git a/samples/EverythingServer/Tools/SampleLlmTool.cs b/samples/EverythingServer/Tools/SampleLlmTool.cs index a58675c30..6bbe6e51d 100644 --- a/samples/EverythingServer/Tools/SampleLlmTool.cs +++ b/samples/EverythingServer/Tools/SampleLlmTool.cs @@ -9,7 +9,7 @@ public class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer server, + McpServer server, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/samples/InMemoryTransport/InMemoryTransport.csproj b/samples/InMemoryTransport/InMemoryTransport.csproj new file mode 100644 index 000000000..7c1161ce9 --- /dev/null +++ b/samples/InMemoryTransport/InMemoryTransport.csproj @@ -0,0 +1,15 @@ + + + + Exe + net8.0 + enable + enable + true + + + + + + + diff --git a/samples/InMemoryTransport/Program.cs b/samples/InMemoryTransport/Program.cs new file mode 100644 index 000000000..dbffaa34d --- /dev/null +++ b/samples/InMemoryTransport/Program.cs @@ -0,0 +1,34 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.IO.Pipelines; + +Pipe clientToServerPipe = new(), serverToClientPipe = new(); + +// Create a server using a stream-based transport over an in-memory pipe. +await using McpServer server = McpServer.Create( + new StreamServerTransport(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()), + new McpServerOptions() + { + ToolCollection = [McpServerTool.Create((string arg) => $"Echo: {arg}", new() { Name = "Echo" })] + }); +_ = server.RunAsync(); + +// Connect a client using a stream-based transport over the same in-memory pipe. +await using McpClient client = await McpClient.CreateAsync( + new StreamClientTransport(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream())); + +// List all tools. +var tools = await client.ListToolsAsync(); +foreach (var tool in tools) +{ + Console.WriteLine($"Tool Name: {tool.Name}"); +} +Console.WriteLine(); + +// Invoke a tool. +var echo = tools.First(t => t.Name == "Echo"); +Console.WriteLine(await echo.InvokeAsync(new() +{ + ["arg"] = "Hello World" +})); \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/HttpClientExt.cs b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs deleted file mode 100644 index f7b2b5499..000000000 --- a/samples/ProtectedMCPServer/Tools/HttpClientExt.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System.Text.Json; - -namespace ModelContextProtocol; - -internal static class HttpClientExt -{ - public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri) - { - using var response = await client.GetAsync(requestUri); - response.EnsureSuccessStatusCode(); - return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); - } -} \ No newline at end of file diff --git a/samples/ProtectedMCPClient/Program.cs b/samples/ProtectedMcpClient/Program.cs similarity index 94% rename from samples/ProtectedMCPClient/Program.cs rename to samples/ProtectedMcpClient/Program.cs index 516227b37..9dc2410ea 100644 --- a/samples/ProtectedMCPClient/Program.cs +++ b/samples/ProtectedMcpClient/Program.cs @@ -25,19 +25,22 @@ builder.AddConsole(); }); -var transport = new SseClientTransport(new() +var transport = new HttpClientTransport(new() { Endpoint = new Uri(serverUrl), Name = "Secure Weather Client", OAuth = new() { - ClientName = "ProtectedMcpClient", RedirectUri = new Uri("http://localhost:1179/callback"), AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + DynamicClientRegistration = new() + { + ClientName = "ProtectedMcpClient", + }, } }, httpClient, consoleLoggerFactory); -var client = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); +var client = await McpClient.CreateAsync(transport, loggerFactory: consoleLoggerFactory); var tools = await client.ListToolsAsync(); if (tools.Count == 0) diff --git a/samples/ProtectedMCPClient/ProtectedMCPClient.csproj b/samples/ProtectedMcpClient/ProtectedMcpClient.csproj similarity index 100% rename from samples/ProtectedMCPClient/ProtectedMCPClient.csproj rename to samples/ProtectedMcpClient/ProtectedMcpClient.csproj diff --git a/samples/ProtectedMCPClient/README.md b/samples/ProtectedMcpClient/README.md similarity index 92% rename from samples/ProtectedMCPClient/README.md rename to samples/ProtectedMcpClient/README.md index 977331a04..81ae67cee 100644 --- a/samples/ProtectedMCPClient/README.md +++ b/samples/ProtectedMcpClient/README.md @@ -14,7 +14,7 @@ The Protected MCP Client sample shows how to: - .NET 9.0 or later - A running TestOAuthServer (for OAuth authentication) -- A running ProtectedMCPServer (for MCP services) +- A running ProtectedMcpServer (for MCP services) ## Setup and Running @@ -31,10 +31,10 @@ The OAuth server will start at `https://localhost:7029` ### Step 2: Start the Protected MCP Server -Next, start the ProtectedMCPServer which provides the weather tools: +Next, start the ProtectedMcpServer which provides the weather tools: ```bash -cd samples\ProtectedMCPServer +cd samples\ProtectedMcpServer dotnet run ``` @@ -45,7 +45,7 @@ The protected server will start at `http://localhost:7071` Finally, run this client: ```bash -cd samples\ProtectedMCPClient +cd samples\ProtectedMcpClient dotnet run ``` @@ -90,4 +90,4 @@ Once authenticated, the client can access weather tools including: ## Key Files - `Program.cs`: Main client application with OAuth flow implementation -- `ProtectedMCPClient.csproj`: Project file with dependencies \ No newline at end of file +- `ProtectedMcpClient.csproj`: Project file with dependencies \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Program.cs b/samples/ProtectedMcpServer/Program.cs similarity index 99% rename from samples/ProtectedMCPServer/Program.cs rename to samples/ProtectedMcpServer/Program.cs index ef70fe731..a36e0367f 100644 --- a/samples/ProtectedMCPServer/Program.cs +++ b/samples/ProtectedMcpServer/Program.cs @@ -1,7 +1,7 @@ using Microsoft.AspNetCore.Authentication.JwtBearer; using Microsoft.IdentityModel.Tokens; using ModelContextProtocol.AspNetCore.Authentication; -using ProtectedMCPServer.Tools; +using ProtectedMcpServer.Tools; using System.Net.Http.Headers; using System.Security.Claims; diff --git a/samples/ProtectedMCPServer/Properties/launchSettings.json b/samples/ProtectedMcpServer/Properties/launchSettings.json similarity index 89% rename from samples/ProtectedMCPServer/Properties/launchSettings.json rename to samples/ProtectedMcpServer/Properties/launchSettings.json index 31b04db83..dbc9a1147 100644 --- a/samples/ProtectedMCPServer/Properties/launchSettings.json +++ b/samples/ProtectedMcpServer/Properties/launchSettings.json @@ -1,6 +1,6 @@ { "profiles": { - "ProtectedMCPServer": { + "ProtectedMcpServer": { "commandName": "Project", "launchBrowser": true, "environmentVariables": { diff --git a/samples/ProtectedMCPServer/ProtectedMCPServer.csproj b/samples/ProtectedMcpServer/ProtectedMcpServer.csproj similarity index 100% rename from samples/ProtectedMCPServer/ProtectedMCPServer.csproj rename to samples/ProtectedMcpServer/ProtectedMcpServer.csproj diff --git a/samples/ProtectedMCPServer/README.md b/samples/ProtectedMcpServer/README.md similarity index 96% rename from samples/ProtectedMCPServer/README.md rename to samples/ProtectedMcpServer/README.md index f0ac708a0..ecbfee633 100644 --- a/samples/ProtectedMCPServer/README.md +++ b/samples/ProtectedMcpServer/README.md @@ -34,7 +34,7 @@ The OAuth server will start at `https://localhost:7029` Run this protected server: ```bash -cd samples\ProtectedMCPServer +cd samples\ProtectedMcpServer dotnet run ``` @@ -42,10 +42,10 @@ The protected server will start at `http://localhost:7071` ### Step 3: Test with Protected MCP Client -You can test the server using the ProtectedMCPClient sample: +You can test the server using the ProtectedMcpClient sample: ```bash -cd samples\ProtectedMCPClient +cd samples\ProtectedMcpClient dotnet run ``` diff --git a/samples/ProtectedMCPServer/Tools/WeatherTools.cs b/samples/ProtectedMcpServer/Tools/WeatherTools.cs similarity index 70% rename from samples/ProtectedMCPServer/Tools/WeatherTools.cs rename to samples/ProtectedMcpServer/Tools/WeatherTools.cs index 7c8c08514..94cc03892 100644 --- a/samples/ProtectedMCPServer/Tools/WeatherTools.cs +++ b/samples/ProtectedMcpServer/Tools/WeatherTools.cs @@ -4,7 +4,7 @@ using System.Globalization; using System.Text.Json; -namespace ProtectedMCPServer.Tools; +namespace ProtectedMcpServer.Tools; [McpServerToolType] public sealed class WeatherTools @@ -21,9 +21,10 @@ public async Task GetAlerts( [Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state) { var client = _httpClientFactory.CreateClient("WeatherApi"); - using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}"); - var jsonElement = jsonDocument.RootElement; - var alerts = jsonElement.GetProperty("features").EnumerateArray(); + using var jsonDocument = await client.GetFromJsonAsync($"/alerts/active/area/{state}") + ?? throw new McpException("No JSON returned from alerts endpoint"); + + var alerts = jsonDocument.RootElement.GetProperty("features").EnumerateArray(); if (!alerts.Any()) { @@ -50,12 +51,14 @@ public async Task GetForecast( { var client = _httpClientFactory.CreateClient("WeatherApi"); var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); - using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); - var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() - ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); - using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); - var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); + using var locationDocument = await client.GetFromJsonAsync(pointUrl); + var forecastUrl = locationDocument?.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + + using var forecastDocument = await client.GetFromJsonAsync(forecastUrl); + var periods = forecastDocument?.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray() + ?? throw new McpException("No JSON returned from forecast endpoint"); return string.Join("\n---\n", periods.Select(period => $""" {period.GetProperty("name").GetString()} diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs index 423af627f..cd1c4c60a 100644 --- a/samples/QuickstartClient/Program.cs +++ b/samples/QuickstartClient/Program.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Client; using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Text; var builder = Host.CreateApplicationBuilder(args); @@ -12,16 +13,27 @@ .AddEnvironmentVariables() .AddUserSecrets(); +IClientTransport clientTransport; var (command, arguments) = GetCommandAndArguments(args); -var clientTransport = new StdioClientTransport(new() +if (command == "http") { - Name = "Demo Server", - Command = command, - Arguments = arguments, -}); - -await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport); + // make sure AspNetCoreMcpServer is running + clientTransport = new HttpClientTransport(new() + { + Endpoint = new Uri("http://localhost:3001") + }); +} +else +{ + clientTransport = new StdioClientTransport(new() + { + Name = "Demo Server", + Command = command, + Arguments = arguments, + }); +} +await using var mcpClient = await McpClient.CreateAsync(clientTransport!); var tools = await mcpClient.ListToolsAsync(); foreach (var tool in tools) @@ -46,8 +58,11 @@ Console.WriteLine("MCP Client Started!"); Console.ResetColor(); +var messages = new List(); +var sb = new StringBuilder(); + PromptForInput(); -while(Console.ReadLine() is string query && !"exit".Equals(query, StringComparison.OrdinalIgnoreCase)) +while (Console.ReadLine() is string query && !"exit".Equals(query, StringComparison.OrdinalIgnoreCase)) { if (string.IsNullOrWhiteSpace(query)) { @@ -55,11 +70,17 @@ continue; } - await foreach (var message in anthropicClient.GetStreamingResponseAsync(query, options)) + messages.Add(new ChatMessage(ChatRole.User, query)); + await foreach (var message in anthropicClient.GetStreamingResponseAsync(messages, options)) { Console.Write(message); + sb.Append(message.ToString()); } + Console.WriteLine(); + sb.AppendLine(); + messages.Add(new ChatMessage(ChatRole.Assistant, sb.ToString())); + sb.Clear(); PromptForInput(); } @@ -79,15 +100,16 @@ static void PromptForInput() /// /// This method uses the file extension of the first argument to determine the command, if it's py, it'll run python, /// if it's js, it'll run node, if it's a directory or a csproj file, it'll run dotnet. -/// +/// /// If no arguments are provided, it defaults to running the QuickstartWeatherServer project from the current repo. -/// +/// /// This method would only be required if you're creating a generic client, such as we use for the quickstart. /// static (string command, string[] arguments) GetCommandAndArguments(string[] args) { return args switch { + [var mode] when mode.Equals("http", StringComparison.OrdinalIgnoreCase) => ("http", args), [var script] when script.EndsWith(".py") => ("python", args), [var script] when script.EndsWith(".js") => ("node", args), [var script] when Directory.Exists(script) || (File.Exists(script) && script.EndsWith(".csproj")) => ("dotnet", ["run", "--project", script]), diff --git a/samples/QuickstartWeatherServer/Program.cs b/samples/QuickstartWeatherServer/Program.cs index 4e6216ee4..9bc050b54 100644 --- a/samples/QuickstartWeatherServer/Program.cs +++ b/samples/QuickstartWeatherServer/Program.cs @@ -15,11 +15,8 @@ options.LogToStandardErrorThreshold = LogLevel.Trace; }); -builder.Services.AddSingleton(_ => -{ - var client = new HttpClient { BaseAddress = new Uri("https://api.weather.gov") }; - client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); - return client; -}); +using var httpClient = new HttpClient { BaseAddress = new Uri("https://api.weather.gov") }; +httpClient.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); +builder.Services.AddSingleton(httpClient); await builder.Build().RunAsync(); diff --git a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs index e02d4c327..61dc0a0ee 100644 --- a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs +++ b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs @@ -43,9 +43,9 @@ public static async Task GetForecast( [Description("Longitude of the location.")] double longitude) { var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); - using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); - var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() - ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + using var locationDocument = await client.ReadJsonDocumentAsync(pointUrl); + var forecastUrl = locationDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new McpException($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index a096f9301..2c96b8c35 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -12,7 +12,7 @@ public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer thisServer, + McpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/src/Common/CancellableStreamReader/ValueStringBuilder.cs b/src/Common/CancellableStreamReader/ValueStringBuilder.cs index 27bea693e..9f1dedcd5 100644 --- a/src/Common/CancellableStreamReader/ValueStringBuilder.cs +++ b/src/Common/CancellableStreamReader/ValueStringBuilder.cs @@ -8,310 +8,309 @@ #nullable enable -namespace System.Text +namespace System.Text; + +internal ref partial struct ValueStringBuilder { - internal ref partial struct ValueStringBuilder + private char[]? _arrayToReturnToPool; + private Span _chars; + private int _pos; + + public ValueStringBuilder(Span initialBuffer) { - private char[]? _arrayToReturnToPool; - private Span _chars; - private int _pos; + _arrayToReturnToPool = null; + _chars = initialBuffer; + _pos = 0; + } - public ValueStringBuilder(Span initialBuffer) + public ValueStringBuilder(int initialCapacity) + { + _arrayToReturnToPool = ArrayPool.Shared.Rent(initialCapacity); + _chars = _arrayToReturnToPool; + _pos = 0; + } + + public int Length + { + get => _pos; + set { - _arrayToReturnToPool = null; - _chars = initialBuffer; - _pos = 0; + Debug.Assert(value >= 0); + Debug.Assert(value <= _chars.Length); + _pos = value; } + } + + public int Capacity => _chars.Length; - public ValueStringBuilder(int initialCapacity) + public void EnsureCapacity(int capacity) + { + // This is not expected to be called this with negative capacity + Debug.Assert(capacity >= 0); + + // If the caller has a bug and calls this with negative capacity, make sure to call Grow to throw an exception. + if ((uint)capacity > (uint)_chars.Length) + Grow(capacity - _pos); + } + + /// + /// Get a pinnable reference to the builder. + /// Does not ensure there is a null char after + /// This overload is pattern matched in the C# 7.3+ compiler so you can omit + /// the explicit method call, and write eg "fixed (char* c = builder)" + /// + public ref char GetPinnableReference() + { + return ref MemoryMarshal.GetReference(_chars); + } + + /// + /// Get a pinnable reference to the builder. + /// + /// Ensures that the builder has a null char after + public ref char GetPinnableReference(bool terminate) + { + if (terminate) { - _arrayToReturnToPool = ArrayPool.Shared.Rent(initialCapacity); - _chars = _arrayToReturnToPool; - _pos = 0; + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; } + return ref MemoryMarshal.GetReference(_chars); + } - public int Length + public ref char this[int index] + { + get { - get => _pos; - set - { - Debug.Assert(value >= 0); - Debug.Assert(value <= _chars.Length); - _pos = value; - } + Debug.Assert(index < _pos); + return ref _chars[index]; } + } - public int Capacity => _chars.Length; + public override string ToString() + { + string s = _chars.Slice(0, _pos).ToString(); + Dispose(); + return s; + } - public void EnsureCapacity(int capacity) - { - // This is not expected to be called this with negative capacity - Debug.Assert(capacity >= 0); + /// Returns the underlying storage of the builder. + public Span RawChars => _chars; - // If the caller has a bug and calls this with negative capacity, make sure to call Grow to throw an exception. - if ((uint)capacity > (uint)_chars.Length) - Grow(capacity - _pos); + /// + /// Returns a span around the contents of the builder. + /// + /// Ensures that the builder has a null char after + public ReadOnlySpan AsSpan(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; } + return _chars.Slice(0, _pos); + } + + public ReadOnlySpan AsSpan() => _chars.Slice(0, _pos); + public ReadOnlySpan AsSpan(int start) => _chars.Slice(start, _pos - start); + public ReadOnlySpan AsSpan(int start, int length) => _chars.Slice(start, length); - /// - /// Get a pinnable reference to the builder. - /// Does not ensure there is a null char after - /// This overload is pattern matched in the C# 7.3+ compiler so you can omit - /// the explicit method call, and write eg "fixed (char* c = builder)" - /// - public ref char GetPinnableReference() + public bool TryCopyTo(Span destination, out int charsWritten) + { + if (_chars.Slice(0, _pos).TryCopyTo(destination)) { - return ref MemoryMarshal.GetReference(_chars); + charsWritten = _pos; + Dispose(); + return true; } - - /// - /// Get a pinnable reference to the builder. - /// - /// Ensures that the builder has a null char after - public ref char GetPinnableReference(bool terminate) + else { - if (terminate) - { - EnsureCapacity(Length + 1); - _chars[Length] = '\0'; - } - return ref MemoryMarshal.GetReference(_chars); + charsWritten = 0; + Dispose(); + return false; } + } - public ref char this[int index] + public void Insert(int index, char value, int count) + { + if (_pos > _chars.Length - count) { - get - { - Debug.Assert(index < _pos); - return ref _chars[index]; - } + Grow(count); } - public override string ToString() + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + _chars.Slice(index, count).Fill(value); + _pos += count; + } + + public void Insert(int index, string? s) + { + if (s == null) { - string s = _chars.Slice(0, _pos).ToString(); - Dispose(); - return s; + return; } - /// Returns the underlying storage of the builder. - public Span RawChars => _chars; + int count = s.Length; - /// - /// Returns a span around the contents of the builder. - /// - /// Ensures that the builder has a null char after - public ReadOnlySpan AsSpan(bool terminate) + if (_pos > (_chars.Length - count)) { - if (terminate) - { - EnsureCapacity(Length + 1); - _chars[Length] = '\0'; - } - return _chars.Slice(0, _pos); + Grow(count); } - public ReadOnlySpan AsSpan() => _chars.Slice(0, _pos); - public ReadOnlySpan AsSpan(int start) => _chars.Slice(start, _pos - start); - public ReadOnlySpan AsSpan(int start, int length) => _chars.Slice(start, length); + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + s +#if !NET + .AsSpan() +#endif + .CopyTo(_chars.Slice(index)); + _pos += count; + } - public bool TryCopyTo(Span destination, out int charsWritten) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(char c) + { + int pos = _pos; + Span chars = _chars; + if ((uint)pos < (uint)chars.Length) { - if (_chars.Slice(0, _pos).TryCopyTo(destination)) - { - charsWritten = _pos; - Dispose(); - return true; - } - else - { - charsWritten = 0; - Dispose(); - return false; - } + chars[pos] = c; + _pos = pos + 1; } - - public void Insert(int index, char value, int count) + else { - if (_pos > _chars.Length - count) - { - Grow(count); - } - - int remaining = _pos - index; - _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); - _chars.Slice(index, count).Fill(value); - _pos += count; + GrowAndAppend(c); } + } - public void Insert(int index, string? s) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(string? s) + { + if (s == null) { - if (s == null) - { - return; - } - - int count = s.Length; - - if (_pos > (_chars.Length - count)) - { - Grow(count); - } - - int remaining = _pos - index; - _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); - s -#if !NET - .AsSpan() -#endif - .CopyTo(_chars.Slice(index)); - _pos += count; + return; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void Append(char c) + int pos = _pos; + if (s.Length == 1 && (uint)pos < (uint)_chars.Length) // very common case, e.g. appending strings from NumberFormatInfo like separators, percent symbols, etc. { - int pos = _pos; - Span chars = _chars; - if ((uint)pos < (uint)chars.Length) - { - chars[pos] = c; - _pos = pos + 1; - } - else - { - GrowAndAppend(c); - } + _chars[pos] = s[0]; + _pos = pos + 1; } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void Append(string? s) + else { - if (s == null) - { - return; - } - - int pos = _pos; - if (s.Length == 1 && (uint)pos < (uint)_chars.Length) // very common case, e.g. appending strings from NumberFormatInfo like separators, percent symbols, etc. - { - _chars[pos] = s[0]; - _pos = pos + 1; - } - else - { - AppendSlow(s); - } + AppendSlow(s); } + } - private void AppendSlow(string s) + private void AppendSlow(string s) + { + int pos = _pos; + if (pos > _chars.Length - s.Length) { - int pos = _pos; - if (pos > _chars.Length - s.Length) - { - Grow(s.Length); - } + Grow(s.Length); + } - s + s #if !NET - .AsSpan() + .AsSpan() #endif - .CopyTo(_chars.Slice(pos)); - _pos += s.Length; - } + .CopyTo(_chars.Slice(pos)); + _pos += s.Length; + } - public void Append(char c, int count) + public void Append(char c, int count) + { + if (_pos > _chars.Length - count) { - if (_pos > _chars.Length - count) - { - Grow(count); - } - - Span dst = _chars.Slice(_pos, count); - for (int i = 0; i < dst.Length; i++) - { - dst[i] = c; - } - _pos += count; + Grow(count); } - public void Append(scoped ReadOnlySpan value) + Span dst = _chars.Slice(_pos, count); + for (int i = 0; i < dst.Length; i++) { - int pos = _pos; - if (pos > _chars.Length - value.Length) - { - Grow(value.Length); - } - - value.CopyTo(_chars.Slice(_pos)); - _pos += value.Length; + dst[i] = c; } + _pos += count; + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Span AppendSpan(int length) + public void Append(scoped ReadOnlySpan value) + { + int pos = _pos; + if (pos > _chars.Length - value.Length) { - int origPos = _pos; - if (origPos > _chars.Length - length) - { - Grow(length); - } - - _pos = origPos + length; - return _chars.Slice(origPos, length); + Grow(value.Length); } - [MethodImpl(MethodImplOptions.NoInlining)] - private void GrowAndAppend(char c) + value.CopyTo(_chars.Slice(_pos)); + _pos += value.Length; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AppendSpan(int length) + { + int origPos = _pos; + if (origPos > _chars.Length - length) { - Grow(1); - Append(c); + Grow(length); } - /// - /// Resize the internal buffer either by doubling current buffer size or - /// by adding to - /// whichever is greater. - /// - /// - /// Number of chars requested beyond current position. - /// - [MethodImpl(MethodImplOptions.NoInlining)] - private void Grow(int additionalCapacityBeyondPos) - { - Debug.Assert(additionalCapacityBeyondPos > 0); - Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed."); + _pos = origPos + length; + return _chars.Slice(origPos, length); + } - const uint ArrayMaxLength = 0x7FFFFFC7; // same as Array.MaxLength + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowAndAppend(char c) + { + Grow(1); + Append(c); + } + + /// + /// Resize the internal buffer either by doubling current buffer size or + /// by adding to + /// whichever is greater. + /// + /// + /// Number of chars requested beyond current position. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private void Grow(int additionalCapacityBeyondPos) + { + Debug.Assert(additionalCapacityBeyondPos > 0); + Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed."); - // Increase to at least the required size (_pos + additionalCapacityBeyondPos), but try - // to double the size if possible, bounding the doubling to not go beyond the max array length. - int newCapacity = (int)Math.Max( - (uint)(_pos + additionalCapacityBeyondPos), - Math.Min((uint)_chars.Length * 2, ArrayMaxLength)); + const uint ArrayMaxLength = 0x7FFFFFC7; // same as Array.MaxLength - // Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative. - // This could also go negative if the actual required length wraps around. - char[] poolArray = ArrayPool.Shared.Rent(newCapacity); + // Increase to at least the required size (_pos + additionalCapacityBeyondPos), but try + // to double the size if possible, bounding the doubling to not go beyond the max array length. + int newCapacity = (int)Math.Max( + (uint)(_pos + additionalCapacityBeyondPos), + Math.Min((uint)_chars.Length * 2, ArrayMaxLength)); - _chars.Slice(0, _pos).CopyTo(poolArray); + // Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative. + // This could also go negative if the actual required length wraps around. + char[] poolArray = ArrayPool.Shared.Rent(newCapacity); - char[]? toReturn = _arrayToReturnToPool; - _chars = _arrayToReturnToPool = poolArray; - if (toReturn != null) - { - ArrayPool.Shared.Return(toReturn); - } + _chars.Slice(0, _pos).CopyTo(poolArray); + + char[]? toReturn = _arrayToReturnToPool; + _chars = _arrayToReturnToPool = poolArray; + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); } + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void Dispose() + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Dispose() + { + char[]? toReturn = _arrayToReturnToPool; + this = default; // for safety, to avoid using pooled array if this instance is erroneously appended to again + if (toReturn != null) { - char[]? toReturn = _arrayToReturnToPool; - this = default; // for safety, to avoid using pooled array if this instance is erroneously appended to again - if (toReturn != null) - { - ArrayPool.Shared.Return(toReturn); - } + ArrayPool.Shared.Return(toReturn); } } } \ No newline at end of file diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 7859ba39a..f229a3eaf 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -6,7 +6,7 @@ https://github.com/modelcontextprotocol/csharp-sdk git 0.3.0 - preview.3 + preview.5 ModelContextProtocolOfficial © Anthropic and Contributors. ModelContextProtocol;mcp;ai;llm diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs index f8c6f41cd..46b8e898b 100644 --- a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs @@ -43,8 +43,7 @@ public async Task HandleRequestAsync() return false; } - await HandleResourceMetadataRequestAsync(); - return true; + return await HandleResourceMetadataRequestAsync(); } /// @@ -78,10 +77,7 @@ private string GetAbsoluteResourceMetadataUri() return absoluteUri.ToString(); } - /// - /// Handles the resource metadata request. - /// - private async Task HandleResourceMetadataRequestAsync() + private async Task HandleResourceMetadataRequestAsync() { var resourceMetadata = Options.ResourceMetadata; @@ -93,6 +89,23 @@ private async Task HandleResourceMetadataRequestAsync() }; await Options.Events.OnResourceMetadataRequest(context); + + if (context.Result is not null) + { + if (context.Result.Handled) + { + return true; + } + else if (context.Result.Skipped) + { + return false; + } + else if (context.Result.Failure is not null) + { + throw new AuthenticationFailureException("An error occurred from the OnResourceMetadataRequest event.", context.Result.Failure); + } + } + resourceMetadata = context.ResourceMetadata; } @@ -104,6 +117,7 @@ private async Task HandleResourceMetadataRequestAsync() } await Results.Json(resourceMetadata, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))).ExecuteAsync(Context); + return true; } /// diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs new file mode 100644 index 000000000..2cfb74d09 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -0,0 +1,359 @@ +using System.Diagnostics.CodeAnalysis; +using System.Security.Claims; +using Microsoft.AspNetCore.Authorization; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Evaluates authorization policies from endpoint metadata. +/// +internal sealed class AuthorizationFilterSetup(IAuthorizationPolicyProvider? policyProvider = null) : IConfigureOptions, IPostConfigureOptions +{ + private static readonly string AuthorizationFilterInvokedKey = "ModelContextProtocol.AspNetCore.AuthorizationFilter.Invoked"; + + public void Configure(McpServerOptions options) + { + ConfigureListToolsFilter(options); + ConfigureCallToolFilter(options); + + ConfigureListResourcesFilter(options); + ConfigureListResourceTemplatesFilter(options); + ConfigureReadResourceFilter(options); + + ConfigureListPromptsFilter(options); + ConfigureGetPromptFilter(options); + } + + public void PostConfigure(string? name, McpServerOptions options) + { + CheckListToolsFilter(options); + CheckCallToolFilter(options); + + CheckListResourcesFilter(options); + CheckListResourceTemplatesFilter(options); + CheckReadResourceFilter(options); + + CheckListPromptsFilter(options); + CheckGetPromptFilter(options); + } + + private void ConfigureListToolsFilter(McpServerOptions options) + { + options.Filters.ListToolsFilters.Add(next => async (context, cancellationToken) => + { + context.Items[AuthorizationFilterInvokedKey] = true; + + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Tools, static tool => tool.McpServerTool, + context.User, context.Services, context); + return result; + }); + } + + private void CheckListToolsFilter(McpServerOptions options) + { + options.Filters.ListToolsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.Tools.Select(static tool => tool.McpServerTool)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for tools/list operation, but authorization metadata was found on the tools. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + + private void ConfigureCallToolFilter(McpServerOptions options) + { + options.Filters.CallToolFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This tool requires authorization.", McpErrorCode.InvalidRequest); + } + + context.Items[AuthorizationFilterInvokedKey] = true; + + return await next(context, cancellationToken); + }); + } + + private void CheckCallToolFilter(McpServerOptions options) + { + options.Filters.CallToolFilters.Add(next => async (context, cancellationToken) => + { + if (HasAuthorizationMetadata(context.MatchedPrimitive) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for tools/call operation, but authorization metadata was found on the tool. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return await next(context, cancellationToken); + }); + } + + private void ConfigureListResourcesFilter(McpServerOptions options) + { + options.Filters.ListResourcesFilters.Add(next => async (context, cancellationToken) => + { + context.Items[AuthorizationFilterInvokedKey] = true; + + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Resources, static resource => resource.McpServerResource, + context.User, context.Services, context); + return result; + }); + } + + private void CheckListResourcesFilter(McpServerOptions options) + { + options.Filters.ListResourcesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.Resources.Select(static resource => resource.McpServerResource)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for resources/list operation, but authorization metadata was found on the resources. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + + private void ConfigureListResourceTemplatesFilter(McpServerOptions options) + { + options.Filters.ListResourceTemplatesFilters.Add(next => async (context, cancellationToken) => + { + context.Items[AuthorizationFilterInvokedKey] = true; + + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.ResourceTemplates, static resourceTemplate => resourceTemplate.McpServerResource, + context.User, context.Services, context); + return result; + }); + } + + private void CheckListResourceTemplatesFilter(McpServerOptions options) + { + options.Filters.ListResourceTemplatesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.ResourceTemplates.Select(static resourceTemplate => resourceTemplate.McpServerResource)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for resources/templates/list operation, but authorization metadata was found on the resource templates. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + + private void ConfigureReadResourceFilter(McpServerOptions options) + { + options.Filters.ReadResourceFilters.Add(next => async (context, cancellationToken) => + { + context.Items[AuthorizationFilterInvokedKey] = true; + + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This resource requires authorization.", McpErrorCode.InvalidRequest); + } + + return await next(context, cancellationToken); + }); + } + + private void CheckReadResourceFilter(McpServerOptions options) + { + options.Filters.ReadResourceFilters.Add(next => async (context, cancellationToken) => + { + if (HasAuthorizationMetadata(context.MatchedPrimitive) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for resources/read operation, but authorization metadata was found on the resource. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return await next(context, cancellationToken); + }); + } + + private void ConfigureListPromptsFilter(McpServerOptions options) + { + options.Filters.ListPromptsFilters.Add(next => async (context, cancellationToken) => + { + context.Items[AuthorizationFilterInvokedKey] = true; + + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Prompts, static prompt => prompt.McpServerPrompt, + context.User, context.Services, context); + return result; + }); + } + + private void CheckListPromptsFilter(McpServerOptions options) + { + options.Filters.ListPromptsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.Prompts.Select(static prompt => prompt.McpServerPrompt)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for prompts/list operation, but authorization metadata was found on the prompts. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + + private void ConfigureGetPromptFilter(McpServerOptions options) + { + options.Filters.GetPromptFilters.Add(next => async (context, cancellationToken) => + { + context.Items[AuthorizationFilterInvokedKey] = true; + + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This prompt requires authorization.", McpErrorCode.InvalidRequest); + } + + return await next(context, cancellationToken); + }); + } + + private void CheckGetPromptFilter(McpServerOptions options) + { + options.Filters.GetPromptFilters.Add(next => async (context, cancellationToken) => + { + if (HasAuthorizationMetadata(context.MatchedPrimitive) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for prompts/get operation, but authorization metadata was found on the prompt. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return await next(context, cancellationToken); + }); + } + + /// + /// Filters a collection of items based on authorization policies in their metadata. + /// For list operations where we need to filter results by authorization. + /// + private async ValueTask FilterAuthorizedItemsAsync(IList items, Func primitiveSelector, + ClaimsPrincipal? user, IServiceProvider? requestServices, object context) + { + for (int i = items.Count - 1; i >= 0; i--) + { + var authorizationResult = await GetAuthorizationResultAsync( + user, primitiveSelector(items[i]), requestServices, context); + + if (!authorizationResult.Succeeded) + { + items.RemoveAt(i); + } + } + } + + private async ValueTask GetAuthorizationResultAsync( + ClaimsPrincipal? user, IMcpServerPrimitive? primitive, IServiceProvider? requestServices, object context) + { + if (!HasAuthorizationMetadata(primitive)) + { + return AuthorizationResult.Success(); + } + + if (policyProvider is null) + { + throw new InvalidOperationException($"You must call AddAuthorization() because an authorization related attribute was found on {primitive.Id}"); + } + + var policy = await CombineAsync(policyProvider, primitive.Metadata); + if (policy is null) + { + return AuthorizationResult.Success(); + } + + if (requestServices is null) + { + // The IAuthorizationPolicyProvider service must be non-null to get to this line, so it's very unexpected for RequestContext.Services to not be set. + throw new InvalidOperationException("RequestContext.Services is not set! The McpServer must be initialized with a non-null IServiceProvider."); + } + + // ASP.NET Core's AuthorizationMiddleware resolves the IAuthorizationService from scoped request services, so we do the same. + var authService = requestServices.GetRequiredService(); + return await authService.AuthorizeAsync(user ?? new ClaimsPrincipal(new ClaimsIdentity()), context, policy); + } + + /// + /// Combines authorization policies and requirements from endpoint metadata without considering . + /// + /// The authorization policy provider. + /// The endpoint metadata collection. + /// The combined authorization policy, or null if no authorization is required. + private static async ValueTask CombineAsync(IAuthorizationPolicyProvider policyProvider, IReadOnlyList endpointMetadata) + { + // https://github.com/dotnet/aspnetcore/issues/63365 tracks adding this as public API to AuthorizationPolicy itself. + // Copied from https://github.com/dotnet/aspnetcore/blob/9f2977bf9cfb539820983bda3bedf81c8cda9f20/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs#L116-L138 + var authorizeData = endpointMetadata.OfType(); + var policies = endpointMetadata.OfType(); + + var policy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData, policies); + + AuthorizationPolicyBuilder? reqPolicyBuilder = null; + + foreach (var m in endpointMetadata) + { + if (m is not IAuthorizationRequirementData requirementData) + { + continue; + } + + reqPolicyBuilder ??= new AuthorizationPolicyBuilder(); + foreach (var requirement in requirementData.GetRequirements()) + { + reqPolicyBuilder.AddRequirements(requirement); + } + } + + if (reqPolicyBuilder is null) + { + return policy; + } + + // Combine policy with requirements or just use requirements if no policy + return (policy is null) + ? reqPolicyBuilder.Build() + : AuthorizationPolicy.Combine(policy, reqPolicyBuilder.Build()); + } + + private static bool HasAuthorizationMetadata([NotNullWhen(true)] IMcpServerPrimitive? primitive) + { + // If no primitive was found for this request or there is IAllowAnonymous metadata anywhere on the class or method, + // the request should go through as normal. + if (primitive is null || primitive.Metadata.Any(static m => m is IAllowAnonymous)) + { + return false; + } + + return primitive.Metadata.Any(static m => m is IAuthorizeData or AuthorizationPolicy or IAuthorizationRequirementData); + } + + private static bool HasAuthorizationMetadata(IEnumerable primitives) + => primitives.Any(HasAuthorizationMetadata); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 0cdc4e37b..fbceab4b1 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,4 +1,6 @@ +using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Server; @@ -23,11 +25,14 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder { ArgumentNullException.ThrowIfNull(builder); + builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); builder.Services.AddDataProtection(); + builder.Services.TryAddEnumerable(ServiceDescriptor.Transient, AuthorizationFilterSetup>()); + if (configureOptions is not null) { builder.Services.Configure(configureOptions); @@ -35,4 +40,27 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder return builder; } + + /// + /// Adds authorization filters to support + /// on MCP server tools, prompts, and resources. This method should always be called when using + /// ASP.NET Core integration to ensure proper authorization support. + /// + /// The builder instance. + /// The builder provided in . + /// is . + /// + /// This method automatically configures authorization filters for all MCP server handlers. These filters respect + /// authorization attributes such as + /// and . + /// + public static IMcpServerBuilder AddAuthorizationFilters(this IMcpServerBuilder builder) + { + ArgumentNullException.ThrowIfNull(builder); + + // Allow the authorization filters to get added multiple times in case other middleware changes the matched primitive. + builder.Services.AddTransient, AuthorizationFilterSetup>(); + + return builder; + } } diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs deleted file mode 100644 index c34aba6c7..000000000 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpSession.cs +++ /dev/null @@ -1,83 +0,0 @@ -using ModelContextProtocol.AspNetCore.Stateless; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Security.Claims; - -namespace ModelContextProtocol.AspNetCore; - -internal sealed class HttpMcpSession( - string sessionId, - TTransport transport, - UserIdClaim? userId, - TimeProvider timeProvider) : IAsyncDisposable - where TTransport : ITransport -{ - private int _referenceCount; - private int _getRequestStarted; - private CancellationTokenSource _disposeCts = new(); - - public string Id { get; } = sessionId; - public TTransport Transport { get; } = transport; - public UserIdClaim? UserIdClaim { get; } = userId; - - public CancellationToken SessionClosed => _disposeCts.Token; - - public bool IsActive => !SessionClosed.IsCancellationRequested && _referenceCount > 0; - public long LastActivityTicks { get; private set; } = timeProvider.GetTimestamp(); - - public IMcpServer? Server { get; set; } - public Task? ServerRunTask { get; set; } - - public IDisposable AcquireReference() - { - Interlocked.Increment(ref _referenceCount); - return new UnreferenceDisposable(this, timeProvider); - } - - public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0; - - public async ValueTask DisposeAsync() - { - try - { - await _disposeCts.CancelAsync(); - - if (ServerRunTask is not null) - { - await ServerRunTask; - } - } - catch (OperationCanceledException) - { - } - finally - { - try - { - if (Server is not null) - { - await Server.DisposeAsync(); - } - } - finally - { - await Transport.DisposeAsync(); - _disposeCts.Dispose(); - } - } - } - - public bool HasSameUserId(ClaimsPrincipal user) - => UserIdClaim == StreamableHttpHandler.GetUserIdClaim(user); - - private sealed class UnreferenceDisposable(HttpMcpSession session, TimeProvider timeProvider) : IDisposable - { - public void Dispose() - { - if (Interlocked.Decrement(ref session._referenceCount) == 0) - { - session.LastActivityTicks = timeProvider.GetTimestamp(); - } - } - } -} diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 2a34a17a1..8d71f5166 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -20,7 +20,7 @@ public class HttpServerTransportOptions /// Gets or sets an optional asynchronous callback for running new MCP sessions manually. /// This is useful for running logic before a sessions starts and after it completes. /// - public Func? RunSessionHandler { get; set; } + public Func? RunSessionHandler { get; set; } /// /// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session @@ -66,9 +66,9 @@ public class HttpServerTransportOptions /// Past this limit, the server will log a critical error and terminate the oldest idle sessions even if they have not reached /// their until the idle session count is below this limit. Clients that keep their session open by /// keeping a GET request open will not count towards this limit. - /// Defaults to 100,000 sessions. + /// Defaults to 10,000 sessions. /// - public int MaxIdleSessionCount { get; set; } = 100_000; + public int MaxIdleSessionCount { get; set; } = 10_000; /// /// Used for testing the . diff --git a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs index 26ffd44bb..a4ae569ba 100644 --- a/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs +++ b/src/ModelContextProtocol.AspNetCore/IdleTrackingBackgroundService.cs @@ -1,17 +1,16 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Server; namespace ModelContextProtocol.AspNetCore; internal sealed partial class IdleTrackingBackgroundService( - StreamableHttpHandler handler, + StatefulSessionManager sessions, IOptions options, IHostApplicationLifetime appLifetime, ILogger logger) : BackgroundService { - // The compiler will complain about the parameter being unused otherwise despite the source generator. + // Workaround for https://github.com/dotnet/runtime/issues/91121. This is fixed in .NET 9 and later. private readonly ILogger _logger = logger; protected override async Task ExecuteAsync(CancellationToken stoppingToken) @@ -21,6 +20,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.IdleTimeout, TimeSpan.Zero); } + ArgumentOutOfRangeException.ThrowIfLessThan(options.Value.MaxIdleSessionCount, 0); try @@ -28,54 +28,9 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) var timeProvider = options.Value.TimeProvider; using var timer = new PeriodicTimer(TimeSpan.FromSeconds(5), timeProvider); - var idleTimeoutTicks = options.Value.IdleTimeout.Ticks; - var maxIdleSessionCount = options.Value.MaxIdleSessionCount; - - // The default ValueTuple Comparer will check the first item then the second which preserves both order and uniqueness. - var idleSessions = new SortedSet<(long Timestamp, string SessionId)>(); - while (!stoppingToken.IsCancellationRequested && await timer.WaitForNextTickAsync(stoppingToken)) { - var idleActivityCutoff = idleTimeoutTicks switch - { - < 0 => long.MinValue, - var ticks => timeProvider.GetTimestamp() - ticks, - }; - - foreach (var (_, session) in handler.Sessions) - { - if (session.IsActive || session.SessionClosed.IsCancellationRequested) - { - // There's a request currently active or the session is already being closed. - continue; - } - - if (session.LastActivityTicks < idleActivityCutoff) - { - RemoveAndCloseSession(session.Id); - continue; - } - - idleSessions.Add((session.LastActivityTicks, session.Id)); - - // Emit critical log at most once every 5 seconds the idle count it exceeded, - // since the IdleTimeout will no longer be respected. - if (idleSessions.Count == maxIdleSessionCount + 1) - { - LogMaxSessionIdleCountExceeded(maxIdleSessionCount); - } - } - - if (idleSessions.Count > maxIdleSessionCount) - { - var sessionsToPrune = idleSessions.ToArray()[..^maxIdleSessionCount]; - foreach (var (_, id) in sessionsToPrune) - { - RemoveAndCloseSession(id); - } - } - - idleSessions.Clear(); + await sessions.PruneIdleSessionsAsync(stoppingToken); } } catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) @@ -85,17 +40,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { try { - List disposeSessionTasks = []; - - foreach (var (sessionKey, _) in handler.Sessions) - { - if (handler.Sessions.TryRemove(sessionKey, out var session)) - { - disposeSessionTasks.Add(DisposeSessionAsync(session)); - } - } - - await Task.WhenAll(disposeSessionTasks); + await sessions.DisposeAllSessionsAsync(); } finally { @@ -110,39 +55,6 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) } } - private void RemoveAndCloseSession(string sessionId) - { - if (!handler.Sessions.TryRemove(sessionId, out var session)) - { - return; - } - - LogSessionIdle(session.Id); - // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. - _ = DisposeSessionAsync(session); - } - - private async Task DisposeSessionAsync(HttpMcpSession session) - { - try - { - await session.DisposeAsync(); - } - catch (Exception ex) - { - LogSessionDisposeError(session.Id, ex); - } - } - - [LoggerMessage(Level = LogLevel.Information, Message = "Closing idle session {sessionId}.")] - private partial void LogSessionIdle(string sessionId); - - [LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {sessionId}.")] - private partial void LogSessionDisposeError(string sessionId, Exception ex); - - [LoggerMessage(Level = LogLevel.Critical, Message = "Exceeded maximum of {maxIdleSessionCount} idle sessions. Now closing sessions active more recently than configured IdleTimeout.")] - private partial void LogMaxSessionIdleCountExceeded(int maxIdleSessionCount); - [LoggerMessage(Level = LogLevel.Critical, Message = "The IdleTrackingBackgroundService has stopped unexpectedly.")] private partial void IdleTrackingBackgroundServiceStoppedUnexpectedly(); -} +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index c5ac5a948..eefe0d29e 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -2,7 +2,6 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.Collections.Concurrent; using System.Diagnostics; @@ -16,7 +15,7 @@ internal sealed class SseHandler( IHostApplicationLifetime hostApplicationLifetime, ILoggerFactory loggerFactory) { - private readonly ConcurrentDictionary> _sessions = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); public async Task HandleSseRequestAsync(HttpContext context) { @@ -34,9 +33,9 @@ public async Task HandleSseRequestAsync(HttpContext context) await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId); var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User); - await using var httpMcpSession = new HttpMcpSession(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider); + var sseSession = new SseSession(transport, userIdClaim); - if (!_sessions.TryAdd(sessionId, httpMcpSession)) + if (!_sessions.TryAdd(sessionId, sseSession)) { throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } @@ -54,13 +53,11 @@ public async Task HandleSseRequestAsync(HttpContext context) try { - await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); - httpMcpSession.Server = mcpServer; + await using var mcpServer = McpServer.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); context.Features.Set(mcpServer); var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync; - httpMcpSession.ServerRunTask = runSessionAsync(context, mcpServer, cancellationToken); - await httpMcpSession.ServerRunTask; + await runSessionAsync(context, mcpServer, cancellationToken); } finally { @@ -87,27 +84,29 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - if (!_sessions.TryGetValue(sessionId.ToString(), out var httpMcpSession)) + if (!_sessions.TryGetValue(sessionId.ToString(), out var sseSession)) { await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); return; } - if (!httpMcpSession.HasSameUserId(context.User)) + if (sseSession.UserId != StreamableHttpHandler.GetUserIdClaim(context.User)) { await Results.Forbid().ExecuteAsync(context); return; } - var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted); + var message = await StreamableHttpHandler.ReadJsonRpcMessageAsync(context); if (message is null) { await Results.BadRequest("No message in request body.").ExecuteAsync(context); return; } - await httpMcpSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); + await sseSession.Transport.OnMessageReceivedAsync(message, context.RequestAborted); context.Response.StatusCode = StatusCodes.Status202Accepted; await context.Response.WriteAsync("Accepted"); } + + private record SseSession(SseResponseStreamTransport Transport, UserIdClaim? UserId); } diff --git a/src/ModelContextProtocol.AspNetCore/StatefulSessionManager.cs b/src/ModelContextProtocol.AspNetCore/StatefulSessionManager.cs new file mode 100644 index 000000000..960488af7 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StatefulSessionManager.cs @@ -0,0 +1,243 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed partial class StatefulSessionManager( + IOptions httpServerTransportOptions, + ILogger logger) +{ + // Workaround for https://github.com/dotnet/runtime/issues/91121. This is fixed in .NET 9 and later. + private readonly ILogger _logger = logger; + + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + + private readonly TimeProvider _timeProvider = httpServerTransportOptions.Value.TimeProvider; + private readonly TimeSpan _idleTimeout = httpServerTransportOptions.Value.IdleTimeout; + private readonly long _idleTimeoutTicks = httpServerTransportOptions.Value.IdleTimeout.Ticks; + private readonly int _maxIdleSessionCount = httpServerTransportOptions.Value.MaxIdleSessionCount; + + private readonly object _idlePruningLock = new(); + private readonly List _idleTimestamps = []; + private readonly List _idleSessionIds = []; + private int _nextIndexToPrune; + + private long _currentIdleSessionCount; + + public TimeProvider TimeProvider => _timeProvider; + + public void IncrementIdleSessionCount() => Interlocked.Increment(ref _currentIdleSessionCount); + public void DecrementIdleSessionCount() => Interlocked.Decrement(ref _currentIdleSessionCount); + + public bool TryGetValue(string key, [NotNullWhen(true)] out StreamableHttpSession? value) => _sessions.TryGetValue(key, out value); + public bool TryRemove(string key, [NotNullWhen(true)] out StreamableHttpSession? value) => _sessions.TryRemove(key, out value); + + public async ValueTask StartNewSessionAsync(StreamableHttpSession newSession, CancellationToken cancellationToken) + { + while (!TryAddSessionImmediately(newSession)) + { + StreamableHttpSession? sessionToPrune = null; + + lock (_idlePruningLock) + { + EnsureIdleSessionsSortedUnsynchronized(); + + while (_nextIndexToPrune < _idleSessionIds.Count) + { + var pruneId = _idleSessionIds[_nextIndexToPrune++]; + if (_sessions.TryRemove(pruneId, out sessionToPrune)) + { + LogIdleSessionLimit(pruneId, _maxIdleSessionCount); + break; + } + } + + if (sessionToPrune is null) + { + // If we couldn't find any active idle sessions to dispose, start another full prune to repopulate _idleSessionIds. + PruneIdleSessionsUnsynchronized(); + + if (_idleSessionIds.Count > 0) + { + continue; + } + else + { + // This indicates all idle sessions are in the process of being disposed which should not happen during normal operation. + // Since there are no idle sessions to prune right now, log a critical error and create the new session anyway. + LogTooManyIdleSessionsClosingConcurrently(newSession.Id, _maxIdleSessionCount, Volatile.Read(ref _currentIdleSessionCount)); + AddSession(newSession); + return; + } + } + } + + try + { + // Since we're at or above the maximum idle session count, we're intentionally waiting for the idle session to be disposed + // before adding a new session to the dictionary to ensure sessions not created faster than they're removed. + await DisposeSessionAsync(sessionToPrune); + + // Take one last chance to check if the initialize request was aborted before we incur the cost of managing a new session. + cancellationToken.ThrowIfCancellationRequested(); + AddSession(newSession); + return; + } + catch + { + await newSession.DisposeAsync(); + throw; + } + } + } + + /// + /// Performs a single pass of idle session pruning, removing sessions that exceed the idle timeout + /// or when the maximum idle session count is exceeded. + /// + public async Task PruneIdleSessionsAsync(CancellationToken cancellationToken) + { + lock (_idlePruningLock) + { + PruneIdleSessionsUnsynchronized(); + } + } + + private void PruneIdleSessionsUnsynchronized() + { + var idleActivityCutoff = _idleTimeoutTicks switch + { + < 0 => long.MinValue, + var ticks => _timeProvider.GetTimestamp() - ticks, + }; + + // We clear the lists at the start of pruning rather than the end so we can use them between runs + // to find the most idle sessions to remove one-at-a-time if necessary to make room for new sessions. + _idleTimestamps.Clear(); + _idleSessionIds.Clear(); + _nextIndexToPrune = -1; + + foreach (var (_, session) in _sessions) + { + if (session.IsActive || session.SessionClosed.IsCancellationRequested) + { + // There's a request currently active or the session is already being closed. + continue; + } + + if (session.LastActivityTicks < idleActivityCutoff) + { + LogIdleSessionTimeout(session.Id, _idleTimeout); + RemoveAndCloseSession(session.Id); + continue; + } + + // Add the timestamp and the session + _idleTimestamps.Add(session.LastActivityTicks); + _idleSessionIds.Add(session.Id); + } + + if (_idleTimestamps.Count > _maxIdleSessionCount) + { + // Sort only if the maximum is breached and sort solely by the timestamp. + EnsureIdleSessionsSortedUnsynchronized(); + + var sessionsToPrune = CollectionsMarshal.AsSpan(_idleSessionIds)[..^_maxIdleSessionCount]; + foreach (var id in sessionsToPrune) + { + LogIdleSessionLimit(id, _maxIdleSessionCount); + RemoveAndCloseSession(id); + } + _nextIndexToPrune = _maxIdleSessionCount; + } + } + + private void EnsureIdleSessionsSortedUnsynchronized() + { + if (_nextIndexToPrune > -1) + { + // Already sorted. + return; + } + + var timestamps = CollectionsMarshal.AsSpan(_idleTimestamps); + timestamps.Sort(CollectionsMarshal.AsSpan(_idleSessionIds)); + _nextIndexToPrune = 0; + } + + /// + /// Disposes all sessions in the manager, typically called during graceful shutdown. + /// + public async Task DisposeAllSessionsAsync() + { + List disposeSessionTasks = []; + + foreach (var (sessionKey, _) in _sessions) + { + if (_sessions.TryRemove(sessionKey, out var session)) + { + disposeSessionTasks.Add(DisposeSessionAsync(session)); + } + } + + await Task.WhenAll(disposeSessionTasks); + } + + private bool TryAddSessionImmediately(StreamableHttpSession session) + { + if (Volatile.Read(ref _currentIdleSessionCount) < _maxIdleSessionCount) + { + AddSession(session); + return true; + } + + return false; + } + + private void AddSession(StreamableHttpSession session) + { + if (!_sessions.TryAdd(session.Id, session)) + { + throw new UnreachableException($"Unreachable given good entropy! Session with ID '{session.Id}' has already been created."); + } + } + + private void RemoveAndCloseSession(string sessionId) + { + if (!_sessions.TryRemove(sessionId, out var session)) + { + return; + } + + // Don't slow down the idle tracking loop. DisposeSessionAsync logs. We only await during graceful shutdown. + _ = DisposeSessionAsync(session); + } + + private async Task DisposeSessionAsync(StreamableHttpSession session) + { + try + { + await session.DisposeAsync(); + } + catch (Exception ex) + { + LogSessionDisposeError(session.Id, ex); + } + } + + [LoggerMessage(Level = LogLevel.Information, Message = "IdleTimeout of {IdleTimeout} exceeded. Closing idle session {SessionId}.")] + private partial void LogIdleSessionTimeout(string sessionId, TimeSpan idleTimeout); + + [LoggerMessage(Level = LogLevel.Information, Message = "MaxIdleSessionCount of {MaxIdleSessionCount} exceeded. Closing idle session {SessionId} despite it being active more recently than the configured IdleTimeout to make room for new sessions.")] + private partial void LogIdleSessionLimit(string sessionId, int maxIdleSessionCount); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error disposing session {SessionId}.")] + private partial void LogSessionDisposeError(string sessionId, Exception ex); + + [LoggerMessage(Level = LogLevel.Critical, Message = "MaxIdleSessionCount of {MaxIdleSessionCount} exceeded, and {CurrentIdleSessionCount} sessions are currently in the process of closing. Creating new session {SessionId} anyway.")] + private partial void LogTooManyIdleSessionsClosingConcurrently(string sessionId, int maxIdleSessionCount, long currentIdleSessionCount); +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 6dac1c3e4..14093facc 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -8,9 +8,6 @@ using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.Collections.Concurrent; -using System.Diagnostics; -using System.IO.Pipelines; using System.Security.Claims; using System.Security.Cryptography; using System.Text.Json; @@ -22,14 +19,15 @@ internal sealed class StreamableHttpHandler( IOptions mcpServerOptionsSnapshot, IOptionsFactory mcpServerOptionsFactory, IOptions httpServerTransportOptions, + StatefulSessionManager sessionManager, IDataProtectionProvider dataProtection, ILoggerFactory loggerFactory, IServiceProvider applicationServices) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; - private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); - public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); + private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; @@ -56,28 +54,24 @@ await WriteJsonRpcErrorAsync(context, return; } - try - { - using var _ = session.AcquireReference(); + await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); - InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); - if (!wroteResponse) - { - // We wound up writing nothing, so there should be no Content-Type response header. - context.Response.Headers.ContentType = (string?)null; - context.Response.StatusCode = StatusCodes.Status202Accepted; - } + var message = await ReadJsonRpcMessageAsync(context); + if (message is null) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The POST body did not contain a valid JSON-RPC message.", + StatusCodes.Status400BadRequest); + return; } - finally + + InitializeSseResponse(context); + var wroteResponse = await session.Transport.HandlePostRequest(message, context.Response.Body, context.RequestAborted); + if (!wroteResponse) { - // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the Mcp-Session-Id. - // Non-stateless sessions are 1:1 with the Mcp-Session-Id and outlive the POST request. - // Non-stateless sessions get disposed by a DELETE request or the IdleTrackingBackgroundService. - if (HttpServerTransportOptions.Stateless) - { - await session.DisposeAsync(); - } + // We wound up writing nothing, so there should be no Content-Type response header. + context.Response.Headers.ContentType = (string?)null; + context.Response.StatusCode = StatusCodes.Status202Accepted; } } @@ -106,7 +100,7 @@ await WriteJsonRpcErrorAsync(context, return; } - using var _ = session.AcquireReference(); + await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); InitializeSseResponse(context); // We should flush headers to indicate a 200 success quickly, because the initialization response @@ -119,17 +113,22 @@ await WriteJsonRpcErrorAsync(context, public async Task HandleDeleteRequestAsync(HttpContext context) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); - if (Sessions.TryRemove(sessionId, out var session)) + if (sessionManager.TryRemove(sessionId, out var session)) { await session.DisposeAsync(); } } - private async ValueTask?> GetSessionAsync(HttpContext context, string sessionId) + private async ValueTask GetSessionAsync(HttpContext context, string sessionId) { - HttpMcpSession? session; + StreamableHttpSession? session; - if (HttpServerTransportOptions.Stateless) + if (string.IsNullOrEmpty(sessionId)) + { + await WriteJsonRpcErrorAsync(context, "Bad Request: Mcp-Session-Id header is required", StatusCodes.Status400BadRequest); + return null; + } + else if (HttpServerTransportOptions.Stateless) { var sessionJson = Protector.Unprotect(sessionId); var statelessSessionId = JsonSerializer.Deserialize(sessionJson, StatelessSessionIdJsonContext.Default.StatelessSessionId); @@ -140,7 +139,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context) }; session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); } - else if (!Sessions.TryGetValue(sessionId, out session)) + else if (!sessionManager.TryGetValue(sessionId, out session)) { // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this @@ -163,7 +162,7 @@ await WriteJsonRpcErrorAsync(context, return session; } - private async ValueTask?> GetOrCreateSessionAsync(HttpContext context) + private async ValueTask GetOrCreateSessionAsync(HttpContext context) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); @@ -177,7 +176,7 @@ await WriteJsonRpcErrorAsync(context, } } - private async ValueTask> StartNewSessionAsync(HttpContext context) + private async ValueTask StartNewSessionAsync(HttpContext context) { string sessionId; StreamableHttpServerTransport transport; @@ -204,21 +203,10 @@ private async ValueTask> StartNewS ScheduleStatelessSessionIdWrite(context, transport); } - var session = await CreateSessionAsync(context, transport, sessionId); - - // The HttpMcpSession is not stored between requests in stateless mode. Instead, the session is recreated from the MCP-Session-Id. - if (!HttpServerTransportOptions.Stateless) - { - if (!Sessions.TryAdd(sessionId, session)) - { - throw new UnreachableException($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); - } - } - - return session; + return await CreateSessionAsync(context, transport, sessionId); } - private async ValueTask> CreateSessionAsync( + private async ValueTask CreateSessionAsync( HttpContext context, StreamableHttpServerTransport transport, string sessionId, @@ -244,14 +232,11 @@ private async ValueTask> CreateSes } } - var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); + var server = McpServer.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); context.Features.Set(server); var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); - var session = new HttpMcpSession(sessionId, transport, userIdClaim, HttpServerTransportOptions.TimeProvider) - { - Server = server, - }; + var session = new StreamableHttpSession(sessionId, transport, server, userIdClaim, sessionManager); var runSessionAsync = HttpServerTransportOptions.RunSessionHandler ?? RunSessionAsync; session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed); @@ -289,6 +274,22 @@ internal static string MakeNewSessionId() return WebEncoders.Base64UrlEncode(buffer); } + internal static async Task ReadJsonRpcMessageAsync(HttpContext context) + { + // Implementation for reading a JSON-RPC message from the request body + var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); + + if (context.User?.Identity?.IsAuthenticated == true && message is not null) + { + message.Context = new() + { + User = context.User, + }; + } + + return message; + } + private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) { transport.OnInitRequestReceived = initRequestParams => @@ -306,7 +307,7 @@ private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttp }; } - internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) + internal static Task RunSessionAsync(HttpContext httpContext, McpServer session, CancellationToken requestAborted) => session.RunAsync(requestAborted); // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. @@ -329,17 +330,11 @@ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session return null; } - private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + internal static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("application/json"); private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("text/event-stream"); - - private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe - { - public PipeReader Input => context.Request.BodyReader; - public PipeWriter Output => context.Response.BodyWriter; - } } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs new file mode 100644 index 000000000..1e8d22dec --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -0,0 +1,154 @@ +using ModelContextProtocol.Server; +using System.Diagnostics; +using System.Security.Claims; + +namespace ModelContextProtocol.AspNetCore; + +internal sealed class StreamableHttpSession( + string sessionId, + StreamableHttpServerTransport transport, + McpServer server, + UserIdClaim? userId, + StatefulSessionManager sessionManager) : IAsyncDisposable +{ + private int _referenceCount; + private SessionState _state; + private readonly object _stateLock = new(); + + private int _getRequestStarted; + private readonly CancellationTokenSource _disposeCts = new(); + + public string Id => sessionId; + public StreamableHttpServerTransport Transport => transport; + public McpServer Server => server; + private StatefulSessionManager SessionManager => sessionManager; + + public CancellationToken SessionClosed => _disposeCts.Token; + public bool IsActive => !SessionClosed.IsCancellationRequested && _referenceCount > 0; + public long LastActivityTicks { get; private set; } = sessionManager.TimeProvider.GetTimestamp(); + + public Task ServerRunTask { get; set; } = Task.CompletedTask; + + public async ValueTask AcquireReferenceAsync(CancellationToken cancellationToken) + { + // The StreamableHttpSession is not stored between requests in stateless mode. Instead, the session is recreated from the MCP-Session-Id. + // Stateless sessions are 1:1 with HTTP requests and are outlived by the MCP session tracked by the Mcp-Session-Id. + // Non-stateless sessions are 1:1 with the Mcp-Session-Id and outlive the POST request. + // Non-stateless sessions get disposed by a DELETE request or the IdleTrackingBackgroundService. + if (transport.Stateless) + { + return this; + } + + SessionState startingState; + + lock (_stateLock) + { + startingState = _state; + _referenceCount++; + + switch (startingState) + { + case SessionState.Uninitialized: + Debug.Assert(_referenceCount == 1, "The _referenceCount should start at 1 when the StreamableHttpSession is uninitialized."); + _state = SessionState.Started; + break; + case SessionState.Started: + if (_referenceCount == 1) + { + sessionManager.DecrementIdleSessionCount(); + } + break; + case SessionState.Disposed: + throw new ObjectDisposedException(nameof(StreamableHttpSession)); + } + } + + if (startingState == SessionState.Uninitialized) + { + await sessionManager.StartNewSessionAsync(this, cancellationToken); + } + + return new UnreferenceDisposable(this); + } + + public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0; + public bool HasSameUserId(ClaimsPrincipal user) => userId == StreamableHttpHandler.GetUserIdClaim(user); + + public async ValueTask DisposeAsync() + { + var wasIdle = false; + + lock (_stateLock) + { + switch (_state) + { + case SessionState.Uninitialized: + break; + case SessionState.Started: + if (_referenceCount == 0) + { + wasIdle = true; + } + break; + case SessionState.Disposed: + return; + } + + _state = SessionState.Disposed; + } + + try + { + try + { + // Dispose transport first to complete the incoming MessageReader gracefully and avoid a potentially unnecessary OCE. + await transport.DisposeAsync(); + await _disposeCts.CancelAsync(); + + await ServerRunTask; + } + finally + { + await server.DisposeAsync(); + } + } + catch (OperationCanceledException) + { + } + finally + { + if (wasIdle) + { + sessionManager.DecrementIdleSessionCount(); + } + _disposeCts.Dispose(); + } + } + + private sealed class UnreferenceDisposable(StreamableHttpSession session) : IAsyncDisposable + { + public ValueTask DisposeAsync() + { + lock (session._stateLock) + { + Debug.Assert(session._state != SessionState.Uninitialized, "The session should have been initialized."); + if (session._state != SessionState.Disposed && --session._referenceCount == 0) + { + var sessionManager = session.SessionManager; + session.LastActivityTicks = sessionManager.TimeProvider.GetTimestamp(); + sessionManager.IncrementIdleSessionCount(); + } + } + + return default; + } + } + + private enum SessionState + { + Uninitialized, + Started, + Disposed + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs b/src/ModelContextProtocol.AspNetCore/UserIdClaim.cs similarity index 58% rename from src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs rename to src/ModelContextProtocol.AspNetCore/UserIdClaim.cs index f18c1c5ff..5b5951d3d 100644 --- a/src/ModelContextProtocol.AspNetCore/Stateless/UserIdClaim.cs +++ b/src/ModelContextProtocol.AspNetCore/UserIdClaim.cs @@ -1,3 +1,3 @@ -namespace ModelContextProtocol.AspNetCore.Stateless; +namespace ModelContextProtocol.AspNetCore; internal sealed record UserIdClaim(string Type, string Value, string Issuer); diff --git a/src/ModelContextProtocol.Core/AssemblyNameHelper.cs b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs new file mode 100644 index 000000000..292ed2f96 --- /dev/null +++ b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs @@ -0,0 +1,9 @@ +using System.Reflection; + +namespace ModelContextProtocol; + +internal static class AssemblyNameHelper +{ + /// Cached naming information used for MCP session name/version when none is specified. + public static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); +} diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index 686316f55..cc6a8952e 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -68,22 +68,12 @@ public sealed class ClientOAuthOptions public Func, Uri?>? AuthServerSelector { get; set; } /// - /// Gets or sets the client name to use during dynamic client registration. + /// Gets or sets the options to use during dynamic client registration. /// /// - /// This is a human-readable name for the client that may be displayed to users during authorization. /// Only used when a is not specified. /// - public string? ClientName { get; set; } - - /// - /// Gets or sets the client URI to use during dynamic client registration. - /// - /// - /// This should be a URL pointing to the client's home page or information page. - /// Only used when a is not specified. - /// - public Uri? ClientUri { get; set; } + public DynamicClientRegistrationOptions? DynamicClientRegistration { get; set; } /// /// Gets or sets additional parameters to include in the query string of the OAuth authorization request diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 96356028f..b72f775c4 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -1,7 +1,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using System.Collections.Specialized; using System.Diagnostics.CodeAnalysis; +using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; using System.Text.Json; @@ -28,9 +28,11 @@ internal sealed partial class ClientOAuthProvider private readonly Func, Uri?> _authServerSelector; private readonly AuthorizationRedirectDelegate _authorizationRedirectDelegate; - // _clientName and _client URI is used for dynamic client registration (RFC 7591) - private readonly string? _clientName; - private readonly Uri? _clientUri; + // _dcrClientName, _dcrClientUri, _dcrInitialAccessToken and _dcrResponseDelegate are used for dynamic client registration (RFC 7591) + private readonly string? _dcrClientName; + private readonly Uri? _dcrClientUri; + private readonly string? _dcrInitialAccessToken; + private readonly Func? _dcrResponseDelegate; private readonly HttpClient _httpClient; private readonly ILogger _logger; @@ -66,9 +68,7 @@ public ClientOAuthProvider( _clientId = options.ClientId; _clientSecret = options.ClientSecret; - _redirectUri = options.RedirectUri ?? throw new ArgumentException("ClientOAuthOptions.RedirectUri must configured."); - _clientName = options.ClientName; - _clientUri = options.ClientUri; + _redirectUri = options.RedirectUri ?? throw new ArgumentException("ClientOAuthOptions.RedirectUri must configured.", nameof(options)); _scopes = options.Scopes?.ToArray(); _additionalAuthorizationParameters = options.AdditionalAuthorizationParameters; @@ -77,6 +77,11 @@ public ClientOAuthProvider( // Set up authorization URL handler (use default if not provided) _authorizationRedirectDelegate = options.AuthorizationRedirectDelegate ?? DefaultAuthorizationUrlHandler; + + _dcrClientName = options.DynamicClientRegistration?.ClientName; + _dcrClientUri = options.DynamicClientRegistration?.ClientUri; + _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; + _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; } /// @@ -212,11 +217,6 @@ private async Task PerformOAuthAuthorizationAsync( // Get auth server metadata var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken).ConfigureAwait(false); - if (authServerMetadata is null) - { - ThrowFailedToHandleUnauthorizedResponse($"Failed to retrieve metadata for authorization server: '{selectedAuthServer}'"); - } - // Store auth server metadata for future refresh operations _authServerMetadata = authServerMetadata; @@ -238,7 +238,7 @@ private async Task PerformOAuthAuthorizationAsync( LogOAuthAuthorizationCompleted(); } - private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) + private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) { if (!authServerUri.OriginalString.EndsWith("/")) { @@ -249,7 +249,9 @@ private async Task PerformOAuthAuthorizationAsync( { try { - var response = await _httpClient.GetAsync(new Uri(authServerUri, path), cancellationToken).ConfigureAwait(false); + var wellKnownEndpoint = new Uri(authServerUri, path); + + var response = await _httpClient.GetAsync(wellKnownEndpoint, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { continue; @@ -258,15 +260,28 @@ private async Task PerformOAuthAuthorizationAsync( using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); var metadata = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata, cancellationToken).ConfigureAwait(false); - if (metadata != null) + if (metadata is null) + { + continue; + } + + if (metadata.AuthorizationEndpoint is null) { - metadata.ResponseTypesSupported ??= ["code"]; - metadata.GrantTypesSupported ??= ["authorization_code", "refresh_token"]; - metadata.TokenEndpointAuthMethodsSupported ??= ["client_secret_post"]; - metadata.CodeChallengeMethodsSupported ??= ["S256"]; + ThrowFailedToHandleUnauthorizedResponse($"No authorization_endpoint was provided via '{wellKnownEndpoint}'."); + } - return metadata; + if (metadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttp && + metadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttps) + { + ThrowFailedToHandleUnauthorizedResponse($"AuthorizationEndpoint must use HTTP or HTTPS. '{metadata.AuthorizationEndpoint}' does not meet this requirement."); } + + metadata.ResponseTypesSupported ??= ["code"]; + metadata.GrantTypesSupported ??= ["authorization_code", "refresh_token"]; + metadata.TokenEndpointAuthMethodsSupported ??= ["client_secret_post"]; + metadata.CodeChallengeMethodsSupported ??= ["S256"]; + + return metadata; } catch (Exception ex) { @@ -274,7 +289,7 @@ private async Task PerformOAuthAuthorizationAsync( } } - return null; + throw new McpException($"Failed to find .well-known/openid-configuration or .well-known/oauth-authorization-server metadata for authorization server: '{authServerUri}'"); } private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) @@ -320,12 +335,6 @@ private Uri BuildAuthorizationUrl( AuthorizationServerMetadata authServerMetadata, string codeChallenge) { - if (authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttp && - authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttps) - { - throw new ArgumentException("AuthorizationEndpoint must use HTTP or HTTPS.", nameof(authServerMetadata)); - } - var queryParamsDictionary = new Dictionary { ["client_id"] = GetClientIdOrThrow(), @@ -443,8 +452,8 @@ private async Task PerformDynamicClientRegistrationAsync( GrantTypes = ["authorization_code", "refresh_token"], ResponseTypes = ["code"], TokenEndpointAuthMethod = "client_secret_post", - ClientName = _clientName, - ClientUri = _clientUri?.ToString(), + ClientName = _dcrClientName, + ClientUri = _dcrClientUri?.ToString(), Scope = _scopes is not null ? string.Join(" ", _scopes) : null }; @@ -456,6 +465,11 @@ private async Task PerformDynamicClientRegistrationAsync( Content = requestContent }; + if (!string.IsNullOrEmpty(_dcrInitialAccessToken)) + { + request.Headers.Authorization = new AuthenticationHeaderValue(BearerScheme, _dcrInitialAccessToken); + } + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); if (!httpResponse.IsSuccessStatusCode) @@ -483,6 +497,11 @@ private async Task PerformDynamicClientRegistrationAsync( } LogDynamicClientRegistrationSuccessful(_clientId!); + + if (_dcrResponseDelegate is not null) + { + await _dcrResponseDelegate(registrationResponse, cancellationToken).ConfigureAwait(false); + } } /// diff --git a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationOptions.cs b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationOptions.cs new file mode 100644 index 000000000..c7337122e --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationOptions.cs @@ -0,0 +1,49 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Provides configuration options for the related to dynamic client registration (RFC 7591). +/// +public sealed class DynamicClientRegistrationOptions +{ + /// + /// Gets or sets the client name to use during dynamic client registration. + /// + /// + /// This is a human-readable name for the client that may be displayed to users during authorization. + /// + public string? ClientName { get; set; } + + /// + /// Gets or sets the client URI to use during dynamic client registration. + /// + /// + /// This should be a URL pointing to the client's home page or information page. + /// + public Uri? ClientUri { get; set; } + + /// + /// Gets or sets the initial access token to use during dynamic client registration. + /// + /// + /// + /// This token is used to authenticate the client during the registration process. + /// + /// + /// This is required if the authorization server does not allow anonymous client registration. + /// + /// + public string? InitialAccessToken { get; set; } + + /// + /// Gets or sets the delegate used for handling the dynamic client registration response. + /// + /// + /// + /// This delegate is responsible for processing the response from the dynamic client registration endpoint. + /// + /// + /// The implementation should save the client credentials securely for future use. + /// + /// + public Func? ResponseDelegate { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs index dcd51d68a..1dfe12294 100644 --- a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs +++ b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a client registration response for OAuth 2.0 Dynamic Client Registration (RFC 7591). /// -internal sealed class DynamicClientRegistrationResponse +public sealed class DynamicClientRegistrationResponse { /// /// Gets or sets the client identifier. diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 06f2e0bfb..2e49babcf 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -12,14 +12,14 @@ namespace ModelContextProtocol.Client; /// internal sealed partial class AutoDetectingClientSessionTransport : ITransport { - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly McpHttpClient _httpClient; private readonly ILoggerFactory? _loggerFactory; private readonly ILogger _logger; private readonly string _name; private readonly Channel _messageChannel; - public AutoDetectingClientSessionTransport(string endpointName, SseClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) + public AutoDetectingClientSessionTransport(string endpointName, HttpClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs similarity index 86% rename from src/ModelContextProtocol.Core/Client/SseClientTransport.cs rename to src/ModelContextProtocol.Core/Client/HttpClientTransport.cs index b31c3479b..322b9175e 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs @@ -13,26 +13,26 @@ namespace ModelContextProtocol.Client; /// Unlike the , this transport connects to an existing server /// rather than launching a new process. /// -public sealed class SseClientTransport : IClientTransport, IAsyncDisposable +public sealed class HttpClientTransport : IClientTransport, IAsyncDisposable { - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly McpHttpClient _mcpHttpClient; private readonly ILoggerFactory? _loggerFactory; private readonly HttpClient? _ownedHttpClient; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Configuration options for the transport. /// Logger factory for creating loggers used for diagnostic output during transport operations. - public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFactory? loggerFactory = null) + public HttpClientTransport(HttpClientTransportOptions transportOptions, ILoggerFactory? loggerFactory = null) : this(transportOptions, new HttpClient(), loggerFactory, ownsHttpClient: true) { } /// - /// Initializes a new instance of the class with a provided HTTP client. + /// Initializes a new instance of the class with a provided HTTP client. /// /// Configuration options for the transport. /// The HTTP client instance used for requests. @@ -41,7 +41,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFac /// to dispose of when the transport is disposed; /// if the caller is retaining ownership of the 's lifetime. /// - public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false) + public HttpClientTransport(HttpClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs similarity index 96% rename from src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs rename to src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index 4097844cf..94b95eecb 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -3,9 +3,9 @@ namespace ModelContextProtocol.Client; /// -/// Provides options for configuring instances. +/// Provides options for configuring instances. /// -public sealed class SseClientTransportOptions +public sealed class HttpClientTransportOptions { /// /// Gets or sets the base address of the server for SSE connections. diff --git a/src/ModelContextProtocol.Core/Client/IClientTransport.cs b/src/ModelContextProtocol.Core/Client/IClientTransport.cs index 525178957..2201e9b4f 100644 --- a/src/ModelContextProtocol.Core/Client/IClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/IClientTransport.cs @@ -11,7 +11,7 @@ namespace ModelContextProtocol.Client; /// and servers, allowing different transport protocols to be used interchangeably. /// /// -/// When creating an , is typically used, and is +/// When creating an , is typically used, and is /// provided with the based on expected server configuration. /// /// @@ -39,7 +39,7 @@ public interface IClientTransport /// the transport session as well. /// /// - /// This method is used by to initialize the connection. + /// This method is used by to initialize the connection. /// /// /// The transport connection could not be established. diff --git a/src/ModelContextProtocol.Core/Client/IMcpClient.cs b/src/ModelContextProtocol.Core/Client/IMcpClient.cs index 68a92a2d9..141add86a 100644 --- a/src/ModelContextProtocol.Core/Client/IMcpClient.cs +++ b/src/ModelContextProtocol.Core/Client/IMcpClient.cs @@ -1,10 +1,13 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; +using System.ComponentModel; namespace ModelContextProtocol.Client; /// /// Represents an instance of a Model Context Protocol (MCP) client that connects to and communicates with an MCP server. /// +[Obsolete($"Use {nameof(McpClient)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 +[EditorBrowsable(EditorBrowsableState.Never)] public interface IMcpClient : IMcpEndpoint { /// @@ -44,4 +47,4 @@ public interface IMcpClient : IMcpEndpoint /// /// string? ServerInstructions { get; } -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs new file mode 100644 index 000000000..5550e786e --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs @@ -0,0 +1,713 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Runtime.CompilerServices; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// +/// Represents an instance of a Model Context Protocol (MCP) client session that connects to and communicates with an MCP server. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpClient : McpSession, IMcpClient +#pragma warning restore CS0618 // Type or member is obsolete +{ + /// Creates an , connecting it to the specified server. + /// The transport instance used to communicate with the server. + /// + /// A client configuration object which specifies client capabilities and protocol version. + /// If , details based on the current process will be employed. + /// + /// A logger factory for creating loggers for clients. + /// The to monitor for cancellation requests. The default is . + /// An that's connected to the specified server. + /// is . + /// is . + public static async Task CreateAsync( + IClientTransport clientTransport, + McpClientOptions? clientOptions = null, + ILoggerFactory? loggerFactory = null, + CancellationToken cancellationToken = default) + { + Throw.IfNull(clientTransport); + + var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + var endpointName = clientTransport.Name; + + var clientSession = new McpClientImpl(transport, endpointName, clientOptions, loggerFactory); + try + { + await clientSession.ConnectAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + await clientSession.DisposeAsync().ConfigureAwait(false); + throw; + } + + return clientSession; + } + + /// + /// Sends a ping request to verify server connectivity. + /// + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the ping is successful. + /// Thrown when the server cannot be reached or returns an error response. + public Task PingAsync(CancellationToken cancellationToken = default) + { + var opts = McpJsonUtilities.DefaultOptions; + opts.MakeReadOnly(); + return SendRequestAsync( + RequestMethods.Ping, + parameters: null, + serializerOptions: opts, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Retrieves a list of available tools from the server. + /// + /// The serializer options governing tool parameter serialization. If null, the default options will be used. + /// The to monitor for cancellation requests. The default is . + /// A list of all available tools as instances. + public async ValueTask> ListToolsAsync( + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + List? tools = null; + string? cursor = null; + do + { + var toolResults = await SendRequestAsync( + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + tools ??= new List(toolResults.Tools.Count); + foreach (var tool in toolResults.Tools) + { + tools.Add(new McpClientTool(this, tool, serializerOptions)); + } + + cursor = toolResults.NextCursor; + } + while (cursor is not null); + + return tools; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available tools from the server. + /// + /// The serializer options governing tool parameter serialization. If null, the default options will be used. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available tools as instances. + public async IAsyncEnumerable EnumerateToolsAsync( + JsonSerializerOptions? serializerOptions = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + string? cursor = null; + do + { + var toolResults = await SendRequestAsync( + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var tool in toolResults.Tools) + { + yield return new McpClientTool(this, tool, serializerOptions); + } + + cursor = toolResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a list of available prompts from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available prompts as instances. + public async ValueTask> ListPromptsAsync( + CancellationToken cancellationToken = default) + { + List? prompts = null; + string? cursor = null; + do + { + var promptResults = await SendRequestAsync( + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + prompts ??= new List(promptResults.Prompts.Count); + foreach (var prompt in promptResults.Prompts) + { + prompts.Add(new McpClientPrompt(this, prompt)); + } + + cursor = promptResults.NextCursor; + } + while (cursor is not null); + + return prompts; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available prompts from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available prompts as instances. + public async IAsyncEnumerable EnumeratePromptsAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var promptResults = await SendRequestAsync( + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var prompt in promptResults.Prompts) + { + yield return new(this, prompt); + } + + cursor = promptResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a specific prompt from the MCP server. + /// + /// The name of the prompt to retrieve. + /// Optional arguments for the prompt. Keys are parameter names, and values are the argument values. + /// The serialization options governing argument serialization. + /// The to monitor for cancellation requests. The default is . + /// A task containing the prompt's result with content and messages. + public ValueTask GetPromptAsync( + string name, + IReadOnlyDictionary? arguments = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(name); + + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + return SendRequestAsync( + RequestMethods.PromptsGet, + new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult, + cancellationToken: cancellationToken); + } + + /// + /// Retrieves a list of available resource templates from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available resource templates as instances. + public async ValueTask> ListResourceTemplatesAsync( + CancellationToken cancellationToken = default) + { + List? resourceTemplates = null; + + string? cursor = null; + do + { + var templateResults = await SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + resourceTemplates ??= new List(templateResults.ResourceTemplates.Count); + foreach (var template in templateResults.ResourceTemplates) + { + resourceTemplates.Add(new McpClientResourceTemplate(this, template)); + } + + cursor = templateResults.NextCursor; + } + while (cursor is not null); + + return resourceTemplates; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available resource templates as instances. + public async IAsyncEnumerable EnumerateResourceTemplatesAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var templateResults = await SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var templateResult in templateResults.ResourceTemplates) + { + yield return new McpClientResourceTemplate(this, templateResult); + } + + cursor = templateResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a list of available resources from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available resources as instances. + public async ValueTask> ListResourcesAsync( + CancellationToken cancellationToken = default) + { + List? resources = null; + + string? cursor = null; + do + { + var resourceResults = await SendRequestAsync( + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + resources ??= new List(resourceResults.Resources.Count); + foreach (var resource in resourceResults.Resources) + { + resources.Add(new McpClientResource(this, resource)); + } + + cursor = resourceResults.NextCursor; + } + while (cursor is not null); + + return resources; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available resources from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available resources as instances. + public async IAsyncEnumerable EnumerateResourcesAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var resourceResults = await SendRequestAsync( + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var resource in resourceResults.Resources) + { + yield return new McpClientResource(this, resource); + } + + cursor = resourceResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Reads a resource from the server. + /// + /// The uri of the resource. + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesRead, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); + } + + /// + /// Reads a resource from the server. + /// + /// The uri of the resource. + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return ReadResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Reads a resource from the server. + /// + /// The uri template of the resource. + /// Arguments to use to format . + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uriTemplate); + Throw.IfNull(arguments); + + return SendRequestAsync( + RequestMethods.ResourcesRead, + new() { Uri = UriTemplate.FormatUri(uriTemplate, arguments) }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests completion suggestions for a prompt argument or resource reference. + /// + /// The reference object specifying the type and optional URI or name. + /// The name of the argument for which completions are requested. + /// The current value of the argument, used to filter relevant completions. + /// The to monitor for cancellation requests. The default is . + /// A containing completion suggestions. + public ValueTask CompleteAsync(Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) + { + Throw.IfNull(reference); + Throw.IfNullOrWhiteSpace(argumentName); + + return SendRequestAsync( + RequestMethods.CompletionComplete, + new() + { + Ref = reference, + Argument = new Argument { Name = argumentName, Value = argumentValue } + }, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult, + cancellationToken: cancellationToken); + } + + /// + /// Subscribes to a resource on the server to receive notifications when it changes. + /// + /// The URI of the resource to which to subscribe. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task SubscribeToResourceAsync(string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesSubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Subscribes to a resource on the server to receive notifications when it changes. + /// + /// The URI of the resource to which to subscribe. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task SubscribeToResourceAsync(Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return SubscribeToResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. + /// + /// The URI of the resource to unsubscribe from. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task UnsubscribeFromResourceAsync(string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesUnsubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. + /// + /// The URI of the resource to unsubscribe from. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task UnsubscribeFromResourceAsync(Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return UnsubscribeFromResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Invokes a tool on the server. + /// + /// The name of the tool to call on the server.. + /// An optional dictionary of arguments to pass to the tool. + /// Optional progress reporter for server notifications. + /// JSON serializer options. + /// A cancellation token. + /// The from the tool execution. + public ValueTask CallToolAsync( + string toolName, + IReadOnlyDictionary? arguments = null, + IProgress? progress = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + Throw.IfNull(toolName); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + if (progress is not null) + { + return SendRequestWithProgressAsync(toolName, arguments, progress, serializerOptions, cancellationToken); + } + + return SendRequestAsync( + RequestMethods.ToolsCall, + new() + { + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult, + cancellationToken: cancellationToken); + + async ValueTask SendRequestWithProgressAsync( + string toolName, + IReadOnlyDictionary? arguments, + IProgress progress, + JsonSerializerOptions serializerOptions, + CancellationToken cancellationToken) + { + ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); + + await using var _ = RegisterNotificationHandler(NotificationMethods.ProgressNotification, + (notification, cancellationToken) => + { + if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && + pn.ProgressToken == progressToken) + { + progress.Report(pn.Progress); + } + + return default; + }).ConfigureAwait(false); + + return await SendRequestAsync( + RequestMethods.ToolsCall, + new() + { + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + ProgressToken = progressToken, + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Converts the contents of a into a pair of + /// and instances to use + /// as inputs into a operation. + /// + /// + /// The created pair of messages and options. + /// is . + internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( + CreateMessageRequestParams requestParams) + { + Throw.IfNull(requestParams); + + ChatOptions? options = null; + + if (requestParams.MaxTokens is int maxTokens) + { + (options ??= new()).MaxOutputTokens = maxTokens; + } + + if (requestParams.Temperature is float temperature) + { + (options ??= new()).Temperature = temperature; + } + + if (requestParams.StopSequences is { } stopSequences) + { + (options ??= new()).StopSequences = stopSequences.ToArray(); + } + + List messages = + (from sm in requestParams.Messages + let aiContent = sm.Content.ToAIContent() + where aiContent is not null + select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) + .ToList(); + + return (messages, options); + } + + /// Converts the contents of a into a . + /// The whose contents should be extracted. + /// The created . + /// is . + internal static CreateMessageResult ToCreateMessageResult(ChatResponse chatResponse) + { + Throw.IfNull(chatResponse); + + // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports + // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one + // in any of the response messages, or we'll use all the text from them concatenated, otherwise. + + ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); + + ContentBlock? content = null; + if (lastMessage is not null) + { + foreach (var lmc in lastMessage.Contents) + { + if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) + { + content = dc.ToContent(); + } + } + } + + return new() + { + Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, + Model = chatResponse.ModelId ?? "unknown", + Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, + StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", + }; + } + + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => + { + Throw.IfNull(requestParams); + + var (messages, options) = ToChatClientArguments(requestParams); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return ToCreateMessageResult(updates.ToChatResponse()); + }; + } + + /// + /// Sets the logging level for the server to control which log messages are sent to the client. + /// + /// The minimum severity level of log messages to receive from the server. + /// The to monitor for cancellation requests. The default is . + /// A task representing the asynchronous operation. + public Task SetLoggingLevel(LoggingLevel level, CancellationToken cancellationToken = default) + { + return SendRequestAsync( + RequestMethods.LoggingSetLevel, + new() { Level = level }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Sets the logging level for the server to control which log messages are sent to the client. + /// + /// The minimum severity level of log messages to receive from the server. + /// The to monitor for cancellation requests. The default is . + /// A task representing the asynchronous operation. + public Task SetLoggingLevel(LogLevel level, CancellationToken cancellationToken = default) => + SetLoggingLevel(McpServerImpl.ToLoggingLevel(level), cancellationToken); + + /// Convers a dictionary with values to a dictionary with values. + private static Dictionary? ToArgumentsDictionary( + IReadOnlyDictionary? arguments, JsonSerializerOptions options) + { + var typeInfo = options.GetTypeInfo(); + + Dictionary? result = null; + if (arguments is not null) + { + result = new(arguments.Count); + foreach (var kvp in arguments) + { + result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); + } + } + + return result; + } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index dd8c7fe09..c4abe33b7 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -1,236 +1,49 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol; -using System.Text.Json; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; -/// -internal sealed partial class McpClient : McpEndpoint, IMcpClient +/// +/// Represents an instance of a Model Context Protocol (MCP) client session that connects to and communicates with an MCP server. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpClient : McpSession, IMcpClient +#pragma warning restore CS0618 // Type or member is obsolete { - private static Implementation DefaultImplementation { get; } = new() - { - Name = DefaultAssemblyName.Name ?? nameof(McpClient), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; - - private readonly IClientTransport _clientTransport; - private readonly McpClientOptions _options; - - private ITransport? _sessionTransport; - private CancellationTokenSource? _connectCts; - - private ServerCapabilities? _serverCapabilities; - private Implementation? _serverInfo; - private string? _serverInstructions; - /// - /// Initializes a new instance of the class. + /// Gets the capabilities supported by the connected server. /// - /// The transport to use for communication with the server. - /// Options for the client, defining protocol version and capabilities. - /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions? options, ILoggerFactory? loggerFactory) - : base(loggerFactory) - { - options ??= new(); - - _clientTransport = clientTransport; - _options = options; - - EndpointName = clientTransport.Name; - - if (options.Capabilities is { } capabilities) - { - if (capabilities.NotificationHandlers is { } notificationHandlers) - { - NotificationHandlers.RegisterRange(notificationHandlers); - } - - if (capabilities.Sampling is { } samplingCapability) - { - if (samplingCapability.SamplingHandler is not { } samplingHandler) - { - throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); - } - - if (capabilities.Roots is { } rootsCapability) - { - if (rootsCapability.RootsHandler is not { } rootsHandler) - { - throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.RootsList, - (request, _, cancellationToken) => rootsHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult); - } - - if (capabilities.Elicitation is { } elicitationCapability) - { - if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) - { - throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.ElicitationCreate, - (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult); - } - } - } - - /// - public string? SessionId - { - get - { - if (_sessionTransport is null) - { - throw new InvalidOperationException("Must have already initialized a session when invoking this property."); - } - - return _sessionTransport.SessionId; - } - } - - /// - public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); - - /// - public Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); - - /// - public string? ServerInstructions => _serverInstructions; - - /// - public override string EndpointName { get; } + /// The client is not connected. + public abstract ServerCapabilities ServerCapabilities { get; } /// - /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// Gets the implementation information of the connected server. /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) - { - _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cancellationToken = _connectCts.Token; - - try - { - // Connect transport - _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - InitializeSession(_sessionTransport); - // We don't want the ConnectAsync token to cancel the session after we've successfully connected. - // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); - - // Perform initialization sequence - using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - initializationCts.CancelAfter(_options.InitializationTimeout); - - try - { - // Send initialize request - string requestProtocol = _options.ProtocolVersion ?? McpSession.LatestProtocolVersion; - var initializeResponse = await this.SendRequestAsync( - RequestMethods.Initialize, - new InitializeRequestParams - { - ProtocolVersion = requestProtocol, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo ?? DefaultImplementation, - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult, - cancellationToken: initializationCts.Token).ConfigureAwait(false); - - // Store server information - if (_logger.IsEnabled(LogLevel.Information)) - { - LogServerCapabilitiesReceived(EndpointName, - capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), - serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); - } - - _serverCapabilities = initializeResponse.Capabilities; - _serverInfo = initializeResponse.ServerInfo; - _serverInstructions = initializeResponse.Instructions; - - // Validate protocol version - bool isResponseProtocolValid = - _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : - McpSession.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); - if (!isResponseProtocolValid) - { - LogServerProtocolVersionMismatch(EndpointName, requestProtocol, initializeResponse.ProtocolVersion); - throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); - } - - // Send initialized notification - await this.SendNotificationAsync( - NotificationMethods.InitializedNotification, - new InitializedNotificationParams(), - McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, - cancellationToken: initializationCts.Token).ConfigureAwait(false); + /// + /// + /// This property provides identification details about the connected server, including its name and version. + /// It is populated during the initialization handshake and is available after a successful connection. + /// + /// + /// This information can be useful for logging, debugging, compatibility checks, and displaying server + /// information to users. + /// + /// + /// The client is not connected. + public abstract Implementation ServerInfo { get; } - } - catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) - { - LogClientInitializationTimeout(EndpointName); - throw new TimeoutException("Initialization timed out", oce); - } - } - catch (Exception e) - { - LogClientInitializationError(EndpointName, e); - await DisposeAsync().ConfigureAwait(false); - throw; - } - } - - /// - public override async ValueTask DisposeUnsynchronizedAsync() - { - try - { - if (_connectCts is not null) - { - await _connectCts.CancelAsync().ConfigureAwait(false); - _connectCts.Dispose(); - } - - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - finally - { - if (_sessionTransport is not null) - { - await _sessionTransport.DisposeAsync().ConfigureAwait(false); - } - } - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] - private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] - private partial void LogClientInitializationError(string endpointName, Exception exception); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] - private partial void LogClientInitializationTimeout(string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] - private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); -} \ No newline at end of file + /// + /// Gets any instructions describing how to use the connected server and its features. + /// + /// + /// + /// This property contains instructions provided by the server during initialization that explain + /// how to effectively use its capabilities. These instructions can include details about available + /// tools, expected input formats, limitations, or any other helpful information. + /// + /// + /// This can be used by clients to improve an LLM's understanding of available tools, prompts, and resources. + /// It can be thought of like a "hint" to the model and may be added to a system prompt. + /// + /// + public abstract string? ServerInstructions { get; } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs index 60a9c3a64..f0cd3c4f9 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs @@ -1,14 +1,14 @@ using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text.Json; namespace ModelContextProtocol.Client; /// -/// Provides extension methods for interacting with an . +/// Provides extension methods for interacting with an . /// /// /// @@ -19,6 +19,53 @@ namespace ModelContextProtocol.Client; /// public static class McpClientExtensions { + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// + /// + /// This method creates a function that converts MCP message requests into chat client calls, enabling + /// an MCP client to generate text or other content using an actual AI model via the provided chat client. + /// + /// + /// The handler can process text messages, image messages, and resource messages as defined in the + /// Model Context Protocol. + /// + /// + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + this IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => + { + Throw.IfNull(requestParams); + + var (messages, options) = requestParams.ToChatClientArguments(); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return updates.ToChatResponse().ToCreateMessageResult(); + }; + } + /// /// Sends a ping request to verify server connectivity. /// @@ -38,17 +85,10 @@ public static class McpClientExtensions /// /// is . /// Thrown when the server cannot be reached or returns an error response. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.PingAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task PingAsync(this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - return client.SendRequestAsync( - RequestMethods.Ping, - parameters: null, - McpJsonUtilities.JsonContext.Default.Object!, - McpJsonUtilities.JsonContext.Default.Object, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).PingAsync(cancellationToken); /// /// Retrieves a list of available tools from the server. @@ -89,39 +129,13 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat /// /// /// is . - public static async ValueTask> ListToolsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListToolsAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static ValueTask> ListToolsAsync( this IMcpClient client, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - List? tools = null; - string? cursor = null; - do - { - var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - tools ??= new List(toolResults.Tools.Count); - foreach (var tool in toolResults.Tools) - { - tools.Add(new McpClientTool(client, tool, serializerOptions)); - } - - cursor = toolResults.NextCursor; - } - while (cursor is not null); - - return tools; - } + => AsClientOrThrow(client).ListToolsAsync(serializerOptions, cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available tools from the server. @@ -155,35 +169,13 @@ public static async ValueTask> ListToolsAsync( /// /// /// is . - public static async IAsyncEnumerable EnumerateToolsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateToolsAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static IAsyncEnumerable EnumerateToolsAsync( this IMcpClient client, JsonSerializerOptions? serializerOptions = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - string? cursor = null; - do - { - var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var tool in toolResults.Tools) - { - yield return new McpClientTool(client, tool, serializerOptions); - } - - cursor = toolResults.NextCursor; - } - while (cursor is not null); - } + CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateToolsAsync(serializerOptions, cancellationToken); /// /// Retrieves a list of available prompts from the server. @@ -202,34 +194,11 @@ public static async IAsyncEnumerable EnumerateToolsAsync( /// /// /// is . - public static async ValueTask> ListPromptsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListPromptsAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static ValueTask> ListPromptsAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? prompts = null; - string? cursor = null; - do - { - var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - prompts ??= new List(promptResults.Prompts.Count); - foreach (var prompt in promptResults.Prompts) - { - prompts.Add(new McpClientPrompt(client, prompt)); - } - - cursor = promptResults.NextCursor; - } - while (cursor is not null); - - return prompts; - } + => AsClientOrThrow(client).ListPromptsAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available prompts from the server. @@ -258,30 +227,11 @@ public static async ValueTask> ListPromptsAsync( /// /// /// is . - public static async IAsyncEnumerable EnumeratePromptsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var prompt in promptResults.Prompts) - { - yield return new(client, prompt); - } - - cursor = promptResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumeratePromptsAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static IAsyncEnumerable EnumeratePromptsAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumeratePromptsAsync(cancellationToken); /// /// Retrieves a specific prompt from the MCP server. @@ -308,26 +258,15 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( /// /// Thrown when the prompt does not exist, when required arguments are missing, or when the server encounters an error processing the prompt. /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.GetPromptAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask GetPromptAsync( this IMcpClient client, string name, IReadOnlyDictionary? arguments = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(name); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - return client.SendRequestAsync( - RequestMethods.PromptsGet, - new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, - McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, - McpJsonUtilities.JsonContext.Default.GetPromptResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).GetPromptAsync(name, arguments, serializerOptions, cancellationToken); /// /// Retrieves a list of available resource templates from the server. @@ -346,35 +285,11 @@ public static ValueTask GetPromptAsync( /// /// /// is . - public static async ValueTask> ListResourceTemplatesAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListResourceTemplatesAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static ValueTask> ListResourceTemplatesAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? resourceTemplates = null; - - string? cursor = null; - do - { - var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - resourceTemplates ??= new List(templateResults.ResourceTemplates.Count); - foreach (var template in templateResults.ResourceTemplates) - { - resourceTemplates.Add(new McpClientResourceTemplate(client, template)); - } - - cursor = templateResults.NextCursor; - } - while (cursor is not null); - - return resourceTemplates; - } + => AsClientOrThrow(client).ListResourceTemplatesAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. @@ -403,30 +318,11 @@ public static async ValueTask> ListResourceTemp /// /// /// is . - public static async IAsyncEnumerable EnumerateResourceTemplatesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var templateResult in templateResults.ResourceTemplates) - { - yield return new McpClientResourceTemplate(client, templateResult); - } - - cursor = templateResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateResourceTemplatesAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static IAsyncEnumerable EnumerateResourceTemplatesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateResourceTemplatesAsync(cancellationToken); /// /// Retrieves a list of available resources from the server. @@ -457,35 +353,11 @@ public static async IAsyncEnumerable EnumerateResourc /// /// /// is . - public static async ValueTask> ListResourcesAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListResourcesAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static ValueTask> ListResourcesAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? resources = null; - - string? cursor = null; - do - { - var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - resources ??= new List(resourceResults.Resources.Count); - foreach (var resource in resourceResults.Resources) - { - resources.Add(new McpClientResource(client, resource)); - } - - cursor = resourceResults.NextCursor; - } - while (cursor is not null); - - return resources; - } + => AsClientOrThrow(client).ListResourcesAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available resources from the server. @@ -514,30 +386,11 @@ public static async ValueTask> ListResourcesAsync( /// /// /// is . - public static async IAsyncEnumerable EnumerateResourcesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var resource in resourceResults.Resources) - { - yield return new McpClientResource(client, resource); - } - - cursor = resourceResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateResourcesAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static IAsyncEnumerable EnumerateResourcesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateResourcesAsync(cancellationToken); /// /// Reads a resource from the server. @@ -548,19 +401,11 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask ReadResourceAsync( this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uri, cancellationToken); /// /// Reads a resource from the server. @@ -570,14 +415,11 @@ public static ValueTask ReadResourceAsync( /// The to monitor for cancellation requests. The default is . /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask ReadResourceAsync( this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return ReadResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uri, cancellationToken); /// /// Reads a resource from the server. @@ -589,20 +431,11 @@ public static ValueTask ReadResourceAsync( /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask ReadResourceAsync( this IMcpClient client, string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uriTemplate); - Throw.IfNull(arguments); - - return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new() { Uri = UriTemplate.FormatUri(uriTemplate, arguments) }, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uriTemplate, arguments, cancellationToken); /// /// Requests completion suggestions for a prompt argument or resource reference. @@ -633,23 +466,10 @@ public static ValueTask ReadResourceAsync( /// is . /// is empty or composed entirely of whitespace. /// The server returned an error response. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CompleteAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask CompleteAsync(this IMcpClient client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(reference); - Throw.IfNullOrWhiteSpace(argumentName); - - return client.SendRequestAsync( - RequestMethods.CompletionComplete, - new() - { - Ref = reference, - Argument = new Argument { Name = argumentName, Value = argumentValue } - }, - McpJsonUtilities.JsonContext.Default.CompleteRequestParams, - McpJsonUtilities.JsonContext.Default.CompleteResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).CompleteAsync(reference, argumentName, argumentValue, cancellationToken); /// /// Subscribes to a resource on the server to receive notifications when it changes. @@ -676,18 +496,10 @@ public static ValueTask CompleteAsync(this IMcpClient client, Re /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.SubscribeToResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesSubscribe, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).SubscribeToResourceAsync(uri, cancellationToken); /// /// Subscribes to a resource on the server to receive notifications when it changes. @@ -713,13 +525,10 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, /// /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.SubscribeToResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return SubscribeToResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).SubscribeToResourceAsync(uri, cancellationToken); /// /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. @@ -745,18 +554,10 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, Can /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.UnsubscribeFromResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesUnsubscribe, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).UnsubscribeFromResourceAsync(uri, cancellationToken); /// /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. @@ -781,13 +582,10 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.UnsubscribeFromResourceAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return UnsubscribeFromResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).UnsubscribeFromResourceAsync(uri, cancellationToken); /// /// Invokes a tool on the server. @@ -824,6 +622,8 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, /// }); /// /// + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CallToolAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask CallToolAsync( this IMcpClient client, string toolName, @@ -831,62 +631,28 @@ public static ValueTask CallToolAsync( IProgress? progress = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(toolName); - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); + => AsClientOrThrow(client).CallToolAsync(toolName, arguments, progress, serializerOptions, cancellationToken); - if (progress is not null) + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpClient AsClientOrThrow(IMcpClient client, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete + { + if (client is not McpClient mcpClient) { - return SendRequestWithProgressAsync(client, toolName, arguments, progress, serializerOptions, cancellationToken); + ThrowInvalidEndpointType(memberName); } - return client.SendRequestAsync( - RequestMethods.ToolsCall, - new() - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken); - - static async ValueTask SendRequestWithProgressAsync( - IMcpClient client, - string toolName, - IReadOnlyDictionary? arguments, - IProgress progress, - JsonSerializerOptions serializerOptions, - CancellationToken cancellationToken) - { - ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); - - await using var _ = client.RegisterNotificationHandler(NotificationMethods.ProgressNotification, - (notification, cancellationToken) => - { - if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && - pn.ProgressToken == progressToken) - { - progress.Report(pn.Progress); - } - - return default; - }).ConfigureAwait(false); + return mcpClient; - return await client.SendRequestAsync( - RequestMethods.ToolsCall, - new() - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - ProgressToken = progressToken, - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - } + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpClient)}' are supported. " + + $"Prefer using '{nameof(McpClient)}.{memberName}' instead, as " + + $"'{nameof(McpClientExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } /// @@ -963,132 +729,4 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", }; } - - /// - /// Creates a sampling handler for use with that will - /// satisfy sampling requests using the specified . - /// - /// The with which to satisfy sampling requests. - /// The created handler delegate that can be assigned to . - /// - /// - /// This method creates a function that converts MCP message requests into chat client calls, enabling - /// an MCP client to generate text or other content using an actual AI model via the provided chat client. - /// - /// - /// The handler can process text messages, image messages, and resource messages as defined in the - /// Model Context Protocol. - /// - /// - /// is . - public static Func, CancellationToken, ValueTask> CreateSamplingHandler( - this IChatClient chatClient) - { - Throw.IfNull(chatClient); - - return async (requestParams, progress, cancellationToken) => - { - Throw.IfNull(requestParams); - - var (messages, options) = requestParams.ToChatClientArguments(); - var progressToken = requestParams.ProgressToken; - - List updates = []; - await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) - { - updates.Add(update); - - if (progressToken is not null) - { - progress.Report(new() - { - Progress = updates.Count, - }); - } - } - - return updates.ToChatResponse().ToCreateMessageResult(); - }; - } - - /// - /// Sets the logging level for the server to control which log messages are sent to the client. - /// - /// The client instance used to communicate with the MCP server. - /// The minimum severity level of log messages to receive from the server. - /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. - /// - /// - /// After this request is processed, the server will send log messages at or above the specified - /// logging level as notifications to the client. For example, if is set, - /// the client will receive , , - /// , , and - /// level messages. - /// - /// - /// To receive all log messages, set the level to . - /// - /// - /// Log messages are delivered as notifications to the client and can be captured by registering - /// appropriate event handlers with the client implementation, such as with . - /// - /// - /// is . - public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - return client.SendRequestAsync( - RequestMethods.LoggingSetLevel, - new() { Level = level }, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } - - /// - /// Sets the logging level for the server to control which log messages are sent to the client. - /// - /// The client instance used to communicate with the MCP server. - /// The minimum severity level of log messages to receive from the server. - /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. - /// - /// - /// After this request is processed, the server will send log messages at or above the specified - /// logging level as notifications to the client. For example, if is set, - /// the client will receive , , - /// and level messages. - /// - /// - /// To receive all log messages, set the level to . - /// - /// - /// Log messages are delivered as notifications to the client and can be captured by registering - /// appropriate event handlers with the client implementation, such as with . - /// - /// - /// is . - public static Task SetLoggingLevel(this IMcpClient client, LogLevel level, CancellationToken cancellationToken = default) => - SetLoggingLevel(client, McpServer.ToLoggingLevel(level), cancellationToken); - - /// Convers a dictionary with values to a dictionary with values. - private static Dictionary? ToArgumentsDictionary( - IReadOnlyDictionary? arguments, JsonSerializerOptions options) - { - var typeInfo = options.GetTypeInfo(); - - Dictionary? result = null; - if (arguments is not null) - { - result = new(arguments.Count); - foreach (var kvp in arguments) - { - result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); - } - } - - return result; - } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs index 30b3a9476..805787256 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; +using System.ComponentModel; namespace ModelContextProtocol.Client; @@ -10,6 +11,8 @@ namespace ModelContextProtocol.Client; /// that connect to MCP servers. It handles the creation and connection /// of appropriate implementations through the supplied transport. /// +[Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CreateAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 +[EditorBrowsable(EditorBrowsableState.Never)] public static partial class McpClientFactory { /// Creates an , connecting it to the specified server. @@ -28,27 +31,5 @@ public static async Task CreateAsync( McpClientOptions? clientOptions = null, ILoggerFactory? loggerFactory = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(clientTransport); - - McpClient client = new(clientTransport, clientOptions, loggerFactory); - try - { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - if (loggerFactory?.CreateLogger(typeof(McpClientFactory)) is ILogger logger) - { - logger.LogClientCreated(client.EndpointName); - } - } - catch - { - await client.DisposeAsync().ConfigureAwait(false); - throw; - } - - return client; - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] - private static partial void LogClientCreated(this ILogger logger, string endpointName); + => await McpClient.CreateAsync(clientTransport, clientOptions, loggerFactory, cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs b/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs new file mode 100644 index 000000000..fecb83299 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs @@ -0,0 +1,88 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Client; + +/// +/// Provides a container for handlers used in the creation of an MCP client. +/// +/// +/// +/// This class provides a centralized collection of delegates that implement various capabilities of the Model Context Protocol. +/// +/// +/// Each handler in this class corresponds to a specific client endpoint in the Model Context Protocol and +/// is responsible for processing a particular type of message. The handlers are used to customize +/// the behavior of the MCP server by providing implementations for the various protocol operations. +/// +/// +/// When a server sends a message to the client, the appropriate handler is invoked to process it +/// according to the protocol specification. Which handler is selected +/// is done based on an ordinal, case-sensitive string comparison. +/// +/// +public class McpClientHandlers +{ + /// Gets or sets notification handlers to register with the client. + /// + /// + /// When constructed, the client will enumerate these handlers once, which may contain multiple handlers per notification method key. + /// The client will not re-enumerate the sequence after initialization. + /// + /// + /// Notification handlers allow the client to respond to server-sent notifications for specific methods. + /// Each key in the collection is a notification method name, and each value is a callback that will be invoked + /// when a notification with that method is received. + /// + /// + /// Handlers provided via will be registered with the client for the lifetime of the client. + /// For transient handlers, may be used to register a handler that can + /// then be unregistered by disposing of the returned from the method. + /// + /// + public IEnumerable>>? NotificationHandlers { get; set; } + + /// + /// Gets or sets the handler for requests. + /// + /// + /// This handler is invoked when a client sends a request to retrieve available roots. + /// The handler receives request parameters and should return a containing the collection of available roots. + /// + public Func>? RootsHandler { get; set; } + + /// + /// Gets or sets the handler for processing requests. + /// + /// + /// + /// This handler function is called when an MCP server requests the client to provide additional + /// information during interactions. The client must set this property for the elicitation capability to work. + /// + /// + /// The handler receives message parameters and a cancellation token. + /// It should return a containing the response to the elicitation request. + /// + /// + public Func>? ElicitationHandler { get; set; } + + /// + /// Gets or sets the handler for processing requests. + /// + /// + /// + /// This handler function is called when an MCP server requests the client to generate content + /// using an AI model. The client must set this property for the sampling capability to work. + /// + /// + /// The handler receives message parameters, a progress reporter for updates, and a + /// cancellation token. It should return a containing the + /// generated content. + /// + /// + /// You can create a handler using the extension + /// method with any implementation of . + /// + /// + public Func, CancellationToken, ValueTask>? SamplingHandler { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs new file mode 100644 index 000000000..3a289d13e --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -0,0 +1,252 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// +internal sealed partial class McpClientImpl : McpClient +{ + private static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpClient), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _transport; + private readonly string _endpointName; + private readonly McpClientOptions _options; + private readonly McpSessionHandler _sessionHandler; + private readonly SemaphoreSlim _disposeLock = new(1, 1); + + private CancellationTokenSource? _connectCts; + + private ServerCapabilities? _serverCapabilities; + private Implementation? _serverInfo; + private string? _serverInstructions; + private string? _negotiatedProtocolVersion; + + private bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// The transport to use for communication with the server. + /// The name of the endpoint for logging and debug purposes. + /// Options for the client, defining protocol version and capabilities. + /// The logger factory. + internal McpClientImpl(ITransport transport, string endpointName, McpClientOptions? options, ILoggerFactory? loggerFactory) + { + options ??= new(); + + _transport = transport; + _endpointName = $"Client ({options.ClientInfo?.Name ?? DefaultImplementation.Name} {options.ClientInfo?.Version ?? DefaultImplementation.Version})"; + _options = options; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + var notificationHandlers = new NotificationHandlers(); + var requestHandlers = new RequestHandlers(); + + RegisterHandlers(options, notificationHandlers, requestHandlers); + + _sessionHandler = new McpSessionHandler(isServer: false, transport, endpointName, requestHandlers, notificationHandlers, _logger); + } + + private void RegisterHandlers(McpClientOptions options, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers) + { + McpClientHandlers handlers = options.Handlers; + +#pragma warning disable CS0618 // Type or member is obsolete + var notificationHandlersFromOptions = handlers.NotificationHandlers ?? options.Capabilities?.NotificationHandlers; + var samplingHandler = handlers.SamplingHandler ?? options.Capabilities?.Sampling?.SamplingHandler; + var rootsHandler = handlers.RootsHandler ?? options.Capabilities?.Roots?.RootsHandler; + var elicitationHandler = handlers.ElicitationHandler ?? options.Capabilities?.Elicitation?.ElicitationHandler; +#pragma warning restore CS0618 // Type or member is obsolete + + if (notificationHandlersFromOptions is not null) + { + notificationHandlers.RegisterRange(notificationHandlersFromOptions); + } + + if (samplingHandler is not null) + { + requestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, _, cancellationToken) => samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); + + _options.Capabilities ??= new(); + _options.Capabilities.Sampling ??= new(); + } + + if (rootsHandler is not null) + { + requestHandlers.Set( + RequestMethods.RootsList, + (request, _, cancellationToken) => rootsHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); + + _options.Capabilities ??= new(); + _options.Capabilities.Roots ??= new(); + } + + if (elicitationHandler is not null) + { + requestHandlers.Set( + RequestMethods.ElicitationCreate, + (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult); + + _options.Capabilities ??= new(); + _options.Capabilities.Elicitation ??= new(); + } + } + + /// + public override string? SessionId => _transport.SessionId; + + /// + public override string? NegotiatedProtocolVersion => _negotiatedProtocolVersion; + + /// + public override ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); + + /// + public override Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); + + /// + public override string? ServerInstructions => _serverInstructions; + + /// + /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _connectCts.Token; + + try + { + // We don't want the ConnectAsync token to cancel the message processing loop after we've successfully connected. + // The session handler handles cancelling the loop upon its disposal. + _ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None); + + // Perform initialization sequence + using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + initializationCts.CancelAfter(_options.InitializationTimeout); + + try + { + // Send initialize request + string requestProtocol = _options.ProtocolVersion ?? McpSessionHandler.LatestProtocolVersion; + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams + { + ProtocolVersion = requestProtocol, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo ?? DefaultImplementation, + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + // Store server information + if (_logger.IsEnabled(LogLevel.Information)) + { + LogServerCapabilitiesReceived(_endpointName, + capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), + serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); + } + + _serverCapabilities = initializeResponse.Capabilities; + _serverInfo = initializeResponse.ServerInfo; + _serverInstructions = initializeResponse.Instructions; + + // Validate protocol version + bool isResponseProtocolValid = + _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : + McpSessionHandler.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); + if (!isResponseProtocolValid) + { + LogServerProtocolVersionMismatch(_endpointName, requestProtocol, initializeResponse.ProtocolVersion); + throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); + } + + _negotiatedProtocolVersion = initializeResponse.ProtocolVersion; + + // Send initialized notification + await this.SendNotificationAsync( + NotificationMethods.InitializedNotification, + new InitializedNotificationParams(), + McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + } + catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + LogClientInitializationTimeout(_endpointName); + throw new TimeoutException("Initialization timed out", oce); + } + } + catch (Exception e) + { + LogClientInitializationError(_endpointName, e); + await DisposeAsync().ConfigureAwait(false); + throw; + } + + LogClientConnected(_endpointName); + } + + /// + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public override async ValueTask DisposeAsync() + { + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + + _disposed = true; + + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + await _transport.DisposeAsync().ConfigureAwait(false); + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] + private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] + private partial void LogClientInitializationError(string endpointName, Exception exception); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] + private partial void LogClientInitializationTimeout(string endpointName); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] + private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] + private partial void LogClientConnected(string endpointName); +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs index 76099d0d9..ff71f5899 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs @@ -3,14 +3,16 @@ namespace ModelContextProtocol.Client; /// -/// Provides configuration options for creating instances. +/// Provides configuration options for creating instances. /// /// -/// These options are typically passed to when creating a client. +/// These options are typically passed to when creating a client. /// They define client capabilities, protocol version, and other client-specific settings. /// public sealed class McpClientOptions { + private McpClientHandlers? _handlers; + /// /// Gets or sets information about this client implementation, including its name and version. /// @@ -63,4 +65,17 @@ public sealed class McpClientOptions /// The default value is 60 seconds. /// public TimeSpan InitializationTimeout { get; set; } = TimeSpan.FromSeconds(60); + + /// + /// Gets or sets the container of handlers used by the client for processing protocol messages. + /// + public McpClientHandlers Handlers + { + get => _handlers ??= new(); + set + { + Throw.IfNull(value); + _handlers = value; + } + } } diff --git a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs index 43fc759a0..5a618242f 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs @@ -10,8 +10,8 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a prompt defined on an MCP server. It allows /// retrieving the prompt's content by sending a request to the server with optional arguments. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// /// Each prompt has a name and optionally a description, and it can be invoked with arguments @@ -20,9 +20,9 @@ namespace ModelContextProtocol.Client; /// public sealed class McpClientPrompt { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientPrompt(IMcpClient client, Prompt prompt) + internal McpClientPrompt(McpClient client, Prompt prompt) { _client = client; ProtocolPrompt = prompt; @@ -63,7 +63,7 @@ internal McpClientPrompt(IMcpClient client, Prompt prompt) /// The server will process the request and return a result containing messages or other content. /// /// - /// This is a convenience method that internally calls + /// This is a convenience method that internally calls /// with this prompt's name and arguments. /// /// diff --git a/src/ModelContextProtocol.Core/Client/McpClientResource.cs b/src/ModelContextProtocol.Core/Client/McpClientResource.cs index 06f8aff67..19f11bfdf 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResource.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResource.cs @@ -9,15 +9,15 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a resource defined on an MCP server. It allows /// retrieving the resource's content by sending a request to the server with the resource's URI. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// public sealed class McpClientResource { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientResource(IMcpClient client, Resource resource) + internal McpClientResource(McpClient client, Resource resource) { _client = client; ProtocolResource = resource; @@ -58,7 +58,7 @@ internal McpClientResource(IMcpClient client, Resource resource) /// A containing the resource's result with content and messages. /// /// - /// This is a convenience method that internally calls . + /// This is a convenience method that internally calls . /// /// public ValueTask ReadAsync( diff --git a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs index 4da1bd0c3..033f7cf00 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs @@ -9,15 +9,15 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a resource template defined on an MCP server. It allows /// retrieving the resource template's content by sending a request to the server with the resource's URI. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// public sealed class McpClientResourceTemplate { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientResourceTemplate(IMcpClient client, ResourceTemplate resourceTemplate) + internal McpClientResourceTemplate(McpClient client, ResourceTemplate resourceTemplate) { _client = client; ProtocolResourceTemplate = resourceTemplate; diff --git a/src/ModelContextProtocol.Core/Client/McpClientTool.cs b/src/ModelContextProtocol.Core/Client/McpClientTool.cs index 1810e9c56..c7af513ef 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientTool.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientTool.cs @@ -6,11 +6,11 @@ namespace ModelContextProtocol.Client; /// -/// Provides an that calls a tool via an . +/// Provides an that calls a tool via an . /// /// /// -/// The class encapsulates an along with a description of +/// The class encapsulates an along with a description of /// a tool available via that client, allowing it to be invoked as an . This enables integration /// with AI models that support function calling capabilities. /// @@ -19,8 +19,8 @@ namespace ModelContextProtocol.Client; /// and without changing the underlying tool functionality. /// /// -/// Typically, you would get instances of this class by calling the -/// or extension methods on an instance. +/// Typically, you would get instances of this class by calling the +/// or extension methods on an instance. /// /// public sealed class McpClientTool : AIFunction @@ -32,13 +32,13 @@ public sealed class McpClientTool : AIFunction ["Strict"] = false, // some MCP schemas may not meet "strict" requirements }); - private readonly IMcpClient _client; + private readonly McpClient _client; private readonly string _name; private readonly string _description; private readonly IProgress? _progress; internal McpClientTool( - IMcpClient client, + McpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index aba7bbcfb..60950dfa5 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -15,7 +15,7 @@ namespace ModelContextProtocol.Client; internal sealed partial class SseClientSessionTransport : TransportBase { private readonly McpHttpClient _httpClient; - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; private readonly CancellationTokenSource _connectionCts; @@ -29,7 +29,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase /// public SseClientSessionTransport( string endpointName, - SseClientTransportOptions transportOptions, + HttpClientTransportOptions transportOptions, McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) @@ -42,7 +42,7 @@ public SseClientSessionTransport( _sseEndpoint = transportOptions.Endpoint; _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _connectionEstablished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } @@ -193,6 +193,8 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation return; } + LogTransportReceivedMessageSensitive(Name, data); + try { var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index c026acb93..3ec0d2880 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -90,13 +90,13 @@ public async Task ConnectAsync(CancellationToken cancellationToken = #if NET foreach (string arg in arguments) { - startInfo.ArgumentList.Add(arg); + startInfo.ArgumentList.Add(EscapeArgumentString(arg)); } #else StringBuilder argsBuilder = new(); foreach (string arg in arguments) { - PasteArguments.AppendArgument(argsBuilder, arg); + PasteArguments.AppendArgument(argsBuilder, EscapeArgumentString(arg)); } startInfo.Arguments = argsBuilder.ToString(); @@ -236,6 +236,26 @@ internal static bool HasExited(Process process) } } + private static string EscapeArgumentString(string argument) => + RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && !ContainsWhitespaceRegex.IsMatch(argument) ? + WindowsCliSpecialArgumentsRegex.Replace(argument, static match => "^" + match.Value) : + argument; + + private const string WindowsCliSpecialArgumentsRegexString = "[&^><|]"; + +#if NET + private static Regex WindowsCliSpecialArgumentsRegex => GetWindowsCliSpecialArgumentsRegex(); + private static Regex ContainsWhitespaceRegex => GetContainsWhitespaceRegex(); + + [GeneratedRegex(WindowsCliSpecialArgumentsRegexString, RegexOptions.CultureInvariant)] + private static partial Regex GetWindowsCliSpecialArgumentsRegex(); + [GeneratedRegex(@"\s", RegexOptions.CultureInvariant)] + private static partial Regex GetContainsWhitespaceRegex(); +#else + private static Regex WindowsCliSpecialArgumentsRegex { get; } = new(WindowsCliSpecialArgumentsRegexString, RegexOptions.Compiled | RegexOptions.CultureInvariant); + private static Regex ContainsWhitespaceRegex { get; } = new(@"\s", RegexOptions.Compiled | RegexOptions.CultureInvariant); +#endif + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} connecting.")] private static partial void LogTransportConnecting(ILogger logger, string endpointName); diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 190bec0b2..f2fd55f16 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -17,7 +17,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); private readonly McpHttpClient _httpClient; - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; private readonly ILogger _logger; @@ -29,7 +29,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa public StreamableHttpClientSessionTransport( string endpointName, - SseClientTransportOptions transportOptions, + HttpClientTransportOptions transportOptions, McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) @@ -41,10 +41,10 @@ public StreamableHttpClientSessionTransport( _options = transportOptions; _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; // We connect with the initialization request with the MCP transport. This means that any errors won't be observed - // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClientFactory.ConnectAsync + // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. SetConnected(); } @@ -211,6 +211,8 @@ private async Task ReceiveUnsolicitedMessagesAsync() private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) { + LogTransportReceivedMessageSensitive(Name, data); + try { var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); @@ -289,7 +291,7 @@ internal static void CopyAdditionalHeaders( { if (!headers.TryAddWithoutValidation(header.Key, header.Value)) { - throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); + throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(HttpClientTransportOptions.AdditionalHeaders)}."); } } } diff --git a/src/ModelContextProtocol.Core/IMcpEndpoint.cs b/src/ModelContextProtocol.Core/IMcpEndpoint.cs index ea825e682..40106cb07 100644 --- a/src/ModelContextProtocol.Core/IMcpEndpoint.cs +++ b/src/ModelContextProtocol.Core/IMcpEndpoint.cs @@ -1,3 +1,4 @@ +using System.ComponentModel; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; @@ -26,6 +27,8 @@ namespace ModelContextProtocol; /// All MCP endpoints should be properly disposed after use as they implement . /// /// +[Obsolete($"Use {nameof(McpSession)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 +[EditorBrowsable(EditorBrowsableState.Never)] public interface IMcpEndpoint : IAsyncDisposable { /// Gets an identifier associated with the current MCP session. diff --git a/src/ModelContextProtocol.Core/McpEndpoint.cs b/src/ModelContextProtocol.Core/McpEndpoint.cs deleted file mode 100644 index 0d0ccbb98..000000000 --- a/src/ModelContextProtocol.Core/McpEndpoint.cs +++ /dev/null @@ -1,144 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; - -namespace ModelContextProtocol; - -/// -/// Base class for an MCP JSON-RPC endpoint. This covers both MCP clients and servers. -/// It is not supported, nor necessary, to implement both client and server functionality in the same class. -/// If an application needs to act as both a client and a server, it should use separate objects for each. -/// This is especially true as a client represents a connection to one and only one server, and vice versa. -/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction. -/// -internal abstract partial class McpEndpoint : IAsyncDisposable -{ - /// Cached naming information used for name/version when none is specified. - internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); - - private McpSession? _session; - private CancellationTokenSource? _sessionCts; - - private readonly SemaphoreSlim _disposeLock = new(1, 1); - private bool _disposed; - - protected readonly ILogger _logger; - - /// - /// Initializes a new instance of the class. - /// - /// The logger factory. - protected McpEndpoint(ILoggerFactory? loggerFactory = null) - { - _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; - } - - protected RequestHandlers RequestHandlers { get; } = []; - - protected NotificationHandlers NotificationHandlers { get; } = new(); - - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); - - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); - - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => - GetSessionOrThrow().RegisterNotificationHandler(method, handler); - - /// - /// Gets the name of the endpoint for logging and debug purposes. - /// - public abstract string EndpointName { get; } - - /// - /// Task that processes incoming messages from the transport. - /// - protected Task? MessageProcessingTask { get; private set; } - - protected void InitializeSession(ITransport sessionTransport) - { - _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); - } - - [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken) - { - _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); - MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); - } - - protected void CancelSession() => _sessionCts?.Cancel(); - - public async ValueTask DisposeAsync() - { - using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); - - if (_disposed) - { - return; - } - _disposed = true; - - await DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - - /// - /// Cleans up the endpoint and releases resources. - /// - /// - public virtual async ValueTask DisposeUnsynchronizedAsync() - { - LogEndpointShuttingDown(EndpointName); - - try - { - if (_sessionCts is not null) - { - await _sessionCts.CancelAsync().ConfigureAwait(false); - } - - if (MessageProcessingTask is not null) - { - try - { - await MessageProcessingTask.ConfigureAwait(false); - } - catch (OperationCanceledException) - { - // Ignore cancellation - } - } - } - finally - { - _session?.Dispose(); - _sessionCts?.Dispose(); - } - - LogEndpointShutDown(EndpointName); - } - - protected McpSession GetSessionOrThrow() - { -#if NET - ObjectDisposedException.ThrowIf(_disposed, this); -#else - if (_disposed) - { - throw new ObjectDisposedException(GetType().Name); - } -#endif - - return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")] - private partial void LogEndpointShuttingDown(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shut down.")] - private partial void LogEndpointShutDown(string endpointName); -} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs index 4e4abe5ce..1a5b5c1e2 100644 --- a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs @@ -1,9 +1,10 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol; @@ -34,6 +35,8 @@ public static class McpEndpointExtensions /// The options governing request serialization. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. The task result contains the deserialized result. + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendRequestAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask SendRequestAsync( this IMcpEndpoint endpoint, string method, @@ -42,53 +45,7 @@ public static ValueTask SendRequestAsync( RequestId requestId = default, CancellationToken cancellationToken = default) where TResult : notnull - { - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); - JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); - return SendRequestAsync(endpoint, method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); - } - - /// - /// Sends a JSON-RPC request and attempts to deserialize the result to . - /// - /// The type of the request parameters to serialize from. - /// The type of the result to deserialize to. - /// The MCP client or server instance. - /// The JSON-RPC method name to invoke. - /// Object representing the request parameters. - /// The type information for request parameter serialization. - /// The type information for request parameter deserialization. - /// The request id for the request. - /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains the deserialized result. - internal static async ValueTask SendRequestAsync( - this IMcpEndpoint endpoint, - string method, - TParameters parameters, - JsonTypeInfo parametersTypeInfo, - JsonTypeInfo resultTypeInfo, - RequestId requestId = default, - CancellationToken cancellationToken = default) - where TResult : notnull - { - Throw.IfNull(endpoint); - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(parametersTypeInfo); - Throw.IfNull(resultTypeInfo); - - JsonRpcRequest jsonRpcRequest = new() - { - Id = requestId, - Method = method, - Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), - }; - - JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); - return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); - } + => AsSessionOrThrow(endpoint).SendRequestAsync(method, parameters, serializerOptions, requestId, cancellationToken); /// /// Sends a parameterless notification to the connected endpoint. @@ -104,12 +61,10 @@ internal static async ValueTask SendRequestAsync( /// changes in state. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendNotificationAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task SendNotificationAsync(this IMcpEndpoint client, string method, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(method); - return client.SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); - } + => AsSessionOrThrow(client).SendNotificationAsync(method, cancellationToken); /// /// Sends a notification with parameters to the connected endpoint. @@ -135,42 +90,15 @@ public static Task SendNotificationAsync(this IMcpEndpoint client, string method /// but custom methods can also be used for application-specific notifications. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendNotificationAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task SendNotificationAsync( this IMcpEndpoint endpoint, string method, TParameters parameters, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); - return SendNotificationAsync(endpoint, method, parameters, parametersTypeInfo, cancellationToken); - } - - /// - /// Sends a notification to the server with parameters. - /// - /// The MCP client or server instance. - /// The JSON-RPC method name to invoke. - /// Object representing the request parameters. - /// The type information for request parameter serialization. - /// The to monitor for cancellation requests. The default is . - internal static Task SendNotificationAsync( - this IMcpEndpoint endpoint, - string method, - TParameters parameters, - JsonTypeInfo parametersTypeInfo, - CancellationToken cancellationToken = default) - { - Throw.IfNull(endpoint); - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(parametersTypeInfo); - - JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); - return endpoint.SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); - } + => AsSessionOrThrow(endpoint).SendNotificationAsync(method, parameters, serializerOptions, cancellationToken); /// /// Notifies the connected endpoint of progress for a long-running operation. @@ -191,22 +119,34 @@ internal static Task SendNotificationAsync( /// Progress notifications are sent asynchronously and don't block the operation from continuing. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.NotifyProgressAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static Task NotifyProgressAsync( this IMcpEndpoint endpoint, ProgressToken progressToken, - ProgressNotificationValue progress, + ProgressNotificationValue progress, CancellationToken cancellationToken = default) + => AsSessionOrThrow(endpoint).NotifyProgressAsync(progressToken, progress, cancellationToken); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpSession AsSessionOrThrow(IMcpEndpoint endpoint, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - Throw.IfNull(endpoint); + if (endpoint is not McpSession session) + { + ThrowInvalidEndpointType(memberName); + } + + return session; - return endpoint.SendNotificationAsync( - NotificationMethods.ProgressNotification, - new ProgressNotificationParams - { - ProgressToken = progressToken, - Progress = progress, - }, - McpJsonUtilities.JsonContext.Default.ProgressNotificationParams, - cancellationToken); + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpSession)}' are supported. " + + $"Prefer using '{nameof(McpServer)}.{memberName}' instead, as " + + $"'{nameof(McpEndpointExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 21e2468d9..8bc9e21b0 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -146,6 +146,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(AudioContentBlock))] [JsonSerializable(typeof(EmbeddedResourceBlock))] [JsonSerializable(typeof(ResourceLinkBlock))] + [JsonSerializable(typeof(IEnumerable))] [JsonSerializable(typeof(PromptReference))] [JsonSerializable(typeof(ResourceTemplateReference))] [JsonSerializable(typeof(BlobResourceContents))] diff --git a/src/ModelContextProtocol.Core/McpSession.Methods.cs b/src/ModelContextProtocol.Core/McpSession.Methods.cs new file mode 100644 index 000000000..c537732f1 --- /dev/null +++ b/src/ModelContextProtocol.Core/McpSession.Methods.cs @@ -0,0 +1,183 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol; + +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpSession : IMcpEndpoint, IAsyncDisposable +#pragma warning restore CS0618 // Type or member is obsolete +{ + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The request id for the request. + /// The options governing request serialization. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + public ValueTask SendRequestAsync( + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + RequestId requestId = default, + CancellationToken cancellationToken = default) + where TResult : notnull + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); + JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); + return SendRequestAsync(method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); + } + + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The type information for request parameter deserialization. + /// The request id for the request. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + internal async ValueTask SendRequestAsync( + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + JsonTypeInfo resultTypeInfo, + RequestId requestId = default, + CancellationToken cancellationToken = default) + where TResult : notnull + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + Throw.IfNull(resultTypeInfo); + + JsonRpcRequest jsonRpcRequest = new() + { + Id = requestId, + Method = method, + Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), + }; + + JsonRpcResponse response = await SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); + return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); + } + + /// + /// Sends a parameterless notification to the connected session. + /// + /// The notification method name. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// + /// + /// This method sends a notification without any parameters. Notifications are one-way messages + /// that don't expect a response. They are commonly used for events, status updates, or to signal + /// changes in state. + /// + /// + public Task SendNotificationAsync(string method, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(method); + return SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); + } + + /// + /// Sends a notification with parameters to the connected session. + /// + /// The type of the notification parameters to serialize. + /// The JSON-RPC method name for the notification. + /// Object representing the notification parameters. + /// The options governing parameter serialization. If null, default options are used. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// + /// + /// This method sends a notification with parameters to the connected session. Notifications are one-way + /// messages that don't expect a response, commonly used for events, status updates, or signaling changes. + /// + /// + /// The parameters object is serialized to JSON according to the provided serializer options or the default + /// options if none are specified. + /// + /// + /// The Model Context Protocol defines several standard notification methods in , + /// but custom methods can also be used for application-specific notifications. + /// + /// + public Task SendNotificationAsync( + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); + return SendNotificationAsync(method, parameters, parametersTypeInfo, cancellationToken); + } + + /// + /// Sends a notification to the server with parameters. + /// + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The to monitor for cancellation requests. The default is . + internal Task SendNotificationAsync( + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + + JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); + return SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); + } + + /// + /// Notifies the connected session of progress for a long-running operation. + /// + /// The identifying the operation for which progress is being reported. + /// The progress update to send, containing information such as percentage complete or status message. + /// The to monitor for cancellation requests. The default is . + /// A task representing the completion of the notification operation (not the operation being tracked). + /// The current session instance is . + /// + /// + /// This method sends a progress notification to the connected session using the Model Context Protocol's + /// standardized progress notification format. Progress updates are identified by a + /// that allows the recipient to correlate multiple updates with a specific long-running operation. + /// + /// + /// Progress notifications are sent asynchronously and don't block the operation from continuing. + /// + /// + public Task NotifyProgressAsync( + ProgressToken progressToken, + ProgressNotificationValue progress, + CancellationToken cancellationToken = default) + { + return SendNotificationAsync( + NotificationMethods.ProgressNotification, + new ProgressNotificationParams + { + ProgressToken = progressToken, + Progress = progress, + }, + McpJsonUtilities.JsonContext.Default.ProgressNotificationParams, + cancellationToken); + } +} diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index 06b2894b0..429fdbfd4 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -1,786 +1,95 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.Collections.Concurrent; -using System.Diagnostics; -using System.Diagnostics.Metrics; -using System.Text.Json; -using System.Text.Json.Nodes; -#if !NET -using System.Threading.Channels; -#endif namespace ModelContextProtocol; /// -/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// Represents a client or server Model Context Protocol (MCP) session. /// -internal sealed partial class McpSession : IDisposable +/// +/// +/// The MCP session provides the core communication functionality used by both clients and servers: +/// +/// Sending JSON-RPC requests and receiving responses. +/// Sending notifications to the connected session. +/// Registering handlers for receiving notifications. +/// +/// +/// +/// serves as the base interface for both and +/// interfaces, providing the common functionality needed for MCP protocol +/// communication. Most applications will use these more specific interfaces rather than working with +/// directly. +/// +/// +/// All MCP sessions should be properly disposed after use as they implement . +/// +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpSession : IMcpEndpoint, IAsyncDisposable +#pragma warning restore CS0618 // Type or member is obsolete { - private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( - "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); - private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( - "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); - private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( - "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); - private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( - "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); + /// Gets an identifier associated with the current MCP session. + /// + /// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE. + /// Can return if the session hasn't initialized or if the transport doesn't + /// support multiple sessions (as is the case with STDIO). + /// + public abstract string? SessionId { get; } - /// The latest version of the protocol supported by this implementation. - internal const string LatestProtocolVersion = "2025-06-18"; - - /// All protocol versions supported by this implementation. - internal static readonly string[] SupportedProtocolVersions = - [ - "2024-11-05", - "2025-03-26", - LatestProtocolVersion, - ]; - - private readonly bool _isServer; - private readonly string _transportKind; - private readonly ITransport _transport; - private readonly RequestHandlers _requestHandlers; - private readonly NotificationHandlers _notificationHandlers; - private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); - - private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; - - /// Collection of requests sent on this session and waiting for responses. - private readonly ConcurrentDictionary> _pendingRequests = []; /// - /// Collection of requests received on this session and currently being handled. The value provides a - /// that can be used to request cancellation of the in-flight handler. + /// Gets the negotiated protocol version for the current MCP session. /// - private readonly ConcurrentDictionary _handlingRequests = new(); - private readonly ILogger _logger; - - // This _sessionId is solely used to identify the session in telemetry and logs. - private readonly string _sessionId = Guid.NewGuid().ToString("N"); - private long _lastRequestId; + /// + /// Returns the protocol version negotiated during session initialization, + /// or if initialization hasn't yet occurred. + /// + public abstract string? NegotiatedProtocolVersion { get; } /// - /// Initializes a new instance of the class. + /// Sends a JSON-RPC request to the connected session and waits for a response. /// - /// true if this is a server; false if it's a client. - /// An MCP transport implementation. - /// The name of the endpoint for logging and debug purposes. - /// A collection of request handlers. - /// A collection of notification handlers. - /// The logger. - public McpSession( - bool isServer, - ITransport transport, - string endpointName, - RequestHandlers requestHandlers, - NotificationHandlers notificationHandlers, - ILogger logger) - { - Throw.IfNull(transport); - - _transportKind = transport switch - { - StdioClientSessionTransport or StdioServerTransport => "stdio", - StreamClientSessionTransport or StreamServerTransport => "stream", - SseClientSessionTransport or SseResponseStreamTransport => "sse", - StreamableHttpClientSessionTransport or StreamableHttpServerTransport or StreamableHttpPostTransport => "http", - _ => "unknownTransport" - }; - - _isServer = isServer; - _transport = transport; - EndpointName = endpointName; - _requestHandlers = requestHandlers; - _notificationHandlers = notificationHandlers; - _logger = logger ?? NullLogger.Instance; - } - - /// - /// Gets and sets the name of the endpoint for logging and debug purposes. - /// - public string EndpointName { get; set; } - - /// - /// Starts processing messages from the transport. This method will block until the transport is disconnected. - /// This is generally started in a background task or thread from the initialization logic of the derived class. - /// - public async Task ProcessMessagesAsync(CancellationToken cancellationToken) - { - try - { - await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) - { - LogMessageRead(EndpointName, message.GetType().Name); - - // Fire and forget the message handling to avoid blocking the transport. - if (message.ExecutionContext is null) - { - _ = ProcessMessageAsync(); - } - else - { - // Flow the execution context from the HTTP request corresponding to this message if provided. - ExecutionContext.Run(message.ExecutionContext, _ => _ = ProcessMessageAsync(), null); - } - - async Task ProcessMessageAsync() - { - JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; - CancellationTokenSource? combinedCts = null; - try - { - // Register before we yield, so that the tracking is guaranteed to be there - // when subsequent messages arrive, even if the asynchronous processing happens - // out of order. - if (messageWithId is not null) - { - combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _handlingRequests[messageWithId.Id] = combinedCts; - } - - // If we await the handler without yielding first, the transport may not be able to read more messages, - // which could lead to a deadlock if the handler sends a message back. -#if NET - await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); -#else - await default(ForceYielding); -#endif - - // Handle the message. - await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - // Only send responses for request errors that aren't user-initiated cancellation. - bool isUserCancellation = - ex is OperationCanceledException && - !cancellationToken.IsCancellationRequested && - combinedCts?.IsCancellationRequested is true; - - if (!isUserCancellation && message is JsonRpcRequest request) - { - LogRequestHandlerException(EndpointName, request.Method, ex); - - JsonRpcErrorDetail detail = ex is McpException mcpe ? - new() - { - Code = (int)mcpe.ErrorCode, - Message = mcpe.Message, - } : - new() - { - Code = (int)McpErrorCode.InternalError, - Message = "An error occurred.", - }; - - await SendMessageAsync(new JsonRpcError - { - Id = request.Id, - JsonRpc = "2.0", - Error = detail, - RelatedTransport = request.RelatedTransport, - }, cancellationToken).ConfigureAwait(false); - } - else if (ex is not OperationCanceledException) - { - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); - } - else - { - LogMessageHandlerException(EndpointName, message.GetType().Name, ex); - } - } - } - finally - { - if (messageWithId is not null) - { - _handlingRequests.TryRemove(messageWithId.Id, out _); - combinedCts!.Dispose(); - } - } - } - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - // Normal shutdown - LogEndpointMessageProcessingCanceled(EndpointName); - } - finally - { - // Fail any pending requests, as they'll never be satisfied. - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); - } - } - } - - private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) - { - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = GetMethodName(message); - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - - Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? - Diagnostics.ActivitySource.StartActivity( - CreateActivityName(method), - ActivityKind.Server, - parentContext: _propagator.ExtractActivityContext(message), - links: Diagnostics.ActivityLinkFromCurrent()) : - null; - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - try - { - if (addTags) - { - AddTags(ref tags, activity, message, method); - } - - switch (message) - { - case JsonRpcRequest request: - var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); - AddResponseTags(ref tags, activity, result, method); - break; - - case JsonRpcNotification notification: - await HandleNotification(notification, cancellationToken).ConfigureAwait(false); - break; - - case JsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; - - default: - LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; - } - } - catch (Exception e) when (addTags) - { - AddExceptionTags(ref tags, activity, e); - throw; - } - finally - { - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) - { - // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) - if (notification.Method == NotificationMethods.CancelledNotification) - { - try - { - if (GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _handlingRequests.TryGetValue(cn.RequestId, out var cts)) - { - await cts.CancelAsync().ConfigureAwait(false); - LogRequestCanceled(EndpointName, cn.RequestId, cn.Reason); - } - } - catch - { - // "Invalid cancellation notifications SHOULD be ignored" - } - } - - // Handle user-defined notifications. - await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); - } - - private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) - { - if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) - { - tcs.TrySetResult(message); - } - else - { - LogNoRequestFoundForMessageWithId(EndpointName, messageWithId.Id); - } - } - - private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) - { - if (!_requestHandlers.TryGetValue(request.Method, out var handler)) - { - LogNoHandlerFoundForRequest(EndpointName, request.Method); - throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); - } - - LogRequestHandlerCalled(EndpointName, request.Method); - JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); - LogRequestHandlerCompleted(EndpointName, request.Method); - - await SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = result, - RelatedTransport = request.RelatedTransport, - }, cancellationToken).ConfigureAwait(false); - - return result; - } - - private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) - { - if (!cancellationToken.CanBeCanceled) - { - return default; - } - - return cancellationToken.Register(static objState => - { - var state = (Tuple)objState!; - _ = state.Item1.SendMessageAsync(new JsonRpcNotification - { - Method = NotificationMethods.CancelledNotification, - Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), - RelatedTransport = state.Item2.RelatedTransport, - }); - }, Tuple.Create(this, request)); - } - - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) - { - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(handler); - - return _notificationHandlers.Register(method, handler); - } + /// The JSON-RPC request to send. + /// The to monitor for cancellation requests. The default is . + /// A task containing the session's response. + /// The transport is not connected, or another error occurs during request processing. + /// An error occurred during request processing. + /// + /// This method provides low-level access to send raw JSON-RPC requests. For most use cases, + /// consider using the strongly-typed methods that provide a more convenient API. + /// + public abstract Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default); /// - /// Sends a JSON-RPC request to the server. - /// It is strongly recommended use the capability-specific methods instead of this one. - /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// Sends a JSON-RPC message to the connected session. /// - /// The JSON-RPC request to send. + /// + /// The JSON-RPC message to send. This can be any type that implements JsonRpcMessage, such as + /// JsonRpcRequest, JsonRpcResponse, JsonRpcNotification, or JsonRpcError. + /// /// The to monitor for cancellation requests. The default is . - /// A task containing the server's response. - public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = request.Method; - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : - null; - - // Set request ID - if (request.Id.Id is null) - { - request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); - } - - _propagator.InjectActivityContext(activity, request); - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _pendingRequests[request.Id] = tcs; - try - { - if (addTags) - { - AddTags(ref tags, activity, request, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingRequest(EndpointName, request.Method); - } - - await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); - - // Now that the request has been sent, register for cancellation. If we registered before, - // a cancellation request could arrive before the server knew about that request ID, in which - // case the server could ignore it. - LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); - JsonRpcMessage? response; - using (var registration = RegisterCancellation(cancellationToken, request)) - { - response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); - } - - if (response is JsonRpcError error) - { - LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); - throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); - } - - if (response is JsonRpcResponse success) - { - if (addTags) - { - AddResponseTags(ref tags, activity, success.Result, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); - } - else - { - LogRequestResponseReceived(EndpointName, request.Method); - } - - return success; - } - - // Unexpected response type - LogSendingRequestInvalidResponseType(EndpointName, request.Method); - throw new McpException("Invalid response type"); - } - catch (Exception ex) when (addTags) - { - AddExceptionTags(ref tags, activity, ex); - throw; - } - finally - { - _pendingRequests.TryRemove(request.Id, out _); - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - Throw.IfNull(message); - - cancellationToken.ThrowIfCancellationRequested(); - - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = GetMethodName(message); - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : - null; - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - - // propagate trace context - _propagator?.InjectActivityContext(activity, message); - - try - { - if (addTags) - { - AddTags(ref tags, activity, message, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingMessage(EndpointName); - } - - await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); - - // If the sent notification was a cancellation notification, cancel the pending request's await, as either the - // server won't be sending a response, or per the specification, the response should be ignored. There are inherent - // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. - if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && - GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _pendingRequests.TryRemove(cn.RequestId, out var tcs)) - { - tcs.TrySetCanceled(default); - } - } - catch (Exception ex) when (addTags) - { - AddExceptionTags(ref tags, activity, ex); - throw; - } - finally - { - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the - // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in - // the HTTP response body for the POST request containing the corresponding JSON-RPC request. - private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) - => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); - - private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) - { - try - { - return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams); - } - catch - { - return null; - } - } - - private string CreateActivityName(string method) => method; - - private static string GetMethodName(JsonRpcMessage message) => - message switch - { - JsonRpcRequest request => request.Method, - JsonRpcNotification notification => notification.Method, - _ => "unknownMethod" - }; - - private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) - { - tags.Add("mcp.method.name", method); - tags.Add("network.transport", _transportKind); - - // TODO: When using SSE transport, add: - // - server.address and server.port on client spans and metrics - // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport - if (activity is { IsAllDataRequested: true }) - { - // session and request id have high cardinality, so not applying to metric tags - activity.AddTag("mcp.session.id", _sessionId); - - if (message is JsonRpcMessageWithId withId) - { - activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); - } - } - - JsonObject? paramsObj = message switch - { - JsonRpcRequest request => request.Params as JsonObject, - JsonRpcNotification notification => notification.Params as JsonObject, - _ => null - }; - - if (paramsObj == null) - { - return; - } - - string? target = null; - switch (method) - { - case RequestMethods.ToolsCall: - case RequestMethods.PromptsGet: - target = GetStringProperty(paramsObj, "name"); - if (target is not null) - { - tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); - } - break; - - case RequestMethods.ResourcesRead: - case RequestMethods.ResourcesSubscribe: - case RequestMethods.ResourcesUnsubscribe: - case NotificationMethods.ResourceUpdatedNotification: - target = GetStringProperty(paramsObj, "uri"); - if (target is not null) - { - tags.Add("mcp.resource.uri", target); - } - break; - } - - if (activity is { IsAllDataRequested: true }) - { - activity.DisplayName = target == null ? method : $"{method} {target}"; - } - } - - private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) - { - if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) - { - e = ae.InnerException; - } - - int? intErrorCode = - (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : - e is JsonException ? (int)McpErrorCode.ParseError : - null; - - string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; - tags.Add("error.type", errorType); - if (intErrorCode is not null) - { - tags.Add("rpc.jsonrpc.error_code", errorType); - } - - if (activity is { IsAllDataRequested: true }) - { - activity.SetStatus(ActivityStatusCode.Error, e.Message); - } - } - - private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) - { - if (response is JsonObject jsonObject - && jsonObject.TryGetPropertyValue("isError", out var isError) - && isError?.GetValueKind() == JsonValueKind.True) - { - if (activity is { IsAllDataRequested: true }) - { - string? content = null; - if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) - { - content = prop.ToJsonString(); - } - - activity.SetStatus(ActivityStatusCode.Error, content); - } - - tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); - } - } - - private static void FinalizeDiagnostics( - Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) - { - try - { - if (startingTimestamp is not null) - { - durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); - } - - if (activity is { IsAllDataRequested: true }) - { - foreach (var tag in tags) - { - activity.AddTag(tag.Key, tag.Value); - } - } - } - finally - { - activity?.Dispose(); - } - } - - public void Dispose() - { - Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; - if (durationMetric.Enabled) - { - TagList tags = default; - tags.Add("network.transport", _transportKind); - - // TODO: Add server.address and server.port on client-side when using SSE transport, - // client.* attributes are not added to metrics because of cardinality - durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); - } - - // Complete all pending requests with cancellation - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetCanceled(); - } - - _pendingRequests.Clear(); - } - -#if !NET - private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; -#endif - - private static TimeSpan GetElapsed(long startingTimestamp) => -#if NET - Stopwatch.GetElapsedTime(startingTimestamp); -#else - new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); -#endif - - private static string? GetStringProperty(JsonObject parameters, string propName) - { - if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) - { - return prop.GetValue(); - } - - return null; - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] - private partial void LogEndpointMessageProcessingCanceled(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler called.")] - private partial void LogRequestHandlerCalled(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler completed.")] - private partial void LogRequestHandlerCompleted(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] - private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] - private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] - private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] - private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending method '{Method}' request.")] - private partial void LogSendingRequest(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending method '{Method}' request. Request: '{Request}'.")] - private partial void LogSendingRequestSensitive(string endpointName, string method, string request); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] - private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] - private partial void LogRequestResponseReceived(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] - private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] - private partial void LogMessageRead(string endpointName, string messageType); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} message handler {MessageType} failed.")] - private partial void LogMessageHandlerException(string endpointName, string messageType, Exception exception); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} message handler {MessageType} failed. Message: '{Message}'.")] - private partial void LogMessageHandlerExceptionSensitive(string endpointName, string messageType, string message, Exception exception); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received unexpected {MessageType} message type.")] - private partial void LogEndpointHandlerUnexpectedMessageType(string endpointName, string messageType); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] - private partial void LogNoHandlerFoundForRequest(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] - private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] - private partial void LogSendingMessage(string endpointName); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")] - private partial void LogSendingMessageSensitive(string endpointName, string message); + /// A task that represents the asynchronous send operation. + /// The transport is not connected. + /// is . + /// + /// + /// This method provides low-level access to send any JSON-RPC message. For specific message types, + /// consider using the higher-level methods such as or methods + /// on this class that provide a simpler API. + /// + /// + /// The method will serialize the message and transmit it using the underlying transport mechanism. + /// + /// + public abstract Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); + + /// Registers a handler to be invoked when a notification for the specified method is received. + /// The notification method. + /// The handler to be invoked. + /// An that will remove the registered handler when disposed. + public abstract IAsyncDisposable RegisterNotificationHandler(string method, Func handler); + + /// + public abstract ValueTask DisposeAsync(); } diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs new file mode 100644 index 000000000..749486e4b --- /dev/null +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -0,0 +1,831 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Text.Json; +using System.Text.Json.Nodes; +#if !NET +using System.Threading.Channels; +#endif + +namespace ModelContextProtocol; + +/// +/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// +internal sealed partial class McpSessionHandler : IAsyncDisposable +{ + private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); + private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); + private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); + private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); + + /// The latest version of the protocol supported by this implementation. + internal const string LatestProtocolVersion = "2025-06-18"; + + /// All protocol versions supported by this implementation. + internal static readonly string[] SupportedProtocolVersions = + [ + "2024-11-05", + "2025-03-26", + LatestProtocolVersion, + ]; + + private readonly bool _isServer; + private readonly string _transportKind; + private readonly ITransport _transport; + private readonly RequestHandlers _requestHandlers; + private readonly NotificationHandlers _notificationHandlers; + private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); + + private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; + + /// Collection of requests sent on this session and waiting for responses. + private readonly ConcurrentDictionary> _pendingRequests = []; + /// + /// Collection of requests received on this session and currently being handled. The value provides a + /// that can be used to request cancellation of the in-flight handler. + /// + private readonly ConcurrentDictionary _handlingRequests = new(); + private readonly ILogger _logger; + + // This _sessionId is solely used to identify the session in telemetry and logs. + private readonly string _sessionId = Guid.NewGuid().ToString("N"); + private long _lastRequestId; + + private CancellationTokenSource? _messageProcessingCts; + private Task? _messageProcessingTask; + + /// + /// Initializes a new instance of the class. + /// + /// true if this is a server; false if it's a client. + /// An MCP transport implementation. + /// The name of the endpoint for logging and debug purposes. + /// A collection of request handlers. + /// A collection of notification handlers. + /// The logger. + public McpSessionHandler( + bool isServer, + ITransport transport, + string endpointName, + RequestHandlers requestHandlers, + NotificationHandlers notificationHandlers, + ILogger logger) + { + Throw.IfNull(transport); + + _transportKind = transport switch + { + StdioClientSessionTransport or StdioServerTransport => "stdio", + StreamClientSessionTransport or StreamServerTransport => "stream", + SseClientSessionTransport or SseResponseStreamTransport => "sse", + StreamableHttpClientSessionTransport or StreamableHttpServerTransport or StreamableHttpPostTransport => "http", + _ => "unknownTransport" + }; + + _isServer = isServer; + _transport = transport; + EndpointName = endpointName; + _requestHandlers = requestHandlers; + _notificationHandlers = notificationHandlers; + _logger = logger ?? NullLogger.Instance; + LogSessionCreated(EndpointName, _sessionId, _transportKind); + } + + /// + /// Gets and sets the name of the endpoint for logging and debug purposes. + /// + public string EndpointName { get; set; } + + /// + /// Starts processing messages from the transport. This method will block until the transport is disconnected. + /// This is generally started in a background task or thread from the initialization logic of the derived class. + /// + public Task ProcessMessagesAsync(CancellationToken cancellationToken) + { + if (_messageProcessingTask is not null) + { + throw new InvalidOperationException("The message processing loop has already started."); + } + + Debug.Assert(_messageProcessingCts is null); + + _messageProcessingCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _messageProcessingTask = ProcessMessagesCoreAsync(_messageProcessingCts.Token); + return _messageProcessingTask; + } + + private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken) + { + try + { + await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + LogMessageRead(EndpointName, message.GetType().Name); + + // Fire and forget the message handling to avoid blocking the transport. + if (message.Context?.ExecutionContext is null) + { + _ = ProcessMessageAsync(); + } + else + { + // Flow the execution context from the HTTP request corresponding to this message if provided. + ExecutionContext.Run(message.Context.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + } + + async Task ProcessMessageAsync() + { + JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; + CancellationTokenSource? combinedCts = null; + try + { + // Register before we yield, so that the tracking is guaranteed to be there + // when subsequent messages arrive, even if the asynchronous processing happens + // out of order. + if (messageWithId is not null) + { + combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _handlingRequests[messageWithId.Id] = combinedCts; + } + + // If we await the handler without yielding first, the transport may not be able to read more messages, + // which could lead to a deadlock if the handler sends a message back. +#if NET + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); +#else + await default(ForceYielding); +#endif + + // Handle the message. + await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + // Only send responses for request errors that aren't user-initiated cancellation. + bool isUserCancellation = + ex is OperationCanceledException && + !cancellationToken.IsCancellationRequested && + combinedCts?.IsCancellationRequested is true; + + if (!isUserCancellation && message is JsonRpcRequest request) + { + LogRequestHandlerException(EndpointName, request.Method, ex); + + JsonRpcErrorDetail detail = ex is McpException mcpe ? + new() + { + Code = (int)mcpe.ErrorCode, + Message = mcpe.Message, + } : + new() + { + Code = (int)McpErrorCode.InternalError, + Message = "An error occurred.", + }; + + var errorMessage = new JsonRpcError + { + Id = request.Id, + JsonRpc = "2.0", + Error = detail, + Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, + }; + await SendMessageAsync(errorMessage, cancellationToken).ConfigureAwait(false); + } + else if (ex is not OperationCanceledException) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); + } + else + { + LogMessageHandlerException(EndpointName, message.GetType().Name, ex); + } + } + } + finally + { + if (messageWithId is not null) + { + _handlingRequests.TryRemove(messageWithId.Id, out _); + combinedCts!.Dispose(); + } + } + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + LogEndpointMessageProcessingCanceled(EndpointName); + } + finally + { + // Fail any pending requests, as they'll never be satisfied. + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); + } + } + } + + private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + + Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity( + CreateActivityName(method), + ActivityKind.Server, + parentContext: _propagator.ExtractActivityContext(message), + links: Diagnostics.ActivityLinkFromCurrent()) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + try + { + if (addTags) + { + AddTags(ref tags, activity, message, method); + } + + switch (message) + { + case JsonRpcRequest request: + var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); + AddResponseTags(ref tags, activity, result, method); + break; + + case JsonRpcNotification notification: + await HandleNotification(notification, cancellationToken).ConfigureAwait(false); + break; + + case JsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + break; + + default: + LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + break; + } + } + catch (Exception e) when (addTags) + { + AddExceptionTags(ref tags, activity, e); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) + { + // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) + if (notification.Method == NotificationMethods.CancelledNotification) + { + try + { + if (GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _handlingRequests.TryGetValue(cn.RequestId, out var cts)) + { + await cts.CancelAsync().ConfigureAwait(false); + LogRequestCanceled(EndpointName, cn.RequestId, cn.Reason); + } + } + catch + { + // "Invalid cancellation notifications SHOULD be ignored" + } + } + + // Handle user-defined notifications. + await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); + } + + private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) + { + if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) + { + tcs.TrySetResult(message); + } + else + { + LogNoRequestFoundForMessageWithId(EndpointName, messageWithId.Id); + } + } + + private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + { + if (!_requestHandlers.TryGetValue(request.Method, out var handler)) + { + LogNoHandlerFoundForRequest(EndpointName, request.Method); + throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); + } + + LogRequestHandlerCalled(EndpointName, request.Method); + JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); + LogRequestHandlerCompleted(EndpointName, request.Method); + + await SendMessageAsync(new JsonRpcResponse + { + Id = request.Id, + Result = result, + Context = request.Context, + }, cancellationToken).ConfigureAwait(false); + + return result; + } + + private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) + { + if (!cancellationToken.CanBeCanceled) + { + return default; + } + + return cancellationToken.Register(static objState => + { + var state = (Tuple)objState!; + _ = state.Item1.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.CancelledNotification, + Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), + Context = new JsonRpcMessageContext { RelatedTransport = state.Item2.Context?.RelatedTransport }, + }); + }, Tuple.Create(this, request)); + } + + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(handler); + + return _notificationHandlers.Register(method, handler); + } + + /// + /// Sends a JSON-RPC request to the server. + /// It is strongly recommended use the capability-specific methods instead of this one. + /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// + /// The JSON-RPC request to send. + /// The to monitor for cancellation requests. The default is . + /// A task containing the server's response. + public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) + { + Throw.IfNull(request); + + cancellationToken.ThrowIfCancellationRequested(); + + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = request.Method; + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : + null; + + // Set request ID + if (request.Id.Id is null) + { + request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); + } + + _propagator.InjectActivityContext(activity, request); + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _pendingRequests[request.Id] = tcs; + try + { + if (addTags) + { + AddTags(ref tags, activity, request, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingRequest(EndpointName, request.Method); + } + + await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); + + // Now that the request has been sent, register for cancellation. If we registered before, + // a cancellation request could arrive before the server knew about that request ID, in which + // case the server could ignore it. + LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); + JsonRpcMessage? response; + using (var registration = RegisterCancellation(cancellationToken, request)) + { + response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + if (response is JsonRpcError error) + { + LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); + throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); + } + + if (response is JsonRpcResponse success) + { + if (addTags) + { + AddResponseTags(ref tags, activity, success.Result, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); + } + else + { + LogRequestResponseReceived(EndpointName, request.Method); + } + + return success; + } + + // Unexpected response type + LogSendingRequestInvalidResponseType(EndpointName, request.Method); + throw new McpException("Invalid response type"); + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, activity, ex); + throw; + } + finally + { + _pendingRequests.TryRemove(request.Id, out _); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + Throw.IfNull(message); + + cancellationToken.ThrowIfCancellationRequested(); + + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + // propagate trace context + _propagator?.InjectActivityContext(activity, message); + + try + { + if (addTags) + { + AddTags(ref tags, activity, message, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingMessage(EndpointName); + } + + await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); + + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, activity, ex); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the + // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in + // the HTTP response body for the POST request containing the corresponding JSON-RPC request. + private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + => (message.Context?.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); + + private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) + { + try + { + return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams); + } + catch + { + return null; + } + } + + private string CreateActivityName(string method) => method; + + private static string GetMethodName(JsonRpcMessage message) => + message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => "unknownMethod" + }; + + private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) + { + tags.Add("mcp.method.name", method); + tags.Add("network.transport", _transportKind); + + // TODO: When using SSE transport, add: + // - server.address and server.port on client spans and metrics + // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport + if (activity is { IsAllDataRequested: true }) + { + // session and request id have high cardinality, so not applying to metric tags + activity.AddTag("mcp.session.id", _sessionId); + + if (message is JsonRpcMessageWithId withId) + { + activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); + } + } + + JsonObject? paramsObj = message switch + { + JsonRpcRequest request => request.Params as JsonObject, + JsonRpcNotification notification => notification.Params as JsonObject, + _ => null + }; + + if (paramsObj == null) + { + return; + } + + string? target = null; + switch (method) + { + case RequestMethods.ToolsCall: + case RequestMethods.PromptsGet: + target = GetStringProperty(paramsObj, "name"); + if (target is not null) + { + tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); + } + break; + + case RequestMethods.ResourcesRead: + case RequestMethods.ResourcesSubscribe: + case RequestMethods.ResourcesUnsubscribe: + case NotificationMethods.ResourceUpdatedNotification: + target = GetStringProperty(paramsObj, "uri"); + if (target is not null) + { + tags.Add("mcp.resource.uri", target); + } + break; + } + + if (activity is { IsAllDataRequested: true }) + { + activity.DisplayName = target == null ? method : $"{method} {target}"; + } + } + + private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) + { + if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) + { + e = ae.InnerException; + } + + int? intErrorCode = + (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : + e is JsonException ? (int)McpErrorCode.ParseError : + null; + + string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; + tags.Add("error.type", errorType); + if (intErrorCode is not null) + { + tags.Add("rpc.jsonrpc.error_code", errorType); + } + + if (activity is { IsAllDataRequested: true }) + { + activity.SetStatus(ActivityStatusCode.Error, e.Message); + } + } + + private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) + { + if (response is JsonObject jsonObject + && jsonObject.TryGetPropertyValue("isError", out var isError) + && isError?.GetValueKind() == JsonValueKind.True) + { + if (activity is { IsAllDataRequested: true }) + { + string? content = null; + if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) + { + content = prop.ToJsonString(); + } + + activity.SetStatus(ActivityStatusCode.Error, content); + } + + tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); + } + } + + private static void FinalizeDiagnostics( + Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) + { + try + { + if (startingTimestamp is not null) + { + durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); + } + + if (activity is { IsAllDataRequested: true }) + { + foreach (var tag in tags) + { + activity.AddTag(tag.Key, tag.Value); + } + } + } + finally + { + activity?.Dispose(); + } + } + + public async ValueTask DisposeAsync() + { + Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; + if (durationMetric.Enabled) + { + TagList tags = default; + tags.Add("network.transport", _transportKind); + + // TODO: Add server.address and server.port on client-side when using SSE transport, + // client.* attributes are not added to metrics because of cardinality + durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); + } + + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetCanceled(); + } + + _pendingRequests.Clear(); + + if (_messageProcessingCts is not null) + { + await _messageProcessingCts.CancelAsync().ConfigureAwait(false); + } + + if (_messageProcessingTask is not null) + { + try + { + await _messageProcessingTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Ignore cancellation + } + } + + LogSessionDisposed(EndpointName, _sessionId, _transportKind); + } + +#if !NET + private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; +#endif + + private static TimeSpan GetElapsed(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); +#endif + + private static string? GetStringProperty(JsonObject parameters, string propName) + { + if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) + { + return prop.GetValue(); + } + + return null; + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] + private partial void LogEndpointMessageProcessingCanceled(string endpointName); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler called.")] + private partial void LogRequestHandlerCalled(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler completed.")] + private partial void LogRequestHandlerCompleted(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] + private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] + private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] + private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] + private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending method '{Method}' request.")] + private partial void LogSendingRequest(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending method '{Method}' request. Request: '{Request}'.")] + private partial void LogSendingRequestSensitive(string endpointName, string method, string request); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] + private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] + private partial void LogRequestResponseReceived(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] + private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] + private partial void LogMessageRead(string endpointName, string messageType); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} message handler {MessageType} failed.")] + private partial void LogMessageHandlerException(string endpointName, string messageType, Exception exception); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} message handler {MessageType} failed. Message: '{Message}'.")] + private partial void LogMessageHandlerExceptionSensitive(string endpointName, string messageType, string message, Exception exception); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received unexpected {MessageType} message type.")] + private partial void LogEndpointHandlerUnexpectedMessageType(string endpointName, string messageType); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] + private partial void LogNoHandlerFoundForRequest(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] + private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] + private partial void LogSendingMessage(string endpointName); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")] + private partial void LogSendingMessageSensitive(string endpointName, string message); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} created with transport {TransportKind}")] + private partial void LogSessionCreated(string endpointName, string sessionId, string transportKind); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} disposed with transport {TransportKind}")] + private partial void LogSessionDisposed(string endpointName, string sessionId, string transportKind); +} diff --git a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs index ebe698135..f133e8dca 100644 --- a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs @@ -1,5 +1,7 @@ -using ModelContextProtocol.Server; +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Client; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -44,8 +46,8 @@ public sealed class ClientCapabilities /// server requests for listing root URIs. Root URIs serve as entry points for resource navigation in the protocol. /// /// - /// The server can use to request the list of - /// available roots from the client, which will trigger the client's . + /// The server can use to request the list of + /// available roots from the client, which will trigger the client's . /// /// [JsonPropertyName("roots")] @@ -78,10 +80,12 @@ public sealed class ClientCapabilities /// /// /// Handlers provided via will be registered with the client for the lifetime of the client. - /// For transient handlers, may be used to register a handler that can + /// For transient handlers, may be used to register a handler that can /// then be unregistered by disposing of the returned from the method. /// /// [JsonIgnore] + [Obsolete($"Use {nameof(McpClientOptions.Handlers.NotificationHandlers)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public IEnumerable>>? NotificationHandlers { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs b/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs index f411c2975..8e28e67d3 100644 --- a/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs @@ -1,5 +1,6 @@ -using ModelContextProtocol.Server; +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -9,7 +10,7 @@ namespace ModelContextProtocol.Protocol; /// /// /// -/// When enabled, this capability allows a Model Context Protocol server to provide +/// When enabled, this capability allows a Model Context Protocol server to provide /// auto-completion suggestions. This capability is advertised to clients during the initialize handshake. /// /// @@ -19,11 +20,14 @@ namespace ModelContextProtocol.Protocol; /// /// See the schema for details. /// +/// +/// This class is intentionally empty as the Model Context Protocol specification does not +/// currently define additional properties for sampling capabilities. Future versions of the +/// specification may extend this capability with additional configuration options. +/// /// public sealed class CompletionsCapability { - // Currently empty in the spec, but may be extended in the future. - /// /// Gets or sets the handler for completion requests. /// @@ -33,5 +37,7 @@ public sealed class CompletionsCapability /// and should return appropriate completion suggestions. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? CompleteHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.CompleteHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? CompleteHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs index 516ea2446..04de39db4 100644 --- a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs +++ b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs @@ -103,6 +103,10 @@ public class Converter : JsonConverter text = reader.GetString(); break; + case "name": + name = reader.GetString(); + break; + case "data": data = reader.GetString(); break; diff --git a/src/ModelContextProtocol.Core/Protocol/ElicitResult.cs b/src/ModelContextProtocol.Core/Protocol/ElicitResult.cs index 39387f500..024f5eb19 100644 --- a/src/ModelContextProtocol.Core/Protocol/ElicitResult.cs +++ b/src/ModelContextProtocol.Core/Protocol/ElicitResult.cs @@ -11,6 +11,9 @@ public sealed class ElicitResult : Result /// /// Gets or sets the user action in response to the elicitation. /// + /// + /// Defaults to "cancel" if not explicitly set. + /// /// /// /// @@ -23,13 +26,22 @@ public sealed class ElicitResult : Result /// /// /// "cancel" - /// User dismissed without making an explicit choice + /// User dismissed without making an explicit choice (default) /// /// /// [JsonPropertyName("action")] public string Action { get; set; } = "cancel"; + /// + /// Convenience indicator for whether the elicitation was accepted by the user. + /// + /// + /// Indicates that the elicitation request completed successfully and value of has been populated with a value. + /// + [JsonIgnore] + public bool IsAccepted => string.Equals(Action, "accept", StringComparison.OrdinalIgnoreCase); + /// /// Gets or sets the submitted form data. /// @@ -44,4 +56,29 @@ public sealed class ElicitResult : Result /// [JsonPropertyName("content")] public IDictionary? Content { get; set; } +} + +/// +/// Represents the client's response to an elicitation request, with typed content payload. +/// +/// The type of the expected content payload. +public sealed class ElicitResult : Result +{ + /// + /// Gets or sets the user action in response to the elicitation. + /// + public string Action { get; set; } = "cancel"; + + /// + /// Convenience indicator for whether the elicitation was accepted by the user. + /// + /// + /// Indicates that the elicitation request completed successfully and value of has been populated with a value. + /// + public bool IsAccepted => string.Equals(Action, "accept", StringComparison.OrdinalIgnoreCase); + + /// + /// Gets or sets the submitted form data as a typed value. + /// + public T? Content { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/ElicitationCapability.cs b/src/ModelContextProtocol.Core/Protocol/ElicitationCapability.cs index d88247d2d..e096e0e09 100644 --- a/src/ModelContextProtocol.Core/Protocol/ElicitationCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/ElicitationCapability.cs @@ -1,4 +1,6 @@ +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Client; namespace ModelContextProtocol.Protocol; @@ -11,13 +13,16 @@ namespace ModelContextProtocol.Protocol; /// /// /// When this capability is enabled, an MCP server can request the client to provide additional information -/// during interactions. The client must set a to process these requests. +/// during interactions. The client must set a to process these requests. +/// +/// +/// This class is intentionally empty as the Model Context Protocol specification does not +/// currently define additional properties for sampling capabilities. Future versions of the +/// specification may extend this capability with additional configuration options. /// /// public sealed class ElicitationCapability { - // Currently empty in the spec, but may be extended in the future. - /// /// Gets or sets the handler for processing requests. /// @@ -32,5 +37,7 @@ public sealed class ElicitationCapability /// /// [JsonIgnore] + [Obsolete($"Use {nameof(McpClientOptions.Handlers.ElicitationHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public Func>? ElicitationHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/ITransport.cs b/src/ModelContextProtocol.Core/Protocol/ITransport.cs index e35b3a6fb..148472e90 100644 --- a/src/ModelContextProtocol.Core/Protocol/ITransport.cs +++ b/src/ModelContextProtocol.Core/Protocol/ITransport.cs @@ -62,8 +62,8 @@ public interface ITransport : IAsyncDisposable /// /// /// This is a core method used by higher-level abstractions in the MCP protocol implementation. - /// Most client code should use the higher-level methods provided by , - /// , , or , + /// Most client code should use the higher-level methods provided by , + /// , or , /// rather than accessing this method directly. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index b3176937c..ae15453db 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Server; using System.ComponentModel; +using System.Security.Claims; using System.Text.Json; using System.Text.Json.Serialization; @@ -29,28 +30,21 @@ private protected JsonRpcMessage() public string JsonRpc { get; init; } = "2.0"; /// - /// Gets or sets the transport the was received on or should be sent over. + /// Gets or sets the contextual information for this JSON-RPC message. /// /// - /// This is used to support the Streamable HTTP transport where the specification states that the server - /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing - /// the corresponding JSON-RPC request. It may be for other transports. + /// This property contains transport-specific and runtime context information that accompanies + /// JSON-RPC messages but is not serialized as part of the JSON-RPC payload. This includes + /// transport references, execution context, and authenticated user information. /// - [JsonIgnore] - public ITransport? RelatedTransport { get; set; } - - /// - /// Gets or sets the that should be used to run any handlers - /// /// - /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, - /// the outlives the initial HTTP request context it was created on, and new - /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the - /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor - /// in tool calls. + /// This property should only be set when implementing a custom + /// that needs to pass additional per-message context or to pass a + /// to + /// or . /// [JsonIgnore] - public ExecutionContext? ExecutionContext { get; set; } + public JsonRpcMessageContext? Context { get; set; } /// /// Provides a for messages, diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs new file mode 100644 index 000000000..261796b5f --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -0,0 +1,61 @@ +using ModelContextProtocol.Server; +using System.Security.Claims; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Contains contextual information for JSON-RPC messages that is not part of the JSON-RPC protocol specification. +/// +/// +/// This class holds transport-specific and runtime context information that accompanies JSON-RPC messages +/// but is not serialized as part of the JSON-RPC payload. This includes transport references, execution context, +/// and authenticated user information. +/// +public class JsonRpcMessageContext +{ + /// + /// Gets or sets the transport the was received on or should be sent over. + /// + /// + /// This is used to support the Streamable HTTP transport where the specification states that the server + /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing + /// the corresponding JSON-RPC request. It may be for other transports. + /// + public ITransport? RelatedTransport { get; set; } + + /// + /// Gets or sets the that should be used to run any handlers + /// + /// + /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, + /// the outlives the initial HTTP request context it was created on, and new + /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the + /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor + /// in tool calls. + /// + public ExecutionContext? ExecutionContext { get; set; } + + /// + /// Gets or sets the authenticated user associated with this JSON-RPC message. + /// + /// + /// + /// This property contains the representing the authenticated user + /// who initiated this JSON-RPC message. This enables request handlers to access user identity + /// and authorization information without requiring dependency on HTTP context accessors + /// or other HTTP-specific abstractions. + /// + /// + /// The user information is automatically populated by the transport layer when processing + /// incoming HTTP requests in ASP.NET Core scenarios. For other transport types or scenarios + /// where user authentication is not applicable, this property may be . + /// + /// + /// This property is particularly useful in the Streamable HTTP transport where JSON-RPC messages + /// may outlive the original HTTP request context, allowing user identity to be preserved + /// throughout the message processing pipeline. + /// + /// + public ClaimsPrincipal? User { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs index ed6c8982a..e80b25f47 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Protocol; /// /// Requests are messages that require a response from the receiver. Each request includes a unique ID /// that will be included in the corresponding response message (either a success response or an error). -/// +/// /// The receiver of a request message is expected to execute the specified method with the provided parameters /// and return either a with the result, or a /// if the method execution fails. @@ -36,7 +36,7 @@ internal JsonRpcRequest WithId(RequestId id) Id = id, Method = Method, Params = Params, - RelatedTransport = RelatedTransport, + Context = Context, }; } } diff --git a/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs b/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs index ab43fb066..c166a223a 100644 --- a/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs @@ -1,5 +1,6 @@ -using ModelContextProtocol.Server; +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -7,16 +8,23 @@ namespace ModelContextProtocol.Protocol; /// Represents the logging capability configuration for a Model Context Protocol server. /// /// +/// /// This capability allows clients to set the logging level and receive log messages from the server. /// See the schema for details. +/// +/// +/// This class is intentionally empty as the Model Context Protocol specification does not +/// currently define additional properties for sampling capabilities. Future versions of the +/// specification may extend this capability with additional configuration options. +/// /// public sealed class LoggingCapability { - // Currently empty in the spec, but may be extended in the future - /// /// Gets or sets the handler for set logging level requests from clients. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? SetLoggingLevelHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.SetLoggingLevelHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? SetLoggingLevelHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/Prompt.cs b/src/ModelContextProtocol.Core/Protocol/Prompt.cs index 1a5004065..fcd3053f5 100644 --- a/src/ModelContextProtocol.Core/Protocol/Prompt.cs +++ b/src/ModelContextProtocol.Core/Protocol/Prompt.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -59,4 +60,10 @@ public sealed class Prompt : IBaseMetadata /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the callable server prompt corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerPrompt? McpServerPrompt { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs b/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs index 8fad1c0e0..223576254 100644 --- a/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Server; +using System.ComponentModel; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -9,7 +10,7 @@ namespace ModelContextProtocol.Protocol; /// /// /// The prompts capability allows a server to expose a collection of predefined prompt templates that clients -/// can discover and use. These prompts can be static (defined in the ) or +/// can discover and use. These prompts can be static (defined in the ) or /// dynamically generated through handlers. /// /// @@ -22,10 +23,10 @@ public sealed class PromptsCapability /// Gets or sets whether this server supports notifications for changes to the prompt list. /// /// - /// When set to , the server will send notifications using - /// when prompts are added, + /// When set to , the server will send notifications using + /// when prompts are added, /// removed, or modified. Clients can register handlers for these notifications to - /// refresh their prompt cache. This capability enables clients to stay synchronized with server-side changes + /// refresh their prompt cache. This capability enables clients to stay synchronized with server-side changes /// to available prompts. /// [JsonPropertyName("listChanged")] @@ -40,15 +41,17 @@ public sealed class PromptsCapability /// along with any prompts defined in . /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListPromptsHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.ListPromptsHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? ListPromptsHandler { get; set; } /// /// Gets or sets the handler for requests. /// /// /// - /// This handler is invoked when a client requests details for a specific prompt by name and provides arguments - /// for the prompt if needed. The handler receives the request context containing the prompt name and any arguments, + /// This handler is invoked when a client requests details for a specific prompt by name and provides arguments + /// for the prompt if needed. The handler receives the request context containing the prompt name and any arguments, /// and should return a with the prompt messages and other details. /// /// @@ -57,7 +60,9 @@ public sealed class PromptsCapability /// /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? GetPromptHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.GetPromptHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? GetPromptHandler { get; set; } /// /// Gets or sets a collection of prompts that will be served by the server. @@ -69,7 +74,7 @@ public sealed class PromptsCapability /// when those are provided: /// /// - /// - For requests: The server returns all prompts from this collection + /// - For requests: The server returns all prompts from this collection /// plus any additional prompts provided by the if it's set. /// /// @@ -78,5 +83,7 @@ public sealed class PromptsCapability /// /// [JsonIgnore] + [Obsolete($"Use {nameof(McpServerOptions.PromptCollection)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public McpServerPrimitiveCollection? PromptCollection { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/Reference.cs b/src/ModelContextProtocol.Core/Protocol/Reference.cs index a9c87fe49..af95cf330 100644 --- a/src/ModelContextProtocol.Core/Protocol/Reference.cs +++ b/src/ModelContextProtocol.Core/Protocol/Reference.cs @@ -12,7 +12,7 @@ namespace ModelContextProtocol.Protocol; /// /// /// -/// References are commonly used with to request completion suggestions for arguments, +/// References are commonly used with to request completion suggestions for arguments, /// and with other methods that need to reference resources or prompts. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/Resource.cs b/src/ModelContextProtocol.Core/Protocol/Resource.cs index 63dce7fdc..1b8a0e9cd 100644 --- a/src/ModelContextProtocol.Core/Protocol/Resource.cs +++ b/src/ModelContextProtocol.Core/Protocol/Resource.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -87,4 +88,10 @@ public sealed class Resource : IBaseMetadata /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; init; } + + /// + /// Gets or sets the callable server resource corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerResource? McpServerResource { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs b/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs index d2959d182..f0f294985 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -84,6 +85,12 @@ public sealed class ResourceTemplate : IBaseMetadata [JsonIgnore] public bool IsTemplated => UriTemplate.Contains('{'); + /// + /// Gets or sets the callable server resource corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerResource? McpServerResource { get; set; } + /// Converts the into a . /// A if is ; otherwise, . public Resource? AsResource() @@ -102,6 +109,7 @@ public sealed class ResourceTemplate : IBaseMetadata MimeType = MimeType, Annotations = Annotations, Meta = Meta, + McpServerResource = McpServerResource, }; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs b/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs index f6486488b..15ab02a06 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Server; +using System.ComponentModel; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -21,8 +22,8 @@ public sealed class ResourcesCapability /// Gets or sets whether this server supports notifications for changes to the resource list. /// /// - /// When set to , the server will send notifications using - /// when resources are added, + /// When set to , the server will send notifications using + /// when resources are added, /// removed, or modified. Clients can register handlers for these notifications to /// refresh their resource cache. /// @@ -39,7 +40,9 @@ public sealed class ResourcesCapability /// allowing clients to discover available resource types and their access patterns. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListResourceTemplatesHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.ListResourceTemplatesHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? ListResourceTemplatesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -49,7 +52,9 @@ public sealed class ResourcesCapability /// The implementation should return a with the matching resources. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListResourcesHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.ListResourcesHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? ListResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -61,7 +66,9 @@ public sealed class ResourcesCapability /// its contents in a ReadResourceResult object. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ReadResourceHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.ReadResourceHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? ReadResourceHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -74,7 +81,9 @@ public sealed class ResourcesCapability /// requiring polling. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? SubscribeToResourcesHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.SubscribeToResourcesHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? SubscribeToResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -85,7 +94,9 @@ public sealed class ResourcesCapability /// about the specified resource. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? UnsubscribeFromResourcesHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.UnsubscribeFromResourcesHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? UnsubscribeFromResourcesHandler { get; set; } /// /// Gets or sets a collection of resources served by the server. @@ -103,5 +114,7 @@ public sealed class ResourcesCapability /// /// [JsonIgnore] + [Obsolete($"Use {nameof(McpServerOptions.ResourceCollection)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public McpServerResourceCollection? ResourceCollection { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/RootsCapability.cs b/src/ModelContextProtocol.Core/Protocol/RootsCapability.cs index 60d20b94f..8e2bcacfe 100644 --- a/src/ModelContextProtocol.Core/Protocol/RootsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/RootsCapability.cs @@ -1,4 +1,6 @@ +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Client; namespace ModelContextProtocol.Protocol; @@ -40,5 +42,7 @@ public sealed class RootsCapability /// The handler receives request parameters and should return a containing the collection of available roots. /// [JsonIgnore] + [Obsolete($"Use {nameof(McpClientOptions.Handlers.RootsHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public Func>? RootsHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs index 6e0f1190a..e917b2af9 100644 --- a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs @@ -1,6 +1,7 @@ +using System.ComponentModel; +using System.Text.Json.Serialization; using Microsoft.Extensions.AI; using ModelContextProtocol.Client; -using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -13,13 +14,16 @@ namespace ModelContextProtocol.Protocol; /// /// /// When this capability is enabled, an MCP server can request the client to generate content -/// using an AI model. The client must set a to process these requests. +/// using an AI model. The client must set a to process these requests. +/// +/// +/// This class is intentionally empty as the Model Context Protocol specification does not +/// currently define additional properties for sampling capabilities. Future versions of the +/// specification may extend this capability with additional configuration options. /// /// public sealed class SamplingCapability { - // Currently empty in the spec, but may be extended in the future - /// /// Gets or sets the handler for processing requests. /// @@ -34,10 +38,12 @@ public sealed class SamplingCapability /// generated content. /// /// - /// You can create a handler using the extension + /// You can create a handler using the extension /// method with any implementation of . /// /// [JsonIgnore] + [Obsolete($"Use {nameof(McpClientOptions.Handlers.SamplingHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public Func, CancellationToken, ValueTask>? SamplingHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs index 6a4b2e62a..ffe38d221 100644 --- a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs @@ -1,4 +1,6 @@ +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -21,13 +23,13 @@ public sealed class ServerCapabilities /// /// /// - /// The dictionary allows servers to advertise support for features that are not yet - /// standardized in the Model Context Protocol specification. This extension mechanism enables + /// The dictionary allows servers to advertise support for features that are not yet + /// standardized in the Model Context Protocol specification. This extension mechanism enables /// future protocol enhancements while maintaining backward compatibility. /// /// - /// Values in this dictionary are implementation-specific and should be coordinated between client - /// and server implementations. Clients should not assume the presence of any experimental capability + /// Values in this dictionary are implementation-specific and should be coordinated between client + /// and server implementations. Clients should not assume the presence of any experimental capability /// without checking for it first. /// /// @@ -77,10 +79,12 @@ public sealed class ServerCapabilities /// /// /// Handlers provided via will be registered with the server for the lifetime of the server. - /// For transient handlers, may be used to register a handler that can + /// For transient handlers, may be used to register a handler that can /// then be unregistered by disposing of the returned from the method. /// /// [JsonIgnore] + [Obsolete($"Use {nameof(McpServerOptions.Handlers.NotificationHandlers)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public IEnumerable>>? NotificationHandlers { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/Tool.cs b/src/ModelContextProtocol.Core/Protocol/Tool.cs index c09598ca7..1c4716691 100644 --- a/src/ModelContextProtocol.Core/Protocol/Tool.cs +++ b/src/ModelContextProtocol.Core/Protocol/Tool.cs @@ -1,6 +1,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -43,7 +44,7 @@ public sealed class Tool : IBaseMetadata /// if an invalid schema is provided. /// /// - /// The schema typically defines the properties (parameters) that the tool accepts, + /// The schema typically defines the properties (parameters) that the tool accepts, /// their types, and which ones are required. This helps AI models understand /// how to structure their calls to the tool. /// @@ -52,9 +53,9 @@ public sealed class Tool : IBaseMetadata /// /// [JsonPropertyName("inputSchema")] - public JsonElement InputSchema - { - get => field; + public JsonElement InputSchema + { + get => field; set { if (!McpJsonUtilities.IsValidMcpToolSchema(value)) @@ -114,4 +115,10 @@ public JsonElement? OutputSchema /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the callable server tool corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerTool? McpServerTool { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs b/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs index 5a3bec5ca..b6903a7db 100644 --- a/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs @@ -1,5 +1,6 @@ -using ModelContextProtocol.Server; +using System.ComponentModel; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -13,10 +14,10 @@ public sealed class ToolsCapability /// Gets or sets whether this server supports notifications for changes to the tool list. /// /// - /// When set to , the server will send notifications using - /// when tools are added, + /// When set to , the server will send notifications using + /// when tools are added, /// removed, or modified. Clients can register handlers for these notifications to - /// refresh their tool cache. This capability enables clients to stay synchronized with server-side + /// refresh their tool cache. This capability enables clients to stay synchronized with server-side /// changes to available tools. /// [JsonPropertyName("listChanged")] @@ -33,19 +34,23 @@ public sealed class ToolsCapability /// and the tools from the collection will be combined to form the complete list of available tools. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListToolsHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.ListToolsHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? ListToolsHandler { get; set; } /// /// Gets or sets the handler for requests. /// /// /// This handler is invoked when a client makes a call to a tool that isn't found in the . - /// The handler should implement logic to execute the requested tool and return appropriate results. - /// It receives a containing information about the tool + /// The handler should implement logic to execute the requested tool and return appropriate results. + /// It receives a containing information about the tool /// being called and its arguments, and should return a with the execution results. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? CallToolHandler { get; set; } + [Obsolete($"Use {nameof(McpServerOptions.Handlers.CallToolHandler)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public McpRequestHandler? CallToolHandler { get; set; } /// /// Gets or sets a collection of tools served by the server. @@ -59,5 +64,7 @@ public sealed class ToolsCapability /// will be invoked as a fallback. /// [JsonIgnore] + [Obsolete($"Use {nameof(McpServerOptions.ToolCollection)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public McpServerPrimitiveCollection? ToolCollection { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/README.md b/src/ModelContextProtocol.Core/README.md index beb365c80..f6cffaf68 100644 --- a/src/ModelContextProtocol.Core/README.md +++ b/src/ModelContextProtocol.Core/README.md @@ -27,8 +27,8 @@ dotnet add package ModelContextProtocol.Core --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClient.CreateAsync` method is used to instantiate and connect an `McpClient` +to a server. Once you have an `McpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions @@ -38,7 +38,7 @@ var clientTransport = new StdioClientTransport(new StdioClientTransportOptions Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); -var client = await McpClientFactory.CreateAsync(clientTransport); +var client = await McpClient.CreateAsync(clientTransport); // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) diff --git a/src/ModelContextProtocol.Core/RequestHandlers.cs b/src/ModelContextProtocol.Core/RequestHandlers.cs index 854a4bddf..fd95751d9 100644 --- a/src/ModelContextProtocol.Core/RequestHandlers.cs +++ b/src/ModelContextProtocol.Core/RequestHandlers.cs @@ -10,8 +10,8 @@ internal sealed class RequestHandlers : Dictionary /// Registers a handler for incoming requests of a specific method in the MCP protocol. /// - /// Type of request payload that will be deserialized from incoming JSON - /// Type of response payload that will be serialized to JSON (not full RPC response) + /// Type of request payload that will be deserialized from incoming JSON + /// Type of response payload that will be serialized to JSON (not full RPC response) /// Method identifier to register for (e.g., "tools/list", "logging/setLevel") /// Handler function to be called when a request with the specified method identifier is received /// The JSON contract governing request parameter deserialization @@ -23,15 +23,15 @@ internal sealed class RequestHandlers : Dictionary /// - /// The handler function receives the deserialized request object and a cancellation token, and should return - /// a response object that will be serialized back to the client. + /// The handler function receives the deserialized request object, the full JSON-RPC request, and a cancellation token, + /// and should return a response object that will be serialized back to the client. /// /// - public void Set( + public void Set( string method, - Func> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) + Func> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) { Throw.IfNull(method); Throw.IfNull(handler); @@ -40,8 +40,8 @@ public void Set( this[method] = async (request, cancellationToken) => { - TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); - object? result = await handler(typedRequest, request.RelatedTransport, cancellationToken).ConfigureAwait(false); + TParams? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); + object? result = await handler(typedRequest, request, cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; } diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs index d651d7ee3..ef068c551 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs @@ -11,6 +11,7 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt { + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. /// @@ -136,7 +137,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Arguments = args, }; - return new AIFunctionMcpServerPrompt(function, prompt); + return new AIFunctionMcpServerPrompt(function, prompt, options?.Metadata ?? []); } private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, McpServerPromptCreateOptions? options) @@ -154,6 +155,9 @@ private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, Mcp newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided + newOptions.Metadata ??= AIFunctionMcpServerTool.CreateMetadata(method); + return newOptions; } @@ -161,15 +165,20 @@ private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, Mcp internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt) + private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt, IReadOnlyList metadata) { AIFunction = function; ProtocolPrompt = prompt; + ProtocolPrompt.McpServerPrompt = this; + _metadata = metadata; } /// public override Prompt ProtocolPrompt { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask GetAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -177,7 +186,7 @@ public override async ValueTask GetAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; if (request.Params?.Arguments is { } argDict) diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs index a8b0d2486..69b8deb8d 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Globalization; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.RegularExpressions; @@ -17,6 +18,7 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource { private readonly Regex? _uriParser; private readonly string[] _templateVariableNames = []; + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. @@ -218,7 +220,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MimeType = options?.MimeType ?? "application/octet-stream", }; - return new AIFunctionMcpServerResource(function, resource); + return new AIFunctionMcpServerResource(function, resource, options?.Metadata ?? []); } private static McpServerResourceCreateOptions DeriveOptions(MemberInfo member, McpServerResourceCreateOptions? options) @@ -238,6 +240,12 @@ private static McpServerResourceCreateOptions DeriveOptions(MemberInfo member, M newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided and the member is a MethodInfo + if (member is MethodInfo method) + { + newOptions.Metadata ??= AIFunctionMcpServerTool.CreateMetadata(method); + } + return newOptions; } @@ -270,11 +278,13 @@ private static string DeriveUriTemplate(string name, AIFunction function) internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resourceTemplate) + private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resourceTemplate, IReadOnlyList metadata) { AIFunction = function; ProtocolResourceTemplate = resourceTemplate; + ProtocolResourceTemplate.McpServerResource = this; ProtocolResource = resourceTemplate.AsResource(); + _metadata = metadata; if (ProtocolResource is null) { @@ -289,6 +299,9 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour /// public override Resource? ProtocolResource { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask ReadAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -316,7 +329,7 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour } // Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI. - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; // For templates, populate the arguments from the URI template. diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index afd3912b6..cb4758486 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -1,7 +1,5 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Diagnostics; @@ -15,8 +13,8 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed partial class AIFunctionMcpServerTool : McpServerTool { - private readonly ILogger _logger; private readonly bool _structuredOutputRequiresWrapping; + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. @@ -26,7 +24,7 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool McpServerToolCreateOptions? options) { Throw.IfNull(method); - + options = DeriveOptions(method.Method, options); return Create(method.Method, method.Target, options); @@ -146,7 +144,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping); + return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping, options?.Metadata ?? []); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -186,6 +184,9 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided + newOptions.Metadata ??= CreateMetadata(method); + return newOptions; } @@ -193,17 +194,22 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping, IReadOnlyList metadata) { AIFunction = function; ProtocolTool = tool; - _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; + ProtocolTool.McpServerTool = this; + _structuredOutputRequiresWrapping = structuredOutputRequiresWrapping; + _metadata = metadata; } /// public override Tool ProtocolTool { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -211,7 +217,7 @@ public override async ValueTask InvokeAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; if (request.Params?.Arguments is { } argDict) @@ -223,24 +229,7 @@ public override async ValueTask InvokeAsync( } object? result; - try - { - result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (e is not OperationCanceledException) - { - ToolCallError(request.Params?.Name ?? string.Empty, e); - - string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; - - return new() - { - IsError = true, - Content = [new TextContentBlock { Text = errorMessage }], - }; - } + result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); JsonNode? structuredContent = CreateStructuredResponse(result); return result switch @@ -257,33 +246,27 @@ public override async ValueTask InvokeAsync( Content = [], StructuredContent = structuredContent, }, - + string text => new() { Content = [new TextContentBlock { Text = text }], StructuredContent = structuredContent, }, - + ContentBlock content => new() { Content = [content], StructuredContent = structuredContent, }, - - IEnumerable texts => new() - { - Content = [.. texts.Select(x => new TextContentBlock { Text = x ?? string.Empty })], - StructuredContent = structuredContent, - }, - + IEnumerable contentItems => ConvertAIContentEnumerableToCallToolResult(contentItems, structuredContent), - + IEnumerable contents => new() { Content = [.. contents], StructuredContent = structuredContent, }, - + CallToolResult callToolResponse => callToolResponse, _ => new() @@ -342,6 +325,26 @@ static bool IsAsyncMethod(MethodInfo method) } } + /// Creates metadata from attributes on the specified method and its declaring class, with the MethodInfo as the first item. + internal static IReadOnlyList CreateMetadata(MethodInfo method) + { + // Add the MethodInfo to the start of the metadata similar to what RouteEndpointDataSource does for minimal endpoints. + List metadata = [method]; + + // Add class-level attributes first, since those are less specific. + if (method.DeclaringType is not null) + { + metadata.AddRange(method.DeclaringType.GetCustomAttributes()); + } + + // Add method-level attributes second, since those are more specific. + // When metadata conflicts, later metadata usually takes precedence with exceptions for metadata like + // IAllowAnonymous which always take precedence over IAuthorizeData no matter the order. + metadata.AddRange(method.GetCustomAttributes()); + + return metadata.AsReadOnly(); + } + /// Regex that flags runs of characters other than ASCII digits or letters. #if NET [GeneratedRegex("[^0-9A-Za-z]+")] @@ -452,7 +455,4 @@ private static CallToolResult ConvertAIContentEnumerableToCallToolResult(IEnumer IsError = allErrorContent && hasAny }; } - - [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] - private partial void ToolCallError(string toolName, Exception exception); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index d286d1ef4..bbbc45dcc 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -3,34 +3,44 @@ namespace ModelContextProtocol.Server; -internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer +internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport? transport) : McpServer { - public string EndpointName => server.EndpointName; - public string? SessionId => transport?.SessionId ?? server.SessionId; - public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; - public Implementation? ClientInfo => server.ClientInfo; - public McpServerOptions ServerOptions => server.ServerOptions; - public IServiceProvider? Services => server.Services; - public LoggingLevel? LoggingLevel => server.LoggingLevel; + public override string? SessionId => transport?.SessionId ?? server.SessionId; + public override string? NegotiatedProtocolVersion => server.NegotiatedProtocolVersion; + public override ClientCapabilities? ClientCapabilities => server.ClientCapabilities; + public override Implementation? ClientInfo => server.ClientInfo; + public override McpServerOptions ServerOptions => server.ServerOptions; + public override IServiceProvider? Services => server.Services; + public override LoggingLevel? LoggingLevel => server.LoggingLevel; - public ValueTask DisposeAsync() => server.DisposeAsync(); + public override ValueTask DisposeAsync() => server.DisposeAsync(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); // This will throw because the server must already be running for this class to be constructed, but it should give us a good Exception message. - public Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); + public override Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - Debug.Assert(message.RelatedTransport is null); - message.RelatedTransport = transport; + if (message.Context is not null) + { + throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); + } + + message.Context = new JsonRpcMessageContext(); + message.Context.RelatedTransport = transport; return server.SendMessageAsync(message, cancellationToken); } - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { - Debug.Assert(request.RelatedTransport is null); - request.RelatedTransport = transport; + if (request.Context is not null) + { + throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); + } + + request.Context = new JsonRpcMessageContext(); + request.Context.RelatedTransport = transport; return server.SendRequestAsync(request, cancellationToken); } } diff --git a/src/ModelContextProtocol.Core/Server/IMcpServer.cs b/src/ModelContextProtocol.Core/Server/IMcpServer.cs index ec2b87ade..8b88aa7a2 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpServer.cs @@ -1,10 +1,13 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; +using System.ComponentModel; namespace ModelContextProtocol.Server; /// /// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. /// +[Obsolete($"Use {nameof(McpServer)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 +[EditorBrowsable(EditorBrowsableState.Never)] public interface IMcpServer : IMcpEndpoint { /// diff --git a/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs b/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs index 597fdec97..f3ec62219 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs @@ -7,4 +7,13 @@ public interface IMcpServerPrimitive { /// Gets the unique identifier of the primitive. string Id { get; } + + /// + /// Gets the metadata for this primitive instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + IReadOnlyList Metadata { get; } } diff --git a/src/ModelContextProtocol.Core/Server/McpRequestFilter.cs b/src/ModelContextProtocol.Core/Server/McpRequestFilter.cs new file mode 100644 index 000000000..bc1cabc45 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpRequestFilter.cs @@ -0,0 +1,11 @@ +namespace ModelContextProtocol.Server; + +/// +/// Delegate type for applying filters to incoming MCP requests with specific parameter and result types. +/// +/// The type of the parameters sent with the request. +/// The type of the response returned by the handler. +/// The next request handler in the pipeline. +/// The next request handler wrapped with the filter. +public delegate McpRequestHandler McpRequestFilter( + McpRequestHandler next); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpRequestHandler.cs b/src/ModelContextProtocol.Core/Server/McpRequestHandler.cs new file mode 100644 index 000000000..651e070e5 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpRequestHandler.cs @@ -0,0 +1,13 @@ +namespace ModelContextProtocol.Server; + +/// +/// Delegate type for handling incoming MCP requests with specific parameter and result types. +/// +/// The type of the parameters sent with the request. +/// The type of the response returned by the handler. +/// The request context containing the parameters and other metadata. +/// A cancellation token to cancel the operation. +/// A task representing the asynchronous operation, with the result of the handler. +public delegate ValueTask McpRequestHandler( + RequestContext request, + CancellationToken cancellationToken); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs new file mode 100644 index 000000000..00fc0a7cc --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs @@ -0,0 +1,557 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Server; + +/// +/// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpServer : McpSession, IMcpServer +#pragma warning restore CS0618 // Type or member is obsolete +{ + /// + /// Caches request schemas for elicitation requests based on the type and serializer options. + /// + private static readonly ConditionalWeakTable> s_elicitResultSchemaCache = new(); + + private static Dictionary>? s_elicitAllowedProperties = null; + + /// + /// Creates a new instance of an . + /// + /// Transport to use for the server representing an already-established MCP session. + /// Configuration options for this server, including capabilities. + /// Logger factory to use for logging. If null, logging will be disabled. + /// Optional service provider to create new instances of tools and other dependencies. + /// An instance that should be disposed when no longer needed. + /// is . + /// is . + public static McpServer Create( + ITransport transport, + McpServerOptions serverOptions, + ILoggerFactory? loggerFactory = null, + IServiceProvider? serviceProvider = null) + { + Throw.IfNull(transport); + Throw.IfNull(serverOptions); + + return new McpServerImpl(transport, serverOptions, loggerFactory, serviceProvider); + } + + /// + /// Requests to sample an LLM via the client using the specified request parameters. + /// + /// The parameters for the sampling request. + /// The to monitor for cancellation requests. + /// A task containing the sampling result from the client. + /// The client does not support sampling. + public ValueTask SampleAsync( + CreateMessageRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfSamplingUnsupported(); + + return SendRequestAsync( + RequestMethods.SamplingCreateMessage, + request, + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests to sample an LLM via the client using the provided chat messages and options. + /// + /// The messages to send as part of the request. + /// The options to use for the request, including model parameters and constraints. + /// The to monitor for cancellation requests. The default is . + /// A task containing the chat response from the model. + /// is . + /// The client does not support sampling. + public async Task SampleAsync( + IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) + { + Throw.IfNull(messages); + + StringBuilder? systemPrompt = null; + + if (options?.Instructions is { } instructions) + { + (systemPrompt ??= new()).Append(instructions); + } + + List samplingMessages = []; + foreach (var message in messages) + { + if (message.Role == ChatRole.System) + { + if (systemPrompt is null) + { + systemPrompt = new(); + } + else + { + systemPrompt.AppendLine(); + } + + systemPrompt.Append(message.Text); + continue; + } + + if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) + { + Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; + + foreach (var content in message.Contents) + { + switch (content) + { + case TextContent textContent: + samplingMessages.Add(new() + { + Role = role, + Content = new TextContentBlock { Text = textContent.Text }, + }); + break; + + case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): + samplingMessages.Add(new() + { + Role = role, + Content = dataContent.HasTopLevelMediaType("image") ? + new ImageContentBlock + { + MimeType = dataContent.MediaType, + Data = dataContent.Base64Data.ToString(), + } : + new AudioContentBlock + { + MimeType = dataContent.MediaType, + Data = dataContent.Base64Data.ToString(), + }, + }); + break; + } + } + } + } + + ModelPreferences? modelPreferences = null; + if (options?.ModelId is { } modelId) + { + modelPreferences = new() { Hints = [new() { Name = modelId }] }; + } + + var result = await SampleAsync(new() + { + Messages = samplingMessages, + MaxTokens = options?.MaxOutputTokens, + StopSequences = options?.StopSequences?.ToArray(), + SystemPrompt = systemPrompt?.ToString(), + Temperature = options?.Temperature, + ModelPreferences = modelPreferences, + }, cancellationToken).ConfigureAwait(false); + + AIContent? responseContent = result.Content.ToAIContent(); + + return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) + { + ModelId = result.Model, + FinishReason = result.StopReason switch + { + "maxTokens" => ChatFinishReason.Length, + "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, + } + }; + } + + /// + /// Creates an wrapper that can be used to send sampling requests to the client. + /// + /// The that can be used to issue sampling requests to the client. + /// The client does not support sampling. + public IChatClient AsSamplingChatClient() + { + ThrowIfSamplingUnsupported(); + return new SamplingChatClient(this); + } + + /// Gets an on which logged messages will be sent as notifications to the client. + /// An that can be used to log to the client.. + public ILoggerProvider AsClientLoggerProvider() + { + return new ClientLoggerProvider(this); + } + + /// + /// Requests the client to list the roots it exposes. + /// + /// The parameters for the list roots request. + /// The to monitor for cancellation requests. + /// A task containing the list of roots exposed by the client. + /// The client does not support roots. + public ValueTask RequestRootsAsync( + ListRootsRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfRootsUnsupported(); + + return SendRequestAsync( + RequestMethods.RootsList, + request, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests additional information from the user via the client, allowing the server to elicit structured data. + /// + /// The parameters for the elicitation request. + /// The to monitor for cancellation requests. + /// A task containing the elicitation result. + /// The client does not support elicitation. + public ValueTask ElicitAsync( + ElicitRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfElicitationUnsupported(); + + return SendRequestAsync( + RequestMethods.ElicitationCreate, + request, + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests additional information from the user via the client, constructing a request schema from the + /// public serializable properties of and deserializing the response into . + /// + /// The type describing the expected input shape. Only primitive members are supported (string, number, boolean, enum). + /// The message to present to the user. + /// Serializer options that influence property naming and deserialization. + /// The to monitor for cancellation requests. + /// An with the user's response, if accepted. + /// + /// Elicitation uses a constrained subset of JSON Schema and only supports strings, numbers/integers, booleans and string enums. + /// Unsupported member types are ignored when constructing the schema. + /// + public async ValueTask> ElicitAsync( + string message, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + ThrowIfElicitationUnsupported(); + + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + var dict = s_elicitResultSchemaCache.GetValue(serializerOptions, _ => new()); + +#if NET + var schema = dict.GetOrAdd(typeof(T), static (t, s) => BuildRequestSchema(t, s), serializerOptions); +#else + var schema = dict.GetOrAdd(typeof(T), type => BuildRequestSchema(type, serializerOptions)); +#endif + + var request = new ElicitRequestParams + { + Message = message, + RequestedSchema = schema, + }; + + var raw = await ElicitAsync(request, cancellationToken).ConfigureAwait(false); + + if (!raw.IsAccepted || raw.Content is null) + { + return new ElicitResult { Action = raw.Action, Content = default }; + } + + var obj = new JsonObject(); + foreach (var kvp in raw.Content) + { + obj[kvp.Key] = JsonNode.Parse(kvp.Value.GetRawText()); + } + + T? typed = JsonSerializer.Deserialize(obj, serializerOptions.GetTypeInfo()); + return new ElicitResult { Action = raw.Action, Content = typed }; + } + + /// + /// Builds a request schema for elicitation based on the public serializable properties of . + /// + /// The type of the schema being built. + /// The serializer options to use. + /// The built request schema. + /// + private static ElicitRequestParams.RequestSchema BuildRequestSchema(Type type, JsonSerializerOptions serializerOptions) + { + var schema = new ElicitRequestParams.RequestSchema(); + var props = schema.Properties; + + JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(type); + + if (typeInfo.Kind != JsonTypeInfoKind.Object) + { + throw new McpException($"Type '{type.FullName}' is not supported for elicitation requests."); + } + + foreach (JsonPropertyInfo pi in typeInfo.Properties) + { + var def = CreatePrimitiveSchema(pi.PropertyType, serializerOptions); + props[pi.Name] = def; + } + + return schema; + } + + /// + /// Creates a primitive schema definition for the specified type, if supported. + /// + /// The type to create the schema for. + /// The serializer options to use. + /// The created primitive schema definition. + /// Thrown when the type is not supported. + private static ElicitRequestParams.PrimitiveSchemaDefinition CreatePrimitiveSchema(Type type, JsonSerializerOptions serializerOptions) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests. Nullable types are not supported."); + } + + var typeInfo = serializerOptions.GetTypeInfo(type); + + if (typeInfo.Kind != JsonTypeInfoKind.None) + { + throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); + } + + var jsonElement = AIJsonUtilities.CreateJsonSchema(type, serializerOptions: serializerOptions); + + if (!TryValidateElicitationPrimitiveSchema(jsonElement, type, out var error)) + { + throw new McpException(error); + } + + var primitiveSchemaDefinition = + jsonElement.Deserialize(McpJsonUtilities.JsonContext.Default.PrimitiveSchemaDefinition); + + if (primitiveSchemaDefinition is null) + throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); + + return primitiveSchemaDefinition; + } + + /// + /// Validate the produced schema strictly to the subset we support. We only accept an object schema + /// with a supported primitive type keyword and no additional unsupported keywords.Reject things like + /// {}, 'true', or schemas that include unrelated keywords(e.g.items, properties, patternProperties, etc.). + /// + /// The schema to validate. + /// The type of the schema being validated, just for reporting errors. + /// The error message, if validation fails. + /// + private static bool TryValidateElicitationPrimitiveSchema(JsonElement schema, Type type, + [NotNullWhen(false)] out string? error) + { + if (schema.ValueKind is not JsonValueKind.Object) + { + error = $"Schema generated for type '{type.FullName}' is invalid: expected an object schema."; + return false; + } + + if (!schema.TryGetProperty("type", out JsonElement typeProperty) + || typeProperty.ValueKind is not JsonValueKind.String) + { + error = $"Schema generated for type '{type.FullName}' is invalid: missing or invalid 'type' keyword."; + return false; + } + + var typeKeyword = typeProperty.GetString(); + + if (string.IsNullOrEmpty(typeKeyword)) + { + error = $"Schema generated for type '{type.FullName}' is invalid: empty 'type' value."; + return false; + } + + if (typeKeyword is not ("string" or "number" or "integer" or "boolean")) + { + error = $"Schema generated for type '{type.FullName}' is invalid: unsupported primitive type '{typeKeyword}'."; + return false; + } + + s_elicitAllowedProperties ??= new() + { + ["string"] = ["type", "title", "description", "minLength", "maxLength", "format", "enum", "enumNames"], + ["number"] = ["type", "title", "description", "minimum", "maximum"], + ["integer"] = ["type", "title", "description", "minimum", "maximum"], + ["boolean"] = ["type", "title", "description", "default"] + }; + + var allowed = s_elicitAllowedProperties[typeKeyword]; + + foreach (JsonProperty prop in schema.EnumerateObject()) + { + if (!allowed.Contains(prop.Name)) + { + error = $"The property '{type.FullName}.{prop.Name}' is not supported for elicitation."; + return false; + } + } + + error = string.Empty; + return true; + } + + private void ThrowIfSamplingUnsupported() + { + if (ClientCapabilities?.Sampling is null) + { + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Sampling is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support sampling."); + } + } + + private void ThrowIfRootsUnsupported() + { + if (ClientCapabilities?.Roots is null) + { + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Roots are not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support roots."); + } + } + + private void ThrowIfElicitationUnsupported() + { + if (ClientCapabilities?.Elicitation is null) + { + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Elicitation is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support elicitation requests."); + } + } + + /// Provides an implementation that's implemented via client sampling. + private sealed class SamplingChatClient : IChatClient + { + private readonly McpServer _server; + + public SamplingChatClient(McpServer server) => _server = server; + + /// + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + _server.SampleAsync(messages, options, cancellationToken); + + /// + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + foreach (var update in response.ToChatResponseUpdates()) + { + yield return update; + } + } + + /// + object? IChatClient.GetService(Type serviceType, object? serviceKey) + { + Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(_server) ? _server : + null; + } + + /// + void IDisposable.Dispose() { } // nop + } + + /// + /// Provides an implementation for creating loggers + /// that send logging message notifications to the client for logged messages. + /// + private sealed class ClientLoggerProvider : ILoggerProvider + { + private readonly McpServer _server; + + public ClientLoggerProvider(McpServer server) => _server = server; + + /// + public ILogger CreateLogger(string categoryName) + { + Throw.IfNull(categoryName); + + return new ClientLogger(_server, categoryName); + } + + /// + void IDisposable.Dispose() { } + + private sealed class ClientLogger : ILogger + { + private readonly McpServer _server; + private readonly string _categoryName; + + public ClientLogger(McpServer server, string categoryName) + { + _server = server; + _categoryName = categoryName; + } + + /// + public IDisposable? BeginScope(TState state) where TState : notnull => + null; + + /// + public bool IsEnabled(LogLevel logLevel) => + _server?.LoggingLevel is { } loggingLevel && + McpServerImpl.ToLoggingLevel(logLevel) >= loggingLevel; + + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + Throw.IfNull(formatter); + + LogInternal(logLevel, formatter(state, exception)); + + void LogInternal(LogLevel level, string message) + { + _ = _server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams + { + Level = McpServerImpl.ToLoggingLevel(level), + Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), + Logger = _categoryName, + }); + } + } + } + } +} diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6c5858f91..02c17de1a 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -1,597 +1,64 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; -using System.Runtime.CompilerServices; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; -/// -internal sealed class McpServer : McpEndpoint, IMcpServer +/// +/// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpServer : McpSession, IMcpServer +#pragma warning restore CS0618 // Type or member is obsolete { - internal static Implementation DefaultImplementation { get; } = new() - { - Name = DefaultAssemblyName.Name ?? nameof(McpServer), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; - - private readonly ITransport _sessionTransport; - private readonly bool _servicesScopePerRequest; - private readonly List _disposables = []; - - private readonly string _serverOnlyEndpointName; - private string? _endpointName; - private int _started; - - /// Holds a boxed value for the server. + /// + /// Gets the capabilities supported by the client. + /// /// - /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box - /// rather than a nullable to be able to manipulate it atomically. + /// + /// These capabilities are established during the initialization handshake and indicate + /// which features the client supports, such as sampling, roots, and other + /// protocol-specific functionality. + /// + /// + /// Server implementations can check these capabilities to determine which features + /// are available when interacting with the client. + /// /// - private StrongBox? _loggingLevel; + public abstract ClientCapabilities? ClientCapabilities { get; } /// - /// Creates a new instance of . + /// Gets the version and implementation information of the connected client. /// - /// Transport to use for the server representing an already-established session. - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. - /// Logger factory to use for logging - /// Optional service provider to use for dependency injection - /// The server was incorrectly configured. - public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : base(loggerFactory) - { - Throw.IfNull(transport); - Throw.IfNull(options); - - options ??= new(); - - _sessionTransport = transport; - ServerOptions = options; - Services = serviceProvider; - _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; - _servicesScopePerRequest = options.ScopeRequests; - - ClientInfo = options.KnownClientInfo; - UpdateEndpointNameWithClientInfo(); - - // Configure all request handlers based on the supplied options. - ServerCapabilities = new(); - ConfigureInitialize(options); - ConfigureTools(options); - ConfigurePrompts(options); - ConfigureResources(options); - ConfigureLogging(options); - ConfigureCompletion(options); - ConfigureExperimental(options); - ConfigurePing(); - - // Register any notification handlers that were provided. - if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) - { - NotificationHandlers.RegisterRange(notificationHandlers); - } - - // Now that everything has been configured, subscribe to any necessary notifications. - if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) - { - Register(ServerOptions.Capabilities?.Tools?.ToolCollection, NotificationMethods.ToolListChangedNotification); - Register(ServerOptions.Capabilities?.Prompts?.PromptCollection, NotificationMethods.PromptListChangedNotification); - Register(ServerOptions.Capabilities?.Resources?.ResourceCollection, NotificationMethods.ResourceListChangedNotification); - - void Register(McpServerPrimitiveCollection? collection, string notificationMethod) - where TPrimitive : IMcpServerPrimitive - { - if (collection is not null) - { - EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(notificationMethod); - collection.Changed += changed; - _disposables.Add(() => collection.Changed -= changed); - } - } - } - - // And initialize the session. - InitializeSession(transport); - } - - /// - public string? SessionId => _sessionTransport.SessionId; - - /// - public ServerCapabilities ServerCapabilities { get; } = new(); - - /// - public ClientCapabilities? ClientCapabilities { get; set; } - - /// - public Implementation? ClientInfo { get; set; } - - /// - public McpServerOptions ServerOptions { get; } - - /// - public IServiceProvider? Services { get; } - - /// - public override string EndpointName => _endpointName ?? _serverOnlyEndpointName; - - /// - public LoggingLevel? LoggingLevel => _loggingLevel?.Value; - - /// - public async Task RunAsync(CancellationToken cancellationToken = default) - { - if (Interlocked.Exchange(ref _started, 1) != 0) - { - throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); - } - - try - { - StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); - await MessageProcessingTask.ConfigureAwait(false); - } - finally - { - await DisposeAsync().ConfigureAwait(false); - } - } - - public override async ValueTask DisposeUnsynchronizedAsync() - { - _disposables.ForEach(d => d()); - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - - private void ConfigurePing() - { - SetHandler(RequestMethods.Ping, - async (request, _) => new PingResult(), - McpJsonUtilities.JsonContext.Default.JsonNode, - McpJsonUtilities.JsonContext.Default.PingResult); - } - - private void ConfigureInitialize(McpServerOptions options) - { - RequestHandlers.Set(RequestMethods.Initialize, - async (request, _, _) => - { - ClientCapabilities = request?.Capabilities ?? new(); - ClientInfo = request?.ClientInfo; - - // Use the ClientInfo to update the session EndpointName for logging. - UpdateEndpointNameWithClientInfo(); - GetSessionOrThrow().EndpointName = EndpointName; - - // Negotiate a protocol version. If the server options provide one, use that. - // Otherwise, try to use whatever the client requested as long as it's supported. - // If it's not supported, fall back to the latest supported version. - string? protocolVersion = options.ProtocolVersion; - if (protocolVersion is null) - { - protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSession.SupportedProtocolVersions.Contains(clientProtocolVersion) ? - clientProtocolVersion : - McpSession.LatestProtocolVersion; - } - - return new InitializeResult - { - ProtocolVersion = protocolVersion, - Instructions = options.ServerInstructions, - ServerInfo = options.ServerInfo ?? DefaultImplementation, - Capabilities = ServerCapabilities ?? new(), - }; - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult); - } - - private void ConfigureCompletion(McpServerOptions options) - { - if (options.Capabilities?.Completions is not { } completionsCapability) - { - return; - } - - ServerCapabilities.Completions = new() - { - CompleteHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()) - }; - - SetHandler( - RequestMethods.CompletionComplete, - ServerCapabilities.Completions.CompleteHandler, - McpJsonUtilities.JsonContext.Default.CompleteRequestParams, - McpJsonUtilities.JsonContext.Default.CompleteResult); - } - - private void ConfigureExperimental(McpServerOptions options) - { - ServerCapabilities.Experimental = options.Capabilities?.Experimental; - } - - private void ConfigureResources(McpServerOptions options) - { - if (options.Capabilities?.Resources is not { } resourcesCapability) - { - return; - } - - ServerCapabilities.Resources = new(); - - var listResourcesHandler = resourcesCapability.ListResourcesHandler ?? (static async (_, __) => new ListResourcesResult()); - var listResourceTemplatesHandler = resourcesCapability.ListResourceTemplatesHandler ?? (static async (_, __) => new ListResourceTemplatesResult()); - var readResourceHandler = resourcesCapability.ReadResourceHandler ?? (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); - var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler ?? (static async (_, __) => new EmptyResult()); - var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler ?? (static async (_, __) => new EmptyResult()); - var resources = resourcesCapability.ResourceCollection; - var listChanged = resourcesCapability.ListChanged; - var subscribe = resourcesCapability.Subscribe; - - // Handle resources provided via DI. - if (resources is { IsEmpty: false }) - { - var originalListResourcesHandler = listResourcesHandler; - listResourcesHandler = async (request, cancellationToken) => - { - ListResourcesResult result = originalListResourcesHandler is not null ? - await originalListResourcesHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var r in resources) - { - if (r.ProtocolResource is { } resource) - { - result.Resources.Add(resource); - } - } - } - - return result; - }; - - var originalListResourceTemplatesHandler = listResourceTemplatesHandler; - listResourceTemplatesHandler = async (request, cancellationToken) => - { - ListResourceTemplatesResult result = originalListResourceTemplatesHandler is not null ? - await originalListResourceTemplatesHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var rt in resources) - { - if (rt.IsTemplated) - { - result.ResourceTemplates.Add(rt.ProtocolResourceTemplate); - } - } - } - - return result; - }; - - // Synthesize read resource handler, which covers both resources and resource templates. - var originalReadResourceHandler = readResourceHandler; - readResourceHandler = async (request, cancellationToken) => - { - if (request.Params?.Uri is string uri) - { - // First try an O(1) lookup by exact match. - if (resources.TryGetPrimitive(uri, out var resource)) - { - if (await resource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } - - // Fall back to an O(N) lookup, trying to match against each URI template. - // The number of templates is controlled by the server developer, and the number is expected to be - // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. - foreach (var resourceTemplate in resources) - { - if (await resourceTemplate.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } - } - - // Finally fall back to the handler. - return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); - }; - - listChanged = true; - - // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. - // subscribe = true; - } - - ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; - ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; - ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; - ServerCapabilities.Resources.ResourceCollection = resources; - ServerCapabilities.Resources.SubscribeToResourcesHandler = subscribeHandler; - ServerCapabilities.Resources.UnsubscribeFromResourcesHandler = unsubscribeHandler; - ServerCapabilities.Resources.ListChanged = listChanged; - ServerCapabilities.Resources.Subscribe = subscribe; - - SetHandler( - RequestMethods.ResourcesList, - listResourcesHandler, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult); - - SetHandler( - RequestMethods.ResourcesTemplatesList, - listResourceTemplatesHandler, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); - - SetHandler( - RequestMethods.ResourcesRead, - readResourceHandler, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult); - - SetHandler( - RequestMethods.ResourcesSubscribe, - subscribeHandler, - McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - - SetHandler( - RequestMethods.ResourcesUnsubscribe, - unsubscribeHandler, - McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - } - - private void ConfigurePrompts(McpServerOptions options) - { - if (options.Capabilities?.Prompts is not { } promptsCapability) - { - return; - } - - ServerCapabilities.Prompts = new(); - - var listPromptsHandler = promptsCapability.ListPromptsHandler ?? (static async (_, __) => new ListPromptsResult()); - var getPromptHandler = promptsCapability.GetPromptHandler ?? (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - var prompts = promptsCapability.PromptCollection; - var listChanged = promptsCapability.ListChanged; - - // Handle tools provided via DI by augmenting the handlers to incorporate them. - if (prompts is { IsEmpty: false }) - { - var originalListPromptsHandler = listPromptsHandler; - listPromptsHandler = async (request, cancellationToken) => - { - ListPromptsResult result = originalListPromptsHandler is not null ? - await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var p in prompts) - { - result.Prompts.Add(p.ProtocolPrompt); - } - } - - return result; - }; - - var originalGetPromptHandler = getPromptHandler; - getPromptHandler = (request, cancellationToken) => - { - if (request.Params is not null && - prompts.TryGetPrimitive(request.Params.Name, out var prompt)) - { - return prompt.GetAsync(request, cancellationToken); - } - - return originalGetPromptHandler(request, cancellationToken); - }; - - listChanged = true; - } - - ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; - ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; - ServerCapabilities.Prompts.PromptCollection = prompts; - ServerCapabilities.Prompts.ListChanged = listChanged; - - SetHandler( - RequestMethods.PromptsList, - listPromptsHandler, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult); - - SetHandler( - RequestMethods.PromptsGet, - getPromptHandler, - McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, - McpJsonUtilities.JsonContext.Default.GetPromptResult); - } - - private void ConfigureTools(McpServerOptions options) - { - if (options.Capabilities?.Tools is not { } toolsCapability) - { - return; - } - - ServerCapabilities.Tools = new(); - - var listToolsHandler = toolsCapability.ListToolsHandler ?? (static async (_, __) => new ListToolsResult()); - var callToolHandler = toolsCapability.CallToolHandler ?? (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - var tools = toolsCapability.ToolCollection; - var listChanged = toolsCapability.ListChanged; - - // Handle tools provided via DI by augmenting the handlers to incorporate them. - if (tools is { IsEmpty: false }) - { - var originalListToolsHandler = listToolsHandler; - listToolsHandler = async (request, cancellationToken) => - { - ListToolsResult result = originalListToolsHandler is not null ? - await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var t in tools) - { - result.Tools.Add(t.ProtocolTool); - } - } - - return result; - }; - - var originalCallToolHandler = callToolHandler; - callToolHandler = (request, cancellationToken) => - { - if (request.Params is not null && - tools.TryGetPrimitive(request.Params.Name, out var tool)) - { - return tool.InvokeAsync(request, cancellationToken); - } - - return originalCallToolHandler(request, cancellationToken); - }; - - listChanged = true; - } - - ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; - ServerCapabilities.Tools.CallToolHandler = callToolHandler; - ServerCapabilities.Tools.ToolCollection = tools; - ServerCapabilities.Tools.ListChanged = listChanged; - - SetHandler( - RequestMethods.ToolsList, - listToolsHandler, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult); - - SetHandler( - RequestMethods.ToolsCall, - callToolHandler, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult); - } - - private void ConfigureLogging(McpServerOptions options) - { - // We don't require that the handler be provided, as we always store the provided log level to the server. - var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; - - ServerCapabilities.Logging = new(); - ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; - - RequestHandlers.Set( - RequestMethods.LoggingSetLevel, - (request, destinationTransport, cancellationToken) => - { - // Store the provided level. - if (request is not null) - { - if (_loggingLevel is null) - { - Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); - } - - _loggingLevel.Value = request.Level; - } - - // If a handler was provided, now delegate to it. - if (setLoggingLevelHandler is not null) - { - return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken); - } - - // Otherwise, consider it handled. - return new ValueTask(EmptyResult.Instance); - }, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - } - - private ValueTask InvokeHandlerAsync( - Func, CancellationToken, ValueTask> handler, - TParams? args, - ITransport? destinationTransport = null, - CancellationToken cancellationToken = default) - { - return _servicesScopePerRequest ? - InvokeScopedAsync(handler, args, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken); - - async ValueTask InvokeScopedAsync( - Func, CancellationToken, ValueTask> handler, - TParams? args, - CancellationToken cancellationToken) - { - var scope = Services?.GetService()?.CreateAsyncScope(); - try - { - return await handler( - new RequestContext(new DestinationBoundMcpServer(this, destinationTransport)) - { - Services = scope?.ServiceProvider ?? Services, - Params = args - }, - cancellationToken).ConfigureAwait(false); - } - finally - { - if (scope is not null) - { - await scope.Value.DisposeAsync().ConfigureAwait(false); - } - } - } - } + /// + /// + /// This property contains identification information about the client that has connected to this server, + /// including its name and version. This information is provided by the client during initialization. + /// + /// + /// Server implementations can use this information for logging, tracking client versions, + /// or implementing client-specific behaviors. + /// + /// + public abstract Implementation? ClientInfo { get; } - private void SetHandler( - string method, - Func, CancellationToken, ValueTask> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) - { - RequestHandlers.Set(method, - (request, destinationTransport, cancellationToken) => - InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), - requestTypeInfo, responseTypeInfo); - } + /// + /// Gets the options used to construct this server. + /// + /// + /// These options define the server's capabilities, protocol version, and other configuration + /// settings that were used to initialize the server. + /// + public abstract McpServerOptions ServerOptions { get; } - private void UpdateEndpointNameWithClientInfo() - { - if (ClientInfo is null) - { - return; - } + /// + /// Gets the service provider for the server. + /// + public abstract IServiceProvider? Services { get; } - _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; - } + /// Gets the last logging level set by the client, or if it's never been set. + public abstract LoggingLevel? LoggingLevel { get; } - /// Maps a to a . - internal static LoggingLevel ToLoggingLevel(LogLevel level) => - level switch - { - LogLevel.Trace => Protocol.LoggingLevel.Debug, - LogLevel.Debug => Protocol.LoggingLevel.Debug, - LogLevel.Information => Protocol.LoggingLevel.Info, - LogLevel.Warning => Protocol.LoggingLevel.Warning, - LogLevel.Error => Protocol.LoggingLevel.Error, - LogLevel.Critical => Protocol.LoggingLevel.Critical, - _ => Protocol.LoggingLevel.Emergency, - }; + /// + /// Runs the server, listening for and handling client requests. + /// + public abstract Task RunAsync(CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs index 277ed737b..cd8c368a1 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs @@ -1,8 +1,9 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; -using System.Text; using System.Text.Json; namespace ModelContextProtocol.Server; @@ -26,19 +27,11 @@ public static class McpServerExtensions /// It allows detailed control over sampling parameters including messages, system prompt, temperature, /// and token limits. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.SampleAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask SampleAsync( this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfSamplingUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.SamplingCreateMessage, - request, - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).SampleAsync(request, cancellationToken); /// /// Requests to sample an LLM via the client using the provided chat messages and options. @@ -55,104 +48,12 @@ public static ValueTask SampleAsync( /// This method converts the provided chat messages into a format suitable for the sampling API, /// handling different content types such as text, images, and audio. /// - public static async Task SampleAsync( + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.SampleAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] + public static Task SampleAsync( this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - Throw.IfNull(messages); - - StringBuilder? systemPrompt = null; - - if (options?.Instructions is { } instructions) - { - (systemPrompt ??= new()).Append(instructions); - } - - List samplingMessages = []; - foreach (var message in messages) - { - if (message.Role == ChatRole.System) - { - if (systemPrompt is null) - { - systemPrompt = new(); - } - else - { - systemPrompt.AppendLine(); - } - - systemPrompt.Append(message.Text); - continue; - } - - if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) - { - Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; - - foreach (var content in message.Contents) - { - switch (content) - { - case TextContent textContent: - samplingMessages.Add(new() - { - Role = role, - Content = new TextContentBlock { Text = textContent.Text }, - }); - break; - - case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): - samplingMessages.Add(new() - { - Role = role, - Content = dataContent.HasTopLevelMediaType("image") ? - new ImageContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - } : - new AudioContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - }, - }); - break; - } - } - } - } - - ModelPreferences? modelPreferences = null; - if (options?.ModelId is { } modelId) - { - modelPreferences = new() { Hints = [new() { Name = modelId }] }; - } - - var result = await server.SampleAsync(new() - { - Messages = samplingMessages, - MaxTokens = options?.MaxOutputTokens, - StopSequences = options?.StopSequences?.ToArray(), - SystemPrompt = systemPrompt?.ToString(), - Temperature = options?.Temperature, - ModelPreferences = modelPreferences, - }, cancellationToken).ConfigureAwait(false); - - AIContent? responseContent = result.Content.ToAIContent(); - - return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) - { - ModelId = result.Model, - FinishReason = result.StopReason switch - { - "maxTokens" => ChatFinishReason.Length, - "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, - } - }; - } + => AsServerOrThrow(server).SampleAsync(messages, options, cancellationToken); /// /// Creates an wrapper that can be used to send sampling requests to the client. @@ -161,23 +62,18 @@ public static async Task SampleAsync( /// The that can be used to issue sampling requests to the client. /// is . /// The client does not support sampling. + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.AsSamplingChatClient)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static IChatClient AsSamplingChatClient(this IMcpServer server) - { - Throw.IfNull(server); - ThrowIfSamplingUnsupported(server); - - return new SamplingChatClient(server); - } + => AsServerOrThrow(server).AsSamplingChatClient(); /// Gets an on which logged messages will be sent as notifications to the client. /// The server to wrap as an . /// An that can be used to log to the client.. + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.AsSamplingChatClient)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) - { - Throw.IfNull(server); - - return new ClientLoggerProvider(server); - } + => AsServerOrThrow(server).AsClientLoggerProvider(); /// /// Requests the client to list the roots it exposes. @@ -194,19 +90,11 @@ public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) /// navigated and accessed by the server. These resources might include file systems, databases, /// or other structured data sources that the client makes available through the protocol. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.RequestRootsAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask RequestRootsAsync( this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfRootsUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.RootsList, - request, - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).RequestRootsAsync(request, cancellationToken); /// /// Requests additional information from the user via the client, allowing the server to elicit structured data. @@ -220,143 +108,31 @@ public static ValueTask RequestRootsAsync( /// /// This method requires the client to support the elicitation capability. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.ElicitAsync)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + [EditorBrowsable(EditorBrowsableState.Never)] public static ValueTask ElicitAsync( this IMcpServer server, ElicitRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfElicitationUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.ElicitationCreate, - request, - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult, - cancellationToken: cancellationToken); - } - - private static void ThrowIfSamplingUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Sampling is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Sampling is not supported in stateless mode."); - } + => AsServerOrThrow(server).ElicitAsync(request, cancellationToken); - throw new InvalidOperationException("Client does not support sampling."); - } - } - - private static void ThrowIfRootsUnsupported(IMcpServer server) + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpServer AsServerOrThrow(IMcpServer server, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - if (server.ClientCapabilities?.Roots is null) + if (server is not McpServer mcpServer) { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Roots are not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support roots."); + ThrowInvalidSessionType(memberName); } - } - private static void ThrowIfElicitationUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Elicitation is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Elicitation is not supported in stateless mode."); - } + return mcpServer; - throw new InvalidOperationException("Client does not support elicitation requests."); - } - } - - /// Provides an implementation that's implemented via client sampling. - private sealed class SamplingChatClient(IMcpServer server) : IChatClient - { - /// - public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - server.SampleAsync(messages, options, cancellationToken); - - /// - async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( - IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) - { - var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - foreach (var update in response.ToChatResponseUpdates()) - { - yield return update; - } - } - - /// - object? IChatClient.GetService(Type serviceType, object? serviceKey) - { - Throw.IfNull(serviceType); - - return - serviceKey is not null ? null : - serviceType.IsInstanceOfType(this) ? this : - serviceType.IsInstanceOfType(server) ? server : - null; - } - - /// - void IDisposable.Dispose() { } // nop - } - - /// - /// Provides an implementation for creating loggers - /// that send logging message notifications to the client for logged messages. - /// - private sealed class ClientLoggerProvider(IMcpServer server) : ILoggerProvider - { - /// - public ILogger CreateLogger(string categoryName) - { - Throw.IfNull(categoryName); - - return new ClientLogger(server, categoryName); - } - - /// - void IDisposable.Dispose() { } - - private sealed class ClientLogger(IMcpServer server, string categoryName) : ILogger - { - /// - public IDisposable? BeginScope(TState state) where TState : notnull => - null; - - /// - public bool IsEnabled(LogLevel logLevel) => - server?.LoggingLevel is { } loggingLevel && - McpServer.ToLoggingLevel(logLevel) >= loggingLevel; - - /// - public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) - { - if (!IsEnabled(logLevel)) - { - return; - } - - Throw.IfNull(formatter); - - Log(logLevel, formatter(state, exception)); - - void Log(LogLevel logLevel, string message) - { - _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams - { - Level = McpServer.ToLoggingLevel(logLevel), - Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), - Logger = categoryName, - }); - } - } - } + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidSessionType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpServer)}' are supported. " + + $"Prefer using '{nameof(McpServer)}.{memberName}' instead, as " + + $"'{nameof(McpServerExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs index 50d4188b5..7a6609d0d 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs @@ -1,5 +1,6 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using System.ComponentModel; namespace ModelContextProtocol.Server; @@ -10,6 +11,8 @@ namespace ModelContextProtocol.Server; /// This is the recommended way to create instances. /// The factory handles proper initialization of server instances with the required dependencies. /// +[Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.Create)} instead. This member will be removed in a subsequent release.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 +[EditorBrowsable(EditorBrowsableState.Never)] public static class McpServerFactory { /// @@ -27,10 +30,5 @@ public static IMcpServer Create( McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null, IServiceProvider? serviceProvider = null) - { - Throw.IfNull(transport); - Throw.IfNull(serverOptions); - - return new McpServer(transport, serverOptions, loggerFactory, serviceProvider); - } + => McpServer.Create(transport, serverOptions, loggerFactory, serviceProvider); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs new file mode 100644 index 000000000..e38421bc1 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs @@ -0,0 +1,161 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Provides filter collections for MCP server handlers. +/// +/// +/// This class contains collections of filters that can be applied to various MCP server handlers. +/// This allows for middleware-style composition where filters can perform actions before and after the inner handler. +/// +public sealed class McpServerFilters +{ + /// + /// Gets the filters for the list tools handler pipeline. + /// + /// + /// + /// These filters wrap handlers that return a list of available tools when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more tools. + /// + /// + /// These filters work alongside any tools defined in the collection. + /// Tools from both sources will be combined when returning results to clients. + /// + /// + public List> ListToolsFilters { get; } = new(); + + /// + /// Gets the filters for the call tool handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client makes a call to a tool that isn't found in the collection. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + public List> CallToolFilters { get; } = new(); + + /// + /// Gets the filters for the list prompts handler pipeline. + /// + /// + /// + /// These filters wrap handlers that return a list of available prompts when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more prompts. + /// + /// + /// These filters work alongside any prompts defined in the collection. + /// Prompts from both sources will be combined when returning results to clients. + /// + /// + public List> ListPromptsFilters { get; } = new(); + + /// + /// Gets the filters for the get prompt handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client requests details for a specific prompt that isn't found in the collection. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. + /// + public List> GetPromptFilters { get; } = new(); + + /// + /// Gets the filters for the list resource templates handler pipeline. + /// + /// + /// These filters wrap handlers that return a list of available resource templates when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. + /// + public List> ListResourceTemplatesFilters { get; } = new(); + + /// + /// Gets the filters for the list resources handler pipeline. + /// + /// + /// These filters wrap handlers that return a list of available resources when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. + /// + public List> ListResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the read resource handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client requests the content of a specific resource identified by its URI. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to locate and retrieve the requested resource. + /// + public List> ReadResourceFilters { get; } = new(); + + /// + /// Gets the filters for the complete handler pipeline. + /// + /// + /// These filters wrap handlers that provide auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the + /// reference type and current argument value. + /// + public List> CompleteFilters { get; } = new(); + + /// + /// Gets the filters for the subscribe to resources handler pipeline. + /// + /// + /// + /// These filters wrap handlers that are invoked when a client wants to receive notifications about changes to specific resources or resource patterns. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to register the client's interest in the specified resources + /// and set up the necessary infrastructure to send notifications when those resources change. + /// + /// + /// After a successful subscription, the server should send resource change notifications to the client + /// whenever a relevant resource is created, updated, or deleted. + /// + /// + public List> SubscribeToResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the unsubscribe from resources handler pipeline. + /// + /// + /// + /// These filters wrap handlers that are invoked when a client wants to stop receiving notifications about previously subscribed resources. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to remove the client's subscriptions to the specified resources + /// and clean up any associated resources. + /// + /// + /// After a successful unsubscription, the server should no longer send resource change notifications + /// to the client for the specified resources. + /// + /// + public List> UnsubscribeFromResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the set logging level handler pipeline. + /// + /// + /// + /// These filters wrap handlers that process requests from clients. When set, it enables + /// clients to control which log messages they receive by specifying a minimum severity threshold. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. + /// + /// + /// After handling a level change request, the server typically begins sending log messages + /// at or above the specified level to the client as notifications/message notifications. + /// + /// + public List> SetLoggingLevelFilters { get; } = new(); +} diff --git a/src/ModelContextProtocol/McpServerHandlers.cs b/src/ModelContextProtocol.Core/Server/McpServerHandlers.cs similarity index 55% rename from src/ModelContextProtocol/McpServerHandlers.cs rename to src/ModelContextProtocol.Core/Server/McpServerHandlers.cs index a07c81b54..0d8deba13 100644 --- a/src/ModelContextProtocol/McpServerHandlers.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerHandlers.cs @@ -1,4 +1,3 @@ -using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; @@ -10,17 +9,12 @@ namespace ModelContextProtocol.Server; /// /// This class provides a centralized collection of delegates that implement various capabilities of the Model Context Protocol. /// Each handler in this class corresponds to a specific endpoint in the Model Context Protocol and -/// is responsible for processing a particular type of request. The handlers are used to customize +/// is responsible for processing a particular type of message. The handlers are used to customize /// the behavior of the MCP server by providing implementations for the various protocol operations. /// /// -/// Handlers can be configured individually using the extension methods in -/// such as and -/// . -/// -/// -/// When a client sends a request to the server, the appropriate handler is invoked to process the -/// request and produce a response according to the protocol specification. Which handler is selected +/// When a client sends a message to the server, the appropriate handler is invoked to process it +/// according to the protocol specification. Which handler is selected /// is done based on an ordinal, case-sensitive string comparison. /// /// @@ -40,7 +34,7 @@ public sealed class McpServerHandlers /// Tools from both sources will be combined when returning results to clients. /// /// - public Func, CancellationToken, ValueTask>? ListToolsHandler { get; set; } + public McpRequestHandler? ListToolsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -49,7 +43,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client makes a call to a tool that isn't found in the collection. /// The handler should implement logic to execute the requested tool and return appropriate results. /// - public Func, CancellationToken, ValueTask>? CallToolHandler { get; set; } + public McpRequestHandler? CallToolHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -65,7 +59,7 @@ public sealed class McpServerHandlers /// Prompts from both sources will be combined when returning results to clients. /// /// - public Func, CancellationToken, ValueTask>? ListPromptsHandler { get; set; } + public McpRequestHandler? ListPromptsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -74,7 +68,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client requests details for a specific prompt that isn't found in the collection. /// The handler should implement logic to fetch or generate the requested prompt and return appropriate results. /// - public Func, CancellationToken, ValueTask>? GetPromptHandler { get; set; } + public McpRequestHandler? GetPromptHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -84,7 +78,7 @@ public sealed class McpServerHandlers /// It supports pagination through the cursor mechanism, where the client can make /// repeated calls with the cursor returned by the previous call to retrieve more resource templates. /// - public Func, CancellationToken, ValueTask>? ListResourceTemplatesHandler { get; set; } + public McpRequestHandler? ListResourceTemplatesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -94,7 +88,7 @@ public sealed class McpServerHandlers /// It supports pagination through the cursor mechanism, where the client can make /// repeated calls with the cursor returned by the previous call to retrieve more resources. /// - public Func, CancellationToken, ValueTask>? ListResourcesHandler { get; set; } + public McpRequestHandler? ListResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -103,17 +97,17 @@ public sealed class McpServerHandlers /// This handler is invoked when a client requests the content of a specific resource identified by its URI. /// The handler should implement logic to locate and retrieve the requested resource. /// - public Func, CancellationToken, ValueTask>? ReadResourceHandler { get; set; } + public McpRequestHandler? ReadResourceHandler { get; set; } /// /// Gets or sets the handler for requests. /// /// /// This handler provides auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. - /// The handler processes auto-completion requests, returning a list of suggestions based on the + /// The handler processes auto-completion requests, returning a list of suggestions based on the /// reference type and current argument value. /// - public Func, CancellationToken, ValueTask>? CompleteHandler { get; set; } + public McpRequestHandler? CompleteHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -129,7 +123,7 @@ public sealed class McpServerHandlers /// whenever a relevant resource is created, updated, or deleted. /// /// - public Func, CancellationToken, ValueTask>? SubscribeToResourcesHandler { get; set; } + public McpRequestHandler? SubscribeToResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -145,7 +139,7 @@ public sealed class McpServerHandlers /// to the client for the specified resources. /// /// - public Func, CancellationToken, ValueTask>? UnsubscribeFromResourcesHandler { get; set; } + public McpRequestHandler? UnsubscribeFromResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -160,67 +154,24 @@ public sealed class McpServerHandlers /// at or above the specified level to the client as notifications/message notifications. /// /// - public Func, CancellationToken, ValueTask>? SetLoggingLevelHandler { get; set; } - - /// - /// Overwrite any handlers in McpServerOptions with non-null handlers from this instance. - /// - /// - /// - internal void OverwriteWithSetHandlers(McpServerOptions options) - { - PromptsCapability? promptsCapability = options.Capabilities?.Prompts; - if (ListPromptsHandler is not null || GetPromptHandler is not null) - { - promptsCapability ??= new(); - promptsCapability.ListPromptsHandler = ListPromptsHandler ?? promptsCapability.ListPromptsHandler; - promptsCapability.GetPromptHandler = GetPromptHandler ?? promptsCapability.GetPromptHandler; - } - - ResourcesCapability? resourcesCapability = options.Capabilities?.Resources; - if (ListResourcesHandler is not null || - ReadResourceHandler is not null) - { - resourcesCapability ??= new(); - resourcesCapability.ListResourceTemplatesHandler = ListResourceTemplatesHandler ?? resourcesCapability.ListResourceTemplatesHandler; - resourcesCapability.ListResourcesHandler = ListResourcesHandler ?? resourcesCapability.ListResourcesHandler; - resourcesCapability.ReadResourceHandler = ReadResourceHandler ?? resourcesCapability.ReadResourceHandler; - - if (SubscribeToResourcesHandler is not null || UnsubscribeFromResourcesHandler is not null) - { - resourcesCapability.SubscribeToResourcesHandler = SubscribeToResourcesHandler ?? resourcesCapability.SubscribeToResourcesHandler; - resourcesCapability.UnsubscribeFromResourcesHandler = UnsubscribeFromResourcesHandler ?? resourcesCapability.UnsubscribeFromResourcesHandler; - resourcesCapability.Subscribe = true; - } - } - - ToolsCapability? toolsCapability = options.Capabilities?.Tools; - if (ListToolsHandler is not null || CallToolHandler is not null) - { - toolsCapability ??= new(); - toolsCapability.ListToolsHandler = ListToolsHandler ?? toolsCapability.ListToolsHandler; - toolsCapability.CallToolHandler = CallToolHandler ?? toolsCapability.CallToolHandler; - } - - LoggingCapability? loggingCapability = options.Capabilities?.Logging; - if (SetLoggingLevelHandler is not null) - { - loggingCapability ??= new(); - loggingCapability.SetLoggingLevelHandler = SetLoggingLevelHandler; - } + public McpRequestHandler? SetLoggingLevelHandler { get; set; } - CompletionsCapability? completionsCapability = options.Capabilities?.Completions; - if (CompleteHandler is not null) - { - completionsCapability ??= new(); - completionsCapability.CompleteHandler = CompleteHandler; - } - - options.Capabilities ??= new(); - options.Capabilities.Prompts = promptsCapability; - options.Capabilities.Resources = resourcesCapability; - options.Capabilities.Tools = toolsCapability; - options.Capabilities.Logging = loggingCapability; - options.Capabilities.Completions = completionsCapability; - } + /// Gets or sets notification handlers to register with the server. + /// + /// + /// When constructed, the server will enumerate these handlers once, which may contain multiple handlers per notification method key. + /// The server will not re-enumerate the sequence after initialization. + /// + /// + /// Notification handlers allow the server to respond to client-sent notifications for specific methods. + /// Each key in the collection is a notification method name, and each value is a callback that will be invoked + /// when a notification with that method is received. + /// + /// + /// Handlers provided via will be registered with the server for the lifetime of the server. + /// For transient handlers, may be used to register a handler that can + /// then be unregistered by disposing of the returned from the method. + /// + /// + public IEnumerable>>? NotificationHandlers { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs new file mode 100644 index 000000000..c152d3a0a --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -0,0 +1,786 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Runtime.CompilerServices; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Server; + +/// +internal sealed partial class McpServerImpl : McpServer +{ + internal static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpServer), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _sessionTransport; + private readonly bool _servicesScopePerRequest; + private readonly List _disposables = []; + private readonly NotificationHandlers _notificationHandlers; + private readonly RequestHandlers _requestHandlers; + private readonly McpSessionHandler _sessionHandler; + private readonly SemaphoreSlim _disposeLock = new(1, 1); + + private ClientCapabilities? _clientCapabilities; + private Implementation? _clientInfo; + + private readonly string _serverOnlyEndpointName; + private string? _negotiatedProtocolVersion; + private string _endpointName; + private int _started; + + private bool _disposed; + + /// Holds a boxed value for the server. + /// + /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box + /// rather than a nullable to be able to manipulate it atomically. + /// + private StrongBox? _loggingLevel; + + /// + /// Creates a new instance of . + /// + /// Transport to use for the server representing an already-established session. + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// Logger factory to use for logging + /// Optional service provider to use for dependency injection + /// The server was incorrectly configured. + public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) + { + Throw.IfNull(transport); + Throw.IfNull(options); + + options ??= new(); + + _sessionTransport = transport; + ServerOptions = options; + Services = serviceProvider; + _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _endpointName = _serverOnlyEndpointName; + _servicesScopePerRequest = options.ScopeRequests; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + _clientInfo = options.KnownClientInfo; + UpdateEndpointNameWithClientInfo(); + + _notificationHandlers = new(); + _requestHandlers = []; + + // Configure all request handlers based on the supplied options. + ServerCapabilities = new(); + ConfigureInitialize(options); + ConfigureTools(options); + ConfigurePrompts(options); + ConfigureResources(options); + ConfigureLogging(options); + ConfigureCompletion(options); + ConfigureExperimental(options); + ConfigurePing(); + + // Register any notification handlers that were provided. + if (options.Handlers.NotificationHandlers is { } notificationHandlers) + { + _notificationHandlers.RegisterRange(notificationHandlers); + } + + // Now that everything has been configured, subscribe to any necessary notifications. + if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) + { + Register(ServerOptions.ToolCollection, NotificationMethods.ToolListChangedNotification); + Register(ServerOptions.PromptCollection, NotificationMethods.PromptListChangedNotification); + Register(ServerOptions.ResourceCollection, NotificationMethods.ResourceListChangedNotification); + + void Register(McpServerPrimitiveCollection? collection, string notificationMethod) + where TPrimitive : IMcpServerPrimitive + { + if (collection is not null) + { + EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(notificationMethod); + collection.Changed += changed; + _disposables.Add(() => collection.Changed -= changed); + } + } + } + + // And initialize the session. + _sessionHandler = new McpSessionHandler(isServer: true, _sessionTransport, _endpointName!, _requestHandlers, _notificationHandlers, _logger); + } + + /// + public override string? SessionId => _sessionTransport.SessionId; + + /// + public override string? NegotiatedProtocolVersion => _negotiatedProtocolVersion; + + /// + public ServerCapabilities ServerCapabilities { get; } = new(); + + /// + public override ClientCapabilities? ClientCapabilities => _clientCapabilities; + + /// + public override Implementation? ClientInfo => _clientInfo; + + /// + public override McpServerOptions ServerOptions { get; } + + /// + public override IServiceProvider? Services { get; } + + /// + public override LoggingLevel? LoggingLevel => _loggingLevel?.Value; + + /// + public override async Task RunAsync(CancellationToken cancellationToken = default) + { + if (Interlocked.Exchange(ref _started, 1) != 0) + { + throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); + } + + try + { + await _sessionHandler.ProcessMessagesAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + await DisposeAsync().ConfigureAwait(false); + } + } + + + /// + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public override async ValueTask DisposeAsync() + { + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + + _disposed = true; + + _disposables.ForEach(d => d()); + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + } + + private void ConfigurePing() + { + SetHandler(RequestMethods.Ping, + async (request, _) => new PingResult(), + McpJsonUtilities.JsonContext.Default.JsonNode, + McpJsonUtilities.JsonContext.Default.PingResult); + } + + private void ConfigureInitialize(McpServerOptions options) + { + _requestHandlers.Set(RequestMethods.Initialize, + async (request, _, _) => + { + _clientCapabilities = request?.Capabilities ?? new(); + _clientInfo = request?.ClientInfo; + + // Use the ClientInfo to update the session EndpointName for logging. + UpdateEndpointNameWithClientInfo(); + _sessionHandler.EndpointName = _endpointName; + + // Negotiate a protocol version. If the server options provide one, use that. + // Otherwise, try to use whatever the client requested as long as it's supported. + // If it's not supported, fall back to the latest supported version. + string? protocolVersion = options.ProtocolVersion; + if (protocolVersion is null) + { + protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? + clientProtocolVersion : + McpSessionHandler.LatestProtocolVersion; + } + + _negotiatedProtocolVersion = protocolVersion; + + return new InitializeResult + { + ProtocolVersion = protocolVersion, + Instructions = options.ServerInstructions, + ServerInfo = options.ServerInfo ?? DefaultImplementation, + Capabilities = ServerCapabilities ?? new(), + }; + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult); + } + + private void ConfigureCompletion(McpServerOptions options) + { + var completeHandler = options.Handlers.CompleteHandler; + var completionsCapability = options.Capabilities?.Completions; + +#pragma warning disable CS0618 // Type or member is obsolete + completeHandler ??= completionsCapability?.CompleteHandler; +#pragma warning restore CS0618 // Type or member is obsolete + + if (completeHandler is null && completionsCapability is null) + { + return; + } + + completeHandler ??= (static async (_, __) => new CompleteResult()); + completeHandler = BuildFilterPipeline(completeHandler, options.Filters.CompleteFilters); + + ServerCapabilities.Completions = new(); + + SetHandler( + RequestMethods.CompletionComplete, + completeHandler, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult); + } + + private void ConfigureExperimental(McpServerOptions options) + { + ServerCapabilities.Experimental = options.Capabilities?.Experimental; + } + + private void ConfigureResources(McpServerOptions options) + { + var listResourcesHandler = options.Handlers.ListResourcesHandler; + var listResourceTemplatesHandler = options.Handlers.ListResourceTemplatesHandler; + var readResourceHandler = options.Handlers.ReadResourceHandler; + var subscribeHandler = options.Handlers.SubscribeToResourcesHandler; + var unsubscribeHandler = options.Handlers.UnsubscribeFromResourcesHandler; + var resources = options.ResourceCollection; + var resourcesCapability = options.Capabilities?.Resources; + +#pragma warning disable CS0618 // Type or member is obsolete + listResourcesHandler ??= resourcesCapability?.ListResourcesHandler; + listResourceTemplatesHandler ??= resourcesCapability?.ListResourceTemplatesHandler; + readResourceHandler ??= resourcesCapability?.ReadResourceHandler; + subscribeHandler ??= resourcesCapability?.SubscribeToResourcesHandler; + unsubscribeHandler ??= resourcesCapability?.UnsubscribeFromResourcesHandler; +#pragma warning restore CS0618 // Type or member is obsolete + + if (listResourcesHandler is null && listResourceTemplatesHandler is null && readResourceHandler is null && + subscribeHandler is null && unsubscribeHandler is null && resources is null && + resourcesCapability is null) + { + return; + } + + ServerCapabilities.Resources = new(); + + listResourcesHandler ??= (static async (_, __) => new ListResourcesResult()); + listResourceTemplatesHandler ??= (static async (_, __) => new ListResourceTemplatesResult()); + readResourceHandler ??= (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); + subscribeHandler ??= (static async (_, __) => new EmptyResult()); + unsubscribeHandler ??= (static async (_, __) => new EmptyResult()); + var listChanged = resourcesCapability?.ListChanged; + var subscribe = resourcesCapability?.Subscribe; + + // Handle resources provided via DI. + if (resources is { IsEmpty: false }) + { + var originalListResourcesHandler = listResourcesHandler; + listResourcesHandler = async (request, cancellationToken) => + { + ListResourcesResult result = originalListResourcesHandler is not null ? + await originalListResourcesHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var r in resources) + { + if (r.ProtocolResource is { } resource) + { + result.Resources.Add(resource); + } + } + } + + return result; + }; + + var originalListResourceTemplatesHandler = listResourceTemplatesHandler; + listResourceTemplatesHandler = async (request, cancellationToken) => + { + ListResourceTemplatesResult result = originalListResourceTemplatesHandler is not null ? + await originalListResourceTemplatesHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var rt in resources) + { + if (rt.IsTemplated) + { + result.ResourceTemplates.Add(rt.ProtocolResourceTemplate); + } + } + } + + return result; + }; + + // Synthesize read resource handler, which covers both resources and resource templates. + var originalReadResourceHandler = readResourceHandler; + readResourceHandler = async (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerResource matchedResource) + { + if (await matchedResource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) + { + return result; + } + } + + return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); + }; + + listChanged = true; + + // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. + // subscribe = true; + } + + listResourcesHandler = BuildFilterPipeline(listResourcesHandler, options.Filters.ListResourcesFilters); + listResourceTemplatesHandler = BuildFilterPipeline(listResourceTemplatesHandler, options.Filters.ListResourceTemplatesFilters); + readResourceHandler = BuildFilterPipeline(readResourceHandler, options.Filters.ReadResourceFilters, handler => + async (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Uri is { } uri && resources is not null) + { + // First try an O(1) lookup by exact match. + if (resources.TryGetPrimitive(uri, out var resource)) + { + request.MatchedPrimitive = resource; + } + else + { + // Fall back to an O(N) lookup, trying to match against each URI template. + // The number of templates is controlled by the server developer, and the number is expected to be + // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. + foreach (var resourceTemplate in resources) + { + // Check if this template would handle the request by testing if ReadAsync would succeed + if (resourceTemplate.IsTemplated) + { + // This is a simplified check - a more robust implementation would match the URI pattern + // For now, we'll let the actual handler attempt the match + request.MatchedPrimitive = resourceTemplate; + break; + } + } + } + } + + return await handler(request, cancellationToken).ConfigureAwait(false); + }); + subscribeHandler = BuildFilterPipeline(subscribeHandler, options.Filters.SubscribeToResourcesFilters); + unsubscribeHandler = BuildFilterPipeline(unsubscribeHandler, options.Filters.UnsubscribeFromResourcesFilters); + + ServerCapabilities.Resources.ListChanged = listChanged; + ServerCapabilities.Resources.Subscribe = subscribe; + + SetHandler( + RequestMethods.ResourcesList, + listResourcesHandler, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult); + + SetHandler( + RequestMethods.ResourcesTemplatesList, + listResourceTemplatesHandler, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); + + SetHandler( + RequestMethods.ResourcesRead, + readResourceHandler, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult); + + SetHandler( + RequestMethods.ResourcesSubscribe, + subscribeHandler, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + + SetHandler( + RequestMethods.ResourcesUnsubscribe, + unsubscribeHandler, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + } + + private void ConfigurePrompts(McpServerOptions options) + { + var listPromptsHandler = options.Handlers.ListPromptsHandler; + var getPromptHandler = options.Handlers.GetPromptHandler; + var prompts = options.PromptCollection; + var promptsCapability = options.Capabilities?.Prompts; + +#pragma warning disable CS0618 // Type or member is obsolete + listPromptsHandler ??= promptsCapability?.ListPromptsHandler; + getPromptHandler ??= promptsCapability?.GetPromptHandler; +#pragma warning restore CS0618 // Type or member is obsolete + + if (listPromptsHandler is null && getPromptHandler is null && prompts is null && + promptsCapability is null) + { + return; + } + + ServerCapabilities.Prompts = new(); + + listPromptsHandler ??= (static async (_, __) => new ListPromptsResult()); + getPromptHandler ??= (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + var listChanged = promptsCapability?.ListChanged; + + // Handle tools provided via DI by augmenting the handlers to incorporate them. + if (prompts is { IsEmpty: false }) + { + var originalListPromptsHandler = listPromptsHandler; + listPromptsHandler = async (request, cancellationToken) => + { + ListPromptsResult result = originalListPromptsHandler is not null ? + await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var p in prompts) + { + result.Prompts.Add(p.ProtocolPrompt); + } + } + + return result; + }; + + var originalGetPromptHandler = getPromptHandler; + getPromptHandler = (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerPrompt prompt) + { + return prompt.GetAsync(request, cancellationToken); + } + + return originalGetPromptHandler(request, cancellationToken); + }; + + listChanged = true; + } + + listPromptsHandler = BuildFilterPipeline(listPromptsHandler, options.Filters.ListPromptsFilters); + getPromptHandler = BuildFilterPipeline(getPromptHandler, options.Filters.GetPromptFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } promptName && prompts is not null && + prompts.TryGetPrimitive(promptName, out var prompt)) + { + request.MatchedPrimitive = prompt; + } + + return handler(request, cancellationToken); + }); + + ServerCapabilities.Prompts.ListChanged = listChanged; + + SetHandler( + RequestMethods.PromptsList, + listPromptsHandler, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult); + + SetHandler( + RequestMethods.PromptsGet, + getPromptHandler, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult); + } + + private void ConfigureTools(McpServerOptions options) + { + var listToolsHandler = options.Handlers.ListToolsHandler; + var callToolHandler = options.Handlers.CallToolHandler; + var tools = options.ToolCollection; + var toolsCapability = options.Capabilities?.Tools; + +#pragma warning disable CS0618 // Type or member is obsolete + listToolsHandler ??= toolsCapability?.ListToolsHandler; + callToolHandler ??= toolsCapability?.CallToolHandler; +#pragma warning restore CS0618 // Type or member is obsolete + + if (listToolsHandler is null && callToolHandler is null && tools is null && + toolsCapability is null) + { + return; + } + + ServerCapabilities.Tools = new(); + + listToolsHandler ??= (static async (_, __) => new ListToolsResult()); + callToolHandler ??= (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + var listChanged = toolsCapability?.ListChanged; + + // Handle tools provided via DI by augmenting the handlers to incorporate them. + if (tools is { IsEmpty: false }) + { + var originalListToolsHandler = listToolsHandler; + listToolsHandler = async (request, cancellationToken) => + { + ListToolsResult result = originalListToolsHandler is not null ? + await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var t in tools) + { + result.Tools.Add(t.ProtocolTool); + } + } + + return result; + }; + + var originalCallToolHandler = callToolHandler; + callToolHandler = (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerTool tool) + { + return tool.InvokeAsync(request, cancellationToken); + } + + return originalCallToolHandler(request, cancellationToken); + }; + + listChanged = true; + } + + listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.ListToolsFilters); + callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.CallToolFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } toolName && tools is not null && + tools.TryGetPrimitive(toolName, out var tool)) + { + request.MatchedPrimitive = tool; + } + + return handler(request, cancellationToken); + }, handler => + async (request, cancellationToken) => + { + // Final handler that provides exception handling only for tool execution + // Only wrap tool execution in try-catch, not tool resolution + if (request.MatchedPrimitive is McpServerTool) + { + try + { + return await handler(request, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) when (e is not OperationCanceledException) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); + + string errorMessage = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; + + return new() + { + IsError = true, + Content = [new TextContentBlock { Text = errorMessage }], + }; + } + } + else + { + // For unmatched tools, let exceptions bubble up as protocol errors + return await handler(request, cancellationToken).ConfigureAwait(false); + } + }); + + ServerCapabilities.Tools.ListChanged = listChanged; + + SetHandler( + RequestMethods.ToolsList, + listToolsHandler, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult); + + SetHandler( + RequestMethods.ToolsCall, + callToolHandler, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult); + } + + private void ConfigureLogging(McpServerOptions options) + { + // We don't require that the handler be provided, as we always store the provided log level to the server. + var setLoggingLevelHandler = options.Handlers.SetLoggingLevelHandler; + +#pragma warning disable CS0618 // Type or member is obsolete + setLoggingLevelHandler ??= options.Capabilities?.Logging?.SetLoggingLevelHandler; +#pragma warning restore CS0618 // Type or member is obsolete + + // Apply filters to the handler + if (setLoggingLevelHandler is not null) + { + setLoggingLevelHandler = BuildFilterPipeline(setLoggingLevelHandler, options.Filters.SetLoggingLevelFilters); + } + + ServerCapabilities.Logging = new(); + + _requestHandlers.Set( + RequestMethods.LoggingSetLevel, + (request, jsonRpcRequest, cancellationToken) => + { + // Store the provided level. + if (request is not null) + { + if (_loggingLevel is null) + { + Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); + } + + _loggingLevel.Value = request.Level; + } + + // If a handler was provided, now delegate to it. + if (setLoggingLevelHandler is not null) + { + return InvokeHandlerAsync(setLoggingLevelHandler, request, jsonRpcRequest, cancellationToken); + } + + // Otherwise, consider it handled. + return new ValueTask(EmptyResult.Instance); + }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + } + + private ValueTask InvokeHandlerAsync( + McpRequestHandler handler, + TParams? args, + JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken = default) + { + return _servicesScopePerRequest ? + InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : + handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Params = args }, cancellationToken); + + async ValueTask InvokeScopedAsync( + McpRequestHandler handler, + TParams? args, + JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken) + { + var scope = Services?.GetService()?.CreateAsyncScope(); + try + { + return await handler( + new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) + { + Services = scope?.ServiceProvider ?? Services, + Params = args + }, + cancellationToken).ConfigureAwait(false); + } + finally + { + if (scope is not null) + { + await scope.Value.DisposeAsync().ConfigureAwait(false); + } + } + } + } + + private void SetHandler( + string method, + McpRequestHandler handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) + { + _requestHandlers.Set(method, + (request, jsonRpcRequest, cancellationToken) => + InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), + requestTypeInfo, responseTypeInfo); + } + + private static McpRequestHandler BuildFilterPipeline( + McpRequestHandler baseHandler, + List> filters, + McpRequestFilter? initialHandler = null, + McpRequestFilter? finalHandler = null) + { + var current = baseHandler; + + if (finalHandler is not null) + { + current = finalHandler(current); + } + + for (int i = filters.Count - 1; i >= 0; i--) + { + current = filters[i](current); + } + + if (initialHandler is not null) + { + current = initialHandler(current); + } + + return current; + } + + private void UpdateEndpointNameWithClientInfo() + { + if (ClientInfo is null) + { + return; + } + + _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; + } + + /// Maps a to a . + internal static LoggingLevel ToLoggingLevel(LogLevel level) => + level switch + { + LogLevel.Trace => Protocol.LoggingLevel.Debug, + LogLevel.Debug => Protocol.LoggingLevel.Debug, + LogLevel.Information => Protocol.LoggingLevel.Info, + LogLevel.Warning => Protocol.LoggingLevel.Warning, + LogLevel.Error => Protocol.LoggingLevel.Error, + LogLevel.Critical => Protocol.LoggingLevel.Critical, + _ => Protocol.LoggingLevel.Emergency, + }; + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); +} diff --git a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs index 8c50a9b55..833c852e3 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs @@ -7,6 +7,8 @@ namespace ModelContextProtocol.Server; /// public sealed class McpServerOptions { + private McpServerHandlers? _handlers; + /// /// Gets or sets information about this server implementation, including its name and version. /// @@ -79,4 +81,77 @@ public sealed class McpServerOptions /// /// public Implementation? KnownClientInfo { get; set; } + + /// + /// Gets the filter collections for MCP server handlers. + /// + /// + /// This property provides access to filter collections that can be used to modify the behavior + /// of various MCP server handlers. Filters are applied in reverse order, so the last filter + /// added will be the outermost (first to execute). + /// + public McpServerFilters Filters { get; } = new(); + + /// + /// Gets or sets the container of handlers used by the server for processing protocol messages. + /// + public McpServerHandlers Handlers + { + get => _handlers ??= new(); + set + { + Throw.IfNull(value); + _handlers = value; + } + } + + /// + /// Gets or sets a collection of tools served by the server. + /// + /// + /// Tools specified via augment the and + /// , if provided. ListTools requests will output information about every tool + /// in and then also any tools output by , if it's + /// non-. CallTool requests will first check for the tool + /// being requested, and if the tool is not found in the , any specified + /// will be invoked as a fallback. + /// + public McpServerPrimitiveCollection? ToolCollection { get; set; } + + /// + /// Gets or sets a collection of resources served by the server. + /// + /// + /// + /// Resources specified via augment the , + /// and handlers, if provided. Resources with template expressions in their URI templates are considered resource templates + /// and are listed via ListResourceTemplate, whereas resources without template parameters are considered static resources and are listed with ListResources. + /// + /// + /// ReadResource requests will first check the for the exact resource being requested. If no match is found, they'll proceed to + /// try to match the resource against each resource template in . If no match is still found, the request will fall back to + /// any handler registered for . + /// + /// + public McpServerResourceCollection? ResourceCollection { get; set; } + + /// + /// Gets or sets a collection of prompts that will be served by the server. + /// + /// + /// + /// The contains the predefined prompts that clients can request from the server. + /// This collection works in conjunction with and + /// when those are provided: + /// + /// + /// - For requests: The server returns all prompts from this collection + /// plus any additional prompts provided by the if it's set. + /// + /// + /// - For requests: The server first checks this collection for the requested prompt. + /// If not found, it will invoke the as a fallback if one is set. + /// + /// + public McpServerPrimitiveCollection? PromptCollection { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs index 68874df3e..a7fa0e242 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs @@ -15,12 +15,12 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP prompt for use in the server (as opposed /// to , which provides the protocol representation of a prompt, and , which /// provides a client-side representation of a prompt). Instances of can be added into a -/// to be picked up automatically when is used to create -/// an , or added into a . +/// to be picked up automatically when is used to create +/// an , or added into a . /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithPromptsFromAssembly and WithPrompts. The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling /// how parameters are marshaled into the method from the JSON received from the MCP client, and how the return value is marshaled back @@ -34,7 +34,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -45,7 +45,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -61,15 +61,15 @@ namespace ModelContextProtocol.Server; /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will be resolved from the provided to +/// according to will be resolved from the provided to /// rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to rather than from the argument collection. /// /// @@ -80,7 +80,7 @@ namespace ModelContextProtocol.Server; /// /// /// In general, the data supplied via the 's dictionary is passed along from the caller and -/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the prompt, consider having +/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the prompt, consider having /// the prompt be an instance method, referring to data stored in the instance, or using an instance or parameters resolved from the /// to provide data to the method. /// @@ -128,6 +128,15 @@ protected McpServerPrompt() /// public abstract Prompt ProtocolPrompt { get; } + /// + /// Gets the metadata for this prompt instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// /// Gets the prompt, rendering it with the provided request parameters and returning the prompt result. /// @@ -170,7 +179,7 @@ public static McpServerPrompt Create( /// is . /// is an instance method but is . public static McpServerPrompt Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerPromptCreateOptions? options = null) => AIFunctionMcpServerPrompt.Create(method, target, options); @@ -201,7 +210,7 @@ public static McpServerPrompt Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerPrompt Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs index c71e969db..ac9e247f6 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs @@ -25,7 +25,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -36,7 +36,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs index 95d712ffd..1853b0f1a 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs @@ -68,6 +68,15 @@ public sealed class McpServerPromptCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the prompt. + /// + /// + /// Metadata includes information such as the attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -80,5 +89,6 @@ internal McpServerPromptCreateOptions Clone() => Description = Description, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerResource.cs b/src/ModelContextProtocol.Core/Server/McpServerResource.cs index 8e42d3e1c..2a43e3349 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResource.cs @@ -11,13 +11,13 @@ namespace ModelContextProtocol.Server; /// /// /// is an abstract base class that represents an MCP resource for use in the server (as opposed -/// to or , which provide the protocol representations of a resource). Instances of +/// to or , which provide the protocol representations of a resource). Instances of /// can be added into a to be picked up automatically when -/// is used to create an , or added into a . +/// is used to create an , or added into a . /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithResourcesFromAssembly and /// . The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling @@ -35,7 +35,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -46,7 +46,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -62,15 +62,15 @@ namespace ModelContextProtocol.Server; /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will be resolved from the provided to the +/// according to will be resolved from the provided to the /// resource invocation rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to the resource invocation rather than from the argument collection. /// /// @@ -149,6 +149,15 @@ protected McpServerResource() /// public virtual Resource? ProtocolResource => ProtocolResourceTemplate.AsResource(); + /// + /// Gets the metadata for this resource instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// /// Gets the resource, rendering it with the provided request parameters and returning the resource result. /// @@ -192,7 +201,7 @@ public static McpServerResource Create( /// is . /// is an instance method but is . public static McpServerResource Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerResourceCreateOptions? options = null) => AIFunctionMcpServerResource.Create(method, target, options); @@ -223,7 +232,7 @@ public static McpServerResource Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerResource Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs index bc2f138f0..66c593e47 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs @@ -23,7 +23,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -34,7 +34,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs index 24051a7ff..2d6b66b32 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs @@ -83,6 +83,15 @@ public sealed class McpServerResourceCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the resource. + /// + /// + /// Metadata includes information such as attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -97,5 +106,6 @@ internal McpServerResourceCreateOptions Clone() => MimeType = MimeType, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index e3958271b..4136f5913 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -15,12 +15,12 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP tool for use in the server (as opposed /// to , which provides the protocol representation of a tool, and , which /// provides a client-side representation of a tool). Instances of can be added into a -/// to be picked up automatically when is used to create -/// an , or added into a . +/// to be picked up automatically when is used to create +/// an , or added into a . /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithToolsFromAssembly and WithTools. The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling /// how parameters are marshaled into the method from the JSON received from the MCP client, and how the return value is marshaled back @@ -35,7 +35,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . The parameter is not included in the generated JSON schema. /// /// @@ -47,7 +47,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are not included in the JSON schema and are bound directly to the +/// parameters are not included in the JSON schema and are bound directly to the /// instance associated with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -56,22 +56,22 @@ namespace ModelContextProtocol.Server; /// /// parameters accepting values /// are not included in the JSON schema and are bound to an instance manufactured -/// to forward progress notifications from the tool to the client. If the client included a in their request, +/// to forward progress notifications from the tool to the client. If the client included a in their request, /// progress reports issued to this instance will propagate to the client as notifications with /// that token. If the client did not include a , the instance will ignore any progress reports issued to it. /// /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will not be included in the generated JSON schema and will be resolved +/// according to will not be included in the generated JSON schema and will be resolved /// from the provided to rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to rather than from the argument /// collection, and will not be included in the generated JSON schema. /// @@ -79,13 +79,13 @@ namespace ModelContextProtocol.Server; /// /// /// -/// All other parameters are deserialized from the s in the dictionary, -/// using the supplied in , or if none was provided, +/// All other parameters are deserialized from the s in the dictionary, +/// using the supplied in , or if none was provided, /// using . /// /// /// In general, the data supplied via the 's dictionary is passed along from the caller and -/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the tool, consider having +/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the tool, consider having /// the tool be an instance method, referring to data stored in the instance, or using an instance or parameters resolved from the /// to provide data to the method. /// @@ -141,6 +141,15 @@ protected McpServerTool() /// Gets the protocol type for this instance. public abstract Tool ProtocolTool { get; } + /// + /// Gets the metadata for this tool instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// Invokes the . /// The request information resulting in the invocation of this tool. /// The to monitor for cancellation requests. The default is . @@ -172,7 +181,7 @@ public static McpServerTool Create( /// is . /// is an instance method but is . public static McpServerTool Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerToolCreateOptions? options = null) => AIFunctionMcpServerTool.Create(method, target, options); @@ -203,7 +212,7 @@ public static McpServerTool Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerTool Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index d4ea9eb75..7d5bf488b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -26,7 +26,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . The parameter is not included in the generated JSON schema. /// /// @@ -38,7 +38,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are not included in the JSON schema and are bound directly to the +/// parameters are not included in the JSON schema and are bound directly to the /// instance associated with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs index bdb4ecb8d..d18af8c02 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs @@ -80,7 +80,7 @@ public sealed class McpServerToolCreateOptions public bool? Destructive { get; set; } /// - /// Gets or sets whether calling the tool repeatedly with the same arguments + /// Gets or sets whether calling the tool repeatedly with the same arguments /// will have no additional effect on its environment. /// /// @@ -155,6 +155,15 @@ public sealed class McpServerToolCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the tool. + /// + /// + /// Metadata includes information such as attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -172,5 +181,6 @@ internal McpServerToolCreateOptions Clone() => UseStructuredContent = UseStructuredContent, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index b0ea9d993..f75cea80b 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -1,3 +1,6 @@ +using System.Security.Claims; +using ModelContextProtocol.Protocol; + namespace ModelContextProtocol.Server; /// @@ -12,22 +15,28 @@ namespace ModelContextProtocol.Server; public sealed class RequestContext { /// The server with which this instance is associated. - private IMcpServer _server; + private McpServer _server; + + private IDictionary? _items; /// - /// Initializes a new instance of the class with the specified server. + /// Initializes a new instance of the class with the specified server and JSON-RPC request. /// /// The server with which this instance is associated. - public RequestContext(IMcpServer server) + /// The JSON-RPC request associated with this context. + public RequestContext(McpServer server, JsonRpcRequest jsonRpcRequest) { Throw.IfNull(server); + Throw.IfNull(jsonRpcRequest); _server = server; + JsonRpcRequest = jsonRpcRequest; Services = server.Services; + User = jsonRpcRequest.Context?.User; } /// Gets or sets the server with which this instance is associated. - public IMcpServer Server + public McpServer Server { get => _server; set @@ -37,15 +46,47 @@ public IMcpServer Server } } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this request. + /// + public IDictionary Items + { + get + { + return _items ??= new Dictionary(); + } + set + { + _items = value; + } + } + /// Gets or sets the services associated with this request. /// - /// This may not be the same instance stored in + /// This may not be the same instance stored in /// if was true, in which case this /// might be a scoped derived from the server's - /// . + /// . /// public IServiceProvider? Services { get; set; } + /// Gets or sets the user associated with this request. + public ClaimsPrincipal? User { get; set; } + /// Gets or sets the parameters associated with this request. public TParams? Params { get; set; } + + /// + /// Gets or sets the primitive that matched the request. + /// + public IMcpServerPrimitive? MatchedPrimitive { get; set; } + + /// + /// Gets the JSON-RPC request associated with this context. + /// + /// + /// This property provides access to the complete JSON-RPC request that initiated this handler invocation, + /// including the method name, parameters, request ID, and associated transport and user information. + /// + public JsonRpcRequest JsonRpcRequest { get; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs similarity index 59% rename from src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs rename to src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs index 3372072fe..9359ea157 100644 --- a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs +++ b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs @@ -1,47 +1,55 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; +using System.Security.Claims; namespace ModelContextProtocol.Server; /// Augments a service provider with additional request-related services. -internal sealed class RequestServiceProvider( - RequestContext request, IServiceProvider? innerServices) : - IServiceProvider, IKeyedServiceProvider, - IServiceProviderIsService, IServiceProviderIsKeyedService, - IDisposable, IAsyncDisposable +internal sealed class RequestServiceProvider(RequestContext request) : + IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, + IDisposable, IAsyncDisposable where TRequestParams : RequestParams { + private readonly IServiceProvider? _innerServices = request.Services; + /// Gets the request associated with this instance. public RequestContext Request => request; /// Gets whether the specified type is in the list of additional types this service provider wraps around the one in a provided request's services. public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(RequestContext) || + serviceType == typeof(McpServer) || +#pragma warning disable CS0618 // Type or member is obsolete serviceType == typeof(IMcpServer) || - serviceType == typeof(IProgress); +#pragma warning restore CS0618 // Type or member is obsolete + serviceType == typeof(IProgress) || + serviceType == typeof(ClaimsPrincipal); /// public object? GetService(Type serviceType) => serviceType == typeof(RequestContext) ? request : - serviceType == typeof(IMcpServer) ? request.Server : +#pragma warning disable CS0618 // Type or member is obsolete + serviceType == typeof(McpServer) || serviceType == typeof(IMcpServer) ? request.Server : +#pragma warning restore CS0618 // Type or member is obsolete serviceType == typeof(IProgress) ? (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : - innerServices?.GetService(serviceType); + serviceType == typeof(ClaimsPrincipal) ? request.User : + _innerServices?.GetService(serviceType); /// public bool IsService(Type serviceType) => IsAugmentedWith(serviceType) || - (innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; + (_innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; /// public bool IsKeyedService(Type serviceType, object? serviceKey) => (serviceKey is null && IsService(serviceType)) || - (innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; + (_innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; /// public object? GetKeyedService(Type serviceType, object? serviceKey) => serviceKey is null ? GetService(serviceType) : - (innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); + (_innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); /// public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => @@ -50,9 +58,9 @@ public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => /// public void Dispose() => - (innerServices as IDisposable)?.Dispose(); + (_innerServices as IDisposable)?.Dispose(); /// public ValueTask DisposeAsync() => - innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; + _innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index 438421f28..8941e4ed6 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -9,7 +10,7 @@ namespace ModelContextProtocol.Server; /// /// /// This transport provides one-way communication from server to client using the SSE protocol over HTTP, -/// while receiving client messages through a separate mechanism. It writes messages as +/// while receiving client messages through a separate mechanism. It writes messages as /// SSE events to a response stream, typically associated with an HTTP response. /// /// @@ -41,7 +42,7 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? /// /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task RunAsync(CancellationToken cancellationToken) + public async Task RunAsync(CancellationToken cancellationToken = default) { _isConnected = true; await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); @@ -64,6 +65,7 @@ public async ValueTask DisposeAsync() /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } @@ -76,8 +78,8 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// Thrown when there is an attempt to process a message before calling . /// /// - /// This method is the entry point for processing client-to-server communication in the SSE transport model. - /// While the SSE protocol itself is unidirectional (server to client), this method allows bidirectional + /// This method is the entry point for processing client-to-server communication in the SSE transport model. + /// While the SSE protocol itself is unidirectional (server to client), this method allows bidirectional /// communication by handling HTTP POST requests sent to the message endpoint. /// /// @@ -85,11 +87,11 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// process the message and make it available to the MCP server via the channel. /// /// - /// This method validates that the transport is connected before processing the message, ensuring proper - /// sequencing of operations in the transport lifecycle. + /// If an authenticated sent the message, that can be included in the . + /// No other part of the context should be set. /// /// - public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken) + public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Throw.IfNull(message); diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index 18571e2c9..4fb7feafe 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -26,6 +26,8 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) { + Throw.IfNull(sseResponseStream); + // When messageEndpoint is set, the very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single // item of a different type, so we fib and special-case the "endpoint" event type in the formatter. if (messageEndpoint is not null && !_messages.Writer.TryWrite(new SseItem(null, "endpoint"))) diff --git a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs index 556a31159..307c180a1 100644 --- a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs @@ -37,7 +37,7 @@ private static string GetServerName(McpServerOptions serverOptions) { Throw.IfNull(serverOptions); - return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name; + return serverOptions.ServerInfo?.Name ?? McpServerImpl.DefaultImplementation.Name; } // Neither WindowsConsoleStream nor UnixConsoleStream respect CancellationTokens or cancel any I/O on Dispose. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 9d225caa8..1992939de 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,7 +1,9 @@ using ModelContextProtocol.Protocol; +using System.Diagnostics; using System.IO.Pipelines; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; +using System.Security.Claims; using System.Text.Json; using System.Threading.Channels; @@ -9,14 +11,14 @@ namespace ModelContextProtocol.Server; /// /// Handles processing the request/response body pairs for the Streamable HTTP transport. -/// This is typically used via . +/// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, IDuplexPipe httpBodies) : ITransport +internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream) : ITransport { private readonly SseWriter _sseWriter = new(); private RequestId _pendingRequest; - public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); + public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages."); string? ITransport.SessionId => parentTransport.SessionId; @@ -25,11 +27,31 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// - public async ValueTask RunAsync(CancellationToken cancellationToken) + public async ValueTask HandlePostAsync(JsonRpcMessage message, CancellationToken cancellationToken) { - var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), - McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); - await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); + Debug.Assert(_pendingRequest.Id is null); + + if (message is JsonRpcRequest request) + { + _pendingRequest = request.Id; + + // Invoke the initialize request callback if applicable. + if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) + { + var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + await onInitRequest(initializeRequest).ConfigureAwait(false); + } + } + + message.Context ??= new JsonRpcMessageContext(); + message.Context.RelatedTransport = this; + + if (parentTransport.FlowExecutionContextFromRequests) + { + message.Context.ExecutionContext = ExecutionContext.Capture(); + } + + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); if (_pendingRequest.Id is null) { @@ -37,12 +59,14 @@ public async ValueTask RunAsync(CancellationToken cancellationToken) } _sseWriter.MessageFilter = StopOnFinalResponseFilter; - await _sseWriter.WriteAllAsync(httpBodies.Output.AsStream(), cancellationToken).ConfigureAwait(false); + await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false); return true; } public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + if (parentTransport.Stateless && message is JsonRpcRequest) { throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); @@ -69,33 +93,4 @@ public async ValueTask DisposeAsync() } } } - - private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, CancellationToken cancellationToken) - { - if (message is null) - { - throw new InvalidOperationException("Received invalid null message."); - } - - if (message is JsonRpcRequest request) - { - _pendingRequest = request.Id; - - // Invoke the initialize request callback if applicable. - if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) - { - var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - await onInitRequest(initializeRequest).ConfigureAwait(false); - } - } - - message.RelatedTransport = this; - - if (parentTransport.FlowExecutionContextFromRequests) - { - message.ExecutionContext = ExecutionContext.Capture(); - } - - await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); - } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index b63c8a651..57283e9a2 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Protocol; using System.IO.Pipelines; +using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -49,8 +50,8 @@ public sealed class StreamableHttpServerTransport : ITransport public bool Stateless { get; init; } /// - /// Gets a value indicating whether the execution context should flow from the calls to - /// to the corresponding emitted by the . + /// Gets a value indicating whether the execution context should flow from the calls to + /// to the corresponding property contained in the instances returned by the . /// /// /// Defaults to . @@ -75,8 +76,10 @@ public sealed class StreamableHttpServerTransport : ITransport /// The response stream to write MCP JSON-RPC messages as SSE events to. /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken) + public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken = default) { + Throw.IfNull(sseResponseStream); + if (Stateless) { throw new InvalidOperationException("GET requests are not supported in stateless mode."); @@ -96,23 +99,33 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c /// and other correlated messages are sent back to the client directly in response /// to the that initiated the message. /// - /// The duplex pipe facilitates the reading and writing of HTTP request and response data. - /// This token allows for the operation to be canceled if needed. + /// The JSON-RPC message received from the client via the POST request body. + /// This token allows for the operation to be canceled if needed. The default is . + /// The POST response body to write MCP JSON-RPC messages to. /// /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// - public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken) + /// + /// If 's an authenticated sent the message, that can be included in the . + /// No other part of the context should be set. + /// + public async Task HandlePostRequest(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + Throw.IfNull(responseStream); + using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(this, httpBodies); - return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false); + await using var postTransport = new StreamableHttpPostTransport(this, responseStream); + return await postTransport.HandlePostAsync(message, postCts.Token).ConfigureAwait(false); } /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + if (Stateless) { throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode."); @@ -126,6 +139,7 @@ public async ValueTask DisposeAsync() { try { + _incomingChannel.Writer.TryComplete(); await _disposeCts.CancelAsync(); } finally diff --git a/src/ModelContextProtocol.Core/TokenProgress.cs b/src/ModelContextProtocol.Core/TokenProgress.cs index f222fbf71..6b7a91e00 100644 --- a/src/ModelContextProtocol.Core/TokenProgress.cs +++ b/src/ModelContextProtocol.Core/TokenProgress.cs @@ -4,13 +4,13 @@ namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications on the supplied endpoint. +/// progress notifications on the supplied session. /// -internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(McpSession session, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); + _ = session.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/src/ModelContextProtocol/IMcpServerBuilder.cs b/src/ModelContextProtocol/IMcpServerBuilder.cs index 5ec37eba9..016e9eb3e 100644 --- a/src/ModelContextProtocol/IMcpServerBuilder.cs +++ b/src/ModelContextProtocol/IMcpServerBuilder.cs @@ -3,7 +3,7 @@ namespace Microsoft.Extensions.DependencyInjection; /// -/// Provides a builder for configuring instances. +/// Provides a builder for configuring instances. /// /// /// diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index d925b24f6..d4c338262 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -53,6 +53,53 @@ public static partial class McpServerBuilderExtensions return builder; } + /// Adds instances to the service collection backing . + /// The tool type. + /// The builder instance. + /// The target instance from which the tools should be sourced. + /// The serializer options governing tool parameter marshalling. + /// The builder provided in . + /// is . + /// + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those tools directly without scanning for methods on . + /// + /// + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TToolType>( + this IMcpServerBuilder builder, + TToolType target, + JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + if (target is IEnumerable tools) + { + return builder.WithTools(tools); + } + + foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (toolMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerTool.Create( + toolMethod, + toolMethod.IsStatic ? null : target, + new() { Services = services, SerializerOptions = serializerOptions })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -137,7 +184,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnume /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] @@ -193,6 +240,50 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The prompt type. + /// The builder instance. + /// The target instance from which the prompts should be sourced. + /// The serializer options governing prompt parameter marshalling. + /// The builder provided in . + /// is . + /// + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those prompts directly without scanning for methods on . + /// + /// + public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TPromptType>( + this IMcpServerBuilder builder, + TPromptType target, + JsonSerializerOptions? serializerOptions = null) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + if (target is IEnumerable prompts) + { + return builder.WithPrompts(prompts); + } + + foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (promptMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerPrompt.Create(promptMethod, target, new() { Services = services, SerializerOptions = serializerOptions })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -277,7 +368,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnu /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)] @@ -311,7 +402,8 @@ where t.GetCustomAttribute() is not null /// instance for each. For instance members, an instance will be constructed for each invocation of the resource. /// public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( - DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | DynamicallyAccessedMemberTypes.PublicConstructors)] TResourceType>( this IMcpServerBuilder builder) { @@ -330,6 +422,48 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The resource type. + /// The builder instance. + /// The target instance from which the prompts should be sourced. + /// The builder provided in . + /// is . + /// + /// + /// This method discovers all methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each, using as the associated instance for instance methods. + /// + /// + /// However, if is itself an of , + /// this method will register those resources directly without scanning for methods on . + /// + /// + public static IMcpServerBuilder WithResources<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods)] TResourceType>( + this IMcpServerBuilder builder, + TResourceType target) + { + Throw.IfNull(builder); + Throw.IfNull(target); + + if (target is IEnumerable resources) + { + return builder.WithResources(resources); + } + + foreach (var resourceTemplateMethod in typeof(TResourceType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (resourceTemplateMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton(services => McpServerResource.Create(resourceTemplateMethod, target, new() { Services = services })); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// The instances to add to the server. @@ -412,7 +546,7 @@ public static IMcpServerBuilder WithResources(this IMcpServerBuilder builder, IE /// /// /// Note that this method performs reflection at runtime and may not work in Native AOT scenarios. For - /// Native AOT compatibility, consider using the generic method instead. + /// Native AOT compatibility, consider using the generic method instead. /// /// [RequiresUnreferencedCode(WithResourcesRequiresUnreferencedCodeMessage)] @@ -451,7 +585,7 @@ where t.GetCustomAttribute() is not null /// resource system where templates define the URI patterns and the read handler provides the actual content. /// /// - public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -484,7 +618,7 @@ public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServer /// executes them when invoked by clients. /// /// - public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -504,7 +638,7 @@ public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder buil /// This method is typically paired with to provide a complete tools implementation, /// where advertises available tools and this handler executes them. /// - public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -537,7 +671,7 @@ public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder build /// produces them when invoked by clients. /// /// - public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -552,7 +686,7 @@ public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder bu /// The handler function that processes prompt requests. /// The builder provided in . /// is . - public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -573,7 +707,7 @@ public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder buil /// where this handler advertises available resources and the read handler provides their content when requested. /// /// - public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -592,7 +726,7 @@ public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder /// This handler is typically paired with to provide a complete resources implementation, /// where the list handler advertises available resources and the read handler provides their content when requested. /// - public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -611,7 +745,7 @@ public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder b /// The completion handler is invoked when clients request suggestions for argument values. /// This enables auto-complete functionality for both prompt arguments and resource references. /// - public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -641,7 +775,7 @@ public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder build /// resources and to send appropriate notifications through the connection when resources change. /// /// - public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -671,7 +805,7 @@ public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerB /// to the specified resource. /// /// - public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -693,12 +827,12 @@ public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpSer /// and may begin sending log messages at or above the specified level to the client. /// /// - /// Regardless of whether a handler is provided, an should itself handle - /// such notifications by updating its property to return the + /// Regardless of whether a handler is provided, an should itself handle + /// such notifications by updating its property to return the /// most recently set level. /// /// - public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -707,6 +841,278 @@ public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilde } #endregion + #region Filters + /// + /// Adds a filter to the list resource templates handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available resource templates when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. + /// + /// + public static IMcpServerBuilder AddListResourceTemplatesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListResourceTemplatesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list tools handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available tools when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more tools. + /// + /// + /// This filter works alongside any tools defined in the collection. + /// Tools from both sources will be combined when returning results to clients. + /// + /// + public static IMcpServerBuilder AddListToolsFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListToolsFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the call tool handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client makes a call to a tool that isn't found in the collection. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + /// + public static IMcpServerBuilder AddCallToolFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.CallToolFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list prompts handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available prompts when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more prompts. + /// + /// + /// This filter works alongside any prompts defined in the collection. + /// Prompts from both sources will be combined when returning results to clients. + /// + /// + public static IMcpServerBuilder AddListPromptsFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListPromptsFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the get prompt handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client requests details for a specific prompt that isn't found in the collection. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. + /// + /// + public static IMcpServerBuilder AddGetPromptFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.GetPromptFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available resources when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. + /// + /// + public static IMcpServerBuilder AddListResourcesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the read resource handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client requests the content of a specific resource identified by its URI. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to locate and retrieve the requested resource. + /// + /// + public static IMcpServerBuilder AddReadResourceFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ReadResourceFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the complete handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that provide auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the + /// reference type and current argument value. + /// + /// + public static IMcpServerBuilder AddCompleteFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.CompleteFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the subscribe to resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client wants to receive notifications about changes to specific resources or resource patterns. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to register the client's interest in the specified resources + /// and set up the necessary infrastructure to send notifications when those resources change. + /// + /// + /// After a successful subscription, the server should send resource change notifications to the client + /// whenever a relevant resource is created, updated, or deleted. + /// + /// + public static IMcpServerBuilder AddSubscribeToResourcesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.SubscribeToResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the unsubscribe from resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client wants to stop receiving notifications about previously subscribed resources. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to remove the client's subscriptions to the specified resources + /// and clean up any associated resources. + /// + /// + /// After a successful unsubscription, the server should no longer send resource change notifications + /// to the client for the specified resources. + /// + /// + public static IMcpServerBuilder AddUnsubscribeFromResourcesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.UnsubscribeFromResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the set logging level handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that process requests from clients. When set, it enables + /// clients to control which log messages they receive by specifying a minimum severity threshold. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. + /// + /// + /// After handling a level change request, the server typically begins sending log messages + /// at or above the specified level to the client as notifications/message notifications. + /// + /// + public static IMcpServerBuilder AddSetLoggingLevelFilter(this IMcpServerBuilder builder, McpRequestFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.SetLoggingLevelFilters.Add(filter)); + return builder; + } + #endregion + #region Transports /// /// Adds a server transport that uses standard input (stdin) and standard output (stdout) for communication. @@ -774,7 +1180,7 @@ private static void AddSingleSessionServerDependencies(IServiceCollection servic ITransport serverTransport = services.GetRequiredService(); IOptions options = services.GetRequiredService>(); ILoggerFactory? loggerFactory = services.GetService(); - return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + return McpServer.Create(serverTransport, options.Value, loggerFactory, services); }); } #endregion diff --git a/src/ModelContextProtocol/McpServerOptionsSetup.cs b/src/ModelContextProtocol/McpServerOptionsSetup.cs index 7fe4f61cb..030c3012a 100644 --- a/src/ModelContextProtocol/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/McpServerOptionsSetup.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; namespace ModelContextProtocol; @@ -29,7 +30,7 @@ public void Configure(McpServerOptions options) // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants // change notifications, etc. - McpServerPrimitiveCollection toolCollection = options.Capabilities?.Tools?.ToolCollection ?? []; + McpServerPrimitiveCollection toolCollection = options.ToolCollection ?? []; foreach (var tool in serverTools) { toolCollection.TryAdd(tool); @@ -37,16 +38,14 @@ public void Configure(McpServerOptions options) if (!toolCollection.IsEmpty) { - options.Capabilities ??= new(); - options.Capabilities.Tools ??= new(); - options.Capabilities.Tools.ToolCollection = toolCollection; + options.ToolCollection = toolCollection; } // Collect all of the provided prompts into a prompts collection. If the options already has // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants // change notifications, etc. - McpServerPrimitiveCollection promptCollection = options.Capabilities?.Prompts?.PromptCollection ?? []; + McpServerPrimitiveCollection promptCollection = options.PromptCollection ?? []; foreach (var prompt in serverPrompts) { promptCollection.TryAdd(prompt); @@ -54,16 +53,14 @@ public void Configure(McpServerOptions options) if (!promptCollection.IsEmpty) { - options.Capabilities ??= new(); - options.Capabilities.Prompts ??= new(); - options.Capabilities.Prompts.PromptCollection = promptCollection; + options.PromptCollection = promptCollection; } // Collect all of the provided resources into a resources collection. If the options already has // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants // change notifications, etc. - McpServerResourceCollection resourceCollection = options.Capabilities?.Resources?.ResourceCollection ?? []; + McpServerResourceCollection resourceCollection = options.ResourceCollection ?? []; foreach (var resource in serverResources) { resourceCollection.TryAdd(resource); @@ -71,12 +68,71 @@ public void Configure(McpServerOptions options) if (!resourceCollection.IsEmpty) { - options.Capabilities ??= new(); - options.Capabilities.Resources ??= new(); - options.Capabilities.Resources.ResourceCollection = resourceCollection; + options.ResourceCollection = resourceCollection; } // Apply custom server handlers. - serverHandlers.Value.OverwriteWithSetHandlers(options); + OverwriteWithSetHandlers(serverHandlers.Value, options); + } + + /// + /// Overwrite any handlers in McpServerOptions with non-null handlers from this instance. + /// + private static void OverwriteWithSetHandlers(McpServerHandlers handlers, McpServerOptions options) + { + McpServerHandlers optionsHandlers = options.Handlers; + + PromptsCapability? promptsCapability = options.Capabilities?.Prompts; + if (handlers.ListPromptsHandler is not null || handlers.GetPromptHandler is not null) + { + promptsCapability ??= new(); + optionsHandlers.ListPromptsHandler = handlers.ListPromptsHandler ?? optionsHandlers.ListPromptsHandler; + optionsHandlers.GetPromptHandler = handlers.GetPromptHandler ?? optionsHandlers.GetPromptHandler; + } + + ResourcesCapability? resourcesCapability = options.Capabilities?.Resources; + if (handlers.ListResourceTemplatesHandler is not null || handlers.ListResourcesHandler is not null || handlers.ReadResourceHandler is not null) + { + resourcesCapability ??= new(); + optionsHandlers.ListResourceTemplatesHandler = handlers.ListResourceTemplatesHandler ?? optionsHandlers.ListResourceTemplatesHandler; + optionsHandlers.ListResourcesHandler = handlers.ListResourcesHandler ?? optionsHandlers.ListResourcesHandler; + optionsHandlers.ReadResourceHandler = handlers.ReadResourceHandler ?? optionsHandlers.ReadResourceHandler; + + if (handlers.SubscribeToResourcesHandler is not null || handlers.UnsubscribeFromResourcesHandler is not null) + { + optionsHandlers.SubscribeToResourcesHandler = handlers.SubscribeToResourcesHandler ?? optionsHandlers.SubscribeToResourcesHandler; + optionsHandlers.UnsubscribeFromResourcesHandler = handlers.UnsubscribeFromResourcesHandler ?? optionsHandlers.UnsubscribeFromResourcesHandler; + resourcesCapability.Subscribe = true; + } + } + + ToolsCapability? toolsCapability = options.Capabilities?.Tools; + if (handlers.ListToolsHandler is not null || handlers.CallToolHandler is not null) + { + toolsCapability ??= new(); + optionsHandlers.ListToolsHandler = handlers.ListToolsHandler ?? optionsHandlers.ListToolsHandler; + optionsHandlers.CallToolHandler = handlers.CallToolHandler ?? optionsHandlers.CallToolHandler; + } + + LoggingCapability? loggingCapability = options.Capabilities?.Logging; + if (handlers.SetLoggingLevelHandler is not null) + { + loggingCapability ??= new(); + optionsHandlers.SetLoggingLevelHandler = handlers.SetLoggingLevelHandler; + } + + CompletionsCapability? completionsCapability = options.Capabilities?.Completions; + if (handlers.CompleteHandler is not null) + { + completionsCapability ??= new(); + optionsHandlers.CompleteHandler = handlers.CompleteHandler; + } + + options.Capabilities ??= new(); + options.Capabilities.Prompts = promptsCapability; + options.Capabilities.Resources = resourcesCapability; + options.Capabilities.Tools = toolsCapability; + options.Capabilities.Logging = loggingCapability; + options.Capabilities.Completions = completionsCapability; } } diff --git a/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs b/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs index b50e46140..80e8216a8 100644 --- a/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs +++ b/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs @@ -10,7 +10,7 @@ namespace ModelContextProtocol; /// /// The host's application lifetime. If available, it will have termination requested when the session's run completes. /// -internal sealed class SingleSessionMcpServerHostedService(IMcpServer session, IHostApplicationLifetime? lifetime = null) : BackgroundService +internal sealed class SingleSessionMcpServerHostedService(McpServer session, IHostApplicationLifetime? lifetime = null) : BackgroundService { /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) diff --git a/tests/Common/Utils/MockLoggerProvider.cs b/tests/Common/Utils/MockLoggerProvider.cs index f5264edc4..14a0f401a 100644 --- a/tests/Common/Utils/MockLoggerProvider.cs +++ b/tests/Common/Utils/MockLoggerProvider.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Tests.Utils; public class MockLoggerProvider() : ILoggerProvider { - public ConcurrentQueue<(string Category, LogLevel LogLevel, string Message, Exception? Exception)> LogMessages { get; } = []; + public ConcurrentQueue<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> LogMessages { get; } = []; public ILogger CreateLogger(string categoryName) { @@ -21,7 +21,7 @@ private class MockLogger(MockLoggerProvider mockProvider, string category) : ILo public void Log( LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { - mockProvider.LogMessages.Enqueue((category, logLevel, formatter(state, exception), exception)); + mockProvider.LogMessages.Enqueue((category, logLevel, eventId, formatter(state, exception), exception)); } public bool IsEnabled(LogLevel logLevel) => true; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs index 6a48c21d2..9144121e8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs @@ -106,7 +106,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport( + await using var transport = new HttpClientTransport( new() { Endpoint = new(McpServerUrl), @@ -122,7 +122,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() LoggerFactory ); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken @@ -140,7 +140,9 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport( + DynamicClientRegistrationResponse? dcrResponse = null; + + await using var transport = new HttpClientTransport( new() { Endpoint = new(McpServerUrl), @@ -148,20 +150,32 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() { RedirectUri = new Uri("http://localhost:1179/callback"), AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, - ClientName = "Test MCP Client", - ClientUri = new Uri("https://example.com"), Scopes = ["mcp:tools"], + DynamicClientRegistration = new() + { + ClientName = "Test MCP Client", + ClientUri = new Uri("https://example.com"), + ResponseDelegate = (response, cancellationToken) => + { + dcrResponse = response; + return Task.CompletedTask; + }, + }, }, }, HttpClient, LoggerFactory ); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken ); + + Assert.NotNull(dcrResponse); + Assert.False(string.IsNullOrEmpty(dcrResponse.ClientId)); + Assert.False(string.IsNullOrEmpty(dcrResponse.ClientSecret)); } [Fact] @@ -289,6 +303,87 @@ public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvide Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); } + [Fact] + public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCalled() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + // Override the configuration to test HandleResponse behavior + Builder.Services.Configure( + McpAuthenticationDefaults.AuthenticationScheme, + options => + { + options.ResourceMetadata = null; + options.Events.OnResourceMetadataRequest = async context => + { + // Call HandleResponse() to discontinue processing and return to client + context.HandleResponse(); + await Task.CompletedTask; + }; + } + ); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + // Make a direct request to the resource metadata endpoint + using var response = await HttpClient.GetAsync( + "/.well-known/oauth-protected-resource", + TestContext.Current.CancellationToken + ); + + // The request should be handled by the event handler without returning metadata + // Since HandleResponse() was called, the handler should have taken responsibility + // for generating the response, which in this case means an empty response + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // The response should be empty since the event handler called HandleResponse() + // but didn't write any content to the response + var content = await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken); + Assert.Empty(content); + } + + [Fact] + public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + // Override the configuration to test SkipHandler behavior + Builder.Services.Configure( + McpAuthenticationDefaults.AuthenticationScheme, + options => + { + options.ResourceMetadata = null; + options.Events.OnResourceMetadataRequest = async context => + { + // Call SkipHandler() to discontinue processing in the current handler + context.SkipHandler(); + await Task.CompletedTask; + }; + } + ); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + // Make a direct request to the resource metadata endpoint + using var response = await HttpClient.GetAsync( + "/.well-known/oauth-protected-resource", + TestContext.Current.CancellationToken + ); + + // When SkipHandler() is called, the authentication handler should skip processing + // and let other handlers in the pipeline handle the request. Since there are no + // other handlers configured for this endpoint, this should result in a 404 + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + private async Task HandleAuthorizationUrlAsync( Uri authorizationUri, Uri redirectUri, diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs index 2252b1b7c..fff7d6d42 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs @@ -97,7 +97,7 @@ public async Task CanAuthenticate() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -109,7 +109,7 @@ public async Task CanAuthenticate() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -124,12 +124,12 @@ public async Task CannotAuthenticate_WithoutOAuthConfiguration() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), }, HttpClient, LoggerFactory); - var httpEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + var httpEx = await Assert.ThrowsAsync(async () => await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal(HttpStatusCode.Unauthorized, httpEx.StatusCode); @@ -146,7 +146,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -159,7 +159,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() }, HttpClient, LoggerFactory); // The EqualException is thrown by HandleAuthorizationUrlAsync when the /authorize request gets a 400 - var equalEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + var equalEx = await Assert.ThrowsAsync(async () => await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } @@ -174,20 +174,23 @@ public async Task CanAuthenticate_WithDynamicClientRegistration() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new ClientOAuthOptions() { RedirectUri = new Uri("http://localhost:1179/callback"), AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, - ClientName = "Test MCP Client", - ClientUri = new Uri("https://example.com"), - Scopes = ["mcp:tools"] + Scopes = ["mcp:tools"], + DynamicClientRegistration = new() + { + ClientName = "Test MCP Client", + ClientUri = new Uri("https://example.com"), + }, }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -202,7 +205,7 @@ public async Task CanAuthenticate_WithTokenRefresh() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -216,7 +219,7 @@ public async Task CanAuthenticate_WithTokenRefresh() // The test-refresh-client should get an expired token first, // then automatically refresh it to get a working token - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); Assert.True(_testOAuthServer.HasIssuedRefreshToken); @@ -233,7 +236,7 @@ public async Task CanAuthenticate_WithExtraParams() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -249,7 +252,7 @@ public async Task CanAuthenticate_WithExtraParams() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(_lastAuthorizationUri?.Query); @@ -267,7 +270,7 @@ public async Task CannotOverrideExistingParameters_WithExtraParams() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -283,7 +286,7 @@ public async Task CannotOverrideExistingParameters_WithExtraParams() }, }, HttpClient, LoggerFactory); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + await Assert.ThrowsAsync(() => McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs new file mode 100644 index 000000000..84d1c1a79 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -0,0 +1,535 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.ComponentModel; +using System.Security.Claims; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for MCP authorization functionality with [Authorize], [AllowAnonymous] and role-based authorization. +/// +public class AuthorizeAttributeTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private readonly MockLoggerProvider _mockLoggerProvider = new(); + + private async Task ConnectAsync() + { + await using var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = new("http://localhost:5000"), + }, HttpClient, LoggerFactory); + + return await McpClient.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken, loggerFactory: LoggerFactory); + } + + [Fact] + public async Task Authorize_Tool_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This tool requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task ClassLevelAuthorize_Tool_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "anonymous_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Anonymous: test", content.Text); + } + + [Fact] + public async Task AllowAnonymous_Tool_AllowsAnonymousAccess() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "anonymous_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Anonymous: test", content.Text); + } + + [Fact] + public async Task Authorize_Tool_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Authorized: test", content.Text); + } + + [Fact] + public async Task AuthorizeWithRoles_Tool_RequiresAdminRole() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This tool requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task AuthorizeWithRoles_Tool_AllowsAdminUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "AdminUser", "Admin"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Admin: test", content.Text); + } + + [Fact] + public async Task ListTools_Anonymous_OnlyReturnsAnonymousTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(tools); + Assert.Equal("anonymous_tool", tools[0].Name); + } + + [Fact] + public async Task ListTools_AuthenticatedUser_ReturnsAuthorizedTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Authenticated user should see anonymous and basic authorized tools, but not admin-only tools + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task ListTools_AdminUser_ReturnsAllTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "AdminUser", "Admin"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Admin user should see all tools + Assert.Equal(3, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["admin_tool", "anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task ListTools_UserRole_DoesNotReturnAdminTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // User with User role should not see admin-only tools + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task Authorize_Prompt_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This prompt requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task Authorize_Prompt_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + var message = Assert.Single(result.Messages); + Assert.Equal(Role.User, message.Role); + var content = Assert.IsType(message.Content); + Assert.Equal("Authorized prompt: test", content.Text); + } + + [Fact] + public async Task ListPrompts_Anonymous_OnlyReturnsAnonymousPrompts() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts()); + + var client = await ConnectAsync(); + var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Anonymous user should only see prompts marked with [AllowAnonymous] + Assert.Single(prompts); + Assert.Equal("anonymous_prompt", prompts[0].Name); + } + + [Fact] + public async Task Authorize_Resource_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This resource requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task Authorize_Resource_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(result.Contents.OfType()); + Assert.Equal("Authorized resource content", content.Text); + } + + [Fact] + public async Task ListResources_Anonymous_OnlyReturnsAnonymousResources() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources()); + + var client = await ConnectAsync(); + var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(resources); + Assert.Equal("resource://anonymous", resources[0].Uri); + } + + [Fact] + public async Task ListTools_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for tools/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task CallTool_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for tools/call operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ListPrompts_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for prompts/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task GetPrompt_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for prompts/get operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ListResources_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for resources/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ReadResource_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for resources/read operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ListResourceTemplates_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for resources/templates/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + private async Task StartServerWithAuth(Action configure, string? userName = null, params string[] roles) + { + var mcpServerBuilder = Builder.Services.AddMcpServer().WithHttpTransport().AddAuthorizationFilters(); + configure(mcpServerBuilder); + + Builder.Services.AddAuthorization(); + Builder.Services.AddSingleton(_mockLoggerProvider); + + var app = Builder.Build(); + + if (userName is not null) + { + app.Use(next => + { + return async context => + { + context.User = CreateUser(userName, roles); + await next(context); + }; + }); + } + + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private async Task StartServerWithoutAuthFilters(Action configure) + { + var mcpServerBuilder = Builder.Services.AddMcpServer().WithHttpTransport(); // No AddAuthorizationFilters() call + configure(mcpServerBuilder); + + Builder.Services.AddAuthorization(); + Builder.Services.AddSingleton(_mockLoggerProvider); + + var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private ClaimsPrincipal CreateUser(string name, params string[] roles) + => new ClaimsPrincipal(new ClaimsIdentity( + [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), .. roles.Select(role => new Claim("role", role))], + "TestAuthType", "name", "role")); + + [McpServerToolType] + private class AuthorizationTestTools + { + [McpServerTool, Description("A tool that allows anonymous access.")] + public static string AnonymousTool(string message) + { + return $"Anonymous: {message}"; + } + + [McpServerTool, Description("A tool that requires authorization.")] + [Authorize] + public static string AuthorizedTool(string message) + { + return $"Authorized: {message}"; + } + + [McpServerTool, Description("A tool that requires Admin role.")] + [Authorize(Roles = "Admin")] + public static string AdminTool(string message) + { + return $"Admin: {message}"; + } + } + + [McpServerToolType] + [Authorize] + private class AllowAnonymousTestTools + { + [McpServerTool, Description("A tool that allows anonymous access.")] + [AllowAnonymous] + public static string AnonymousTool(string message) + { + return $"Anonymous: {message}"; + } + + [McpServerTool, Description("A tool that requires authorization.")] + public static string AuthorizedTool(string message) + { + return $"Authorized: {message}"; + } + } + + [McpServerPromptType] + private class AuthorizationTestPrompts + { + [McpServerPrompt, Description("A prompt that allows anonymous access.")] + public static string AnonymousPrompt(string message) + { + return $"Anonymous prompt: {message}"; + } + + [McpServerPrompt, Description("A prompt that requires authorization.")] + [Authorize] + public static string AuthorizedPrompt(string message) + { + return $"Authorized prompt: {message}"; + } + } + + [McpServerResourceType] + private class AuthorizationTestResources + { + [McpServerResource(UriTemplate = "resource://anonymous"), Description("A resource that allows anonymous access.")] + public static string AnonymousResource() + { + return "Anonymous resource content"; + } + + [McpServerResource(UriTemplate = "resource://authorized"), Description("A resource that requires authorization.")] + [Authorize] + public static string AuthorizedResource() + { + return "Authorized resource content"; + } + + [McpServerResource(UriTemplate = "resource://authorized/{id}"), Description("A resource template that requires authorization.")] + [Authorize] + public static string AuthorizedResourceWithTemplate(string id) + { + return "Authorized resource content"; + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 9b3c91b94..1d27a219e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -21,9 +21,9 @@ public override void Dispose() base.Dispose(); } - protected abstract SseClientTransportOptions ClientTransportOptions { get; } + protected abstract HttpClientTransportOptions ClientTransportOptions { get; } - private Task GetClientAsync(McpClientOptions? options = null) + private Task GetClientAsync(McpClientOptions? options = null) { return _fixture.ConnectMcpClientAsync(options, LoggerFactory); } @@ -52,6 +52,7 @@ public async Task Connect_TestServer_ShouldProvideServerFields() // Assert Assert.NotNull(client.ServerCapabilities); Assert.NotNull(client.ServerInfo); + Assert.NotNull(client.NegotiatedProtocolVersion); if (ClientTransportOptions.Endpoint.AbsolutePath.EndsWith("/sse")) { @@ -250,9 +251,7 @@ public async Task Sampling_Sse_TestServer() int samplingHandlerCalls = 0; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously McpClientOptions options = new(); - options.Capabilities = new(); - options.Capabilities.Sampling ??= new(); - options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => + options.Handlers.SamplingHandler = async (_, _, _) => { samplingHandlerCalls++; return new CreateMessageResult diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index f31621307..728304070 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -13,7 +13,7 @@ public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(output [InlineData("/mcp/secondary")] public async Task Allows_Customizing_Route(string pattern) { - Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless); + Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); app.MapMcp(pattern); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index cb1f86db9..f8b61aa21 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -171,13 +171,13 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia await app.StartAsync(TestContext.Current.CancellationToken); - await using (var mcpClient = await ConnectAsync(clientOptions: new() + await using var mcpClient = await ConnectAsync(clientOptions: new() { ProtocolVersion = "2025-03-26", - })) - { - await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - } + }); + + Assert.Equal("2025-03-26", mcpClient.NegotiatedProtocolVersion); + await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); // The header should be included in the GET request, the initialized notification, the tools/list call, and the delete request. Assert.NotEmpty(protocolVersionHeaderValues); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 4d0d73562..341171c51 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -23,21 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync( + protected async Task ConnectAsync( string? path = null, - SseClientTransportOptions? transportOptions = null, + HttpClientTransportOptions? transportOptions = null, McpClientOptions? clientOptions = null) { // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; - await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions + await using var transport = new HttpClientTransport(transportOptions ?? new HttpClientTransportOptions { Endpoint = new Uri($"http://localhost:5000{path}"), TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); - return await McpClientFactory.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); + return await McpClient.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); } [Fact] @@ -111,6 +111,35 @@ public async Task Messages_FromNewUser_AreRejected() Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode); } + [Fact] + public async Task ClaimsPrincipal_CanBeInjectedIntoToolMethod() + { + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => async context => + { + context.User = CreateUser("TestUser"); + await next(context); + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var client = await ConnectAsync(); + + var response = await client.CallToolAsync( + "echo_claims_principal", + new Dictionary() { ["message"] = "Hello world!" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(response.Content.OfType()); + Assert.Equal("TestUser: Hello world!", content.Text); + } + [Fact] public async Task Sampling_DoesNotCloseStream_Prematurely() { @@ -132,29 +161,26 @@ public async Task Sampling_DoesNotCloseStream_Prematurely() await app.StartAsync(TestContext.Current.CancellationToken); var sampleCount = 0; - var clientOptions = new McpClientOptions + var clientOptions = new McpClientOptions() { - Capabilities = new() + Handlers = new() { - Sampling = new() + SamplingHandler = async (parameters, _, _) => { - SamplingHandler = async (parameters, _, _) => + Assert.NotNull(parameters?.Messages); + var message = Assert.Single(parameters.Messages); + Assert.Equal(Role.User, message.Role); + Assert.Equal("Test prompt for sampling", Assert.IsType(message.Content).Text); + + sampleCount++; + return new CreateMessageResult { - Assert.NotNull(parameters?.Messages); - var message = Assert.Single(parameters.Messages); - Assert.Equal(Role.User, message.Role); - Assert.Equal("Test prompt for sampling", Assert.IsType(message.Content).Text); - - sampleCount++; - return new CreateMessageResult - { - Model = "test-model", - Role = Role.Assistant, - Content = new TextContentBlock { Text = "Sampling response from client" }, - }; - }, - }, - }, + Model = "test-model", + Role = Role.Assistant, + Content = new TextContentBlock { Text = "Sampling response from client" }, + }; + } + } }; await using var mcpClient = await ConnectAsync(clientOptions: clientOptions); @@ -200,11 +226,22 @@ public string EchoWithUserName(string message) } } + [McpServerToolType] + protected class ClaimsPrincipalTools + { + [McpServerTool, Description("Echoes the input back to the client with the user name from ClaimsPrincipal.")] + public string EchoClaimsPrincipal(ClaimsPrincipal? user, string message) + { + var userName = user?.Identity?.Name ?? "anonymous"; + return $"{userName}: {message}"; + } + } + [McpServerToolType] private class SamplingRegressionTools { [McpServerTool(Name = "sampling-tool")] - public static async Task SamplingToolAsync(IMcpServer server, string prompt, CancellationToken cancellationToken) + public static async Task SamplingToolAsync(McpServer server, string prompt, CancellationToken cancellationToken) { // This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464 // 1. The client calls tool with request ID 2, because it's the first request after the initialize request. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 8191f6091..ffec1a4be 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -15,15 +15,15 @@ namespace ModelContextProtocol.AspNetCore.Tests; public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { - private readonly SseClientTransportOptions DefaultTransportOptions = new() + private readonly HttpClientTransportOptions DefaultTransportOptions = new() { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", }; - private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) - => McpClientFactory.CreateAsync( - new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), + private Task ConnectMcpClientAsync(HttpClient? httpClient = null, HttpClientTransportOptions? transportOptions = null) + => McpClient.CreateAsync( + new HttpClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -195,7 +195,7 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - var sseOptions = new SseClientTransportOptions + var sseOptions = new HttpClientTransportOptions { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", @@ -222,7 +222,7 @@ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - var sseOptions = new SseClientTransportOptions + var sseOptions = new HttpClientTransportOptions { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", @@ -257,7 +257,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, b try { var transportTask = transport.RunAsync(cancellationToken: requestAborted); - await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); + await using var server = McpServer.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); try { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 2aa675c84..c382c4385 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -18,7 +18,7 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable // multiple tests, so this dispatches the output to the current test. private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - private SseClientTransportOptions DefaultTransportOptions { get; set; } = new() + private HttpClientTransportOptions DefaultTransportOptions { get; set; } = new() { Endpoint = new("http://localhost:5000/"), }; @@ -44,16 +44,16 @@ public SseServerIntegrationTestFixture() public HttpClient HttpClient { get; } - public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) + public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) { - return McpClientFactory.CreateAsync( - new SseClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), + return McpClient.CreateAsync( + new HttpClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), options, loggerFactory, TestContext.Current.CancellationToken); } - public void Initialize(ITestOutputHelper output, SseClientTransportOptions clientTransportOptions) + public void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) { _delegatingTestOutputHelper.CurrentTestOutputHelper = output; DefaultTransportOptions = clientTransportOptions; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index 2d4a78685..eb7db0110 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -8,7 +8,7 @@ public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, : HttpServerIntegrationTests(fixture, testOutputHelper) { - protected override SseClientTransportOptions ClientTransportOptions => new() + protected override HttpClientTransportOptions ClientTransportOptions => new() { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index d16e510cc..2ce63a1bc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) : StreamableHttpServerIntegrationTests(fixture, testOutputHelper) { - protected override SseClientTransportOptions ClientTransportOptions => new() + protected override HttpClientTransportOptions ClientTransportOptions => new() { Endpoint = new("http://localhost:5000/stateless"), Name = "In-memory Streamable HTTP Client", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index b50a43edc..199d815eb 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -14,7 +14,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem { private WebApplication? _app; - private readonly SseClientTransportOptions DefaultTransportOptions = new() + private readonly HttpClientTransportOptions DefaultTransportOptions = new() { Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", @@ -58,9 +58,9 @@ private async Task StartAsync() HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); } - private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) - => McpClientFactory.CreateAsync( - new SseClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), + private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) + => McpClient.CreateAsync( + new HttpClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), clientOptions, LoggerFactory, TestContext.Current.CancellationToken); public async ValueTask DisposeAsync() @@ -102,9 +102,7 @@ public async Task SamplingRequest_Fails_WithInvalidOperationException() await StartAsync(); var mcpClientOptions = new McpClientOptions(); - mcpClientOptions.Capabilities = new(); - mcpClientOptions.Capabilities.Sampling ??= new(); - mcpClientOptions.Capabilities.Sampling.SamplingHandler = (_, _, _) => + mcpClientOptions.Handlers.SamplingHandler = (_, _, _) => { throw new UnreachableException(); }; @@ -122,9 +120,7 @@ public async Task RootsRequest_Fails_WithInvalidOperationException() await StartAsync(); var mcpClientOptions = new McpClientOptions(); - mcpClientOptions.Capabilities = new(); - mcpClientOptions.Capabilities.Roots ??= new(); - mcpClientOptions.Capabilities.Roots.RootsHandler = (_, _) => + mcpClientOptions.Handlers.RootsHandler = (_, _) => { throw new UnreachableException(); }; @@ -142,9 +138,7 @@ public async Task ElicitRequest_Fails_WithInvalidOperationException() await StartAsync(); var mcpClientOptions = new McpClientOptions(); - mcpClientOptions.Capabilities = new(); - mcpClientOptions.Capabilities.Elicitation ??= new(); - mcpClientOptions.Capabilities.Elicitation.ElicitationHandler = (_, _) => + mcpClientOptions.Handlers.ElicitationHandler = (_, _) => { throw new UnreachableException(); }; @@ -194,7 +188,7 @@ public async Task ScopedServices_Resolve_FromRequestScope() } [McpServerTool(Name = "testSamplingErrors")] - public static async Task TestSamplingErrors(IMcpServer server) + public static async Task TestSamplingErrors(McpServer server) { const string expectedSamplingErrorMessage = "Sampling is not supported in stateless mode."; @@ -212,7 +206,7 @@ public static async Task TestSamplingErrors(IMcpServer server) } [McpServerTool(Name = "testRootsErrors")] - public static async Task TestRootsErrors(IMcpServer server) + public static async Task TestRootsErrors(McpServer server) { const string expectedRootsErrorMessage = "Roots are not supported in stateless mode."; @@ -227,7 +221,7 @@ public static async Task TestRootsErrors(IMcpServer server) } [McpServerTool(Name = "testElicitationErrors")] - public static async Task TestElicitationErrors(IMcpServer server) + public static async Task TestElicitationErrors(McpServer server) { const string expectedElicitationErrorMessage = "Elicitation is not supported in stateless mode."; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 7ce3516ef..f1cd458f9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -112,13 +112,13 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() { await StartAsync(); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echoTool = Assert.Single(tools); @@ -132,13 +132,13 @@ public async Task CanCallToolConcurrently() { await StartAsync(); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echoTool = Assert.Single(tools); @@ -158,13 +158,13 @@ public async Task SendsDeleteRequestOnDispose() { await StartAsync(enableDelete: true); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Dispose should trigger DELETE request await client.DisposeAsync(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 0b3ae4c2a..7b2be8f98 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -252,7 +252,7 @@ public async Task MultipleConcurrentJsonRpcRequests_IsHandled_InParallel() [Fact] public async Task GetRequest_Receives_UnsolicitedNotifications() { - IMcpServer? server = null; + McpServer? server = null; Builder.Services.AddMcpServer() .WithHttpTransport(options => @@ -505,7 +505,6 @@ public async Task IdleSessionsPastMaxIdleSessionCount_ArePruned_LongestIdleFirst Assert.NotEqual(secondSessionId, thirdSessionId); // Pruning of the second session results in a 404 since we used the first session more recently. - fakeTimeProvider.Advance(TimeSpan.FromSeconds(10)); SetSessionId(secondSessionId); using var response = await HttpClient.PostAsync("", JsonContent(EchoRequest), TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -517,8 +516,9 @@ public async Task IdleSessionsPastMaxIdleSessionCount_ArePruned_LongestIdleFirst SetSessionId(thirdSessionId); await CallEchoAndValidateAsync(); - var logMessage = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Critical); - Assert.StartsWith("Exceeded maximum of 2 idle sessions.", logMessage.Message); + var idleLimitLogMessage = Assert.Single(mockLoggerProvider.LogMessages, m => m.EventId.Name == "LogIdleSessionLimit"); + Assert.Equal(LogLevel.Information, idleLimitLogMessage.LogLevel); + Assert.StartsWith("MaxIdleSessionCount of 2 exceeded. Closing idle session", idleLimitLogMessage.Message); } private static StringContent JsonContent(string json) => new StringContent(json, Encoding.UTF8, "application/json"); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 3524c60a4..b2b0b5499 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -11,7 +11,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} """; - protected override SseClientTransportOptions ClientTransportOptions => new() + protected override HttpClientTransportOptions ClientTransportOptions => new() { Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", diff --git a/tests/ModelContextProtocol.TestOAuthServer/Program.cs b/tests/ModelContextProtocol.TestOAuthServer/Program.cs index 3970394b6..bb251035d 100644 --- a/tests/ModelContextProtocol.TestOAuthServer/Program.cs +++ b/tests/ModelContextProtocol.TestOAuthServer/Program.cs @@ -14,7 +14,7 @@ public sealed class Program private const int _port = 7029; private static readonly string _url = $"https://localhost:{_port}"; - // Port 5000 is used by tests and port 7071 is used by the ProtectedMCPServer sample + // Port 5000 is used by tests and port 7071 is used by the ProtectedMcpServer sample private static readonly string[] ValidResources = ["http://localhost:5000/", "http://localhost:7071/"]; private readonly ConcurrentDictionary _authCodes = new(); diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 0bc4134fa..9765ed928 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -1,10 +1,10 @@ -using Microsoft.Extensions.Logging; +using System.Collections.Concurrent; +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Serilog; -using System.Collections.Concurrent; -using System.Text; -using System.Text.Json; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously @@ -36,22 +36,22 @@ private static async Task Main(string[] args) { Log.Logger.Information("Starting server..."); + string? cliArg = ParseCliArgument(args); McpServerOptions options = new() { - Capabilities = new ServerCapabilities - { - Tools = ConfigureTools(), - Resources = ConfigureResources(), - Prompts = ConfigurePrompts(), - Logging = ConfigureLogging(), - Completions = ConfigureCompletions(), - }, + Capabilities = new ServerCapabilities(), ServerInstructions = "This is a test server with only stub functionality", }; + ConfigureTools(options, cliArg); + ConfigureResources(options); + ConfigurePrompts(options); + ConfigureLogging(options); + ConfigureCompletions(options); + using var loggerFactory = CreateLoggerFactory(); await using var stdioTransport = new StdioServerTransport("TestServer", loggerFactory); - await using IMcpServer server = McpServerFactory.Create(stdioTransport, options, loggerFactory); + await using McpServer server = McpServer.Create(stdioTransport, options, loggerFactory); Log.Logger.Information("Server running..."); @@ -61,7 +61,7 @@ private static async Task Main(string[] args) await server.RunAsync(); } - private static async Task RunBackgroundLoop(IMcpServer server, CancellationToken cancellationToken = default) + private static async Task RunBackgroundLoop(McpServer server, CancellationToken cancellationToken = default) { var loggingLevels = (LoggingLevel[])Enum.GetValues(typeof(LoggingLevel)); var random = new Random(); @@ -105,222 +105,222 @@ await server.SendMessageAsync(new JsonRpcNotification } } - private static ToolsCapability ConfigureTools() + private static void ConfigureTools(McpServerOptions options, string? cliArg) { - return new() + options.Handlers.ListToolsHandler = async (request, cancellationToken) => { - ListToolsHandler = async (request, cancellationToken) => + return new ListToolsResult { - return new ListToolsResult - { - Tools = - [ - new Tool - { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "The input to echo back." - } - }, - "required": ["message"] - } - """), - }, - new Tool - { - Name = "echoSessionId", - Description = "Echoes the session id back to the client.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object" - } - """, McpJsonUtilities.DefaultOptions), - }, - new Tool - { - Name = "sampleLLM", - Description = "Samples from an LLM using MCP's sampling feature.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "The prompt to send to the LLM" - }, - "maxTokens": { - "type": "number", - "description": "Maximum number of tokens to generate" - } + Tools = + [ + new Tool + { + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back." + } + }, + "required": ["message"] + } + """), + }, + new Tool + { + Name = "echoSessionId", + Description = "Echoes the session id back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object" + } + """, McpJsonUtilities.DefaultOptions), + }, + new Tool + { + Name = "sampleLLM", + Description = "Samples from an LLM using MCP's sampling feature.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to send to the LLM" }, - "required": ["prompt", "maxTokens"] - } - """), - } - ] - }; - }, - - CallToolHandler = async (request, cancellationToken) => + "maxTokens": { + "type": "number", + "description": "Maximum number of tokens to generate" + } + }, + "required": ["prompt", "maxTokens"] + } + """), + } + ] + }; + }; + options.Handlers.CallToolHandler = async (request, cancellationToken) => + { + if (request.Params?.Name == "echo") { - if (request.Params?.Name == "echo") + if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) { - if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) - { - throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); - } - return new CallToolResult - { - Content = [new TextContentBlock { Text = $"Echo: {message}" }] - }; + throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); } - else if (request.Params?.Name == "echoSessionId") + return new CallToolResult { - return new CallToolResult - { - Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] - }; - } - else if (request.Params?.Name == "sampleLLM") + Content = [new TextContentBlock { Text = $"Echo: {message}" }] + }; + } + else if (request.Params?.Name == "echoSessionId") + { + return new CallToolResult { - if (request.Params?.Arguments is null || - !request.Params.Arguments.TryGetValue("prompt", out var prompt) || - !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) - { - throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); - } - var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), - cancellationToken); - - return new CallToolResult - { - Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] - }; - } - else + Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] + }; + } + else if (request.Params?.Name == "sampleLLM") + { + if (request.Params?.Arguments is null || + !request.Params.Arguments.TryGetValue("prompt", out var prompt) || + !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) { - throw new McpException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); + throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } + var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), + cancellationToken); + + return new CallToolResult + { + Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] + }; + } + else if (request.Params?.Name == "echoCliArg") + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = cliArg ?? "null" }] + }; + } + else + { + throw new McpException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); } }; } - private static PromptsCapability ConfigurePrompts() + private static void ConfigurePrompts(McpServerOptions options) { - return new() + options.Handlers.ListPromptsHandler = async (request, cancellationToken) => { - ListPromptsHandler = async (request, cancellationToken) => + return new ListPromptsResult { - return new ListPromptsResult - { - Prompts = [ - new Prompt - { - Name = "simple_prompt", - Description = "A prompt without arguments" - }, - new Prompt - { - Name = "complex_prompt", - Description = "A prompt with arguments", - Arguments = - [ - new PromptArgument - { - Name = "temperature", - Description = "Temperature setting", - Required = true - }, - new PromptArgument - { - Name = "style", - Description = "Output style", - Required = false - } - ] - } - ] - }; - }, + Prompts = [ + new Prompt + { + Name = "simple_prompt", + Description = "A prompt without arguments" + }, + new Prompt + { + Name = "complex_prompt", + Description = "A prompt with arguments", + Arguments = + [ + new PromptArgument + { + Name = "temperature", + Description = "Temperature setting", + Required = true + }, + new PromptArgument + { + Name = "style", + Description = "Output style", + Required = false + } + ] + } + ] + }; + }; - GetPromptHandler = async (request, cancellationToken) => + options.Handlers.GetPromptHandler = async (request, cancellationToken) => + { + List messages = []; + if (request.Params?.Name == "simple_prompt") { - List messages = []; - if (request.Params?.Name == "simple_prompt") + messages.Add(new PromptMessage { - messages.Add(new PromptMessage - { - Role = Role.User, - Content = new TextContentBlock { Text = "This is a simple prompt without arguments." }, - }); - } - else if (request.Params?.Name == "complex_prompt") + Role = Role.User, + Content = new TextContentBlock { Text = "This is a simple prompt without arguments." }, + }); + } + else if (request.Params?.Name == "complex_prompt") + { + string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; + string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; + messages.Add(new PromptMessage { - string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; - string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; - messages.Add(new PromptMessage - { - Role = Role.User, - Content = new TextContentBlock { Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" }, - }); - messages.Add(new PromptMessage - { - Role = Role.Assistant, - Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, - }); - messages.Add(new PromptMessage - { - Role = Role.User, - Content = new ImageContentBlock - { - Data = MCP_TINY_IMAGE, - MimeType = "image/png" - } - }); - } - else + Role = Role.User, + Content = new TextContentBlock { Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" }, + }); + messages.Add(new PromptMessage { - throw new McpException($"Unknown prompt: {request.Params?.Name}", McpErrorCode.InvalidParams); - } - - return new GetPromptResult + Role = Role.Assistant, + Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, + }); + messages.Add(new PromptMessage { - Messages = messages - }; + Role = Role.User, + Content = new ImageContentBlock + { + Data = MCP_TINY_IMAGE, + MimeType = "image/png" + } + }); + } + else + { + throw new McpException($"Unknown prompt: {request.Params?.Name}", McpErrorCode.InvalidParams); } + + return new GetPromptResult + { + Messages = messages + }; }; } private static LoggingLevel? _minimumLoggingLevel = null; - private static LoggingCapability ConfigureLogging() + private static void ConfigureLogging(McpServerOptions options) { - return new() + options.Handlers.SetLoggingLevelHandler = async (request, cancellationToken) => { - SetLoggingLevelHandler = async (request, cancellationToken) => + if (request.Params?.Level is null) { - if (request.Params?.Level is null) - { - throw new McpException("Missing required argument 'level'", McpErrorCode.InvalidParams); - } + throw new McpException("Missing required argument 'level'", McpErrorCode.InvalidParams); + } - _minimumLoggingLevel = request.Params.Level; + _minimumLoggingLevel = request.Params.Level; - return new EmptyResult(); - } + return new EmptyResult(); }; } private static readonly ConcurrentDictionary _subscribedResources = new(); - private static ResourcesCapability ConfigureResources() + private static void ConfigureResources(McpServerOptions options) { + var capabilities = options.Capabilities ??= new(); + capabilities.Resources = new() { Subscribe = true }; + List resources = []; List resourceContents = []; for (int i = 0; i < 100; ++i) @@ -361,128 +361,123 @@ private static ResourcesCapability ConfigureResources() const int pageSize = 10; - return new() + options.Handlers.ListResourceTemplatesHandler = async (request, cancellationToken) => { - ListResourceTemplatesHandler = async (request, cancellationToken) => + return new ListResourceTemplatesResult { - return new ListResourceTemplatesResult - { - ResourceTemplates = [ - new ResourceTemplate - { - UriTemplate = "test://dynamic/resource/{id}", - Name = "Dynamic Resource", - } - ] - }; - }, - - ListResourcesHandler = async (request, cancellationToken) => - { - int startIndex = 0; - if (request.Params?.Cursor is not null) - { - try + ResourceTemplates = [ + new ResourceTemplate { - var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); - startIndex = Convert.ToInt32(startIndexAsString); + UriTemplate = "test://dynamic/resource/{id}", + Name = "Dynamic Resource", } - catch (Exception e) - { - throw new McpException($"Invalid cursor: '{request.Params.Cursor}'", e, McpErrorCode.InvalidParams); - } - } - - int endIndex = Math.Min(startIndex + pageSize, resources.Count); - string? nextCursor = null; + ] + }; + }; - if (endIndex < resources.Count) + options.Handlers.ListResourcesHandler = async (request, cancellationToken) => + { + int startIndex = 0; + if (request.Params?.Cursor is not null) + { + try { - nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); + var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); + startIndex = Convert.ToInt32(startIndexAsString); } - return new ListResourcesResult + catch (Exception e) { - NextCursor = nextCursor, - Resources = resources.GetRange(startIndex, endIndex - startIndex) - }; - }, + throw new McpException($"Invalid cursor: '{request.Params.Cursor}'", e, McpErrorCode.InvalidParams); + } + } - ReadResourceHandler = async (request, cancellationToken) => + int endIndex = Math.Min(startIndex + pageSize, resources.Count); + string? nextCursor = null; + + if (endIndex < resources.Count) { - if (request.Params?.Uri is null) - { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); - } + nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); + } + return new ListResourcesResult + { + NextCursor = nextCursor, + Resources = resources.GetRange(startIndex, endIndex - startIndex) + }; + }; - if (request.Params.Uri.StartsWith("test://dynamic/resource/")) - { - var id = request.Params.Uri.Split('/').LastOrDefault(); - if (string.IsNullOrEmpty(id)) - { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - } + options.Handlers.ReadResourceHandler = async (request, cancellationToken) => + { + if (request.Params?.Uri is null) + { + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + } - return new ReadResourceResult - { - Contents = [ - new TextResourceContents - { - Uri = request.Params.Uri, - MimeType = "text/plain", - Text = $"Dynamic resource {id}: This is a plaintext resource" - } - ] - }; + if (request.Params.Uri.StartsWith("test://dynamic/resource/")) + { + var id = request.Params.Uri.Split('/').LastOrDefault(); + if (string.IsNullOrEmpty(id)) + { + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } - ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) - ?? throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - return new ReadResourceResult { - Contents = [contents] + Contents = [ + new TextResourceContents + { + Uri = request.Params.Uri, + MimeType = "text/plain", + Text = $"Dynamic resource {id}: This is a plaintext resource" + } + ] }; - }, + } - SubscribeToResourcesHandler = async (request, cancellationToken) => + ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) + ?? throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + + return new ReadResourceResult { - if (request?.Params?.Uri is null) - { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); - } - if (!request.Params.Uri.StartsWith("test://static/resource/") - && !request.Params.Uri.StartsWith("test://dynamic/resource/")) - { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - } + Contents = [contents] + }; + }; + + options.Handlers.SubscribeToResourcesHandler = async (request, cancellationToken) => + { + if (request?.Params?.Uri is null) + { + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + } + if (!request.Params.Uri.StartsWith("test://static/resource/") + && !request.Params.Uri.StartsWith("test://dynamic/resource/")) + { + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + } - _subscribedResources.TryAdd(request.Params.Uri, true); + _subscribedResources.TryAdd(request.Params.Uri, true); - return new EmptyResult(); - }, + return new EmptyResult(); + }; - UnsubscribeFromResourcesHandler = async (request, cancellationToken) => + options.Handlers.UnsubscribeFromResourcesHandler = async (request, cancellationToken) => + { + if (request?.Params?.Uri is null) { - if (request?.Params?.Uri is null) - { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); - } - if (!request.Params.Uri.StartsWith("test://static/resource/") - && !request.Params.Uri.StartsWith("test://dynamic/resource/")) - { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - } - - _subscribedResources.TryRemove(request.Params.Uri, out _); + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + } + if (!request.Params.Uri.StartsWith("test://static/resource/") + && !request.Params.Uri.StartsWith("test://dynamic/resource/")) + { + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + } - return new EmptyResult(); - }, + _subscribedResources.TryRemove(request.Params.Uri, out _); - Subscribe = true + return new EmptyResult(); }; } - private static CompletionsCapability ConfigureCompletions() + private static void ConfigureCompletions(McpServerOptions options) { List sampleResourceIds = ["1", "2", "3", "4", "5"]; Dictionary> exampleCompletions = new() @@ -491,7 +486,7 @@ private static CompletionsCapability ConfigureCompletions() {"temperature", ["0", "0.5", "0.7", "1.0"]}, }; - Func, CancellationToken, ValueTask> handler = async (request, cancellationToken) => + options.Handlers.CompleteHandler = async (request, cancellationToken) => { string[]? values; switch (request.Params?.Ref) @@ -517,8 +512,6 @@ private static CompletionsCapability ConfigureCompletions() throw new McpException($"Unknown reference type: '{request.Params?.Ref.Type}'", McpErrorCode.InvalidParams); } }; - - return new() { CompleteHandler = handler }; } static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) @@ -537,6 +530,19 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; } + private static string? ParseCliArgument(string[] args) + { + foreach (var arg in args) + { + if (arg.StartsWith("--cli-arg=")) + { + return arg["--cli-arg=".Length..]; + } + } + + return null; + } + const string MCP_TINY_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAYAAACNiR0NAAAKsGlDQ1BJQ0MgUHJvZmlsZQAASImVlwdUU+kSgOfe9JDQEiIgJfQmSCeAlBBaAAXpYCMkAUKJMRBU7MriClZURLCs6KqIgo0idizYFsWC3QVZBNR1sWDDlXeBQ9jdd9575805c+a7c+efmf+e/z9nLgCdKZDJMlF1gCxpjjwyyI8dn5DIJvUABRiY0kBdIMyWcSMiwgCTUft3+dgGyJC9YzuU69/f/1fREImzhQBIBMbJomxhFsbHMe0TyuQ5ALg9mN9kbo5siK9gzJRjDWL8ZIhTR7hviJOHGY8fjomO5GGsDUCmCQTyVACaKeZn5wpTsTw0f4ztpSKJFGPsGbyzsmaLMMbqgiUWI8N4KD8n+S95Uv+WM1mZUyBIVfLIXoaF7C/JlmUK5v+fn+N/S1amYrSGOaa0NHlwJGaxvpAHGbNDlSxNnhI+yhLRcPwwpymCY0ZZmM1LHGWRwD9UuTZzStgop0gC+co8OfzoURZnB0SNsnx2pLJWipzHHWWBfKyuIiNG6U8T85X589Ki40Y5VxI7ZZSzM6JCx2J4Sr9cEansXywN8hurG6jce1b2X/Yr4SvX5qRFByv3LhjrXyzljuXMjlf2JhL7B4zFxCjjZTl+ylqyzAhlvDgzSOnPzo1Srs3BDuTY2gjlN0wXhESMMoRBELAhBjIhB+QggECQgBTEOeJ5Q2cUeLNl8+WS1LQcNhe7ZWI2Xyq0m8B2tHd0Bhi6syNH4j1r+C4irGtjvhWVAF4nBgcHT475Qm4BHEkCoNaO+SxnAKh3A1w5JVTIc0d8Q9cJCEAFNWCCDhiACViCLTiCK3iCLwRACIRDNCTATBBCGmRhnc+FhbAMCqAI1sNmKIOdsBv2wyE4CvVwCs7DZbgOt+AePIZ26IJX0AcfYQBBEBJCRxiIDmKImCE2iCPCQbyRACQMiUQSkCQkFZEiCmQhsgIpQoqRMmQXUokcQU4g55GrSCvyEOlAepF3yFcUh9JQJqqPmqMTUQ7KRUPRaHQGmorOQfPQfHQtWopWoAfROvQ8eh29h7ajr9B+HOBUcCycEc4Wx8HxcOG4RFwKTo5bjCvEleAqcNW4Rlwz7g6uHfca9wVPxDPwbLwt3hMfjI/BC/Fz8Ivxq/Fl+P34OvxF/B18B74P/51AJ+gRbAgeBD4hnpBKmEsoIJQQ9hJqCZcI9whdhI9EIpFFtCC6EYOJCcR04gLiauJ2Yg3xHLGV2EnsJ5FIOiQbkhcpnCQg5ZAKSFtJB0lnSbdJXaTPZBWyIdmRHEhOJEvJy8kl5APkM+Tb5G7yAEWdYkbxoIRTRJT5lHWUPZRGyk1KF2WAqkG1oHpRo6np1GXUUmo19RL1CfW9ioqKsYq7ylQVicpSlVKVwypXVDpUvtA0adY0Hm06TUFbS9tHO0d7SHtPp9PN6b70RHoOfS29kn6B/oz+WZWhaqfKVxWpLlEtV61Tva36Ro2iZqbGVZuplqdWonZM7abaa3WKurk6T12gvli9XP2E+n31fg2GhoNGuEaWxmqNAxpXNXo0SZrmmgGaIs18zd2aFzQ7GTiGCYPHEDJWMPYwLjG6mESmBZPPTGcWMQ8xW5h9WppazlqxWvO0yrVOa7WzcCxzFp+VyVrHOspqY30dpz+OO048btW46nG3x33SHq/tqy3WLtSu0b6n/VWHrROgk6GzQade56kuXtdad6ruXN0dupd0X49njvccLxxfOP7o+Ed6qJ61XqTeAr3dejf0+vUN9IP0Zfpb9S/ovzZgGfgapBtsMjhj0GvIMPQ2lBhuMjxr+JKtxeayM9ml7IvsPiM9o2AjhdEuoxajAWML4xjj5cY1xk9NqCYckxSTTSZNJn2mhqaTTReaVpk+MqOYcczSzLaYNZt9MrcwjzNfaV5v3mOhbcG3yLOosnhiSbf0sZxjWWF514poxbHKsNpudcsatXaxTrMut75pg9q42khsttu0TiBMcJ8gnVAx4b4tzZZrm2tbZdthx7ILs1tuV2/3ZqLpxMSJGyY2T/xu72Kfab/H/rGDpkOIw3KHRod3jtaOQsdyx7tOdKdApyVODU5vnW2cxc47nB+4MFwmu6x0aXL509XNVe5a7drrZuqW5LbN7T6HyYngrOZccSe4+7kvcT/l/sXD1SPH46jHH562nhmeBzx7JllMEk/aM6nTy9hL4LXLq92b7Z3k/ZN3u4+Rj8Cnwue5r4mvyHevbzfXipvOPch942fvJ/er9fvE8+At4p3zx/kH+Rf6twRoBsQElAU8CzQOTA2sCuwLcglaEHQumBAcGrwh+D5fny/kV/L7QtxCFoVcDKWFRoWWhT4Psw6ThzVORieHTN44+ckUsynSKfXhEM4P3xj+NMIiYk7EyanEqRFTy6e+iHSIXBjZHMWImhV1IOpjtF/0uujHMZYxipimWLXY6bGVsZ/i/OOK49rjJ8Yvir+eoJsgSWhIJCXGJu5N7J8WMG3ztK7pLtMLprfNsJgxb8bVmbozM2eenqU2SzDrWBIhKS7pQNI3QbigQtCfzE/eltwn5Am3CF+JfEWbRL1iL3GxuDvFK6U4pSfVK3Vjam+aT1pJ2msJT1ImeZsenL4z/VNGeMa+jMHMuMyaLHJWUtYJqaY0Q3pxtsHsebNbZTayAln7HI85m+f0yUPle7OR7BnZDTlMbDi6obBU/KDoyPXOLc/9PDd27rF5GvOk827Mt56/an53XmDezwvwC4QLmhYaLVy2sGMRd9Guxcji5MVNS0yW5C/pWhq0dP8y6rKMZb8st19evPzDirgVjfn6+UvzO38I+qGqQLVAXnB/pefKnT/if5T82LLKadXWVd8LRYXXiuyLSoq+rRauvrbGYU3pmsG1KWtb1rmu27GeuF66vm2Dz4b9xRrFecWdGydvrNvE3lS46cPmWZuvljiX7NxC3aLY0l4aVtqw1XTr+q3fytLK7pX7ldds09u2atun7aLtt3f47qjeqb+zaOfXnyQ/PdgVtKuuwryiZDdxd+7uF3ti9zT/zPm5cq/u3qK9f+6T7mvfH7n/YqVbZeUBvQPrqtAqRVXvwekHbx3yP9RQbVu9q4ZVU3QYDisOvzySdKTtaOjRpmOcY9XHzY5vq2XUFtYhdfPr+urT6tsbEhpaT4ScaGr0bKw9aXdy3ymjU+WntU6vO0M9k39m8Gze2f5zsnOvz6ee72ya1fT4QvyFuxenXmy5FHrpyuXAyxeauc1nr3hdOXXV4+qJa5xr9dddr9fdcLlR+4vLL7Utri11N91uNtzyv9XYOqn1zG2f2+fv+N+5fJd/9/q9Kfda22LaHtyffr/9gehBz8PMh28f5T4aeLz0CeFJ4VP1pyXP9J5V/Gr1a027a/vpDv+OG8+jnj/uFHa++i37t29d+S/oL0q6Dbsrexx7TvUG9t56Oe1l1yvZq4HXBb9r/L7tjeWb43/4/nGjL76v66387eC71e913u/74PyhqT+i/9nHrI8Dnwo/63ze/4Xzpflr3NfugbnfSN9K/7T6s/F76Pcng1mDgzKBXDA8CuAwRVNSAN7tA6AnADCwGYI6bWSmHhZk5D9gmOA/8cjcPSyuANWYGRqNeOcADmNqvhRAzRdgaCyK9gXUyUmpo/Pv8Kw+JAbYv8K0HECi2x6tebQU/iEjc/xf+v6nBWXWv9l/AV0EC6JTIblRAAAAeGVYSWZNTQAqAAAACAAFARIAAwAAAAEAAQAAARoABQAAAAEAAABKARsABQAAAAEAAABSASgAAwAAAAEAAgAAh2kABAAAAAEAAABaAAAAAAAAAJAAAAABAAAAkAAAAAEAAqACAAQAAAABAAAAFKADAAQAAAABAAAAFAAAAAAXNii1AAAACXBIWXMAABYlAAAWJQFJUiTwAAAB82lUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNi4wLjAiPgogICA8cmRmOlJERiB4bWxuczpyZGY9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkvMDIvMjItcmRmLXN5bnRheC1ucyMiPgogICAgICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgICAgICAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyI+CiAgICAgICAgIDx0aWZmOllSZXNvbHV0aW9uPjE0NDwvdGlmZjpZUmVzb2x1dGlvbj4KICAgICAgICAgPHRpZmY6T3JpZW50YXRpb24+MTwvdGlmZjpPcmllbnRhdGlvbj4KICAgICAgICAgPHRpZmY6WFJlc29sdXRpb24+MTQ0PC90aWZmOlhSZXNvbHV0aW9uPgogICAgICAgICA8dGlmZjpSZXNvbHV0aW9uVW5pdD4yPC90aWZmOlJlc29sdXRpb25Vbml0PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KReh49gAAAjRJREFUOBGFlD2vMUEUx2clvoNCcW8hCqFAo1dKhEQpvsF9KrWEBh/ALbQ0KkInBI3SWyGPCCJEQliXgsTLefaca/bBWjvJzs6cOf/fnDkzOQJIjWm06/XKBEGgD8c6nU5VIWgBtQDPZPWtJE8O63a7LBgMMo/Hw0ql0jPjcY4RvmqXy4XMjUYDUwLtdhtmsxnYbDbI5/O0djqdFFKmsEiGZ9jP9gem0yn0ej2Yz+fg9XpfycimAD7DttstQTDKfr8Po9GIIg6Hw1Cr1RTgB+A72GAwgMPhQLBMJgNSXsFqtUI2myUo18pA6QJogefsPrLBX4QdCVatViklw+EQRFGEj88P2O12pEUGATmsXq+TaLPZ0AXgMRF2vMEqlQoJTSYTpNNpApvNZliv1/+BHDaZTAi2Wq1A3Ig0xmMej7+RcZjdbodUKkWAaDQK+GHjHPnImB88JrZIJAKFQgH2+z2BOczhcMiwRCIBgUAA+NN5BP6mj2DYff35gk6nA61WCzBn2JxO5wPM7/fLz4vD0E+OECfn8xl/0Gw2KbLxeAyLxQIsFgt8p75pDSO7h/HbpUWpewCike9WLpfB7XaDy+WCYrFI/slk8i0MnRRAUt46hPMI4vE4+Hw+ec7t9/44VgWigEeby+UgFArJWjUYOqhWG6x50rpcSfR6PVUfNOgEVRlTX0HhrZBKz4MZjUYWi8VoA+lc9H/VaRZYjBKrtXR8tlwumcFgeMWRbZpA9ORQWfVm8A/FsrLaxebd5wAAAABJRU5ErkJggg=="; } \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index d4abf81f9..cf78c0896 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -95,284 +95,273 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } const int pageSize = 10; - - options.Capabilities = new() + options.Handlers = new() { - Tools = new() + ListToolsHandler = async (request, cancellationToken) => { - ListToolsHandler = async (request, cancellationToken) => - { - return new ListToolsResult - { - Tools = - [ - new Tool - { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "The input to echo back." - } - }, - "required": ["message"] - } - """, McpJsonUtilities.DefaultOptions), - }, - new Tool - { - Name = "echoSessionId", - Description = "Echoes the session id back to the client.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object" - } - """, McpJsonUtilities.DefaultOptions), - }, - new Tool - { - Name = "sampleLLM", - Description = "Samples from an LLM using MCP's sampling feature.", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "The prompt to send to the LLM" - }, - "maxTokens": { - "type": "number", - "description": "Maximum number of tokens to generate" - } - }, - "required": ["prompt", "maxTokens"] - } - """, McpJsonUtilities.DefaultOptions), - } - ] - }; - }, - CallToolHandler = async (request, cancellationToken) => + return new ListToolsResult { - if (request.Params is null) - { - throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); - } - if (request.Params.Name == "echo") - { - if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + Tools = + [ + new Tool { - throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); - } - return new CallToolResult - { - Content = [new TextContentBlock { Text = $"Echo: {message}" }] - }; - } - else if (request.Params.Name == "echoSessionId") - { - return new CallToolResult + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back." + } + }, + "required": ["message"] + } + """, McpJsonUtilities.DefaultOptions), + }, + new Tool { - Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] - }; - } - else if (request.Params.Name == "sampleLLM") - { - if (request.Params.Arguments is null || - !request.Params.Arguments.TryGetValue("prompt", out var prompt) || - !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + Name = "echoSessionId", + Description = "Echoes the session id back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object" + } + """, McpJsonUtilities.DefaultOptions), + }, + new Tool { - throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); + Name = "sampleLLM", + Description = "Samples from an LLM using MCP's sampling feature.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to send to the LLM" + }, + "maxTokens": { + "type": "number", + "description": "Maximum number of tokens to generate" + } + }, + "required": ["prompt", "maxTokens"] + } + """, McpJsonUtilities.DefaultOptions), } - var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), - cancellationToken); - - return new CallToolResult - { - Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] - }; - } - else + ] + }; + }, + CallToolHandler = async (request, cancellationToken) => + { + if (request.Params is null) + { + throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); + } + if (request.Params.Name == "echo") + { + if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) { - throw new McpException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); + throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); } + return new CallToolResult + { + Content = [new TextContentBlock { Text = $"Echo: {message}" }] + }; } - }, - Resources = new() - { - ListResourceTemplatesHandler = async (request, cancellationToken) => + else if (request.Params.Name == "echoSessionId") { - - return new ListResourceTemplatesResult + return new CallToolResult { - ResourceTemplates = [ - new ResourceTemplate - { - UriTemplate = "test://dynamic/resource/{id}", - Name = "Dynamic Resource", - } - ] + Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] }; - }, - - ListResourcesHandler = async (request, cancellationToken) => + } + else if (request.Params.Name == "sampleLLM") { - int startIndex = 0; - var requestParams = request.Params ?? new(); - if (requestParams.Cursor is not null) + if (request.Params.Arguments is null || + !request.Params.Arguments.TryGetValue("prompt", out var prompt) || + !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) { - try - { - var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(requestParams.Cursor)); - startIndex = Convert.ToInt32(startIndexAsString); - } - catch (Exception e) - { - throw new McpException($"Invalid cursor: '{requestParams.Cursor}'", e, McpErrorCode.InvalidParams); - } + throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } + var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), + cancellationToken); - int endIndex = Math.Min(startIndex + pageSize, resources.Count); - string? nextCursor = null; - - if (endIndex < resources.Count) + return new CallToolResult { - nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); - } + Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] + }; + } + else + { + throw new McpException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); + } + }, + ListResourceTemplatesHandler = async (request, cancellationToken) => + { - return new ListResourcesResult + return new ListResourceTemplatesResult + { + ResourceTemplates = [ + new ResourceTemplate { - NextCursor = nextCursor, - Resources = resources.GetRange(startIndex, endIndex - startIndex) - }; - }, - ReadResourceHandler = async (request, cancellationToken) => + UriTemplate = "test://dynamic/resource/{id}", + Name = "Dynamic Resource", + } + ] + }; + }, + ListResourcesHandler = async (request, cancellationToken) => + { + int startIndex = 0; + var requestParams = request.Params ?? new(); + if (requestParams.Cursor is not null) { - if (request.Params?.Uri is null) + try { - throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(requestParams.Cursor)); + startIndex = Convert.ToInt32(startIndexAsString); } - - if (request.Params.Uri.StartsWith("test://dynamic/resource/")) + catch (Exception e) { - var id = request.Params.Uri.Split('/').LastOrDefault(); - if (string.IsNullOrEmpty(id)) - { - throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - } - - return new ReadResourceResult - { - Contents = [ - new TextResourceContents - { - Uri = request.Params.Uri, - MimeType = "text/plain", - Text = $"Dynamic resource {id}: This is a plaintext resource" - } - ] - }; + throw new McpException($"Invalid cursor: '{requestParams.Cursor}'", e, McpErrorCode.InvalidParams); } + } - ResourceContents? contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? - throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + int endIndex = Math.Min(startIndex + pageSize, resources.Count); + string? nextCursor = null; - return new ReadResourceResult - { - Contents = [contents] - }; + if (endIndex < resources.Count) + { + nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } + + return new ListResourcesResult + { + NextCursor = nextCursor, + Resources = resources.GetRange(startIndex, endIndex - startIndex) + }; }, - Prompts = new() + ReadResourceHandler = async (request, cancellationToken) => { - ListPromptsHandler = async (request, cancellationToken) => + if (request.Params?.Uri is null) { - return new ListPromptsResult + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); + } + + if (request.Params.Uri.StartsWith("test://dynamic/resource/")) + { + var id = request.Params.Uri.Split('/').LastOrDefault(); + if (string.IsNullOrEmpty(id)) { - Prompts = [ - new Prompt - { - Name = "simple_prompt", - Description = "A prompt without arguments" - }, - new Prompt - { - Name = "complex_prompt", - Description = "A prompt with arguments", - Arguments = - [ - new PromptArgument - { - Name = "temperature", - Description = "Temperature setting", - Required = true - }, - new PromptArgument + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + } + + return new ReadResourceResult + { + Contents = [ + new TextResourceContents { - Name = "style", - Description = "Output style", - Required = false + Uri = request.Params.Uri, + MimeType = "text/plain", + Text = $"Dynamic resource {id}: This is a plaintext resource" } - ], - } ] }; - }, - GetPromptHandler = async (request, cancellationToken) => + } + + ResourceContents? contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? + throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); + + return new ReadResourceResult { - if (request.Params is null) - { - throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); - } - List messages = new(); - if (request.Params.Name == "simple_prompt") - { - messages.Add(new PromptMessage - { - Role = Role.User, - Content = new TextContentBlock { Text = "This is a simple prompt without arguments." }, - }); - } - else if (request.Params.Name == "complex_prompt") - { - string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; - string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; - messages.Add(new PromptMessage - { - Role = Role.User, - Content = new TextContentBlock { Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" }, - }); - messages.Add(new PromptMessage + Contents = [contents] + }; + }, + ListPromptsHandler = async (request, cancellationToken) => + { + return new ListPromptsResult + { + Prompts = [ + new Prompt { - Role = Role.Assistant, - Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, - }); - messages.Add(new PromptMessage + Name = "simple_prompt", + Description = "A prompt without arguments" + }, + new Prompt { - Role = Role.User, - Content = new ImageContentBlock - { - Data = MCP_TINY_IMAGE, - MimeType = "image/png" - } - }); - } - else + Name = "complex_prompt", + Description = "A prompt with arguments", + Arguments = + [ + new PromptArgument + { + Name = "temperature", + Description = "Temperature setting", + Required = true + }, + new PromptArgument + { + Name = "style", + Description = "Output style", + Required = false + } + ], + } + ] + }; + }, + GetPromptHandler = async (request, cancellationToken) => + { + if (request.Params is null) + { + throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); + } + List messages = new(); + if (request.Params.Name == "simple_prompt") + { + messages.Add(new PromptMessage { - throw new McpException($"Unknown prompt: {request.Params.Name}", McpErrorCode.InvalidParams); - } - - return new GetPromptResult + Role = Role.User, + Content = new TextContentBlock { Text = "This is a simple prompt without arguments." }, + }); + } + else if (request.Params.Name == "complex_prompt") + { + string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; + string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; + messages.Add(new PromptMessage { - Messages = messages - }; + Role = Role.User, + Content = new TextContentBlock { Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" }, + }); + messages.Add(new PromptMessage + { + Role = Role.Assistant, + Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, + }); + messages.Add(new PromptMessage + { + Role = Role.User, + Content = new ImageContentBlock + { + Data = MCP_TINY_IMAGE, + MimeType = "image/png" + } + }); } - }, + else + { + throw new McpException($"Unknown prompt: {request.Params.Name}", McpErrorCode.InvalidParams); + } + + return new GetPromptResult + { + Messages = messages + }; + } }; } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs similarity index 79% rename from tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs rename to tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs index 7516a2186..0eb84262b 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs @@ -1,26 +1,24 @@ -using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -using Moq; using System.IO.Pipelines; using System.Text.Json; using System.Threading.Channels; namespace ModelContextProtocol.Tests.Client; -public class McpClientFactoryTests +public class McpClientCreationTests { [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("clientTransport", () => McpClientFactory.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync("clientTransport", () => McpClient.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] public async Task CreateAsync_NopTransport_ReturnsClient() { // Act - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); @@ -39,7 +37,7 @@ public async Task Cancellation_ThrowsCancellationException(bool preCanceled) cts.Cancel(); } - Task t = McpClientFactory.CreateAsync( + Task t = McpClient.CreateAsync( new StreamClientTransport(new Pipe().Writer.AsStream(), new Pipe().Reader.AsStream()), cancellationToken: cts.Token); if (!preCanceled) @@ -65,29 +63,28 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) { Capabilities = new ClientCapabilities { - Sampling = new SamplingCapability - { - SamplingHandler = async (c, p, t) => - new CreateMessageResult - { - Content = new TextContentBlock { Text = "result" }, - Model = "test-model", - Role = Role.User, - StopReason = "endTurn" - }, - }, Roots = new RootsCapability { ListChanged = true, - RootsHandler = async (t, r) => new ListRootsResult { Roots = [] }, + } + }, + Handlers = new() + { + RootsHandler = async (t, r) => new ListRootsResult { Roots = [] }, + SamplingHandler = async (c, p, t) => new CreateMessageResult + { + Content = new TextContentBlock { Text = "result" }, + Model = "test-model", + Role = Role.User, + StopReason = "endTurn" } } }; var clientTransport = (IClientTransport)Activator.CreateInstance(transportType)!; - IMcpClient? client = null; + McpClient? client = null; - var actionTask = McpClientFactory.CreateAsync(clientTransport, clientOptions, new Mock().Object, CancellationToken.None); + var actionTask = McpClient.CreateAsync(clientTransport, clientOptions, loggerFactory: null, CancellationToken.None); // Act if (clientTransport is FailureTransport) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index e3d7ce44c..f4e6062de 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,471 +1,387 @@ -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; using Moq; using System.Text.Json; -using System.Text.Json.Serialization.Metadata; -using System.Threading.Channels; -namespace ModelContextProtocol.Tests.Client; +namespace ModelContextProtocol.Tests; -public class McpClientExtensionsTests : ClientServerTestBase +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpClientExtensionsTests { - public McpClientExtensionsTests(ITestOutputHelper outputHelper) - : base(outputHelper) + [Fact] + public async Task PingAsync_Throws_When_Not_McpClient() { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.PingAsync(TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.PingAsync' instead", ex.Message); } - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + [Fact] + public async Task GetPromptAsync_Throws_When_Not_McpClient() { - for (int f = 0; f < 10; f++) - { - string name = $"Method{f}"; - mcpServerBuilder.WithTools([McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })]); - } - mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })]); - mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })]); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + "name", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.GetPromptAsync' instead", ex.Message); } - [Theory] - [InlineData(null, null)] - [InlineData(0.7f, 50)] - [InlineData(1.0f, 100)] - public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens) + [Fact] + public async Task CallToolAsync_Throws_When_Not_McpClient() { - // Arrange - var mockChatClient = new Mock(); - var requestParams = new CreateMessageRequestParams - { - Messages = - [ - new SamplingMessage - { - Role = Role.User, - Content = new TextContentBlock { Text = "Hello" } - } - ], - Temperature = temperature, - MaxTokens = maxTokens, - }; - - var cancellationToken = CancellationToken.None; - var expectedResponse = new[] { - new ChatResponseUpdate - { - ModelId = "test-model", - FinishReason = ChatFinishReason.Stop, - Role = ChatRole.Assistant, - Contents = - [ - new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" } - ] - } - }.ToAsyncEnumerable(); - - mockChatClient - .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) - .Returns(expectedResponse); - - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); - - // Act - var result = await handler(requestParams, Mock.Of>(), cancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal("Hello, World!", (result.Content as TextContentBlock)?.Text); - Assert.Equal("test-model", result.Model); - Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("endTurn", result.StopReason); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + "tool", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.CallToolAsync' instead", ex.Message); } [Fact] - public async Task CreateSamplingHandler_ShouldHandleImageMessages() + public async Task ListResourcesAsync_Throws_When_Not_McpClient() { - // Arrange - var mockChatClient = new Mock(); - var requestParams = new CreateMessageRequestParams - { - Messages = - [ - new SamplingMessage - { - Role = Role.User, - Content = new ImageContentBlock - { - MimeType = "image/png", - Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) - } - } - ], - MaxTokens = 100 - }; - - const string expectedData = "SGVsbG8sIFdvcmxkIQ=="; - var cancellationToken = CancellationToken.None; - var expectedResponse = new[] { - new ChatResponseUpdate - { - ModelId = "test-model", - FinishReason = ChatFinishReason.Stop, - Role = ChatRole.Assistant, - Contents = - [ - new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" } - ] - } - }.ToAsyncEnumerable(); - - mockChatClient - .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) - .Returns(expectedResponse); - - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); - - // Act - var result = await handler(requestParams, Mock.Of>(), cancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal(expectedData, (result.Content as ImageContentBlock)?.Data); - Assert.Equal("test-model", result.Model); - Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("endTurn", result.StopReason); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListResourcesAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListResourcesAsync' instead", ex.Message); } [Fact] - public async Task CreateSamplingHandler_ShouldHandleResourceMessages() + public void EnumerateResourcesAsync_Throws_When_Not_McpClient() { - // Arrange - const string data = "SGVsbG8sIFdvcmxkIQ=="; - string content = $"data:application/octet-stream;base64,{data}"; - var mockChatClient = new Mock(); - var resource = new BlobResourceContents - { - Blob = data, - MimeType = "application/octet-stream", - Uri = "data:application/octet-stream" - }; - - var requestParams = new CreateMessageRequestParams - { - Messages = - [ - new SamplingMessage - { - Role = Role.User, - Content = new EmbeddedResourceBlock { Resource = resource }, - } - ], - MaxTokens = 100 - }; - - var cancellationToken = CancellationToken.None; - var expectedResponse = new[] { - new ChatResponseUpdate - { - ModelId = "test-model", - FinishReason = ChatFinishReason.Stop, - AuthorName = "bot", - Role = ChatRole.Assistant, - Contents = - [ - resource.ToAIContent() - ] - } - }.ToAsyncEnumerable(); - - mockChatClient - .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) - .Returns(expectedResponse); - - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); - - // Act - var result = await handler(requestParams, Mock.Of>(), cancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal("test-model", result.Model); - Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("endTurn", result.StopReason); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateResourcesAsync' instead", ex.Message); } [Fact] - public async Task ListToolsAsync_AllToolsReturned() + public async Task SubscribeToResourceAsync_String_Throws_When_Not_McpClient() { - await using IMcpClient client = await CreateMcpClientForServer(); - - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); - var echo = tools.Single(t => t.Name == "Method4"); - var result = await echo.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); - Assert.Contains("Method4 Result 42", result?.ToString()); - - var valuesSetViaAttr = tools.Single(t => t.Name == "ValuesSetViaAttr"); - Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.Title); - Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.ReadOnlyHint); - Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.IdempotentHint); - Assert.False(valuesSetViaAttr.ProtocolTool.Annotations?.DestructiveHint); - Assert.True(valuesSetViaAttr.ProtocolTool.Annotations?.OpenWorldHint); - - var valuesSetViaOptions = tools.Single(t => t.Name == "ValuesSetViaOptions"); - Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.Title); - Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.ReadOnlyHint); - Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.IdempotentHint); - Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.DestructiveHint); - Assert.False(valuesSetViaOptions.ProtocolTool.Annotations?.OpenWorldHint); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.SubscribeToResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.SubscribeToResourceAsync' instead", ex.Message); } [Fact] - public async Task EnumerateToolsAsync_AllToolsReturned() + public async Task SubscribeToResourceAsync_Uri_Throws_When_Not_McpClient() { - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) - { - if (tool.Name == "Method4") - { - var result = await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); - Assert.Contains("Method4 Result 42", result?.ToString()); - return; - } - } + var ex = await Assert.ThrowsAsync(async () => await client.SubscribeToResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.SubscribeToResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_String_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; - Assert.Fail("Couldn't find target method"); + var ex = await Assert.ThrowsAsync(async () => await client.UnsubscribeFromResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.UnsubscribeFromResourceAsync' instead", ex.Message); } [Fact] - public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() + public async Task UnsubscribeFromResourceAsync_Uri_Throws_When_Not_McpClient() { - JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); - bool hasTools = false; - - await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) - { - Assert.Same(options, tool.JsonSerializerOptions); - hasTools = true; - } - - foreach (var tool in await client.ListToolsAsync(options, TestContext.Current.CancellationToken)) - { - Assert.Same(options, tool.JsonSerializerOptions); - } - - Assert.True(hasTools); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.UnsubscribeFromResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.UnsubscribeFromResourceAsync' instead", ex.Message); } [Fact] - public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() + public async Task ReadResourceAsync_String_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); - await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); } [Fact] - public async Task SendRequestAsync_HonorsJsonSerializerOptions() + public async Task ReadResourceAsync_Uri_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - await Assert.ThrowsAsync(async () => await client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); } [Fact] - public async Task SendNotificationAsync_HonorsJsonSerializerOptions() + public async Task ReadResourceAsync_Template_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + "mcp://resource/{id}", new Dictionary { ["id"] = 1 }, TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); } [Fact] - public async Task GetPromptsAsync_HonorsJsonSerializerOptions() + public async Task CompleteAsync_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; + var reference = new PromptReference { Name = "prompt" }; - await Assert.ThrowsAsync(async () => await client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.CompleteAsync( + reference, "arg", "val", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.CompleteAsync' instead", ex.Message); } [Fact] - public async Task WithName_ChangesToolName() + public async Task ListToolsAsync_Throws_When_Not_McpClient() { - JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); - var originalName = tool.Name; - var renamedTool = tool.WithName("RenamedTool"); + var ex = await Assert.ThrowsAsync(async () => await client.ListToolsAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListToolsAsync' instead", ex.Message); + } - Assert.NotNull(renamedTool); - Assert.Equal("RenamedTool", renamedTool.Name); - Assert.Equal(originalName, tool?.Name); + [Fact] + public void EnumerateToolsAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateToolsAsync' instead", ex.Message); } [Fact] - public async Task WithDescription_ChangesToolDescription() + public async Task ListPromptsAsync_Throws_When_Not_McpClient() { - JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); - var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); - var originalDescription = tool?.Description; - var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); - Assert.NotNull(redescribedTool); - Assert.Equal("ToolWithNewDescription", redescribedTool.Description); - Assert.Equal(originalDescription, tool?.Description); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListPromptsAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListPromptsAsync' instead", ex.Message); } [Fact] - public async Task WithProgress_ProgressReported() + public void EnumeratePromptsAsync_Throws_When_Not_McpClient() { - const int TotalNotifications = 3; - int remainingProgress = TotalNotifications; - TaskCompletionSource allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously); + var client = new Mock(MockBehavior.Strict).Object; - Server.ServerOptions.Capabilities?.Tools?.ToolCollection?.Add(McpServerTool.Create(async (IProgress progress) => - { - for (int i = 0; i < TotalNotifications; i++) + var ex = Assert.Throws(() => client.EnumeratePromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumeratePromptsAsync' instead", ex.Message); + } + + [Fact] + public async Task ListResourceTemplatesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListResourceTemplatesAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListResourceTemplatesAsync' instead", ex.Message); + } + + [Fact] + public void EnumerateResourceTemplatesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateResourceTemplatesAsync' instead", ex.Message); + } + + [Fact] + public async Task PingAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" }); - await Task.Delay(1); - } + Result = JsonSerializer.SerializeToNode(new object(), McpJsonUtilities.DefaultOptions), + }); - await allProgressReceived.Task; + IMcpClient client = mockClient.Object; - return 42; - }, new() { Name = "ProgressReporter" })); + await client.PingAsync(TestContext.Current.CancellationToken); - await using IMcpClient client = await CreateMcpClientForServer(); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task GetPromptAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; - var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); + var resultPayload = new GetPromptResult { Messages = [new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }; - IProgress progress = new SynchronousProgress(value => - { - Assert.True(value.Progress >= 0 && value.Progress <= 100); - Assert.Equal("making progress", value.Message); - if (Interlocked.Decrement(ref remainingProgress) == 0) + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - allProgressReceived.SetResult(true); - } - }); + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); - Assert.Throws("progress", () => tool.WithProgress(null!)); + IMcpClient client = mockClient.Object; - var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Contains("42", result?.ToString()); + var result = await client.GetPromptAsync("name", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("hi", Assert.IsType(result.Messages[0].Content).Text); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } - private sealed class SynchronousProgress(Action callback) : IProgress + [Fact] + public async Task CallToolAsync_Forwards_To_McpClient_SendRequestAsync() { - public void Report(ProgressNotificationValue value) => callback(value); + var mockClient = new Mock { CallBase = true }; + + var callResult = new CallToolResult { Content = [new TextContentBlock { Text = "ok" }] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(callResult, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.CallToolAsync("tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("ok", Assert.IsType(result.Content[0]).Text); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } [Fact] - public async Task AsClientLoggerProvider_MessagesSentToClient() + public async Task SubscribeToResourceAsync_Forwards_To_McpClient_SendRequestAsync() { - await using IMcpClient client = await CreateMcpClientForServer(); - - ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); - Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); - - ILogger logger = loggerProvider.CreateLogger("TestLogger"); - Assert.NotNull(logger); - - Assert.Null(logger.BeginScope("")); - - Assert.Null(Server.LoggingLevel); - Assert.False(logger.IsEnabled(LogLevel.Trace)); - Assert.False(logger.IsEnabled(LogLevel.Debug)); - Assert.False(logger.IsEnabled(LogLevel.Information)); - Assert.False(logger.IsEnabled(LogLevel.Warning)); - Assert.False(logger.IsEnabled(LogLevel.Error)); - Assert.False(logger.IsEnabled(LogLevel.Critical)); - - await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken); - - DateTime start = DateTime.UtcNow; - while (Server.LoggingLevel is null) - { - await Task.Delay(1, TestContext.Current.CancellationToken); - Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set"); - } - - Assert.Equal(LoggingLevel.Info, Server.LoggingLevel); - Assert.False(logger.IsEnabled(LogLevel.Trace)); - Assert.False(logger.IsEnabled(LogLevel.Debug)); - Assert.True(logger.IsEnabled(LogLevel.Information)); - Assert.True(logger.IsEnabled(LogLevel.Warning)); - Assert.True(logger.IsEnabled(LogLevel.Error)); - Assert.True(logger.IsEnabled(LogLevel.Critical)); - - List data = []; - var channel = Channel.CreateUnbounded(); - - await using (client.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification, - (notification, cancellationToken) => + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions))); - return default; - })) - { - logger.LogTrace("Trace {Message}", "message"); - logger.LogDebug("Debug {Message}", "message"); - logger.LogInformation("Information {Message}", "message"); - logger.LogWarning("Warning {Message}", "message"); - logger.LogError("Error {Message}", "message"); - logger.LogCritical("Critical {Message}", "message"); - - for (int i = 0; i < 4; i++) + Result = JsonSerializer.SerializeToNode(new EmptyResult(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.SubscribeToResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(new EmptyResult(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.UnsubscribeFromResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task CompleteAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var completion = new Completion { Values = ["one", "two"] }; + var resultPayload = new CompleteResult { Completion = completion }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.CompleteAsync(new PromptReference { Name = "p" }, "arg", "val", TestContext.Current.CancellationToken); + + Assert.Contains("one", result.Completion.Values); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_String_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_Uri_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - var m = await channel.Reader.ReadAsync(TestContext.Current.CancellationToken); - Assert.NotNull(m); - Assert.NotNull(m.Data); - - Assert.Equal("TestLogger", m.Logger); - - string ? s = JsonSerializer.Deserialize(m.Data.Value, McpJsonUtilities.DefaultOptions); - Assert.NotNull(s); - - if (s.Contains("Information")) - { - Assert.Equal(LoggingLevel.Info, m.Level); - } - else if (s.Contains("Warning")) - { - Assert.Equal(LoggingLevel.Warning, m.Level); - } - else if (s.Contains("Error")) - { - Assert.Equal(LoggingLevel.Error, m.Level); - } - else if (s.Contains("Critical")) - { - Assert.Equal(LoggingLevel.Critical, m.Level); - } - - data.Add(s); - } - - channel.Writer.Complete(); - } - - Assert.False(await channel.Reader.WaitToReadAsync(TestContext.Current.CancellationToken)); - Assert.Equal( - [ - "Critical message", - "Error message", - "Information message", - "Warning message", - ], - data.OrderBy(s => s)); + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync(new Uri("mcp://resource/1"), TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_Template_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync("mcp://resource/{id}", new Dictionary { ["id"] = 1 }, TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } -} \ No newline at end of file +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs index 48c3c370d..2599d7485 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs @@ -73,7 +73,7 @@ public static IEnumerable UriTemplate_InputsProduceExpectedOutputs_Mem public async Task UriTemplate_InputsProduceExpectedOutputs( IReadOnlyDictionary variables, string uriTemplate, object expected) { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.ReadResourceAsync(uriTemplate, variables, TestContext.Current.CancellationToken); Assert.NotNull(result); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs new file mode 100644 index 000000000..3f5b80ae7 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -0,0 +1,480 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Moq; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading.Channels; + +namespace ModelContextProtocol.Tests.Client; + +public class McpClientTests : ClientServerTestBase +{ + public McpClientTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + for (int f = 0; f < 10; f++) + { + string name = $"Method{f}"; + mcpServerBuilder.WithTools([McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })]); + } + mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })]); + mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })]); + } + + [Theory] + [InlineData(null, null)] + [InlineData(0.7f, 50)] + [InlineData(1.0f, 100)] + public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens) + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new TextContentBlock { Text = "Hello" } + } + ], + Temperature = temperature, + MaxTokens = maxTokens, + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("Hello, World!", (result.Content as TextContentBlock)?.Text); + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleImageMessages() + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new ImageContentBlock + { + MimeType = "image/png", + Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) + } + } + ], + MaxTokens = 100 + }; + + const string expectedData = "SGVsbG8sIFdvcmxkIQ=="; + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(expectedData, (result.Content as ImageContentBlock)?.Data); + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleResourceMessages() + { + // Arrange + const string data = "SGVsbG8sIFdvcmxkIQ=="; + string content = $"data:application/octet-stream;base64,{data}"; + var mockChatClient = new Mock(); + var resource = new BlobResourceContents + { + Blob = data, + MimeType = "application/octet-stream", + Uri = "data:application/octet-stream" + }; + + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new EmbeddedResourceBlock { Resource = resource }, + } + ], + MaxTokens = 100 + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + AuthorName = "bot", + Role = ChatRole.Assistant, + Contents = + [ + resource.ToAIContent() + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task ListToolsAsync_AllToolsReturned() + { + await using McpClient client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(12, tools.Count); + var echo = tools.Single(t => t.Name == "Method4"); + var result = await echo.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); + Assert.Contains("Method4 Result 42", result?.ToString()); + + var valuesSetViaAttr = tools.Single(t => t.Name == "ValuesSetViaAttr"); + Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.Title); + Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.ReadOnlyHint); + Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.IdempotentHint); + Assert.False(valuesSetViaAttr.ProtocolTool.Annotations?.DestructiveHint); + Assert.True(valuesSetViaAttr.ProtocolTool.Annotations?.OpenWorldHint); + + var valuesSetViaOptions = tools.Single(t => t.Name == "ValuesSetViaOptions"); + Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.Title); + Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.ReadOnlyHint); + Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.IdempotentHint); + Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.DestructiveHint); + Assert.False(valuesSetViaOptions.ProtocolTool.Annotations?.OpenWorldHint); + } + + [Fact] + public async Task EnumerateToolsAsync_AllToolsReturned() + { + await using McpClient client = await CreateMcpClientForServer(); + + await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) + { + if (tool.Name == "Method4") + { + var result = await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); + Assert.Contains("Method4 Result 42", result?.ToString()); + return; + } + } + + Assert.Fail("Couldn't find target method"); + } + + [Fact] + public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + await using McpClient client = await CreateMcpClientForServer(); + bool hasTools = false; + + await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) + { + Assert.Same(options, tool.JsonSerializerOptions); + hasTools = true; + } + + foreach (var tool in await client.ListToolsAsync(options, TestContext.Current.CancellationToken)) + { + Assert.Same(options, tool.JsonSerializerOptions); + } + + Assert.True(hasTools); + } + + [Fact] + public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); + await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task SendRequestAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(async () => await client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task SendNotificationAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task GetPromptsAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(async () => await client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task WithName_ChangesToolName() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + await using McpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); + var originalName = tool.Name; + var renamedTool = tool.WithName("RenamedTool"); + + Assert.NotNull(renamedTool); + Assert.Equal("RenamedTool", renamedTool.Name); + Assert.Equal(originalName, tool?.Name); + } + + [Fact] + public async Task WithDescription_ChangesToolDescription() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + await using McpClient client = await CreateMcpClientForServer(); + var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); + var originalDescription = tool?.Description; + var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); + Assert.NotNull(redescribedTool); + Assert.Equal("ToolWithNewDescription", redescribedTool.Description); + Assert.Equal(originalDescription, tool?.Description); + } + + [Fact] + public async Task WithProgress_ProgressReported() + { + const int TotalNotifications = 3; + int remainingProgress = TotalNotifications; + TaskCompletionSource allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously); + + Server.ServerOptions.ToolCollection?.Add(McpServerTool.Create(async (IProgress progress) => + { + for (int i = 0; i < TotalNotifications; i++) + { + progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" }); + await Task.Delay(1); + } + + await allProgressReceived.Task; + + return 42; + }, new() { Name = "ProgressReporter" })); + + await using McpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); + + IProgress progress = new SynchronousProgress(value => + { + Assert.True(value.Progress >= 0 && value.Progress <= 100); + Assert.Equal("making progress", value.Message); + if (Interlocked.Decrement(ref remainingProgress) == 0) + { + allProgressReceived.SetResult(true); + } + }); + + Assert.Throws("progress", () => tool.WithProgress(null!)); + + var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Contains("42", result?.ToString()); + } + + private sealed class SynchronousProgress(Action callback) : IProgress + { + public void Report(ProgressNotificationValue value) => callback(value); + } + + [Fact] + public async Task AsClientLoggerProvider_MessagesSentToClient() + { + await using McpClient client = await CreateMcpClientForServer(); + + ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); + Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); + + ILogger logger = loggerProvider.CreateLogger("TestLogger"); + Assert.NotNull(logger); + + Assert.Null(logger.BeginScope("")); + + Assert.Null(Server.LoggingLevel); + Assert.False(logger.IsEnabled(LogLevel.Trace)); + Assert.False(logger.IsEnabled(LogLevel.Debug)); + Assert.False(logger.IsEnabled(LogLevel.Information)); + Assert.False(logger.IsEnabled(LogLevel.Warning)); + Assert.False(logger.IsEnabled(LogLevel.Error)); + Assert.False(logger.IsEnabled(LogLevel.Critical)); + + await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken); + + DateTime start = DateTime.UtcNow; + while (Server.LoggingLevel is null) + { + await Task.Delay(1, TestContext.Current.CancellationToken); + Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set"); + } + + Assert.Equal(LoggingLevel.Info, Server.LoggingLevel); + Assert.False(logger.IsEnabled(LogLevel.Trace)); + Assert.False(logger.IsEnabled(LogLevel.Debug)); + Assert.True(logger.IsEnabled(LogLevel.Information)); + Assert.True(logger.IsEnabled(LogLevel.Warning)); + Assert.True(logger.IsEnabled(LogLevel.Error)); + Assert.True(logger.IsEnabled(LogLevel.Critical)); + + List data = []; + var channel = Channel.CreateUnbounded(); + + await using (client.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification, + (notification, cancellationToken) => + { + Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions))); + return default; + })) + { + logger.LogTrace("Trace {Message}", "message"); + logger.LogDebug("Debug {Message}", "message"); + logger.LogInformation("Information {Message}", "message"); + logger.LogWarning("Warning {Message}", "message"); + logger.LogError("Error {Message}", "message"); + logger.LogCritical("Critical {Message}", "message"); + + for (int i = 0; i < 4; i++) + { + var m = await channel.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.NotNull(m); + Assert.NotNull(m.Data); + + Assert.Equal("TestLogger", m.Logger); + + string ? s = JsonSerializer.Deserialize(m.Data.Value, McpJsonUtilities.DefaultOptions); + Assert.NotNull(s); + + if (s.Contains("Information")) + { + Assert.Equal(LoggingLevel.Info, m.Level); + } + else if (s.Contains("Warning")) + { + Assert.Equal(LoggingLevel.Warning, m.Level); + } + else if (s.Contains("Error")) + { + Assert.Equal(LoggingLevel.Error, m.Level); + } + else if (s.Contains("Critical")) + { + Assert.Equal(LoggingLevel.Critical, m.Level); + } + + data.Add(s); + } + + channel.Writer.Complete(); + } + + Assert.False(await channel.Reader.WaitToReadAsync(TestContext.Current.CancellationToken)); + Assert.Equal( + [ + "Critical message", + "Error message", + "Information message", + "Warning message", + ], + data.OrderBy(s => s)); + } + + [Theory] + [InlineData(null)] + [InlineData("2025-03-26")] + public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) + { + await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = protocolVersion }); + Assert.Equal(protocolVersion ?? "2025-06-18", client.NegotiatedProtocolVersion); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index ebc7171e2..6f625866a 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -41,8 +41,8 @@ public void Initialize(ILoggerFactory loggerFactory) _loggerFactory = loggerFactory; } - public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => - McpClientFactory.CreateAsync(new StdioClientTransport(clientId switch + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => + McpClient.CreateAsync(new StdioClientTransport(clientId switch { "everything" => EverythingServerTransportOptions, "test_server" => TestServerTransportOptions, diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 3e4361a57..20c6f374b 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -52,8 +52,12 @@ public async Task Connect_ShouldProvideServerFields(string clientId) // Assert Assert.NotNull(client.ServerCapabilities); Assert.NotNull(client.ServerInfo); + Assert.NotNull(client.NegotiatedProtocolVersion); + if (clientId != "everything") // Note: Comment the below assertion back when the everything server is updated to provide instructions + { Assert.NotNull(client.ServerInstructions); + } Assert.Null(client.SessionId); } @@ -272,7 +276,7 @@ public async Task SubscribeResource_Stdio() TaskCompletionSource tcs = new(); await using var client = await _fixture.CreateClientAsync(clientId, new() { - Capabilities = new() + Handlers = new() { NotificationHandlers = [ @@ -302,7 +306,7 @@ public async Task UnsubscribeResource_Stdio() TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId, new() { - Capabilities = new() + Handlers = new() { NotificationHandlers = [ @@ -370,22 +374,19 @@ public async Task Sampling_Stdio(string clientId) int samplingHandlerCalls = 0; await using var client = await _fixture.CreateClientAsync(clientId, new() { - Capabilities = new() + Handlers = new() { - Sampling = new() + SamplingHandler = async (_, _, _) => { - SamplingHandler = async (_, _, _) => + samplingHandlerCalls++; + return new CreateMessageResult { - samplingHandlerCalls++; - return new CreateMessageResult - { - Model = "test-model", - Role = Role.Assistant, - Content = new TextContentBlock { Text = "Test response" }, - }; - }, - }, - }, + Model = "test-model", + Role = Role.Assistant, + Content = new TextContentBlock { Text = "Test response" }, + }; + } + } }); // Call the server's sampleLLM tool which should trigger our sampling handler @@ -471,10 +472,10 @@ public async Task CallTool_Stdio_MemoryServer() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new StdioClientTransport(stdioOptions), - clientOptions, - loggerFactory: LoggerFactory, + clientOptions, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // act @@ -495,7 +496,7 @@ public async Task CallTool_Stdio_MemoryServer() public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() { // Get the MCP client and tools from it. - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new StdioClientTransport(_fixture.EverythingServerTransportOptions), cancellationToken: TestContext.Current.CancellationToken); var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -527,15 +528,12 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() var samplingHandler = new OpenAIClient(s_openAIKey).GetChatClient("gpt-4o-mini") .AsIChatClient() .CreateSamplingHandler(); - await using var client = await McpClientFactory.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() + await using var client = await McpClient.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() { - Capabilities = new() + Handlers = new() { - Sampling = new() - { - SamplingHandler = samplingHandler, - }, - }, + SamplingHandler = samplingHandler + } }, cancellationToken: TestContext.Current.CancellationToken); var result = await client.CallToolAsync("sampleLLM", new Dictionary() @@ -557,7 +555,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId, new() { - Capabilities = new() + Handlers = new() { NotificationHandlers = [ diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index ec1c85107..ff04c3b19 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -20,7 +20,8 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); + sc.AddLogging(); + sc.AddSingleton(XunitLoggerProvider); _builder = sc .AddMcpServer() .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); @@ -28,11 +29,11 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) ServiceProvider = sc.BuildServiceProvider(validateScopes: true); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - Server = ServiceProvider.GetRequiredService(); + Server = ServiceProvider.GetRequiredService(); _serverTask = Server.RunAsync(_cts.Token); } - protected IMcpServer Server { get; } + protected McpServer Server { get; } protected IServiceProvider ServiceProvider { get; } @@ -62,9 +63,9 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) + protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) { - return await McpClientFactory.CreateAsync( + return await McpClient.CreateAsync( new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs new file mode 100644 index 000000000..00e67c247 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs @@ -0,0 +1,314 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Configuration; + +public class McpServerBuilderExtensionsFilterTests : ClientServerTestBase +{ + public McpServerBuilderExtensionsFilterTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + private MockLoggerProvider _mockLoggerProvider = new(); + + private static ILogger GetLogger(IServiceProvider? services, string categoryName) + { + var loggerFactory = services?.GetRequiredService() ?? throw new InvalidOperationException("LoggerFactory not available"); + return loggerFactory.CreateLogger(categoryName); + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder + .AddListResourceTemplatesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListResourceTemplatesFilter"); + logger.LogInformation("ListResourceTemplatesFilter executed"); + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsFilter"); + logger.LogInformation("ListToolsFilter executed"); + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsOrder1"); + logger.LogInformation("ListToolsOrder1 before"); + var result = await next(request, cancellationToken); + logger.LogInformation("ListToolsOrder1 after"); + return result; + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsOrder2"); + logger.LogInformation("ListToolsOrder2 before"); + var result = await next(request, cancellationToken); + logger.LogInformation("ListToolsOrder2 after"); + return result; + }) + .AddCallToolFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "CallToolFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"CallToolFilter executed for tool: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddListPromptsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListPromptsFilter"); + logger.LogInformation("ListPromptsFilter executed"); + return await next(request, cancellationToken); + }) + .AddGetPromptFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "GetPromptFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"GetPromptFilter executed for prompt: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddListResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListResourcesFilter"); + logger.LogInformation("ListResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddReadResourceFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ReadResourceFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"ReadResourceFilter executed for resource: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddCompleteFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "CompleteFilter"); + logger.LogInformation("CompleteFilter executed"); + return await next(request, cancellationToken); + }) + .AddSubscribeToResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "SubscribeToResourcesFilter"); + logger.LogInformation("SubscribeToResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddUnsubscribeFromResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "UnsubscribeFromResourcesFilter"); + logger.LogInformation("UnsubscribeFromResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddSetLoggingLevelFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "SetLoggingLevelFilter"); + logger.LogInformation("SetLoggingLevelFilter executed"); + return await next(request, cancellationToken); + }) + .WithTools() + .WithPrompts() + .WithResources() + .WithSetLoggingLevelHandler(async (request, cancellationToken) => new EmptyResult()) + .WithListResourceTemplatesHandler(async (request, cancellationToken) => new ListResourceTemplatesResult + { + ResourceTemplates = [new() { Name = "test", UriTemplate = "test://resource/{id}" }] + }) + .WithCompleteHandler(async (request, cancellationToken) => new CompleteResult + { + Completion = new() { Values = ["test"] } + }); + + services.AddSingleton(_mockLoggerProvider); + } + + [Fact] + public async Task AddListResourceTemplatesFilter_Logs_When_ListResourceTemplates_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListResourceTemplatesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListResourceTemplatesFilter", logMessage.Category); + } + + [Fact] + public async Task AddListToolsFilter_Logs_When_ListTools_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListToolsFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListToolsFilter", logMessage.Category); + } + + [Fact] + public async Task AddCallToolFilter_Logs_When_CallTool_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.CallToolAsync("test_tool_method", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "CallToolFilter executed for tool: test_tool_method"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("CallToolFilter", logMessage.Category); + } + + [Fact] + public async Task AddListPromptsFilter_Logs_When_ListPrompts_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListPromptsFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListPromptsFilter", logMessage.Category); + } + + [Fact] + public async Task AddGetPromptFilter_Logs_When_GetPrompt_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.GetPromptAsync("test_prompt_method", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "GetPromptFilter executed for prompt: test_prompt_method"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("GetPromptFilter", logMessage.Category); + } + + [Fact] + public async Task AddListResourcesFilter_Logs_When_ListResources_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddReadResourceFilter_Logs_When_ReadResource_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.ReadResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ReadResourceFilter executed for resource: test://resource/{id}"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ReadResourceFilter", logMessage.Category); + } + + [Fact] + public async Task AddCompleteFilter_Logs_When_Complete_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + var reference = new PromptReference { Name = "test_prompt_method" }; + await client.CompleteAsync(reference, "argument", "value", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "CompleteFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("CompleteFilter", logMessage.Category); + } + + [Fact] + public async Task AddSubscribeToResourcesFilter_Logs_When_SubscribeToResources_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.SubscribeToResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "SubscribeToResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("SubscribeToResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddUnsubscribeFromResourcesFilter_Logs_When_UnsubscribeFromResources_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.UnsubscribeFromResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "UnsubscribeFromResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("UnsubscribeFromResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddSetLoggingLevelFilter_Logs_When_SetLoggingLevel_Called() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.SetLoggingLevel(LoggingLevel.Info, cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "SetLoggingLevelFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("SetLoggingLevelFilter", logMessage.Category); + } + + [Fact] + public async Task AddListToolsFilter_Multiple_Filters_Log_In_Expected_Order() + { + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessages = _mockLoggerProvider.LogMessages + .Where(m => m.Category.StartsWith("ListToolsOrder")) + .Select(m => m.Message); + + Assert.Collection(logMessages, + m => Assert.Equal("ListToolsOrder1 before", m), + m => Assert.Equal("ListToolsOrder2 before", m), + m => Assert.Equal("ListToolsOrder2 after", m), + m => Assert.Equal("ListToolsOrder1 after", m) + ); + } + + [McpServerToolType] + public sealed class TestTool + { + [McpServerTool] + public static string TestToolMethod() + { + return "test result"; + } + } + + [McpServerPromptType] + public sealed class TestPrompt + { + [McpServerPrompt] + public static Task TestPromptMethod() + { + return Task.FromResult(new GetPromptResult + { + Description = "Test prompt", + Messages = [new() { Role = Role.User, Content = new TextContentBlock { Text = "Test" } }] + }); + } + } + + [McpServerResourceType] + public sealed class TestResource + { + [McpServerResource(UriTemplate = "test://resource/{id}")] + public static string TestResourceMethod(string id) + { + return $"Test resource for ID: {id}"; + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs index c446eb5da..adae22f24 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs @@ -21,7 +21,7 @@ public McpServerBuilderExtensionsHandlerTests() [Fact] public void WithListToolsHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListToolsResult(); + McpRequestHandler handler = async (context, token) => new ListToolsResult(); _builder.Object.WithListToolsHandler(handler); @@ -34,7 +34,7 @@ public void WithListToolsHandler_Sets_Handler() [Fact] public void WithCallToolHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new CallToolResult(); + McpRequestHandler handler = async (context, token) => new CallToolResult(); _builder.Object.WithCallToolHandler(handler); @@ -47,7 +47,7 @@ public void WithCallToolHandler_Sets_Handler() [Fact] public void WithListPromptsHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListPromptsResult(); + McpRequestHandler handler = async (context, token) => new ListPromptsResult(); _builder.Object.WithListPromptsHandler(handler); @@ -60,7 +60,7 @@ public void WithListPromptsHandler_Sets_Handler() [Fact] public void WithGetPromptHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new GetPromptResult(); + McpRequestHandler handler = async (context, token) => new GetPromptResult(); _builder.Object.WithGetPromptHandler(handler); @@ -73,7 +73,7 @@ public void WithGetPromptHandler_Sets_Handler() [Fact] public void WithListResourceTemplatesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListResourceTemplatesResult(); + McpRequestHandler handler = async (context, token) => new ListResourceTemplatesResult(); _builder.Object.WithListResourceTemplatesHandler(handler); @@ -86,7 +86,7 @@ public void WithListResourceTemplatesHandler_Sets_Handler() [Fact] public void WithListResourcesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListResourcesResult(); + McpRequestHandler handler = async (context, token) => new ListResourcesResult(); _builder.Object.WithListResourcesHandler(handler); @@ -99,7 +99,7 @@ public void WithListResourcesHandler_Sets_Handler() [Fact] public void WithReadResourceHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ReadResourceResult(); + McpRequestHandler handler = async (context, token) => new ReadResourceResult(); _builder.Object.WithReadResourceHandler(handler); @@ -112,7 +112,7 @@ public void WithReadResourceHandler_Sets_Handler() [Fact] public void WithCompleteHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new CompleteResult(); + McpRequestHandler handler = async (context, token) => new CompleteResult(); _builder.Object.WithCompleteHandler(handler); @@ -125,7 +125,7 @@ public void WithCompleteHandler_Sets_Handler() [Fact] public void WithSubscribeToResourcesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new EmptyResult(); + McpRequestHandler handler = async (context, token) => new EmptyResult(); _builder.Object.WithSubscribeToResourcesHandler(handler); @@ -138,7 +138,7 @@ public void WithSubscribeToResourcesHandler_Sets_Handler() [Fact] public void WithUnsubscribeFromResourcesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new EmptyResult(); + McpRequestHandler handler = async (context, token) => new EmptyResult(); _builder.Object.WithUnsubscribeFromResourcesHandler(handler); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 3fa2ec78b..18db1f14b 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -4,7 +4,10 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; +using System.Collections; using System.ComponentModel; +using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Channels; @@ -87,7 +90,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer public void Adds_Prompts_To_Server() { var serverOptions = ServiceProvider.GetRequiredService>().Value; - var prompts = serverOptions?.Capabilities?.Prompts?.PromptCollection; + var prompts = serverOptions.PromptCollection; Assert.NotNull(prompts); Assert.NotEmpty(prompts); } @@ -95,7 +98,7 @@ public void Adds_Prompts_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Prompts() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -124,7 +127,7 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -134,7 +137,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() Assert.False(notificationRead.IsCompleted); var serverOptions = ServiceProvider.GetRequiredService>().Value; - var serverPrompts = serverOptions.Capabilities?.Prompts?.PromptCollection; + var serverPrompts = serverOptions.PromptCollection; Assert.NotNull(serverPrompts); var newPrompt = McpServerPrompt.Create([McpServerPrompt(Name = "NewPrompt")] () => "42"); @@ -165,7 +168,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(prompts); @@ -179,7 +182,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Prompt_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ThrowsException), @@ -189,7 +192,7 @@ await Assert.ThrowsAsync(async () => await client.GetPromptAsync( [Fact] public async Task Throws_Exception_On_Unknown_Prompt() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "NotRegisteredPrompt", @@ -201,7 +204,7 @@ public async Task Throws_Exception_On_Unknown_Prompt() [Fact] public async Task Throws_Exception_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "returns_chat_messages", @@ -217,13 +220,63 @@ public void WithPrompts_InvalidArgs_Throws() Assert.Throws("prompts", () => builder.WithPrompts((IEnumerable)null!)); Assert.Throws("promptTypes", () => builder.WithPrompts((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithPrompts(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithPrompts()); + Assert.Throws("builder", () => nullBuilder.WithPrompts(new object())); Assert.Throws("builder", () => nullBuilder.WithPrompts(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithPromptsFromAssembly()); } + [Fact] + public async Task WithPrompts_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new SimplePrompts(new ObjectWithId() { Id = "42" }); + sc.AddMcpServer().WithPrompts(target); + + McpServerPrompt prompt = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolPrompt.Name == "returns_string"); + var result = await prompt.GetAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) + { + Params = new GetPromptRequestParams + { + Name = "returns_string", + Arguments = new Dictionary + { + ["message"] = JsonSerializer.SerializeToElement("hello", AIJsonUtilities.DefaultOptions), + } + } + }, TestContext.Current.CancellationToken); + + Assert.Equal(target.ReturnsString("hello"), (result.Messages[0].Content as TextContentBlock)?.Text); + } + + [Fact] + public async Task WithPrompts_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithPrompts(new MyPromptProvider()); + + var prompts = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, prompts.Length); + Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns42"); + Assert.Contains(prompts, t => t.ProtocolPrompt.Name == "Returns43"); + } + + private sealed class MyPromptProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerPrompt.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerPrompt.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index ed930b174..939904cb7 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -4,8 +4,12 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using Moq; +using System.Collections; using System.ComponentModel; +using System.Text.Json; using System.Threading.Channels; +using static ModelContextProtocol.Tests.Configuration.McpServerBuilderExtensionsPromptsTests; namespace ModelContextProtocol.Tests.Configuration; @@ -114,7 +118,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer public void Adds_Resources_To_Server() { var serverOptions = ServiceProvider.GetRequiredService>().Value; - var resources = serverOptions?.Capabilities?.Resources?.ResourceCollection; + var resources = serverOptions.ResourceCollection; Assert.NotNull(resources); Assert.NotEmpty(resources); } @@ -122,7 +126,7 @@ public void Adds_Resources_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Resources() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); Assert.NotNull(client.ServerCapabilities.Resources); @@ -141,7 +145,7 @@ public async Task Can_List_And_Call_Registered_Resources() [Fact] public async Task Can_List_And_Call_Registered_ResourceTemplates() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken); Assert.Equal(3, resources.Count); @@ -158,7 +162,7 @@ public async Task Can_List_And_Call_Registered_ResourceTemplates() [Fact] public async Task Can_Be_Notified_Of_Resource_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); Assert.Equal(5, resources.Count); @@ -168,7 +172,7 @@ public async Task Can_Be_Notified_Of_Resource_Changes() Assert.False(notificationRead.IsCompleted); var serverOptions = ServiceProvider.GetRequiredService>().Value; - var serverResources = serverOptions.Capabilities?.Resources?.ResourceCollection; + var serverResources = serverOptions.ResourceCollection; Assert.NotNull(serverResources); var newResource = McpServerResource.Create([McpServerResource(Name = "NewResource")] () => "42"); @@ -199,7 +203,7 @@ public async Task Can_Be_Notified_Of_Resource_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(resources); @@ -217,7 +221,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Resource_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( $"resource://mcp/{nameof(SimpleResources.ThrowsException)}", @@ -227,7 +231,7 @@ await Assert.ThrowsAsync(async () => await client.ReadResourceAsyn [Fact] public async Task Throws_Exception_On_Unknown_Resource() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "test:///NotRegisteredResource", @@ -243,13 +247,59 @@ public void WithResources_InvalidArgs_Throws() Assert.Throws("resourceTemplates", () => builder.WithResources((IEnumerable)null!)); Assert.Throws("resourceTemplateTypes", () => builder.WithResources((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithResources(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithResources()); + Assert.Throws("builder", () => nullBuilder.WithResources(new object())); Assert.Throws("builder", () => nullBuilder.WithResources(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithResourcesFromAssembly()); } + [Fact] + public async Task WithResources_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new ResourceWithId(new ObjectWithId() { Id = "42" }); + sc.AddMcpServer().WithResources(target); + + McpServerResource resource = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolResource?.Name == "returns_string"); + var result = await resource.ReadAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) + { + Params = new() + { + Uri = "returns://string" + } + }, TestContext.Current.CancellationToken); + + Assert.Equal(target.ReturnsString(), (result?.Contents[0] as TextResourceContents)?.Text); + } + + [Fact] + public async Task WithResources_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithResources(new MyResourceProvider()); + + var resources = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, resources.Length); + Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns42"); + Assert.Contains(resources, t => t.ProtocolResource?.Name == "Returns43"); + } + + private sealed class MyResourceProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerResource.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerResource.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void Empty_Enumerables_Is_Allowed() { @@ -307,4 +357,11 @@ public sealed class MoreResources [McpServerResource, Description("Another neat direct resource")] public static string AnotherNeatDirectResource() => "This is a neat resource"; } + + [McpServerResourceType] + public sealed class ResourceWithId(ObjectWithId id) + { + [McpServerResource(UriTemplate = "returns://string")] + public string ReturnsString() => $"Id: {id.Id}"; + } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index d2080e1fc..cf2dfd0f7 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -5,6 +5,9 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using Moq; +using System.Collections; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -22,6 +25,8 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { } + private MockLoggerProvider _mockLoggerProvider = new(); + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) { mcpServerBuilder @@ -107,13 +112,14 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options); services.AddSingleton(new ObjectWithId()); + services.AddSingleton(_mockLoggerProvider); } [Fact] public void Adds_Tools_To_Server() { var serverOptions = ServiceProvider.GetRequiredService>().Value; - var tools = serverOptions.Capabilities?.Tools?.ToolCollection; + var tools = serverOptions.ToolCollection; Assert.NotNull(tools); Assert.NotEmpty(tools); } @@ -121,7 +127,7 @@ public void Adds_Tools_To_Server() [Fact] public async Task Can_List_Registered_Tools() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -150,13 +156,13 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdoutPipe = new Pipe(); await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - await using var server = McpServerFactory.Create(transport, options, loggerFactory, ServiceProvider); + await using var server = McpServer.Create(transport, options, loggerFactory, ServiceProvider); var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (var client = await McpClientFactory.CreateAsync( + await using (var client = await McpClient.CreateAsync( new StreamClientTransport( - serverInput: stdinPipe.Writer.AsStream(), - serverOutput: stdoutPipe.Reader.AsStream(), + serverInput: stdinPipe.Writer.AsStream(), + serverOutput: stdoutPipe.Reader.AsStream(), LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) @@ -185,7 +191,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -195,7 +201,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() Assert.False(notificationRead.IsCompleted); var serverOptions = ServiceProvider.GetRequiredService>().Value; - var serverTools = serverOptions.Capabilities?.Tools?.ToolCollection; + var serverTools = serverOptions.ToolCollection; Assert.NotNull(serverTools); var newTool = McpServerTool.Create([McpServerTool(Name = "NewTool")] () => "42"); @@ -226,11 +232,11 @@ public async Task Can_Be_Notified_Of_Tool_Changes() [Fact] public async Task Can_Call_Registered_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", - new Dictionary() { ["message"] = "Peter" }, + new Dictionary() { ["message"] = "Peter" }, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); @@ -245,7 +251,7 @@ public async Task Can_Call_Registered_Tool() [Fact] public async Task Can_Call_Registered_Tool_With_Array_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_array", @@ -254,8 +260,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); - Assert.Equal("hello Peter", (result.Content[0] as TextContentBlock)?.Text); - Assert.Equal("hello2 Peter", (result.Content[1] as TextContentBlock)?.Text); + Assert.Equal("""["hello Peter","hello2 Peter"]""", (result.Content[0] as TextContentBlock)?.Text); result = await client.CallToolAsync( "SecondCustomTool", @@ -269,7 +274,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Null_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_null", @@ -283,7 +288,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Json_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_json", @@ -300,7 +305,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Int_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_integer", @@ -315,7 +320,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() [Fact] public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_complex", @@ -332,7 +337,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() [Fact] public async Task Can_Call_Registered_Tool_With_Instance_Method() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); string[][] parts = new string[2][]; for (int i = 0; i < 2; i++) @@ -352,16 +357,16 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() string random1 = parts[0][0]; string random2 = parts[1][0]; Assert.NotEqual(random1, random2); - + string id1 = parts[0][1]; string id2 = parts[1][1]; Assert.Equal(id1, id2); } [Fact] - public async Task Returns_IsError_Content_When_Tool_Fails() + public async Task Returns_IsError_Content_And_Logs_Error_When_Tool_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "throw_exception", @@ -371,12 +376,17 @@ public async Task Returns_IsError_Content_When_Tool_Fails() Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); Assert.Contains("An error occurred", (result.Content[0] as TextContentBlock)?.Text); + + var errorLog = Assert.Single(_mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); + Assert.Equal($"\"throw_exception\" threw an unhandled exception.", errorLog.Message); + Assert.IsType(errorLog.Exception); + Assert.Equal("Test error", errorLog.Exception.Message); } [Fact] public async Task Throws_Exception_On_Unknown_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "NotRegisteredTool", @@ -388,7 +398,7 @@ public async Task Throws_Exception_On_Unknown_Tool() [Fact] public async Task Returns_IsError_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -404,9 +414,11 @@ public void WithTools_InvalidArgs_Throws() Assert.Throws("tools", () => builder.WithTools((IEnumerable)null!)); Assert.Throws("toolTypes", () => builder.WithTools((IEnumerable)null!)); + Assert.Throws("target", () => builder.WithTools(target: null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithTools()); + Assert.Throws("builder", () => nullBuilder.WithTools(new object())); Assert.Throws("builder", () => nullBuilder.WithTools(Array.Empty())); Assert.Throws("builder", () => nullBuilder.WithToolsFromAssembly()); } @@ -504,10 +516,48 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime } } + [Fact] + public async Task WithTools_TargetInstance_UsesTarget() + { + ServiceCollection sc = new(); + + var target = new EchoTool(new ObjectWithId()); + sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options); + + McpServerTool tool = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolTool.Name == "get_ctor_parameter"); + var result = await tool.InvokeAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }), TestContext.Current.CancellationToken); + + Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); + } + + [Fact] + public async Task WithTools_TargetInstance_UsesEnumerableImplementation() + { + ServiceCollection sc = new(); + + sc.AddMcpServer().WithTools(new MyToolProvider()); + + var tools = sc.BuildServiceProvider().GetServices().ToArray(); + Assert.Equal(2, tools.Length); + Assert.Contains(tools, t => t.ProtocolTool.Name == "Returns42"); + Assert.Contains(tools, t => t.ProtocolTool.Name == "Returns43"); + } + + private sealed class MyToolProvider : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return McpServerTool.Create(() => "42", new() { Name = "Returns42" }); + yield return McpServerTool.Create(() => "43", new() { Name = "Returns43" }); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public async Task Recognizes_Parameter_Types() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -582,7 +632,7 @@ public void Create_ExtractsToolAnnotations_SomeSet() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -598,7 +648,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task HandlesIProgressParameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -652,7 +702,7 @@ public async Task HandlesIProgressParameter() [Fact] public async Task CancellationNotificationsPropagateToToolTokens() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs new file mode 100644 index 000000000..151111db2 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs @@ -0,0 +1,194 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests.Configuration; + +public class McpServerOptionsSetupTests +{ + #region Prompt Handler Tests + [Fact] + public void Configure_WithListPromptsHandler_CreatesPromptsCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithListPromptsHandler(async (request, ct) => new ListPromptsResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ListPromptsHandler); + Assert.NotNull(options.Capabilities?.Prompts); + } + + [Fact] + public void Configure_WithGetPromptHandler_CreatesPromptsCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithGetPromptHandler(async (request, ct) => new GetPromptResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.GetPromptHandler); + Assert.NotNull(options.Capabilities?.Prompts); + } + #endregion + + #region Resource Handler Tests + [Fact] + public void Configure_WithListResourceTemplatesHandler_CreatesResourcesCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithListResourceTemplatesHandler(async (request, ct) => new ListResourceTemplatesResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ListResourceTemplatesHandler); + Assert.NotNull(options.Capabilities?.Resources); + } + + [Fact] + public void Configure_WithListResourcesHandler_CreatesResourcesCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithListResourcesHandler(async (request, ct) => new ListResourcesResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ListResourcesHandler); + Assert.NotNull(options.Capabilities?.Resources); + } + + [Fact] + public void Configure_WithReadResourceHandler_CreatesResourcesCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithReadResourceHandler(async (request, ct) => new ReadResourceResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ReadResourceHandler); + Assert.NotNull(options.Capabilities?.Resources); + } + + [Fact] + public void Configure_WithSubscribeToResourcesHandler_And_WithOtherResourcesHandler_EnablesSubscription() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithListResourcesHandler(async (request, ct) => new ListResourcesResult()) + .WithSubscribeToResourcesHandler(async (request, ct) => new EmptyResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ListResourcesHandler); + Assert.NotNull(options.Handlers.SubscribeToResourcesHandler); + Assert.NotNull(options.Capabilities?.Resources); + Assert.True(options.Capabilities.Resources.Subscribe); + } + + [Fact] + public void Configure_WithUnsubscribeFromResourcesHandler_And_WithOtherResourcesHandler_EnablesSubscription() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithListResourcesHandler(async (request, ct) => new ListResourcesResult()) + .WithUnsubscribeFromResourcesHandler(async (request, ct) => new EmptyResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ListResourcesHandler); + Assert.NotNull(options.Handlers.UnsubscribeFromResourcesHandler); + Assert.NotNull(options.Capabilities?.Resources); + Assert.True(options.Capabilities.Resources.Subscribe); + } + + [Fact] + public void Configure_WithSubscribeToResourcesHandler_WithoutOtherResourcesHandler_DoesNotCreateResourcesCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithSubscribeToResourcesHandler(async (request, ct) => new EmptyResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.Null(options.Handlers.SubscribeToResourcesHandler); + Assert.Null(options.Capabilities?.Resources); + } + + [Fact] + public void Configure_WithUnsubscribeFromResourcesHandler_WithoutOtherResourcesHandler_DoesNotCreateResourcesCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithUnsubscribeFromResourcesHandler(async (request, ct) => new EmptyResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.Null(options.Handlers.UnsubscribeFromResourcesHandler); + Assert.Null(options.Capabilities?.Resources); + } + #endregion + + #region Tool Handler Tests + [Fact] + public void Configure_WithListToolsHandler_CreatesToolsCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithListToolsHandler(async (request, ct) => new ListToolsResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.ListToolsHandler); + Assert.NotNull(options.Capabilities?.Tools); + } + + [Fact] + public void Configure_WithCallToolHandler_CreatesToolsCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithCallToolHandler(async (request, ct) => new CallToolResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.CallToolHandler); + Assert.NotNull(options.Capabilities?.Tools); + } + #endregion + + #region Logging Handler Tests + [Fact] + public void Configure_WithSetLoggingLevelHandler_CreatesLoggingCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithSetLoggingLevelHandler(async (request, ct) => new EmptyResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.SetLoggingLevelHandler); + Assert.NotNull(options.Capabilities?.Logging); + } + #endregion + + #region Completion Handler Tests + [Fact] + public void Configure_WithCompleteHandler_CreatesCompletionsCapability() + { + var services = new ServiceCollection(); + services.AddMcpServer() + .WithCompleteHandler(async (request, ct) => new CompleteResult()); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.NotNull(options.Handlers.CompleteHandler); + Assert.NotNull(options.Capabilities?.Completions); + } + #endregion +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index b940c1c7c..5ddc3c54a 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -22,7 +22,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task InjectScopedServiceAsArgument() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); var tool = tools.First(t => t.Name == "echo_complex"); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 116c62a15..8996b9962 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -128,7 +128,7 @@ await RunConnected(async (client, server) => Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } - private static async Task RunConnected(Func action, List clientToServerLog) + private static async Task RunConnected(Func action, List clientToServerLog) { Pipe clientToServerPipe = new(), serverToClientPipe = new(); StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); @@ -137,23 +137,17 @@ private static async Task RunConnected(Func action Task serverTask; - await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() + await using (McpServer server = McpServer.Create(serverTransport, new() { - Capabilities = new() - { - Tools = new() - { - ToolCollection = [ - McpServerTool.Create((int amount) => amount * 2, new() { Name = "DoubleValue", Description = "Doubles the value." }), - McpServerTool.Create(() => { throw new Exception("boom"); }, new() { Name = "Throw", Description = "Throws error." }), - ], - } - } + ToolCollection = [ + McpServerTool.Create((int amount) => amount * 2, new() { Name = "DoubleValue", Description = "Doubles the value." }), + McpServerTool.Create(() => { throw new Exception("boom"); }, new() { Name = "Throw", Description = "Throws error." }), + ] })) { serverTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (IMcpClient client = await McpClientFactory.CreateAsync( + await using (McpClient client = await McpClient.CreateAsync( clientTransport, cancellationToken: TestContext.Current.CancellationToken)) { diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index ffd95076f..2d5ef5f2d 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -36,15 +36,15 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - var defaultConfig = new SseClientTransportOptions + var defaultConfig = new HttpClientTransportOptions { Endpoint = new Uri($"http://localhost:{port}/sse"), Name = "Everything", }; // Create client and run tests - await using var client = await McpClientFactory.CreateAsync( - new SseClientTransport(defaultConfig), + await using var client = await McpClient.CreateAsync( + new HttpClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -63,35 +63,32 @@ public async Task Sampling_Sse_EverythingServer() await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); - var defaultConfig = new SseClientTransportOptions + var defaultConfig = new HttpClientTransportOptions { Endpoint = new Uri($"http://localhost:{port}/sse"), Name = "Everything", }; int samplingHandlerCalls = 0; - var defaultOptions = new McpClientOptions + var defaultOptions = new McpClientOptions() { - Capabilities = new() + Handlers = new() { - Sampling = new() + SamplingHandler = async (_, _, _) => { - SamplingHandler = async (_, _, _) => + samplingHandlerCalls++; + return new CreateMessageResult { - samplingHandlerCalls++; - return new CreateMessageResult - { - Model = "test-model", - Role = Role.Assistant, - Content = new TextContentBlock { Text = "Test response" }, - }; - }, - }, - }, + Model = "test-model", + Role = Role.Assistant, + Content = new TextContentBlock { Text = "Test response" }, + }; + } + } }; - await using var client = await McpClientFactory.CreateAsync( - new SseClientTransport(defaultConfig), + await using var client = await McpClient.CreateAsync( + new HttpClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs b/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs new file mode 100644 index 000000000..613c703c3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs @@ -0,0 +1,118 @@ +using ModelContextProtocol.Protocol; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpEndpointExtensionsTests +{ + [Fact] + public async Task SendRequestAsync_Generic_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendRequestAsync( + endpoint, "method", "param", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendRequestAsync' instead", ex.Message); + } + + [Fact] + public async Task SendNotificationAsync_Parameterless_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendNotificationAsync( + endpoint, "notify", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendNotificationAsync' instead", ex.Message); + } + + [Fact] + public async Task SendNotificationAsync_Generic_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendNotificationAsync( + endpoint, "notify", "payload", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendNotificationAsync' instead", ex.Message); + } + + [Fact] + public async Task NotifyProgressAsync_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.NotifyProgressAsync( + endpoint, new ProgressToken("t1"), new ProgressNotificationValue { Progress = 0.5f }, cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.NotifyProgressAsync' instead", ex.Message); + } + + [Fact] + public async Task SendRequestAsync_Generic_Forwards_To_McpSession_SendRequestAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(42, McpJsonUtilities.DefaultOptions), + }); + + IMcpEndpoint endpoint = mockSession.Object; + + var result = await endpoint.SendRequestAsync("method", "param", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(42, result); + mockSession.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SendNotificationAsync_Parameterless_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.SendNotificationAsync("notify", cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SendNotificationAsync_Generic_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.SendNotificationAsync("notify", "payload", cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task NotifyProgressAsync_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.NotifyProgressAsync(new ProgressToken("progress-token"), new ProgressNotificationValue { Progress = 1 }, cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs b/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs index e0af61eed..cc55746fe 100644 --- a/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs +++ b/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs @@ -1,6 +1,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Tests; @@ -43,6 +44,38 @@ public static void DefaultOptions_UnknownEnumHandling() } } + [Fact] + public static void DefaultOptions_CanSerializeIEnumerableOfContentBlock() + { + var options = McpJsonUtilities.DefaultOptions; + + // Create an IEnumerable with different content types + IEnumerable contentBlocks = new List + { + new TextContentBlock { Text = "Hello World" }, + new TextContentBlock { Text = "Test message" } + }; + + // Should not throw NotSupportedException + string json = JsonSerializer.Serialize(contentBlocks, options); + + Assert.NotNull(json); + Assert.Contains("Hello World", json); + Assert.Contains("Test message", json); + Assert.Contains("\"type\":\"text\"", json); + + // Should also be able to deserialize back + var deserialized = JsonSerializer.Deserialize>(json, options); + Assert.NotNull(deserialized); + var deserializedList = deserialized.ToList(); + Assert.Equal(2, deserializedList.Count); + Assert.All(deserializedList, cb => Assert.Equal("text", cb.Type)); + + var textBlocks = deserializedList.Cast().ToArray(); + Assert.Equal("Hello World", textBlocks[0].Text); + Assert.Equal("Test message", textBlocks[1].Text); + } + public enum EnumWithoutAnnotation { A = 1, B = 2, C = 3 } [JsonConverter(typeof(JsonStringEnumConverter))] diff --git a/tests/ModelContextProtocol.Tests/PlatformDetection.cs b/tests/ModelContextProtocol.Tests/PlatformDetection.cs index 1eef99420..f439147ff 100644 --- a/tests/ModelContextProtocol.Tests/PlatformDetection.cs +++ b/tests/ModelContextProtocol.Tests/PlatformDetection.cs @@ -1,6 +1,9 @@ +using System.Runtime.InteropServices; + namespace ModelContextProtocol.Tests; internal static class PlatformDetection { public static bool IsMonoRuntime { get; } = Type.GetType("Mono.Runtime") is not null; + public static bool IsWindows { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs new file mode 100644 index 000000000..c5ab88b3a --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs @@ -0,0 +1,83 @@ +using System.Text.Json; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Tests.Protocol; + +public class ContentBlockTests +{ + [Fact] + public void ResourceLinkBlock_SerializationRoundTrip_PreservesAllProperties() + { + // Arrange + var original = new ResourceLinkBlock + { + Uri = "https://example.com/resource", + Name = "Test Resource", + Description = "A test resource for validation", + MimeType = "text/plain", + Size = 1024 + }; + + // Act - Serialize to JSON + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + + // Act - Deserialize back from JSON + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(deserialized); + var resourceLink = Assert.IsType(deserialized); + + Assert.Equal(original.Uri, resourceLink.Uri); + Assert.Equal(original.Name, resourceLink.Name); + Assert.Equal(original.Description, resourceLink.Description); + Assert.Equal(original.MimeType, resourceLink.MimeType); + Assert.Equal(original.Size, resourceLink.Size); + Assert.Equal("resource_link", resourceLink.Type); + } + + [Fact] + public void ResourceLinkBlock_DeserializationWithMinimalProperties_Succeeds() + { + // Arrange - JSON with only required properties + const string Json = """ + { + "type": "resource_link", + "uri": "https://example.com/minimal", + "name": "Minimal Resource" + } + """; + + // Act + var deserialized = JsonSerializer.Deserialize(Json, McpJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(deserialized); + var resourceLink = Assert.IsType(deserialized); + + Assert.Equal("https://example.com/minimal", resourceLink.Uri); + Assert.Equal("Minimal Resource", resourceLink.Name); + Assert.Null(resourceLink.Description); + Assert.Null(resourceLink.MimeType); + Assert.Null(resourceLink.Size); + Assert.Equal("resource_link", resourceLink.Type); + } + + [Fact] + public void ResourceLinkBlock_DeserializationWithoutName_ThrowsJsonException() + { + // Arrange - JSON missing the required "name" property + const string Json = """ + { + "type": "resource_link", + "uri": "https://example.com/missing-name" + } + """; + + // Act & Assert + var exception = Assert.Throws(() => + JsonSerializer.Deserialize(Json, McpJsonUtilities.DefaultOptions)); + + Assert.Contains("Name must be provided for 'resource_link' type", exception.Message); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs index f44743916..f3ae33ed5 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs @@ -67,77 +67,74 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task Can_Elicit_Information() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { - Capabilities = new() + Handlers = new McpClientHandlers() { - Elicitation = new() + ElicitationHandler = async (request, cancellationtoken) => { - ElicitationHandler = async (request, cancellationtoken) => - { - Assert.NotNull(request); - Assert.Equal("Please provide more information.", request.Message); - Assert.Equal(4, request.RequestedSchema.Properties.Count); + Assert.NotNull(request); + Assert.Equal("Please provide more information.", request.Message); + Assert.Equal(4, request.RequestedSchema.Properties.Count); - foreach (var entry in request.RequestedSchema.Properties) + foreach (var entry in request.RequestedSchema.Properties) + { + switch (entry.Key) { - switch (entry.Key) - { - case "prop1": - var primitiveString = Assert.IsType(entry.Value); - Assert.Equal("title1", primitiveString.Title); - Assert.Equal(1, primitiveString.MinLength); - Assert.Equal(100, primitiveString.MaxLength); - break; + case "prop1": + var primitiveString = Assert.IsType(entry.Value); + Assert.Equal("title1", primitiveString.Title); + Assert.Equal(1, primitiveString.MinLength); + Assert.Equal(100, primitiveString.MaxLength); + break; - case "prop2": - var primitiveNumber = Assert.IsType(entry.Value); - Assert.Equal("description2", primitiveNumber.Description); - Assert.Equal(0, primitiveNumber.Minimum); - Assert.Equal(1000, primitiveNumber.Maximum); - break; + case "prop2": + var primitiveNumber = Assert.IsType(entry.Value); + Assert.Equal("description2", primitiveNumber.Description); + Assert.Equal(0, primitiveNumber.Minimum); + Assert.Equal(1000, primitiveNumber.Maximum); + break; - case "prop3": - var primitiveBool = Assert.IsType(entry.Value); - Assert.Equal("title3", primitiveBool.Title); - Assert.Equal("description4", primitiveBool.Description); - Assert.True(primitiveBool.Default); - break; + case "prop3": + var primitiveBool = Assert.IsType(entry.Value); + Assert.Equal("title3", primitiveBool.Title); + Assert.Equal("description4", primitiveBool.Description); + Assert.True(primitiveBool.Default); + break; - case "prop4": - var primitiveEnum = Assert.IsType(entry.Value); - Assert.Equal(["option1", "option2", "option3"], primitiveEnum.Enum); - Assert.Equal(["Name1", "Name2", "Name3"], primitiveEnum.EnumNames); - break; + case "prop4": + var primitiveEnum = Assert.IsType(entry.Value); + Assert.Equal(["option1", "option2", "option3"], primitiveEnum.Enum); + Assert.Equal(["Name1", "Name2", "Name3"], primitiveEnum.EnumNames); + break; - default: - Assert.Fail($"Unknown property: {entry.Key}"); - break; - } + default: + Assert.Fail($"Unknown property: {entry.Key}"); + break; } + } - return new ElicitResult + return new ElicitResult + { + Action = "accept", + Content = new Dictionary { - Action = "accept", - Content = new Dictionary - { - ["prop1"] = (JsonElement)JsonSerializer.Deserialize(""" - "string result" - """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, - ["prop2"] = (JsonElement)JsonSerializer.Deserialize(""" - 42 - """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, - ["prop3"] = (JsonElement)JsonSerializer.Deserialize(""" - true - """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, - ["prop4"] = (JsonElement)JsonSerializer.Deserialize(""" - "option2" - """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, - }, - }; - }, - }, - }, + ["prop1"] = (JsonElement)JsonSerializer.Deserialize(""" + "string result" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop2"] = (JsonElement)JsonSerializer.Deserialize(""" + 42 + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop3"] = (JsonElement)JsonSerializer.Deserialize(""" + true + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop4"] = (JsonElement)JsonSerializer.Deserialize(""" + "option2" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + }, + }; + } + } }); var result = await client.CallToolAsync("TestElicitation", cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs new file mode 100644 index 000000000..47da166ca --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs @@ -0,0 +1,363 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Tests.Configuration; + +public partial class ElicitationTypedTests : ClientServerTestBase +{ + public ElicitationTypedTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithCallToolHandler(async (request, cancellationToken) => + { + Assert.NotNull(request.Params); + + if (request.Params!.Name == "TestElicitationTyped") + { + var result = await request.Server.ElicitAsync( + message: "Please provide more information.", + serializerOptions: ElicitationTypedDefaultJsonContext.Default.Options, + cancellationToken: CancellationToken.None); + + Assert.Equal("accept", result.Action); + Assert.NotNull(result.Content); + Assert.Equal("Alice", result.Content!.Name); + Assert.Equal(30, result.Content!.Age); + Assert.True(result.Content!.Active); + Assert.Equal(SampleRole.Admin, result.Content!.Role); + Assert.Equal(99.5, result.Content!.Score); + } + else if (request.Params!.Name == "TestElicitationCamelForm") + { + var result = await request.Server.ElicitAsync( + message: "Please provide more information.", + serializerOptions: ElicitationTypedCamelJsonContext.Default.Options, + cancellationToken: CancellationToken.None); + + Assert.Equal("accept", result.Action); + Assert.NotNull(result.Content); + Assert.Equal("Bob", result.Content!.FirstName); + Assert.Equal(90210, result.Content!.ZipCode); + Assert.False(result.Content!.IsAdmin); + } + else if (request.Params!.Name == "TestElicitationNullablePropertyForm") + { + var result = await request.Server.ElicitAsync( + message: "Please provide more information.", + serializerOptions: ElicitationNullablePropertyJsonContext.Default.Options, + cancellationToken: CancellationToken.None); + + // Should be unreachable + return new CallToolResult + { + Content = [new TextContentBlock { Text = "unexpected" }], + }; + } + else if (request.Params!.Name == "TestElicitationUnsupportedType") + { + await request.Server.ElicitAsync( + message: "Please provide more information.", + serializerOptions: ElicitationUnsupportedJsonContext.Default.Options, + cancellationToken: CancellationToken.None); + + // Should be unreachable + return new CallToolResult + { + Content = [new TextContentBlock { Text = "unexpected" }], + }; + } + else if (request.Params!.Name == "TestElicitationNonObjectGenericType") + { + // This should throw because T is not an object type with properties (string primitive) + await request.Server.ElicitAsync( + message: "Any message", + serializerOptions: McpJsonUtilities.DefaultOptions, + cancellationToken: CancellationToken.None); + + return new CallToolResult + { + Content = [new TextContentBlock { Text = "unexpected" }], + }; + } + else + { + Assert.Fail($"Unexpected tool name: {request.Params!.Name}"); + } + + return new CallToolResult + { + Content = [new TextContentBlock { Text = "success" }], + }; + }); + } + + [Fact] + public async Task Can_Elicit_Typed_Information() + { + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new() + { + ElicitationHandler = async (request, cancellationToken) => + { + Assert.NotNull(request); + Assert.Equal("Please provide more information.", request.Message); + + Assert.Equal(6, request.RequestedSchema.Properties.Count); + + foreach (var entry in request.RequestedSchema.Properties) + { + var key = entry.Key; + var value = entry.Value; + switch (key) + { + case nameof(SampleForm.Name): + var stringSchema = Assert.IsType(value); + Assert.Equal("string", stringSchema.Type); + break; + + case nameof(SampleForm.Age): + var intSchema = Assert.IsType(value); + Assert.Equal("integer", intSchema.Type); + break; + + case nameof(SampleForm.Active): + var boolSchema = Assert.IsType(value); + Assert.Equal("boolean", boolSchema.Type); + break; + + case nameof(SampleForm.Role): + var enumSchema = Assert.IsType(value); + Assert.Equal("string", enumSchema.Type); + Assert.Equal([nameof(SampleRole.User), nameof(SampleRole.Admin)], enumSchema.Enum); + break; + + case nameof(SampleForm.Score): + var numSchema = Assert.IsType(value); + Assert.Equal("number", numSchema.Type); + break; + + case nameof(SampleForm.Created): + var dateTimeSchema = Assert.IsType(value); + Assert.Equal("string", dateTimeSchema.Type); + Assert.Equal("date-time", dateTimeSchema.Format); + + break; + + default: + Assert.Fail($"Unexpected property in schema: {key}"); + break; + } + } + + return new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + [nameof(SampleForm.Name)] = (JsonElement)JsonSerializer.Deserialize(""" + "Alice" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + [nameof(SampleForm.Age)] = (JsonElement)JsonSerializer.Deserialize(""" + 30 + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + [nameof(SampleForm.Active)] = (JsonElement)JsonSerializer.Deserialize(""" + true + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + [nameof(SampleForm.Role)] = (JsonElement)JsonSerializer.Deserialize(""" + "Admin" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + [nameof(SampleForm.Score)] = (JsonElement)JsonSerializer.Deserialize(""" + 99.5 + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + [nameof(SampleForm.Created)] = (JsonElement)JsonSerializer.Deserialize(""" + "2023-08-27T03:05:00" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + }, + }; + }, + } + }); + + var result = await client.CallToolAsync("TestElicitationTyped", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("success", (result.Content[0] as TextContentBlock)?.Text); + } + + [Fact] + public async Task Elicit_Typed_Respects_NamingPolicy() + { + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new() + { + ElicitationHandler = async (request, cancellationToken) => + { + Assert.NotNull(request); + Assert.Equal("Please provide more information.", request.Message); + + // Expect camelCase names based on serializer options + Assert.Contains("firstName", request.RequestedSchema.Properties.Keys); + Assert.Contains("zipCode", request.RequestedSchema.Properties.Keys); + Assert.Contains("isAdmin", request.RequestedSchema.Properties.Keys); + + return new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["firstName"] = (JsonElement)JsonSerializer.Deserialize(""" + "Bob" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["zipCode"] = (JsonElement)JsonSerializer.Deserialize(""" + 90210 + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["isAdmin"] = (JsonElement)JsonSerializer.Deserialize(""" + false + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + }, + }; + }, + }, + }); + + var result = await client.CallToolAsync("TestElicitationCamelForm", cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal("success", (result.Content[0] as TextContentBlock)?.Text); + } + + [Fact] + public async Task Elicit_Typed_With_Unsupported_Property_Type_Throws() + { + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new() + { + // Handler should never be invoked because the exception occurs before the request is sent. + ElicitationHandler = async (req, ct) => + { + Assert.Fail("Elicitation handler should not be called for unsupported schema test."); + return new ElicitResult { Action = "cancel" }; + }, + }, + }); + + var ex = await Assert.ThrowsAsync(async() => + await client.CallToolAsync("TestElicitationUnsupportedType", cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains(typeof(UnsupportedForm.Nested).FullName!, ex.Message); + } + + [Fact] + public async Task Elicit_Typed_With_Nullable_Property_Type_Throws() + { + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new() + { + // Handler should never be invoked because the exception occurs before the request is sent. + ElicitationHandler = async (req, ct) => + { + Assert.Fail("Elicitation handler should not be called for unsupported schema test."); + return new ElicitResult { Action = "cancel" }; + }, + } + }); + + var ex = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("TestElicitationNullablePropertyForm", cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task Elicit_Typed_With_NonObject_Generic_Type_Throws() + { + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new() + { + // Should not be invoked + ElicitationHandler = async (req, ct) => + { + Assert.Fail("Elicitation handler should not be called for non-object generic type test."); + return new ElicitResult { Action = "cancel" }; + }, + } + }); + + var ex = await Assert.ThrowsAsync(async () => + await client.CallToolAsync("TestElicitationNonObjectGenericType", cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains(typeof(string).FullName!, ex.Message); + } + + [JsonConverter(typeof(CustomizableJsonStringEnumConverter))] + + public enum SampleRole + { + User, + Admin, + } + + public sealed class SampleForm + { + public required string Name { get; set; } + public int Age { get; set; } + public bool Active { get; set; } + public SampleRole Role { get; set; } + public double Score { get; set; } + + + public DateTime Created { get; set; } + } + + public sealed class CamelForm + { + public required string FirstName { get; set; } + public int ZipCode { get; set; } + public bool IsAdmin { get; set; } + } + + public sealed class NullablePropertyForm + { + public string? FirstName { get; set; } + public int ZipCode { get; set; } + public bool IsAdmin { get; set; } + } + + [JsonSerializable(typeof(SampleForm))] + [JsonSerializable(typeof(SampleRole))] + [JsonSerializable(typeof(JsonElement))] + internal partial class ElicitationTypedDefaultJsonContext : JsonSerializerContext; + + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(CamelForm))] + [JsonSerializable(typeof(JsonElement))] + internal partial class ElicitationTypedCamelJsonContext : JsonSerializerContext; + + + [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] + [JsonSerializable(typeof(NullablePropertyForm))] + [JsonSerializable(typeof(JsonElement))] + internal partial class ElicitationNullablePropertyJsonContext : JsonSerializerContext; + + public sealed class UnsupportedForm + { + public string? Name { get; set; } + public Nested? NestedProperty { get; set; } // Triggers unsupported (complex object) + public sealed class Nested + { + public string? Value { get; set; } + } + } + + [JsonSerializable(typeof(UnsupportedForm))] + [JsonSerializable(typeof(UnsupportedForm.Nested))] + [JsonSerializable(typeof(JsonElement))] + internal partial class ElicitationUnsupportedJsonContext : JsonSerializerContext; +} diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 0d18667e9..25470650e 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -13,7 +13,7 @@ public NotificationHandlerTests(ITestOutputHelper testOutputHelper) public async Task RegistrationsAreRemovedWhenDisposed() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int Iterations = 10; @@ -40,7 +40,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() public async Task MultipleRegistrationsResultInMultipleCallbacks() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -80,7 +80,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() public async Task MultipleHandlersRunEvenIfOneThrows() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -122,7 +122,7 @@ public async Task MultipleHandlersRunEvenIfOneThrows() public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -163,7 +163,7 @@ public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int nu public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs new file mode 100644 index 000000000..5569f993c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs @@ -0,0 +1,195 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Server; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpServerExtensionsTests +{ + [Fact] + public async Task SampleAsync_Request_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( + new CreateMessageRequestParams { Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }, + TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SampleAsync' instead", ex.Message); + } + + [Fact] + public async Task SampleAsync_Messages_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( + [new ChatMessage(ChatRole.User, "hi")], cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SampleAsync' instead", ex.Message); + } + + [Fact] + public void AsSamplingChatClient_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(server.AsSamplingChatClient); + Assert.Contains("Prefer using 'McpServer.AsSamplingChatClient' instead", ex.Message); + } + + [Fact] + public void AsClientLoggerProvider_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(server.AsClientLoggerProvider); + Assert.Contains("Prefer using 'McpServer.AsClientLoggerProvider' instead", ex.Message); + } + + [Fact] + public async Task RequestRootsAsync_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.RequestRootsAsync( + new ListRootsRequestParams(), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.RequestRootsAsync' instead", ex.Message); + } + + [Fact] + public async Task ElicitAsync_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.ElicitAsync( + new ElicitRequestParams { Message = "hello" }, TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.ElicitAsync' instead", ex.Message); + } + + [Fact] + public async Task SampleAsync_Request_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new CreateMessageResult + { + Content = new TextContentBlock { Text = "resp" }, + Model = "test-model", + Role = Role.Assistant, + StopReason = "endTurn", + }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Sampling = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] + }, TestContext.Current.CancellationToken); + + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("resp", Assert.IsType(result.Content).Text); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SampleAsync_Messages_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new CreateMessageResult + { + Content = new TextContentBlock { Text = "resp" }, + Model = "test-model", + Role = Role.Assistant, + StopReason = "endTurn", + }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Sampling = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var chatResponse = await server.SampleAsync([new ChatMessage(ChatRole.User, "hi")], cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("test-model", chatResponse.ModelId); + var last = chatResponse.Messages.Last(); + Assert.Equal(ChatRole.Assistant, last.Role); + Assert.Equal("resp", last.Text); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task RequestRootsAsync_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new ListRootsResult { Roots = [new Root { Uri = "root://a" }] }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Roots = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.RequestRootsAsync(new ListRootsRequestParams(), TestContext.Current.CancellationToken); + + Assert.Equal("root://a", result.Roots[0].Uri); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ElicitAsync_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new ElicitResult { Action = "accept" }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Elicitation = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.ElicitAsync(new ElicitRequestParams { Message = "hi" }, TestContext.Current.CancellationToken); + + Assert.Equal("accept", result.Action); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs deleted file mode 100644 index 034a30bd7..000000000 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ /dev/null @@ -1,45 +0,0 @@ -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; - -namespace ModelContextProtocol.Tests.Server; - -public class McpServerFactoryTests : LoggedTest -{ - private readonly McpServerOptions _options; - - public McpServerFactoryTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - _options = new McpServerOptions - { - ProtocolVersion = "1.0", - InitializationTimeout = TimeSpan.FromSeconds(30) - }; - } - - [Fact] - public async Task Create_Should_Initialize_With_Valid_Parameters() - { - // Arrange & Act - await using var transport = new TestServerTransport(); - await using IMcpServer server = McpServerFactory.Create(transport, _options, LoggerFactory); - - // Assert - Assert.NotNull(server); - } - - [Fact] - public void Create_Throws_For_Null_ServerTransport() - { - // Arrange, Act & Assert - Assert.Throws("transport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); - } - - [Fact] - public async Task Create_Throws_For_Null_Options() - { - // Arrange, Act & Assert - await using var transport = new TestServerTransport(); - Assert.Throws("serverOptions", () => McpServerFactory.Create(transport, null!, LoggerFactory)); - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index b2e748730..10116e70e 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -25,7 +25,7 @@ public void CanCreateServerWithLoggingLevelHandler() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact] @@ -39,10 +39,10 @@ public void AddingLoggingLevelHandlerSetsLoggingCapability() var provider = services.BuildServiceProvider(); - var server = provider.GetRequiredService(); + var server = provider.GetRequiredService(); Assert.NotNull(server.ServerOptions.Capabilities?.Logging); - Assert.NotNull(server.ServerOptions.Capabilities.Logging.SetLoggingLevelHandler); + Assert.NotNull(server.ServerOptions.Handlers.SetLoggingLevelHandler); } [Fact] @@ -52,7 +52,7 @@ public void ServerWithoutCallingLoggingLevelHandlerDoesNotSetLoggingCapability() services.AddMcpServer() .WithStdioServerTransport(); var provider = services.BuildServiceProvider(); - var server = provider.GetRequiredService(); + var server = provider.GetRequiredService(); Assert.Null(server.ServerOptions.Capabilities?.Logging); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 39e9b72ff..41c26f405 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -1,11 +1,9 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Primitives; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; using System.ComponentModel; -using System.Diagnostics; using System.Reflection; using System.Runtime.InteropServices; using System.Text.Json; @@ -15,6 +13,16 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerPromptTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerPromptTests() { #if !NET @@ -33,11 +41,11 @@ public void Create_InvalidArgs_Throws() } [Fact] - public async Task SupportsIMcpServer() + public async Task SupportsMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerPrompt prompt = McpServerPrompt.Create((IMcpServer server) => + McpServerPrompt prompt = McpServerPrompt.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new ChatMessage(ChatRole.User, "Hello"); @@ -46,7 +54,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -63,7 +71,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestPrompt)); @@ -75,7 +83,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -86,11 +94,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -125,11 +133,11 @@ public async Task SupportsServiceFromDI() Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); await Assert.ThrowsAnyAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -150,7 +158,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -163,7 +171,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -176,7 +184,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() _ => new AsyncDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -189,7 +197,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable _ => new AsyncDisposableAndDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:0, asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -205,7 +213,7 @@ public async Task CanReturnGetPromptResult() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(expected, actual); @@ -222,7 +230,7 @@ public async Task CanReturnText() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -248,7 +256,7 @@ public async Task CanReturnPromptMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -260,7 +268,7 @@ public async Task CanReturnPromptMessage() [Fact] public async Task CanReturnPromptMessages() { - IList expected = + IList expected = [ new() { @@ -280,7 +288,7 @@ public async Task CanReturnPromptMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -307,7 +315,7 @@ public async Task CanReturnChatMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -339,7 +347,7 @@ public async Task CanReturnChatMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -360,7 +368,7 @@ public async Task ThrowsForNullReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } @@ -373,7 +381,7 @@ public async Task ThrowsForUnexpectedTypeReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 011c4f2b6..f7f2a7742 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; @@ -11,6 +11,16 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerResourceTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerResourceTests() { #if !NET @@ -50,7 +60,7 @@ public void CanCreateServerWithResource() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } @@ -86,7 +96,7 @@ public void CanCreateServerWithResourceTemplates() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact] @@ -109,7 +119,7 @@ public void CreatingReadHandlerWithNoListHandlerSucceeds() }); var sp = services.BuildServiceProvider(); - sp.GetRequiredService(); + sp.GetRequiredService(); } [Fact] @@ -133,20 +143,20 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() McpServerResource t; ReadResourceResult? result; - IMcpServer server = new Mock().Object; + McpServer server = new Mock().Object; t = McpServerResource.Create(() => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); - t = McpServerResource.Create((IMcpServer server) => "42", new() { Name = Name }); + t = McpServerResource.Create((McpServer server) => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -154,7 +164,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((string arg1) => arg1, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?arg1}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("wOrLd", ((TextResourceContents)result.Contents[0]).Text); @@ -162,7 +172,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((string arg1, string? arg2 = null) => arg1 + arg2, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?arg1,arg2}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("world", ((TextResourceContents)result.Contents[0]).Text); @@ -170,7 +180,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((object a1, bool a2, char a3, byte a4, sbyte a5) => a1.ToString() + a2 + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("hiTrues1234", ((TextResourceContents)result.Contents[0]).Text); @@ -178,7 +188,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((ushort a1, short a2, uint a3, int a4, ulong a5) => (a1 + a2 + a3 + a4 + (long)a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); @@ -186,7 +196,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((long a1, float a2, double a3, decimal a4, TimeSpan a5) => a5.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); @@ -194,7 +204,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((DateTime a1, DateTimeOffset a2, Uri a3, Guid a4, Version a5) => a4.ToString("N") + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e411.2.3.4", ((TextResourceContents)result.Contents[0]).Text); @@ -203,7 +213,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((Half a2, Int128 a3, UInt128 a4, IntPtr a5) => (a3 + (Int128)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); @@ -211,7 +221,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((UIntPtr a1, DateOnly a2, TimeOnly a3) => a1.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); @@ -220,7 +230,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((bool? a2, char? a3, byte? a4, sbyte? a5) => a2?.ToString() + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("Trues1234", ((TextResourceContents)result.Contents[0]).Text); @@ -228,7 +238,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((ushort? a1, short? a2, uint? a3, int? a4, ulong? a5) => (a1 + a2 + a3 + a4 + (long?)a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); @@ -236,7 +246,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((long? a1, float? a2, double? a3, decimal? a4, TimeSpan? a5) => a5?.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); @@ -244,7 +254,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((DateTime? a1, DateTimeOffset? a2, Guid? a4) => a4?.ToString("N"), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a4}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e41", ((TextResourceContents)result.Contents[0]).Text); @@ -253,7 +263,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((Half? a2, Int128? a3, UInt128? a4, IntPtr? a5) => (a3 + (Int128?)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); @@ -261,7 +271,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((UIntPtr? a1, DateOnly? a2, TimeOnly? a3) => a1?.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); @@ -277,7 +287,7 @@ public async Task UriTemplate_NonMatchingUri_ReturnsNull(string uri) McpServerResource t = McpServerResource.Create((string arg1) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); Assert.Null(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -288,7 +298,7 @@ public async Task UriTemplate_IsHostCaseInsensitive(string actualUri, string que { McpServerResource t = McpServerResource.Create(() => "resource", new() { UriTemplate = actualUri }); Assert.NotNull(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = queriedUri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = queriedUri } }, TestContext.Current.CancellationToken)); } @@ -317,7 +327,7 @@ public async Task UriTemplate_MissingParameter_Throws(string uri) McpServerResource t = McpServerResource.Create((string arg1, int arg2) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); await Assert.ThrowsAsync(async () => await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -330,43 +340,43 @@ public async Task UriTemplate_MissingOptionalParameter_Succeeds() ReadResourceResult? result; result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first42", ((TextResourceContents)result.Contents[0]).Text); } [Fact] - public async Task SupportsIMcpServer() + public async Task SupportsMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -381,7 +391,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestResource)); @@ -393,7 +403,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await tool.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "https://something" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "https://something" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Contents); @@ -404,11 +414,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -467,14 +477,14 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime McpServerResource resource = services.GetRequiredService(); - Mock mockServer = new(); + Mock mockServer = new(); await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken)); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -496,7 +506,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services, Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -512,7 +522,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposableResourceType()); var result = await resource1.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("0", ((TextResourceContents)result.Contents[0]).Text); @@ -523,14 +533,14 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() [Fact] public async Task CanReturnReadResult() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new ReadResourceResult { Contents = new List { new TextResourceContents { Text = "hello" } } }; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -540,14 +550,14 @@ public async Task CanReturnReadResult() [Fact] public async Task CanReturnResourceContents() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new TextResourceContents { Text = "hello" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -557,8 +567,8 @@ public async Task CanReturnResourceContents() [Fact] public async Task CanReturnCollectionOfResourceContents() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (IList) @@ -568,7 +578,7 @@ public async Task CanReturnCollectionOfResourceContents() ]; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -579,14 +589,14 @@ public async Task CanReturnCollectionOfResourceContents() [Fact] public async Task CanReturnString() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -596,14 +606,14 @@ public async Task CanReturnString() [Fact] public async Task CanReturnCollectionOfStrings() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { "42", "43" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -614,14 +624,14 @@ public async Task CanReturnCollectionOfStrings() [Fact] public async Task CanReturnDataContent() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new DataContent(new byte[] { 0, 1, 2 }, "application/octet-stream"); }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -632,8 +642,8 @@ public async Task CanReturnDataContent() [Fact] public async Task CanReturnCollectionOfAIContent() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List @@ -643,7 +653,7 @@ public async Task CanReturnCollectionOfAIContent() }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 6750b2cad..40461d415 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -32,12 +32,39 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = }; } + [Fact] + public async Task Create_Should_Initialize_With_Valid_Parameters() + { + // Arrange & Act + await using var transport = new TestServerTransport(); + await using McpServer server = McpServer.Create(transport, _options, LoggerFactory); + + // Assert + Assert.NotNull(server); + Assert.Null(server.NegotiatedProtocolVersion); + } + + [Fact] + public void Create_Throws_For_Null_ServerTransport() + { + // Arrange, Act & Assert + Assert.Throws("transport", () => McpServer.Create(null!, _options, LoggerFactory)); + } + + [Fact] + public async Task Create_Throws_For_Null_Options() + { + // Arrange, Act & Assert + await using var transport = new TestServerTransport(); + Assert.Throws("serverOptions", () => McpServer.Create(transport, null!, LoggerFactory)); + } + [Fact] public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -47,7 +74,7 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory)); + Assert.Throws(() => McpServer.Create(null!, _options, LoggerFactory)); } [Fact] @@ -55,7 +82,7 @@ public async Task Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert await using var transport = new TestServerTransport(); - Assert.Throws(() => McpServerFactory.Create(transport, null!, LoggerFactory)); + Assert.Throws(() => McpServer.Create(transport, null!, LoggerFactory)); } [Fact] @@ -63,7 +90,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, null); + await using var server = McpServer.Create(transport, _options, null); // Assert Assert.NotNull(server); @@ -74,7 +101,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, null); + await using var server = McpServer.Create(transport, _options, LoggerFactory, null); // Assert Assert.NotNull(server); @@ -85,7 +112,7 @@ public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Run { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act & Assert @@ -100,7 +127,7 @@ public async Task SampleAsync_Should_Throw_Exception_If_Client_Does_Not_Support_ { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); var action = async () => await server.SampleAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -114,7 +141,7 @@ public async Task SampleAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -136,7 +163,7 @@ public async Task RequestRootsAsync_Should_Throw_Exception_If_Client_Does_Not_Su { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -148,7 +175,7 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -170,7 +197,7 @@ public async Task ElicitAsync_Should_Throw_Exception_If_Client_Does_Not_Support_ { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -182,7 +209,7 @@ public async Task ElicitAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Elicitation = new ElicitationCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -206,7 +233,7 @@ await Can_Handle_Requests( serverCapabilities: null, method: RequestMethods.Ping, configureOptions: null, - assertResult: response => + assertResult: (_, response) => { JsonObject jObj = Assert.IsType(response); Assert.Empty(jObj); @@ -216,18 +243,19 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Initialize_Requests() { - AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(IMcpServer).Assembly).GetName(); + AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(McpServer).Assembly).GetName(); await Can_Handle_Requests( serverCapabilities: null, method: RequestMethods.Initialize, configureOptions: null, - assertResult: response => + assertResult: (server, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result); Assert.Equal(expectedAssemblyName.Name, result.ServerInfo.Name); Assert.Equal(expectedAssemblyName.Version?.ToString() ?? "1.0.0", result.ServerInfo.Version); Assert.Equal("2024", result.ProtocolVersion); + Assert.Equal("2024", server.NegotiatedProtocolVersion); }); } @@ -235,25 +263,25 @@ await Can_Handle_Requests( public async Task Can_Handle_Completion_Requests() { await Can_Handle_Requests( - new() + new ServerCapabilities { Completions = new() - { - CompleteHandler = async (request, ct) => - new CompleteResult + }, + method: RequestMethods.CompletionComplete, + configureOptions: options => + { + options.Handlers.CompleteHandler = async (request, ct) => + new CompleteResult + { + Completion = new() { - Completion = new() - { - Values = ["test"], - Total = 2, - HasMore = true - } + Values = ["test"], + Total = 2, + HasMore = true } - } + }; }, - method: RequestMethods.CompletionComplete, - configureOptions: null, - assertResult: response => + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result?.Completion); @@ -270,27 +298,27 @@ await Can_Handle_Requests( new ServerCapabilities { Resources = new() + }, + RequestMethods.ResourcesTemplatesList, + configureOptions: options => + { + options.Handlers.ListResourceTemplatesHandler = async (request, ct) => { - ListResourceTemplatesHandler = async (request, ct) => + return new ListResourceTemplatesResult { - return new ListResourceTemplatesResult - { - ResourceTemplates = [new() { UriTemplate = "test", Name = "Test Resource" }] - }; - }, - ListResourcesHandler = async (request, ct) => + ResourceTemplates = [new() { UriTemplate = "test", Name = "Test Resource" }] + }; + }; + options.Handlers.ListResourcesHandler = async (request, ct) => + { + return new ListResourcesResult { - return new ListResourcesResult - { - Resources = [new() { Uri = "test", Name = "Test Resource" }] - }; - }, - ReadResourceHandler = (request, ct) => throw new NotImplementedException(), - } + Resources = [new() { Uri = "test", Name = "Test Resource" }] + }; + }; + options.Handlers.ReadResourceHandler = (request, ct) => throw new NotImplementedException(); }, - RequestMethods.ResourcesTemplatesList, - configureOptions: null, - assertResult: response => + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result?.ResourceTemplates); @@ -306,20 +334,20 @@ await Can_Handle_Requests( new ServerCapabilities { Resources = new() + }, + RequestMethods.ResourcesList, + configureOptions: options => + { + options.Handlers.ListResourcesHandler = async (request, ct) => { - ListResourcesHandler = async (request, ct) => + return new ListResourcesResult { - return new ListResourcesResult - { - Resources = [new() { Uri = "test", Name = "Test Resource" }] - }; - }, - ReadResourceHandler = (request, ct) => throw new NotImplementedException(), - } + Resources = [new() { Uri = "test", Name = "Test Resource" }] + }; + }; + options.Handlers.ReadResourceHandler = (request, ct) => throw new NotImplementedException(); }, - RequestMethods.ResourcesList, - configureOptions: null, - assertResult: response => + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result?.Resources); @@ -341,20 +369,20 @@ await Can_Handle_Requests( new ServerCapabilities { Resources = new() - { - ReadResourceHandler = async (request, ct) => - { - return new ReadResourceResult - { - Contents = [new TextResourceContents { Text = "test" }] - }; - }, - ListResourcesHandler = (request, ct) => throw new NotImplementedException(), - } }, method: RequestMethods.ResourcesRead, - configureOptions: null, - assertResult: response => + configureOptions: options => + { + options.Handlers.ReadResourceHandler = async (request, ct) => + { + return new ReadResourceResult + { + Contents = [new TextResourceContents { Text = "test" }] + }; + }; + options.Handlers.ListResourcesHandler = (request, ct) => throw new NotImplementedException(); + }, + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result?.Contents); @@ -378,20 +406,20 @@ await Can_Handle_Requests( new ServerCapabilities { Prompts = new() + }, + method: RequestMethods.PromptsList, + configureOptions: options => + { + options.Handlers.ListPromptsHandler = async (request, ct) => { - ListPromptsHandler = async (request, ct) => + return new ListPromptsResult { - return new ListPromptsResult - { - Prompts = [new() { Name = "test" }] - }; - }, - GetPromptHandler = (request, ct) => throw new NotImplementedException(), - }, + Prompts = [new() { Name = "test" }] + }; + }; + options.Handlers.GetPromptHandler = (request, ct) => throw new NotImplementedException(); }, - method: RequestMethods.PromptsList, - configureOptions: null, - assertResult: response => + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result?.Prompts); @@ -413,14 +441,14 @@ await Can_Handle_Requests( new ServerCapabilities { Prompts = new() - { - GetPromptHandler = async (request, ct) => new GetPromptResult { Description = "test" }, - ListPromptsHandler = (request, ct) => throw new NotImplementedException(), - } }, method: RequestMethods.PromptsGet, - configureOptions: null, - assertResult: response => + configureOptions: options => + { + options.Handlers.GetPromptHandler = async (request, ct) => new GetPromptResult { Description = "test" }; + options.Handlers.ListPromptsHandler = (request, ct) => throw new NotImplementedException(); + }, + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result); @@ -441,20 +469,20 @@ await Can_Handle_Requests( new ServerCapabilities { Tools = new() + }, + method: RequestMethods.ToolsList, + configureOptions: options => + { + options.Handlers.ListToolsHandler = async (request, ct) => { - ListToolsHandler = async (request, ct) => + return new ListToolsResult { - return new ListToolsResult - { - Tools = [new() { Name = "test" }] - }; - }, - CallToolHandler = (request, ct) => throw new NotImplementedException(), - } + Tools = [new() { Name = "test" }] + }; + }; + options.Handlers.CallToolHandler = (request, ct) => throw new NotImplementedException(); }, - method: RequestMethods.ToolsList, - configureOptions: null, - assertResult: response => + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result); @@ -476,20 +504,20 @@ await Can_Handle_Requests( new ServerCapabilities { Tools = new() - { - CallToolHandler = async (request, ct) => - { - return new CallToolResult - { - Content = [new TextContentBlock { Text = "test" }] - }; - }, - ListToolsHandler = (request, ct) => throw new NotImplementedException(), - } }, method: RequestMethods.ToolsCall, - configureOptions: null, - assertResult: response => + configureOptions: options => + { + options.Handlers.CallToolHandler = async (request, ct) => + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = "test" }] + }; + }; + options.Handlers.ListToolsHandler = (request, ct) => throw new NotImplementedException(); + }, + assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result); @@ -504,13 +532,13 @@ public async Task Can_Handle_Call_Tool_Requests_Throws_Exception_If_No_Handler_A await Succeeds_Even_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, RequestMethods.ToolsCall, "CallTool handler not configured"); } - private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action? configureOptions, Action assertResult) + private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action? configureOptions, Action assertResult) { await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = McpServerFactory.Create(transport, options, LoggerFactory); + await using var server = McpServer.Create(transport, options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -533,7 +561,7 @@ await transport.SendMessageAsync( var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(5)); Assert.NotNull(response); - assertResult(response.Result); + assertResult(server, response.Result); await transport.DisposeAsync(); await runTask; @@ -544,7 +572,7 @@ private async Task Succeeds_Even_If_No_Handler_Assigned(ServerCapabilities serve await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - var server = McpServerFactory.Create(transport, options, LoggerFactory); + var server = McpServer.Create(transport, options, LoggerFactory); await server.DisposeAsync(); } @@ -589,7 +617,7 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() public async Task Can_SendMessage_Before_RunAsync() { await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); var logNotification = new JsonRpcNotification { @@ -605,22 +633,22 @@ public async Task Can_SendMessage_Before_RunAsync() Assert.Same(logNotification, transport.SentMessages[0]); } - private static void SetClientCapabilities(IMcpServer server, ClientCapabilities capabilities) + private static void SetClientCapabilities(McpServer server, ClientCapabilities capabilities) { - PropertyInfo? property = server.GetType().GetProperty("ClientCapabilities", BindingFlags.Public | BindingFlags.Instance); - Assert.NotNull(property); - property.SetValue(server, capabilities); + FieldInfo? field = server.GetType().GetField("_clientCapabilities", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(field); + field.SetValue(server, capabilities); } - private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServer + private sealed class TestServerForIChatClient(bool supportsSampling) : McpServer { - public ClientCapabilities? ClientCapabilities => + public override ClientCapabilities? ClientCapabilities => supportsSampling ? new ClientCapabilities { Sampling = new SamplingCapability() } : null; - public McpServerOptions ServerOptions => new(); + public override McpServerOptions ServerOptions => new(); - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { CreateMessageRequestParams? rp = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions); @@ -653,17 +681,18 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati }); } - public ValueTask DisposeAsync() => default; + public override ValueTask DisposeAsync() => default; - public string? SessionId => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); - public IServiceProvider? Services => throw new NotImplementedException(); - public LoggingLevel? LoggingLevel => throw new NotImplementedException(); - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => + public override string? SessionId => throw new NotImplementedException(); + public override string? NegotiatedProtocolVersion => throw new NotImplementedException(); + public override Implementation? ClientInfo => throw new NotImplementedException(); + public override IServiceProvider? Services => throw new NotImplementedException(); + public override LoggingLevel? LoggingLevel => throw new NotImplementedException(); + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task RunAsync(CancellationToken cancellationToken = default) => + public override Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => throw new NotImplementedException(); } @@ -674,16 +703,14 @@ public async Task NotifyProgress_Should_Be_Handled() var options = CreateOptions(); var notificationReceived = new TaskCompletionSource(); - options.Capabilities = new() - { - NotificationHandlers = [new(NotificationMethods.ProgressNotification, (notification, cancellationToken) => + options.Handlers.NotificationHandlers = + [new(NotificationMethods.ProgressNotification, (notification, cancellationToken) => { notificationReceived.TrySetResult(notification); return default; - })], - }; + })]; - var server = McpServerFactory.Create(transport, options, LoggerFactory); + var server = McpServer.Create(transport, options, LoggerFactory); Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index bd0ca5ef9..b9463e18f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,10 +1,8 @@ -using Json.Schema; +using Json.Schema; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; using System.Runtime.InteropServices; @@ -18,6 +16,16 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerToolTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerToolTests() { #if !NET @@ -40,11 +48,11 @@ public void Create_InvalidArgs_Throws() } [Fact] - public async Task SupportsIMcpServer() + public async Task SupportsMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -53,7 +61,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -67,7 +75,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestTool)); @@ -79,7 +87,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Content); @@ -90,11 +98,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -154,15 +162,16 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Assert.DoesNotContain("actualMyService", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); - Mock mockServer = new(); + Mock mockServer = new(); - var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), - TestContext.Current.CancellationToken); - Assert.True(result.IsError); + var ex = await Assert.ThrowsAsync(async () => await tool.InvokeAsync( + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), + TestContext.Current.CancellationToken)); - result = await tool.InvokeAsync( - new RequestContext(mockServer.Object) { Services = services }, + mockServer.SetupGet(s => s.Services).Returns(services); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -183,7 +192,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -198,7 +207,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"disposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -213,7 +222,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -232,7 +241,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -241,8 +250,8 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable [Fact] public async Task CanReturnCollectionOfAIContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { @@ -253,7 +262,7 @@ public async Task CanReturnCollectionOfAIContent() }, new() { SerializerOptions = JsonContext2.Default.Options }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal(3, result.Content.Count); @@ -273,8 +282,8 @@ public async Task CanReturnCollectionOfAIContent() [InlineData("data:audio/wav;base64,1234", "audio")] public async Task CanReturnSingleAIContent(string data, string type) { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return type switch @@ -287,7 +296,7 @@ public async Task CanReturnSingleAIContent(string data, string type) }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); @@ -316,14 +325,14 @@ public async Task CanReturnSingleAIContent(string data, string type) [Fact] public async Task CanReturnNullAIContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (string?)null; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Empty(result.Content); } @@ -331,14 +340,14 @@ public async Task CanReturnNullAIContent() [Fact] public async Task CanReturnString() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -347,31 +356,30 @@ public async Task CanReturnString() [Fact] public async Task CanReturnCollectionOfStrings() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { "42", "43" }; }, new() { SerializerOptions = JsonContext2.Default.Options }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); - Assert.Equal(2, result.Content.Count); - Assert.Equal("42", Assert.IsType(result.Content[0]).Text); - Assert.Equal("43", Assert.IsType(result.Content[1]).Text); + Assert.Single(result.Content); + Assert.Equal("""["42","43"]""", Assert.IsType(result.Content[0]).Text); } [Fact] public async Task CanReturnMcpContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new TextContentBlock { Text = "42" }; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -381,18 +389,18 @@ public async Task CanReturnMcpContent() [Fact] public async Task CanReturnCollectionOfMcpContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (IList) [ - new TextContentBlock { Text = "42" }, - new ImageContentBlock { Data = "1234", MimeType = "image/png" } + new TextContentBlock { Text = "42" }, + new ImageContentBlock { Data = "1234", MimeType = "image/png" } ]; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal(2, result.Content.Count); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -408,14 +416,14 @@ public async Task CanReturnCallToolResult() Content = new List { new TextContentBlock { Text = "text" }, new ImageContentBlock { Data = "1234", MimeType = "image/png" } } }; - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return response; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(response, result); @@ -428,7 +436,7 @@ public async Task CanReturnCallToolResult() [Fact] public async Task SupportsSchemaCreateOptions() { - AIJsonSchemaCreateOptions schemaCreateOptions = new () + AIJsonSchemaCreateOptions schemaCreateOptions = new() { TransformSchemaNode = (context, node) => { @@ -448,53 +456,14 @@ public async Task SupportsSchemaCreateOptions() ); } - [Fact] - public async Task ToolCallError_LogsErrorMessage() - { - // Arrange - var mockLoggerProvider = new MockLoggerProvider(); - var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider }); - var services = new ServiceCollection(); - services.AddSingleton(loggerFactory); - var serviceProvider = services.BuildServiceProvider(); - - var toolName = "tool-that-throws"; - var exceptionMessage = "Test exception message"; - - McpServerTool tool = McpServerTool.Create(() => - { - throw new InvalidOperationException(exceptionMessage); - }, new() { Name = toolName, Services = serviceProvider }); - - var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) - { - Params = new CallToolRequestParams { Name = toolName }, - Services = serviceProvider - }; - - // Act - var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken); - - // Assert - Assert.True(result.IsError); - Assert.Single(result.Content); - Assert.Equal($"An error occurred invoking '{toolName}'.", Assert.IsType(result.Content[0]).Text); - - var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); - Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message); - Assert.IsType(errorLog.Exception); - Assert.Equal(exceptionMessage, errorLog.Exception.Message); - } - [Theory] [MemberData(nameof(StructuredOutput_ReturnsExpectedSchema_Inputs))] public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) { JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { Name = "tool", UseStructuredContent = true, SerializerOptions = options }); - var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var mockServer = new Mock(); + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -511,8 +480,8 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSchema() { McpServerTool tool = McpServerTool.Create(() => { }); - var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var mockServer = new Mock(); + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -523,7 +492,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(result.StructuredContent); tool = McpServerTool.Create(() => Task.CompletedTask); - request = new RequestContext(mockServer.Object) + request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -534,7 +503,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(result.StructuredContent); tool = McpServerTool.Create(() => default(ValueTask)); - request = new RequestContext(mockServer.Object) + request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -551,8 +520,8 @@ public async Task StructuredOutput_Disabled_ReturnsExpectedSchema(T value) { JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { UseStructuredContent = false, SerializerOptions = options }); - var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var mockServer = new Mock(); + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -593,7 +562,7 @@ public static IEnumerable StructuredOutput_ReturnsExpectedSchema_Input yield return new object[] { new() }; yield return new object[] { new List { "item1", "item2" } }; yield return new object[] { new Dictionary { ["key1"] = 1, ["key2"] = 2 } }; - yield return new object[] { new Person("John", 27) }; + yield return new object[] { new Person("John", 27) }; } private sealed class MyService; diff --git a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs index f3927be62..d14c376c1 100644 --- a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs @@ -35,7 +35,7 @@ public async Task SigInt_DisposesTestServerWithHosting_Gracefully() process.StandardInput.BaseStream, serverName: "TestServerWithHosting"); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new TestClientTransport(streamServerTransport), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs similarity index 89% rename from tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs rename to tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs index 8f6fbff2c..768ebf7ea 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs @@ -4,12 +4,12 @@ namespace ModelContextProtocol.Tests.Transport; -public class SseClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +public class HttpClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { [Fact] public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() { - var options = new SseClientTransportOptions + var options = new HttpClientTransportOptions { Endpoint = new Uri("http://localhost"), TransportMode = HttpTransportMode.AutoDetect, @@ -18,7 +18,7 @@ public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(options, httpClient, LoggerFactory); // Simulate successful Streamable HTTP response for initialize mockHttpHandler.RequestHandler = (request) => @@ -50,7 +50,7 @@ public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() [Fact] public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() { - var options = new SseClientTransportOptions + var options = new HttpClientTransportOptions { Endpoint = new Uri("http://localhost"), TransportMode = HttpTransportMode.AutoDetect, @@ -59,7 +59,7 @@ public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(options, httpClient, LoggerFactory); var requestCount = 0; diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportTests.cs similarity index 86% rename from tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs rename to tests/ModelContextProtocol.Tests/Transport/HttpClientTransportTests.cs index 3ff504304..fc1ac2d88 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportTests.cs @@ -5,14 +5,14 @@ namespace ModelContextProtocol.Tests.Transport; -public class SseClientTransportTests : LoggedTest +public class HttpClientTransportTests : LoggedTest { - private readonly SseClientTransportOptions _transportOptions; + private readonly HttpClientTransportOptions _transportOptions; - public SseClientTransportTests(ITestOutputHelper testOutputHelper) + public HttpClientTransportTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _transportOptions = new SseClientTransportOptions + _transportOptions = new HttpClientTransportOptions { Endpoint = new Uri("http://localhost:8080"), ConnectionTimeout = TimeSpan.FromSeconds(2), @@ -28,14 +28,14 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper) [Fact] public void Constructor_Throws_For_Null_Options() { - var exception = Assert.Throws(() => new SseClientTransport(null!, LoggerFactory)); + var exception = Assert.Throws(() => new HttpClientTransport(null!, LoggerFactory)); Assert.Equal("transportOptions", exception.ParamName); } [Fact] public void Constructor_Throws_For_Null_HttpClient() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); + var exception = Assert.Throws(() => new HttpClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); } @@ -44,7 +44,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); bool firstCall = true; @@ -68,7 +68,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); var retries = 0; mockHttpHandler.RequestHandler = (request) => @@ -87,7 +87,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -125,7 +125,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); var callIndex = 0; mockHttpHandler.RequestHandler = (request) => @@ -165,7 +165,7 @@ public async Task DisposeAsync_Should_Dispose_Resources() }); }; - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); await session.DisposeAsync(); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 93cbcec82..5394ba30e 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; using System.Runtime.InteropServices; using System.Text; @@ -8,21 +9,21 @@ namespace ModelContextProtocol.Tests.Transport; public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime; - + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() { string id = Guid.NewGuid().ToString("N"); StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }, LoggerFactory) : - new(new() { Command = "ls", Arguments = [id] }, LoggerFactory); + new(new() { Command = "cmd", Arguments = ["/c", $"echo {id} >&2 & exit /b 1"] }, LoggerFactory) : + new(new() { Command = "sh", Arguments = ["-c", $"echo {id} >&2; exit 1"] }, LoggerFactory); - IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); - Assert.Contains(id, e.ToString()); + await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } - - [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] + + // [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { string id = Guid.NewGuid().ToString("N"); @@ -40,12 +41,92 @@ public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() }; StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"], StandardErrorLines = stdErrCallback }, LoggerFactory) : - new(new() { Command = "ls", Arguments = [id], StandardErrorLines = stdErrCallback }, LoggerFactory); + new(new() { Command = "cmd", Arguments = ["/c", $"echo {id} >&2 & exit /b 1"], StandardErrorLines = stdErrCallback }, LoggerFactory) : + new(new() { Command = "sh", Arguments = ["-c", $"echo {id} >&2; exit 1"], StandardErrorLines = stdErrCallback }, LoggerFactory); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.InRange(count, 1, int.MaxValue); Assert.Contains(id, sb.ToString()); } + + [Theory] + [InlineData(null)] + [InlineData("argument with spaces")] + [InlineData("&")] + [InlineData("|")] + [InlineData(">")] + [InlineData("<")] + [InlineData("^")] + [InlineData(" & ")] + [InlineData(" | ")] + [InlineData(" > ")] + [InlineData(" < ")] + [InlineData(" ^ ")] + [InlineData("& ")] + [InlineData("| ")] + [InlineData("> ")] + [InlineData("< ")] + [InlineData("^ ")] + [InlineData(" &")] + [InlineData(" |")] + [InlineData(" >")] + [InlineData(" <")] + [InlineData(" ^")] + [InlineData("^&<>|")] + [InlineData("^&<>| ")] + [InlineData(" ^&<>|")] + [InlineData("\t^&<>")] + [InlineData("^&\t<>")] + [InlineData("ls /tmp | grep foo.txt > /dev/null")] + [InlineData("let rec Y f x = f (Y f) x")] + [InlineData("value with \"quotes\" and spaces")] + [InlineData("C:\\Program Files\\Test App\\app.dll")] + [InlineData("C:\\EndsWithBackslash\\")] + [InlineData("--already-looks-like-flag")] + [InlineData("-starts-with-dash")] + [InlineData("name=value=another")] + [InlineData("$(echo injected)")] + [InlineData("value-with-\"quotes\"-and-\\backslashes\\")] + [InlineData("http://localhost:1234/callback?foo=1&bar=2")] + public async Task EscapesCliArgumentsCorrectly(string? cliArgumentValue) + { + if (PlatformDetection.IsMonoRuntime && cliArgumentValue?.EndsWith("\\") is true) + { + Assert.Skip("mono runtime does not handle arguments ending with backslash correctly."); + } + + string cliArgument = $"--cli-arg={cliArgumentValue}"; + + StdioClientTransportOptions options = new() + { + Name = "TestServer", + Command = (PlatformDetection.IsMonoRuntime, PlatformDetection.IsWindows) switch + { + (true, _) => "mono", + (_, true) => "TestServer.exe", + _ => "dotnet", + }, + Arguments = (PlatformDetection.IsMonoRuntime, PlatformDetection.IsWindows) switch + { + (true, _) => ["TestServer.exe", cliArgument], + (_, true) => [cliArgument], + _ => ["TestServer.dll", cliArgument], + }, + }; + + var transport = new StdioClientTransport(options, LoggerFactory); + + // Act: Create client (handshake) and list tools to ensure full round trip works with the argument present. + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(tools); + Assert.NotEmpty(tools); + + var result = await client.CallToolAsync("echoCliArg", cancellationToken: TestContext.Current.CancellationToken); + var content = Assert.IsType(Assert.Single(result.Content)); + Assert.Equal(cliArgumentValue ?? "", content.Text); + } }