diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cde3103b6..66bf206ec 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,12 +1,11 @@ from datetime import timedelta from typing import Any, Protocol -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl, TypeAdapter import mcp.types as types from mcp.shared.context import RequestContext -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -59,8 +58,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb96..0f3039b55 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -6,10 +6,16 @@ import anyio import httpx from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared.session import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -31,11 +37,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -84,8 +90,11 @@ async def sse_reader( case "message": try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data + message = MessageFrame( + message=types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data + ), + raw=sse, ) logger.debug( f"Received server message: {message}" diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 3e73b0204..f2107d6ba 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,7 +1,7 @@ import json import logging from contextlib import asynccontextmanager -from typing import AsyncGenerator +from typing import Any, AsyncGenerator import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -10,6 +10,7 @@ from websockets.typing import Subprotocol import mcp.types as types +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -19,8 +20,8 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - MemoryObjectSendStream[types.JSONRPCMessage], + MemoryObjectReceiveStream[MessageFrame[Any] | Exception], + MemoryObjectSendStream[MessageFrame[Any]], ], None, ]: @@ -53,7 +54,11 @@ async def ws_reader(): async with read_stream_writer: async for raw_text in ws: try: - message = types.JSONRPCMessage.model_validate_json(raw_text) + json_message = types.JSONRPCMessage.model_validate_json( + raw_text + ) + # Create MessageFrame with JSON message as root + message = MessageFrame(message=json_message, raw=raw_text) await read_stream_writer.send(message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception @@ -66,8 +71,8 @@ async def ws_writer(): """ async with write_stream_reader: async for message in write_stream_reader: - # Convert to a dict, then to JSON - msg_dict = message.model_dump( + # Extract the JSON-RPC message from MessageFrame and convert to JSON + msg_dict = message.message.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 817d1918a..7ceb103e5 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -74,7 +74,6 @@ async def main(): from typing import Any, AsyncIterator, Generic, TypeVar import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types @@ -84,7 +83,7 @@ async def main(): from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.session import RequestResponder +from mcp.shared.session import ReadStream, RequestResponder, WriteStream logger = logging.getLogger(__name__) @@ -474,8 +473,8 @@ async def handler(req: types.CompleteRequest): async def run( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 3b5abba78..58a2db1df 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -5,9 +5,7 @@ from pydantic import BaseModel -from mcp.types import ( - ServerCapabilities, -) +from mcp.types import ServerCapabilities class InitializationOptions(BaseModel): diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 788bb9f83..c22dcf871 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -42,14 +42,15 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types from mcp.server.models import InitializationOptions from mcp.shared.session import ( BaseSession, + ReadStream, RequestResponder, + WriteStream, ) @@ -76,8 +77,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, init_options: InitializationOptions, ) -> None: super().__init__( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 0127753d0..1e8696858 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -38,7 +38,6 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -46,6 +45,13 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.shared.session import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -63,9 +69,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[ - UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] - ] + _read_stream_writers: dict[UUID, ReadStreamWriter] def __init__(self, endpoint: str) -> None: """ @@ -85,11 +89,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -172,4 +176,4 @@ async def handle_post_message( logger.debug(f"Sending message to writer: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(message) + await writer.send(MessageFrame(message=message, raw=request)) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 0e0e49129..91819a7de 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -24,9 +24,15 @@ async def run_server(): import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types +from mcp.shared.session import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) +from mcp.types import MessageFrame @asynccontextmanager @@ -47,11 +53,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -66,7 +72,9 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + await read_stream_writer.send( + MessageFrame(message=message, raw=line) + ) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() @@ -74,6 +82,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for message in write_stream_reader: + # Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame json = message.model_dump_json(by_alias=True, exclude_none=True) await stdout.write(json + "\n") await stdout.flush() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index bd3d632ee..2da93634c 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -2,11 +2,17 @@ from contextlib import asynccontextmanager import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket import mcp.types as types +from mcp.shared.session import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -21,11 +27,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -40,7 +46,9 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(client_message) + await read_stream_writer.send( + MessageFrame(message=client_message, raw=message) + ) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index ae6b0be53..762ff28a4 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,11 +11,11 @@ from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server -from mcp.types import JSONRPCMessage +from mcp.types import MessageFrame MessageStream = tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[MessageFrame | Exception], + MemoryObjectSendStream[MessageFrame], ] @@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + MessageFrame | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + MessageFrame | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) @@ -60,12 +60,9 @@ async def create_connected_server_and_client_session( ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( - client_streams, - server_streams, + (client_read, client_write), + (server_read, server_write), ): - client_read, client_write = client_streams - server_read, server_write = server_streams - # Create a cancel scope for the server task async with anyio.create_task_group() as tg: tg.start_soon( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 31f888246..7dd6fefc1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -22,12 +22,18 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + MessageFrame, RequestParams, ServerNotification, ServerRequest, ServerResult, ) +ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception] +ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception] +WriteStream = MemoryObjectSendStream[MessageFrame] +WriteStreamReader = MemoryObjectReceiveStream[MessageFrame] + SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) @@ -165,8 +171,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out @@ -242,7 +248,9 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + await self._write_stream.send( + MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None) + ) try: with anyio.fail_after( @@ -278,14 +286,18 @@ async def send_notification(self, notification: SendNotificationT) -> None: **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + await self._write_stream.send( + MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None) + ) async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) + await self._write_stream.send( + MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None) + ) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -294,7 +306,9 @@ async def _send_response( by_alias=True, mode="json", exclude_none=True ), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + await self._write_stream.send( + MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None) + ) async def _receive_loop(self) -> None: async with ( @@ -302,10 +316,13 @@ async def _receive_loop(self) -> None: self._write_stream, self._incoming_message_stream_writer, ): - async for message in self._read_stream: - if isinstance(message, Exception): - await self._incoming_message_stream_writer.send(message) - elif isinstance(message.root, JSONRPCRequest): + async for raw_message in self._read_stream: + if isinstance(raw_message, Exception): + await self._incoming_message_stream_writer.send(raw_message) + continue + + message = raw_message.message + if isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( message.root.model_dump( by_alias=True, mode="json", exclude_none=True diff --git a/src/mcp/types.py b/src/mcp/types.py index 7d867bd3b..38384dea8 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -180,6 +180,49 @@ class JSONRPCMessage( pass +RawT = TypeVar("RawT") + + +class MessageFrame(BaseModel, Generic[RawT]): + """ + A wrapper around the general message received that contains both the parsed message + and the raw message. + + This class serves as an encapsulation for JSON-RPC messages, providing access to + both the parsed structure (root) and the original raw data. This design is + particularly useful for Server-Sent Events (SSE) consumers who may need to access + additional metadata or headers associated with the message. + + The 'root' attribute contains the parsed JSONRPCMessage, which could be a request, + notification, response, or error. The 'raw' attribute preserves the original + message as received, allowing access to any additional context or metadata that + might be lost in parsing. + + This dual representation allows for flexible handling of messages, where consumers + can work with the structured data for standard operations, but still have the + option to examine or utilize the raw data when needed, such as for debugging, + logging, or accessing transport-specific information. + """ + + message: JSONRPCMessage + raw: RawT | None = None + model_config = ConfigDict(extra="allow") + + def model_dump(self, *args, **kwargs): + """ + Dumps the model to a dictionary, delegating to the root JSONRPCMessage. + This method allows for consistent serialization of the parsed message. + """ + return self.message.model_dump(*args, **kwargs) + + def model_dump_json(self, *args, **kwargs): + """ + Dumps the model to a JSON string, delegating to the root JSONRPCMessage. + This method provides a convenient way to serialize the parsed message to JSON. + """ + return self.message.model_dump_json(*args, **kwargs) + + class EmptyResult(Result): """A response that indicates success but carries no data.""" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 7d579cdac..27f02abf7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,3 +1,5 @@ +from types import NoneType + import anyio import pytest @@ -11,9 +13,9 @@ InitializeRequest, InitializeResult, JSONRPCMessage, - JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + MessageFrame, ServerCapabilities, ServerResult, ) @@ -22,10 +24,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + MessageFrame[NoneType] ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + MessageFrame[NoneType] ](1) initialized_notification = None @@ -34,7 +36,7 @@ async def mock_server(): nonlocal initialized_notification jsonrpc_request = await client_to_server_receive.receive() - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, MessageFrame) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -56,21 +58,25 @@ async def mock_server(): ) async with server_to_client_send: + assert isinstance(jsonrpc_request.message.root, JSONRPCRequest) await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), - ) + MessageFrame( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.message.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ), + raw=None, ) ) jsonrpc_notification = await client_to_server_receive.receive() - assert isinstance(jsonrpc_notification.root, JSONRPCNotification) + assert isinstance(jsonrpc_notification.message, JSONRPCMessage) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.model_dump( + jsonrpc_notification.message.model_dump( by_alias=True, mode="json", exclude_none=True ) ) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 00e187895..fd05c7737 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -11,6 +11,7 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + MessageFrame, NotificationParams, ) @@ -64,7 +65,9 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=init_req)) + await client_writer.send( + MessageFrame(message=JSONRPCMessage(root=init_req), raw=None) + ) await server_reader.receive() # Get init response but don't need to check it # Send initialized notification @@ -73,21 +76,27 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=initialized_notification)) + await client_writer.send( + MessageFrame( + message=JSONRPCMessage(root=initialized_notification), raw=None + ) + ) # Send ping request with custom ID ping_request = JSONRPCRequest( id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send(JSONRPCMessage(root=ping_request)) + await client_writer.send( + MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None) + ) # Read response response = await server_reader.receive() # Verify response ID matches request ID assert ( - response.root.id == custom_request_id + response.message.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 37a52969a..18d9a4c5b 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -17,6 +17,7 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + MessageFrame, ) @@ -64,7 +65,7 @@ async def run_server(): send_stream2, InitializationOptions( server_name="test", - server_version="0.1.0", + server_version="1.0.0", capabilities=server.get_capabilities( notification_options=NotificationOptions(), experimental_capabilities={}, @@ -82,42 +83,51 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ), + raw=None, ) ) response = await receive_stream2.receive() # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ), + raw=None, ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ), + raw=None, ) ) # Get response and verify response = await receive_stream2.receive() - assert response.root.result["content"][0]["text"] == "true" + assert response.message.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() @@ -178,42 +188,51 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ), + raw=None, ) ) response = await receive_stream2.receive() # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ), + raw=None, ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ), + raw=None, ) ) # Get response and verify response = await receive_stream2.receive() - assert response.root.result["content"][0]["text"] == "true" + assert response.message.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 333196c96..a28fda7fa 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -9,7 +9,7 @@ from mcp.types import ( ClientNotification, InitializedNotification, - JSONRPCMessage, + MessageFrame, PromptsCapability, ResourcesCapability, ServerCapabilities, @@ -19,10 +19,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + MessageFrame[None] ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + MessageFrame[None] ](1) async def run_client(client: ClientSession): diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 85c5bf219..c12c26373 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,7 +4,7 @@ import pytest from mcp.server.stdio import stdio_server -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame @pytest.mark.anyio @@ -13,8 +13,8 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=2, result={}), ] for message in messages: @@ -35,17 +35,29 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - ) - assert received_messages[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) - ) + assert isinstance(received_messages[0].message, JSONRPCMessage) + assert isinstance(received_messages[0].message.root, JSONRPCRequest) + assert received_messages[0].message.root.id == 1 + assert received_messages[0].message.root.method == "ping" + + assert isinstance(received_messages[1].message, JSONRPCMessage) + assert isinstance(received_messages[1].message.root, JSONRPCResponse) + assert received_messages[1].message.root.id == 2 # Test sending responses from the server responses = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})), + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + ), + raw=None, + ), + MessageFrame( + message=JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + ), + raw=None, + ), ] async with write_stream: @@ -56,13 +68,10 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - received_responses = [ - JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines - ] - assert len(received_responses) == 2 - assert received_responses[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") - ) - assert received_responses[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) - ) + # Parse and verify the JSON responses directly + request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip()) + response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip()) + + assert request_json.id == 3 + assert request_json.method == "ping" + assert response_json.id == 4