From 8ea8bf58256aef5f67029ed4b3800d51c7d3e90d Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 25 Feb 2025 18:42:33 +0000 Subject: [PATCH 1/9] refactor: improve typing with memory stream type aliases Move memory stream type definitions to models.py and use them throughout the codebase for better type safety and maintainability. GitHub-Issue:#201 --- src/mcp/server/lowlevel/server.py | 7 +++---- src/mcp/server/models.py | 10 +++++++--- src/mcp/server/session.py | 7 +++---- src/mcp/server/sse.py | 19 +++++++++++-------- src/mcp/server/stdio.py | 15 ++++++++++----- src/mcp/server/websocket.py | 15 ++++++++++----- 6 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 817d1918a..8b759ef6b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -74,12 +74,11 @@ async def main(): from typing import Any, AsyncIterator, Generic, TypeVar import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.models import InitializationOptions +from mcp.server.models import InitializationOptions, ReadStream, WriteStream from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext @@ -474,8 +473,8 @@ async def handler(req: types.CompleteRequest): async def run( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 3b5abba78..ffa15fdfb 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -3,11 +3,15 @@ and tools. """ +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel -from mcp.types import ( - ServerCapabilities, -) +from mcp.types import JSONRPCMessage, ServerCapabilities + +ReadStream = MemoryObjectReceiveStream[JSONRPCMessage | Exception] +ReadStreamWriter = MemoryObjectSendStream[JSONRPCMessage | Exception] +WriteStream = MemoryObjectSendStream[JSONRPCMessage] +WriteStreamReader = MemoryObjectReceiveStream[JSONRPCMessage] class InitializationOptions(BaseModel): diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 788bb9f83..f7e8812dc 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -42,11 +42,10 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types -from mcp.server.models import InitializationOptions +from mcp.server.models import InitializationOptions, ReadStream, WriteStream from mcp.shared.session import ( BaseSession, RequestResponder, @@ -76,8 +75,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, init_options: InitializationOptions, ) -> None: super().__init__( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 0127753d0..9da9a2ffd 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -38,7 +38,6 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -46,6 +45,12 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.server.models import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) logger = logging.getLogger(__name__) @@ -63,9 +68,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[ - UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] - ] + _read_stream_writers: dict[UUID, ReadStreamWriter] def __init__(self, endpoint: str) -> None: """ @@ -85,11 +88,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 0e0e49129..d140ffb8d 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -24,9 +24,14 @@ async def run_server(): import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types +from mcp.server.models import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) @asynccontextmanager @@ -47,11 +52,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index bd3d632ee..51379e8c6 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -2,11 +2,16 @@ from contextlib import asynccontextmanager import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket import mcp.types as types +from mcp.server.models import ( + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) logger = logging.getLogger(__name__) @@ -21,11 +26,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) From 3af37d61d0c84757912150453403b1446b48c4a8 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 26 Feb 2025 18:50:34 +0000 Subject: [PATCH 2/9] refactor: move streams to ParsedMessage --- src/mcp/client/session.py | 7 ++--- src/mcp/client/sse.py | 23 ++++++++++----- src/mcp/server/lowlevel/server.py | 4 +-- src/mcp/server/models.py | 8 +----- src/mcp/server/session.py | 4 ++- src/mcp/server/sse.py | 5 ++-- src/mcp/server/stdio.py | 5 ++-- src/mcp/server/websocket.py | 7 +++-- src/mcp/shared/memory.py | 17 +++++------ src/mcp/shared/session.py | 47 +++++++++++++++++++++++-------- 10 files changed, 79 insertions(+), 48 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cde3103b6..66bf206ec 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,12 +1,11 @@ from datetime import timedelta from typing import Any, Protocol -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl, TypeAdapter import mcp.types as types from mcp.shared.context import RequestContext -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -59,8 +58,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb96..a42f69a43 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -6,10 +6,16 @@ import anyio import httpx from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared.session import ( + ParsedMessage, + ReadStream, + ReadStreamWriter, + WriteStream, + WriteStreamReader, +) logger = logging.getLogger(__name__) @@ -31,11 +37,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: ReadStream + read_stream_writer: ReadStreamWriter - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: WriteStream + write_stream_reader: WriteStreamReader read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -84,8 +90,11 @@ async def sse_reader( case "message": try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data + message = ParsedMessage( + types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data + ), + raw=sse, ) logger.debug( f"Received server message: {message}" diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 8b759ef6b..7ceb103e5 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -78,12 +78,12 @@ async def main(): import mcp.types as types from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.models import InitializationOptions, ReadStream, WriteStream +from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.session import RequestResponder +from mcp.shared.session import ReadStream, RequestResponder, WriteStream logger = logging.getLogger(__name__) diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index ffa15fdfb..58a2db1df 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -3,15 +3,9 @@ and tools. """ -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel -from mcp.types import JSONRPCMessage, ServerCapabilities - -ReadStream = MemoryObjectReceiveStream[JSONRPCMessage | Exception] -ReadStreamWriter = MemoryObjectSendStream[JSONRPCMessage | Exception] -WriteStream = MemoryObjectSendStream[JSONRPCMessage] -WriteStreamReader = MemoryObjectReceiveStream[JSONRPCMessage] +from mcp.types import ServerCapabilities class InitializationOptions(BaseModel): diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f7e8812dc..c22dcf871 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -45,10 +45,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from pydantic import AnyUrl import mcp.types as types -from mcp.server.models import InitializationOptions, ReadStream, WriteStream +from mcp.server.models import InitializationOptions from mcp.shared.session import ( BaseSession, + ReadStream, RequestResponder, + WriteStream, ) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9da9a2ffd..b5cedc2c8 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -45,7 +45,8 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types -from mcp.server.models import ( +from mcp.shared.session import ( + ParsedMessage, ReadStream, ReadStreamWriter, WriteStream, @@ -175,4 +176,4 @@ async def handle_post_message( logger.debug(f"Sending message to writer: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(message) + await writer.send(ParsedMessage(message, raw=request)) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index d140ffb8d..cf2fd0cbe 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -26,7 +26,8 @@ async def run_server(): import anyio.lowlevel import mcp.types as types -from mcp.server.models import ( +from mcp.shared.session import ( + ParsedMessage, ReadStream, ReadStreamWriter, WriteStream, @@ -71,7 +72,7 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + await read_stream_writer.send(ParsedMessage(message, raw=line)) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 51379e8c6..58e22adc0 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -6,7 +6,8 @@ from starlette.websockets import WebSocket import mcp.types as types -from mcp.server.models import ( +from mcp.shared.session import ( + ParsedMessage, ReadStream, ReadStreamWriter, WriteStream, @@ -45,7 +46,9 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(client_message) + await read_stream_writer.send( + ParsedMessage(client_message, raw=message) + ) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index ae6b0be53..24358c9cf 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,11 +11,11 @@ from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server -from mcp.types import JSONRPCMessage +from mcp.shared.session import ParsedMessage MessageStream = tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[ParsedMessage | Exception], + MemoryObjectSendStream[ParsedMessage], ] @@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + ParsedMessage | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + ParsedMessage | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) @@ -60,12 +60,9 @@ async def create_connected_server_and_client_session( ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( - client_streams, - server_streams, + (client_read, client_write), + (server_read, server_write), ): - client_read, client_write = client_streams - server_read, server_write = server_streams - # Create a cancel scope for the server task async with anyio.create_task_group() as tg: tg.start_soon( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 31f888246..a4f0ada61 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -7,7 +7,7 @@ import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel +from pydantic import BaseModel, RootModel from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -28,6 +28,22 @@ ServerResult, ) +RawT = TypeVar("RawT") + + +class ParsedMessage(RootModel[JSONRPCMessage], Generic[RawT]): + root: JSONRPCMessage + raw: RawT | None = None + + class Config: + arbitrary_types_allowed = True + + +ReadStream = MemoryObjectReceiveStream[ParsedMessage[RawT] | Exception] +ReadStreamWriter = MemoryObjectSendStream[ParsedMessage[RawT] | Exception] +WriteStream = MemoryObjectSendStream[ParsedMessage[RawT]] +WriteStreamReader = MemoryObjectReceiveStream[ParsedMessage[RawT]] + SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) @@ -165,8 +181,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: ReadStream, + write_stream: WriteStream, receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out @@ -242,7 +258,9 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + await self._write_stream.send( + ParsedMessage(JSONRPCMessage(jsonrpc_request), None) + ) try: with anyio.fail_after( @@ -278,14 +296,16 @@ async def send_notification(self, notification: SendNotificationT) -> None: **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + await self._write_stream.send( + ParsedMessage(JSONRPCMessage(jsonrpc_notification)) + ) async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) + await self._write_stream.send(ParsedMessage(JSONRPCMessage(jsonrpc_error))) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -294,7 +314,9 @@ async def _send_response( by_alias=True, mode="json", exclude_none=True ), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + await self._write_stream.send( + ParsedMessage(JSONRPCMessage(jsonrpc_response)) + ) async def _receive_loop(self) -> None: async with ( @@ -302,10 +324,13 @@ async def _receive_loop(self) -> None: self._write_stream, self._incoming_message_stream_writer, ): - async for message in self._read_stream: - if isinstance(message, Exception): - await self._incoming_message_stream_writer.send(message) - elif isinstance(message.root, JSONRPCRequest): + async for raw_message in self._read_stream: + if isinstance(raw_message, Exception): + await self._incoming_message_stream_writer.send(raw_message) + continue + + message = raw_message.root + if isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( message.root.model_dump( by_alias=True, mode="json", exclude_none=True From 5be7a25fe0c52619438b782a2213303dd6a7cf71 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 3 Mar 2025 14:14:23 +0000 Subject: [PATCH 3/9] refactor: update test files to use ParsedMessage Updates test files to work with the ParsedMessage stream type aliases and fixes a line length issue in test_201_client_hangs_on_logging.py. Github-Issue:#201 --- tests/client/test_session.py | 29 ++++++++++++++------------ tests/server/test_lifespan.py | 4 ++-- tests/server/test_session.py | 6 +++--- tests/server/test_stdio.py | 39 +++++++++++++++++++++++------------ 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 7d579cdac..5dbeb5bb2 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,6 +2,7 @@ import pytest from mcp.client.session import ClientSession +from mcp.shared.session import ParsedMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientNotification, @@ -11,7 +12,6 @@ InitializeRequest, InitializeResult, JSONRPCMessage, - JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, ServerCapabilities, @@ -22,10 +22,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + ParsedMessage[None] ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + ParsedMessage[None] ](1) initialized_notification = None @@ -57,20 +57,23 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), - ) + ParsedMessage( + root=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ), + raw=None, ) ) jsonrpc_notification = await client_to_server_receive.receive() - assert isinstance(jsonrpc_notification.root, JSONRPCNotification) + assert isinstance(jsonrpc_notification.root, ParsedMessage) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.model_dump( + jsonrpc_notification.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 37a52969a..a434f6030 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -117,7 +117,7 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() - assert response.root.result["content"][0]["text"] == "true" + assert response.root.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() @@ -213,7 +213,7 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() - assert response.root.result["content"][0]["text"] == "true" + assert response.root.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 333196c96..84a090f1a 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -6,10 +6,10 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.session import ParsedMessage from mcp.types import ( ClientNotification, InitializedNotification, - JSONRPCMessage, PromptsCapability, ResourcesCapability, ServerCapabilities, @@ -19,10 +19,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + ParsedMessage[None] ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + ParsedMessage[None] ](1) async def run_client(client: ClientSession): diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 85c5bf219..64d4ab4d6 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,6 +4,7 @@ import pytest from mcp.server.stdio import stdio_server +from mcp.shared.session import ParsedMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -13,8 +14,12 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + ParsedMessage( + root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) + ), + ParsedMessage( + root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) + ), ] for message in messages: @@ -35,17 +40,25 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + assert received_messages[0] == ParsedMessage( + root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) ) - assert received_messages[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) + assert received_messages[1] == ParsedMessage( + root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) ) # Test sending responses from the server responses = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})), + ParsedMessage( + root=JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + ) + ), + ParsedMessage( + root=JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + ) + ), ] async with write_stream: @@ -57,12 +70,12 @@ async def test_stdio_server(): assert len(output_lines) == 2 received_responses = [ - JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines + ParsedMessage.model_validate_json(line.strip()) for line in output_lines ] assert len(received_responses) == 2 - assert received_responses[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + assert received_responses[0] == ParsedMessage( + root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) ) - assert received_responses[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + assert received_responses[1] == ParsedMessage( + root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) ) From 1d98c5c43ff49a9c031238966bf866a995690d3f Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 3 Mar 2025 14:19:53 +0000 Subject: [PATCH 4/9] refactor: rename ParsedMessage to MessageFrame for clarity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/mcp/client/sse.py | 4 ++-- src/mcp/server/sse.py | 4 ++-- src/mcp/server/stdio.py | 4 ++-- src/mcp/server/websocket.py | 4 ++-- src/mcp/shared/memory.py | 10 +++++----- src/mcp/shared/session.py | 18 +++++++++--------- tests/client/test_session.py | 10 +++++----- tests/server/test_session.py | 6 +++--- tests/server/test_stdio.py | 20 ++++++++++---------- 9 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index a42f69a43..a500badb4 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,7 +10,7 @@ import mcp.types as types from mcp.shared.session import ( - ParsedMessage, + MessageFrame, ReadStream, ReadStreamWriter, WriteStream, @@ -90,7 +90,7 @@ async def sse_reader( case "message": try: - message = ParsedMessage( + message = MessageFrame( types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ), diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index b5cedc2c8..d99a2e8da 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -46,7 +46,7 @@ async def handle_sse(request): import mcp.types as types from mcp.shared.session import ( - ParsedMessage, + MessageFrame, ReadStream, ReadStreamWriter, WriteStream, @@ -176,4 +176,4 @@ async def handle_post_message( logger.debug(f"Sending message to writer: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(ParsedMessage(message, raw=request)) + await writer.send(MessageFrame(message, raw=request)) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index cf2fd0cbe..7db1ed150 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -27,7 +27,7 @@ async def run_server(): import mcp.types as types from mcp.shared.session import ( - ParsedMessage, + MessageFrame, ReadStream, ReadStreamWriter, WriteStream, @@ -72,7 +72,7 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(ParsedMessage(message, raw=line)) + await read_stream_writer.send(MessageFrame(message, raw=line)) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 58e22adc0..3c4fedc3e 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -7,7 +7,7 @@ import mcp.types as types from mcp.shared.session import ( - ParsedMessage, + MessageFrame, ReadStream, ReadStreamWriter, WriteStream, @@ -47,7 +47,7 @@ async def ws_reader(): continue await read_stream_writer.send( - ParsedMessage(client_message, raw=message) + MessageFrame(client_message, raw=message) ) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 24358c9cf..c619a227a 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,11 +11,11 @@ from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server -from mcp.shared.session import ParsedMessage +from mcp.shared.session import MessageFrame MessageStream = tuple[ - MemoryObjectReceiveStream[ParsedMessage | Exception], - MemoryObjectSendStream[ParsedMessage], + MemoryObjectReceiveStream[MessageFrame | Exception], + MemoryObjectSendStream[MessageFrame], ] @@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - ParsedMessage | Exception + MessageFrame | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - ParsedMessage | Exception + MessageFrame | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index a4f0ada61..7c98e9934 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -31,7 +31,7 @@ RawT = TypeVar("RawT") -class ParsedMessage(RootModel[JSONRPCMessage], Generic[RawT]): +class MessageFrame(RootModel[JSONRPCMessage], Generic[RawT]): root: JSONRPCMessage raw: RawT | None = None @@ -39,10 +39,10 @@ class Config: arbitrary_types_allowed = True -ReadStream = MemoryObjectReceiveStream[ParsedMessage[RawT] | Exception] -ReadStreamWriter = MemoryObjectSendStream[ParsedMessage[RawT] | Exception] -WriteStream = MemoryObjectSendStream[ParsedMessage[RawT]] -WriteStreamReader = MemoryObjectReceiveStream[ParsedMessage[RawT]] +ReadStream = MemoryObjectReceiveStream[MessageFrame[RawT] | Exception] +ReadStreamWriter = MemoryObjectSendStream[MessageFrame[RawT] | Exception] +WriteStream = MemoryObjectSendStream[MessageFrame[RawT]] +WriteStreamReader = MemoryObjectReceiveStream[MessageFrame[RawT]] SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) @@ -259,7 +259,7 @@ async def send_request( # TODO: Support progress callbacks await self._write_stream.send( - ParsedMessage(JSONRPCMessage(jsonrpc_request), None) + MessageFrame(JSONRPCMessage(jsonrpc_request), None) ) try: @@ -297,7 +297,7 @@ async def send_notification(self, notification: SendNotificationT) -> None: ) await self._write_stream.send( - ParsedMessage(JSONRPCMessage(jsonrpc_notification)) + MessageFrame(JSONRPCMessage(jsonrpc_notification)) ) async def _send_response( @@ -305,7 +305,7 @@ async def _send_response( ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(ParsedMessage(JSONRPCMessage(jsonrpc_error))) + await self._write_stream.send(MessageFrame(JSONRPCMessage(jsonrpc_error))) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -315,7 +315,7 @@ async def _send_response( ), ) await self._write_stream.send( - ParsedMessage(JSONRPCMessage(jsonrpc_response)) + MessageFrame(JSONRPCMessage(jsonrpc_response)) ) async def _receive_loop(self) -> None: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 5dbeb5bb2..ec08baeca 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,7 +2,7 @@ import pytest from mcp.client.session import ClientSession -from mcp.shared.session import ParsedMessage +from mcp.shared.session import MessageFrame from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientNotification, @@ -22,10 +22,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - ParsedMessage[None] + MessageFrame[None] ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - ParsedMessage[None] + MessageFrame[None] ](1) initialized_notification = None @@ -57,7 +57,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - ParsedMessage( + MessageFrame( root=JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", @@ -71,7 +71,7 @@ async def mock_server(): ) ) jsonrpc_notification = await client_to_server_receive.receive() - assert isinstance(jsonrpc_notification.root, ParsedMessage) + assert isinstance(jsonrpc_notification.root, MessageFrame) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.root.model_dump( by_alias=True, mode="json", exclude_none=True diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 84a090f1a..526fcfe9d 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -6,7 +6,7 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.session import ParsedMessage +from mcp.shared.session import MessageFrame from mcp.types import ( ClientNotification, InitializedNotification, @@ -19,10 +19,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - ParsedMessage[None] + MessageFrame[None] ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - ParsedMessage[None] + MessageFrame[None] ](1) async def run_client(client: ClientSession): diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 64d4ab4d6..f64a03c0e 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,7 +4,7 @@ import pytest from mcp.server.stdio import stdio_server -from mcp.shared.session import ParsedMessage +from mcp.shared.session import MessageFrame from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -14,10 +14,10 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - ParsedMessage( + MessageFrame( root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) ), - ParsedMessage( + MessageFrame( root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) ), ] @@ -40,21 +40,21 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == ParsedMessage( + assert received_messages[0] == MessageFrame( root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) ) - assert received_messages[1] == ParsedMessage( + assert received_messages[1] == MessageFrame( root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) ) # Test sending responses from the server responses = [ - ParsedMessage( + MessageFrame( root=JSONRPCMessage( root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") ) ), - ParsedMessage( + MessageFrame( root=JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) ) @@ -70,12 +70,12 @@ async def test_stdio_server(): assert len(output_lines) == 2 received_responses = [ - ParsedMessage.model_validate_json(line.strip()) for line in output_lines + MessageFrame.model_validate_json(line.strip()) for line in output_lines ] assert len(received_responses) == 2 - assert received_responses[0] == ParsedMessage( + assert received_responses[0] == MessageFrame( root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) ) - assert received_responses[1] == ParsedMessage( + assert received_responses[1] == MessageFrame( root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) ) From 7a74c0e0f0739f1bc167d15b27e15cefadb9d0a3 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 3 Mar 2025 18:50:31 +0000 Subject: [PATCH 5/9] refactor: move MessageFrame class to types.py for better code organization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/mcp/client/sse.py | 4 +- src/mcp/server/sse.py | 4 +- src/mcp/server/stdio.py | 5 +- src/mcp/server/websocket.py | 4 +- src/mcp/shared/memory.py | 2 +- src/mcp/shared/session.py | 32 ++++------ src/mcp/types.py | 15 +++++ tests/client/test_session.py | 9 ++- tests/issues/test_192_request_id.py | 15 +++-- tests/server/test_lifespan.py | 97 +++++++++++++++++------------ tests/server/test_session.py | 2 +- tests/server/test_stdio.py | 48 +++++++------- 12 files changed, 133 insertions(+), 104 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index a500badb4..50b069b1d 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,12 +10,12 @@ import mcp.types as types from mcp.shared.session import ( - MessageFrame, ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader, ) +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ async def sse_reader( case "message": try: message = MessageFrame( - types.JSONRPCMessage.model_validate_json( # noqa: E501 + root=types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ), raw=sse, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d99a2e8da..105b2d072 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -46,12 +46,12 @@ async def handle_sse(request): import mcp.types as types from mcp.shared.session import ( - MessageFrame, ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader, ) +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -176,4 +176,4 @@ async def handle_post_message( logger.debug(f"Sending message to writer: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(MessageFrame(message, raw=request)) + await writer.send(MessageFrame(root=message, raw=request)) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 7db1ed150..8c357e7f3 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -27,12 +27,12 @@ async def run_server(): import mcp.types as types from mcp.shared.session import ( - MessageFrame, ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader, ) +from mcp.types import MessageFrame @asynccontextmanager @@ -72,7 +72,7 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(MessageFrame(message, raw=line)) + await read_stream_writer.send(MessageFrame(root=message, raw=line)) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() @@ -80,6 +80,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for message in write_stream_reader: + # Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame json = message.model_dump_json(by_alias=True, exclude_none=True) await stdout.write(json + "\n") await stdout.flush() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3c4fedc3e..fc78f09e3 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -7,12 +7,12 @@ import mcp.types as types from mcp.shared.session import ( - MessageFrame, ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader, ) +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ async def ws_reader(): continue await read_stream_writer.send( - MessageFrame(client_message, raw=message) + MessageFrame(root=client_message, raw=message) ) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index c619a227a..762ff28a4 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,7 +11,7 @@ from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server -from mcp.shared.session import MessageFrame +from mcp.types import MessageFrame MessageStream = tuple[ MemoryObjectReceiveStream[MessageFrame | Exception], diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 7c98e9934..217de38c8 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -7,7 +7,7 @@ import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel, RootModel +from pydantic import BaseModel from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -22,27 +22,17 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + MessageFrame, RequestParams, ServerNotification, ServerRequest, ServerResult, ) -RawT = TypeVar("RawT") - - -class MessageFrame(RootModel[JSONRPCMessage], Generic[RawT]): - root: JSONRPCMessage - raw: RawT | None = None - - class Config: - arbitrary_types_allowed = True - - -ReadStream = MemoryObjectReceiveStream[MessageFrame[RawT] | Exception] -ReadStreamWriter = MemoryObjectSendStream[MessageFrame[RawT] | Exception] -WriteStream = MemoryObjectSendStream[MessageFrame[RawT]] -WriteStreamReader = MemoryObjectReceiveStream[MessageFrame[RawT]] +ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception] +ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception] +WriteStream = MemoryObjectSendStream[MessageFrame] +WriteStreamReader = MemoryObjectReceiveStream[MessageFrame] SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) @@ -259,7 +249,7 @@ async def send_request( # TODO: Support progress callbacks await self._write_stream.send( - MessageFrame(JSONRPCMessage(jsonrpc_request), None) + MessageFrame(root=JSONRPCMessage(jsonrpc_request), raw=None) ) try: @@ -297,7 +287,7 @@ async def send_notification(self, notification: SendNotificationT) -> None: ) await self._write_stream.send( - MessageFrame(JSONRPCMessage(jsonrpc_notification)) + MessageFrame(root=JSONRPCMessage(jsonrpc_notification), raw=None) ) async def _send_response( @@ -305,7 +295,9 @@ async def _send_response( ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(MessageFrame(JSONRPCMessage(jsonrpc_error))) + await self._write_stream.send( + MessageFrame(root=JSONRPCMessage(jsonrpc_error), raw=None) + ) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -315,7 +307,7 @@ async def _send_response( ), ) await self._write_stream.send( - MessageFrame(JSONRPCMessage(jsonrpc_response)) + MessageFrame(root=JSONRPCMessage(jsonrpc_response), raw=None) ) async def _receive_loop(self) -> None: diff --git a/src/mcp/types.py b/src/mcp/types.py index 7d867bd3b..848764a67 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -180,6 +180,21 @@ class JSONRPCMessage( pass +RawT = TypeVar("RawT") + + +class MessageFrame(BaseModel, Generic[RawT]): + root: JSONRPCMessage + raw: RawT | None = None + model_config = ConfigDict(extra="allow") + + def model_dump(self, *args, **kwargs): + return self.root.model_dump(*args, **kwargs) + + def model_dump_json(self, *args, **kwargs): + return self.root.model_dump_json(*args, **kwargs) + + class EmptyResult(Result): """A response that indicates success but carries no data.""" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index ec08baeca..d8cf92127 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,7 +2,6 @@ import pytest from mcp.client.session import ClientSession -from mcp.shared.session import MessageFrame from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientNotification, @@ -12,8 +11,8 @@ InitializeRequest, InitializeResult, JSONRPCMessage, - JSONRPCRequest, JSONRPCResponse, + MessageFrame, ServerCapabilities, ServerResult, ) @@ -34,7 +33,7 @@ async def mock_server(): nonlocal initialized_notification jsonrpc_request = await client_to_server_receive.receive() - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, MessageFrame) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -61,7 +60,7 @@ async def mock_server(): root=JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", - id=jsonrpc_request.root.id, + id=jsonrpc_request.root.root.id, result=result.model_dump( by_alias=True, mode="json", exclude_none=True ), @@ -71,7 +70,7 @@ async def mock_server(): ) ) jsonrpc_notification = await client_to_server_receive.receive() - assert isinstance(jsonrpc_notification.root, MessageFrame) + assert isinstance(jsonrpc_notification.root, JSONRPCMessage) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.root.model_dump( by_alias=True, mode="json", exclude_none=True diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 00e187895..ac78dab4c 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -11,6 +11,7 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + MessageFrame, NotificationParams, ) @@ -64,7 +65,9 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=init_req)) + await client_writer.send( + MessageFrame(root=JSONRPCMessage(root=init_req), raw=None) + ) await server_reader.receive() # Get init response but don't need to check it # Send initialized notification @@ -73,21 +76,25 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=initialized_notification)) + await client_writer.send( + MessageFrame(root=JSONRPCMessage(root=initialized_notification), raw=None) + ) # Send ping request with custom ID ping_request = JSONRPCRequest( id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send(JSONRPCMessage(root=ping_request)) + await client_writer.send( + MessageFrame(root=JSONRPCMessage(root=ping_request), raw=None) + ) # Read response response = await server_reader.receive() # Verify response ID matches request ID assert ( - response.root.id == custom_request_id + response.root.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a434f6030..67677fae1 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -17,6 +17,7 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + MessageFrame, ) @@ -64,7 +65,7 @@ async def run_server(): send_stream2, InitializationOptions( server_name="test", - server_version="0.1.0", + server_version="1.0.0", capabilities=server.get_capabilities( notification_options=NotificationOptions(), experimental_capabilities={}, @@ -82,36 +83,45 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + MessageFrame( + root=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ), + raw=None, ) ) response = await receive_stream2.receive() # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) + MessageFrame( + root=JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ), + raw=None, ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + MessageFrame( + root=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ), + raw=None, ) ) @@ -178,36 +188,45 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + MessageFrame( + root=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ), + raw=None, ) ) response = await receive_stream2.receive() # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) + MessageFrame( + root=JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ), + raw=None, ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + MessageFrame( + root=JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ), + raw=None, ) ) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 526fcfe9d..a28fda7fa 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -6,10 +6,10 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.session import MessageFrame from mcp.types import ( ClientNotification, InitializedNotification, + MessageFrame, PromptsCapability, ResourcesCapability, ServerCapabilities, diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index f64a03c0e..df4f165bd 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,8 +4,7 @@ import pytest from mcp.server.stdio import stdio_server -from mcp.shared.session import MessageFrame -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame @pytest.mark.anyio @@ -14,12 +13,8 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - MessageFrame( - root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) - ), - MessageFrame( - root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) - ), + JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=2, result={}), ] for message in messages: @@ -40,24 +35,28 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == MessageFrame( - root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) - ) - assert received_messages[1] == MessageFrame( - root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) - ) + assert isinstance(received_messages[0].root, JSONRPCMessage) + assert isinstance(received_messages[0].root.root, JSONRPCRequest) + assert received_messages[0].root.root.id == 1 + assert received_messages[0].root.root.method == "ping" + + assert isinstance(received_messages[1].root, JSONRPCMessage) + assert isinstance(received_messages[1].root.root, JSONRPCResponse) + assert received_messages[1].root.root.id == 2 # Test sending responses from the server responses = [ MessageFrame( root=JSONRPCMessage( root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") - ) + ), + raw=None, ), MessageFrame( root=JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) - ) + ), + raw=None, ), ] @@ -69,13 +68,10 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - received_responses = [ - MessageFrame.model_validate_json(line.strip()) for line in output_lines - ] - assert len(received_responses) == 2 - assert received_responses[0] == MessageFrame( - root=JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) - ) - assert received_responses[1] == MessageFrame( - root=JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) - ) + # Parse and verify the JSON responses directly + request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip()) + response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip()) + + assert request_json.id == 3 + assert request_json.method == "ping" + assert response_json.id == 4 From 898b3a4f1aa0ee4df9fb2e8bb0378ffd83199f89 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 4 Mar 2025 11:57:09 +0000 Subject: [PATCH 6/9] fix pyright --- tests/client/test_session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index d8cf92127..f091d3a97 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -11,6 +11,7 @@ InitializeRequest, InitializeResult, JSONRPCMessage, + JSONRPCRequest, JSONRPCResponse, MessageFrame, ServerCapabilities, @@ -55,6 +56,7 @@ async def mock_server(): ) async with server_to_client_send: + assert isinstance(jsonrpc_request.root.root, JSONRPCRequest) await server_to_client_send.send( MessageFrame( root=JSONRPCMessage( From d3e3763ae827a3a3a69be8d573e8e039fad09590 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Mar 2025 13:32:15 +0000 Subject: [PATCH 7/9] refactor: update websocket client to use MessageFrame Modified the websocket client to work with the new MessageFrame type, preserving raw message text and properly extracting the root JSON-RPC message when sending. Github-Issue:#204 --- src/mcp/client/websocket.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 3e73b0204..0e257674f 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,7 +1,7 @@ import json import logging from contextlib import asynccontextmanager -from typing import AsyncGenerator +from typing import Any, AsyncGenerator import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -10,6 +10,7 @@ from websockets.typing import Subprotocol import mcp.types as types +from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -19,8 +20,8 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - MemoryObjectSendStream[types.JSONRPCMessage], + MemoryObjectReceiveStream[MessageFrame[Any] | Exception], + MemoryObjectSendStream[MessageFrame[Any]], ], None, ]: @@ -53,7 +54,11 @@ async def ws_reader(): async with read_stream_writer: async for raw_text in ws: try: - message = types.JSONRPCMessage.model_validate_json(raw_text) + json_message = types.JSONRPCMessage.model_validate_json( + raw_text + ) + # Create MessageFrame with JSON message as root + message = MessageFrame(root=json_message, raw=raw_text) await read_stream_writer.send(message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception @@ -66,8 +71,8 @@ async def ws_writer(): """ async with write_stream_reader: async for message in write_stream_reader: - # Convert to a dict, then to JSON - msg_dict = message.model_dump( + # Extract the JSON-RPC message from MessageFrame and convert to JSON + msg_dict = message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) From 4a341a67efa719b381b3a4de8dde4e10b243a5a3 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Mar 2025 13:35:30 +0000 Subject: [PATCH 8/9] fix: use NoneType instead of None for type parameters in MessageFrame MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) --- tests/client/test_session.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f091d3a97..2ee7351d4 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,3 +1,5 @@ +from types import NoneType + import anyio import pytest @@ -22,10 +24,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - MessageFrame[None] + MessageFrame[NoneType] ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - MessageFrame[None] + MessageFrame[NoneType] ](1) initialized_notification = None From 1c53fc208beaa8fa13bccfaa41d2d56908563239 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Mar 2025 13:39:58 +0000 Subject: [PATCH 9/9] refactor: rename root to message --- src/mcp/client/sse.py | 2 +- src/mcp/client/websocket.py | 4 ++-- src/mcp/server/sse.py | 2 +- src/mcp/server/stdio.py | 4 +++- src/mcp/server/websocket.py | 2 +- src/mcp/shared/session.py | 10 ++++----- src/mcp/types.py | 34 ++++++++++++++++++++++++++--- tests/client/test_session.py | 10 ++++----- tests/issues/test_192_request_id.py | 10 +++++---- tests/server/test_lifespan.py | 16 +++++++------- tests/server/test_stdio.py | 18 +++++++-------- 11 files changed, 72 insertions(+), 40 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 50b069b1d..0f3039b55 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -91,7 +91,7 @@ async def sse_reader( case "message": try: message = MessageFrame( - root=types.JSONRPCMessage.model_validate_json( # noqa: E501 + message=types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ), raw=sse, diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 0e257674f..f2107d6ba 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -58,7 +58,7 @@ async def ws_reader(): raw_text ) # Create MessageFrame with JSON message as root - message = MessageFrame(root=json_message, raw=raw_text) + message = MessageFrame(message=json_message, raw=raw_text) await read_stream_writer.send(message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception @@ -72,7 +72,7 @@ async def ws_writer(): async with write_stream_reader: async for message in write_stream_reader: # Extract the JSON-RPC message from MessageFrame and convert to JSON - msg_dict = message.root.model_dump( + msg_dict = message.message.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 105b2d072..1e8696858 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -176,4 +176,4 @@ async def handle_post_message( logger.debug(f"Sending message to writer: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(MessageFrame(root=message, raw=request)) + await writer.send(MessageFrame(message=message, raw=request)) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 8c357e7f3..91819a7de 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -72,7 +72,9 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(MessageFrame(root=message, raw=line)) + await read_stream_writer.send( + MessageFrame(message=message, raw=line) + ) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index fc78f09e3..2da93634c 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -47,7 +47,7 @@ async def ws_reader(): continue await read_stream_writer.send( - MessageFrame(root=client_message, raw=message) + MessageFrame(message=client_message, raw=message) ) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 217de38c8..7dd6fefc1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -249,7 +249,7 @@ async def send_request( # TODO: Support progress callbacks await self._write_stream.send( - MessageFrame(root=JSONRPCMessage(jsonrpc_request), raw=None) + MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None) ) try: @@ -287,7 +287,7 @@ async def send_notification(self, notification: SendNotificationT) -> None: ) await self._write_stream.send( - MessageFrame(root=JSONRPCMessage(jsonrpc_notification), raw=None) + MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None) ) async def _send_response( @@ -296,7 +296,7 @@ async def _send_response( if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) await self._write_stream.send( - MessageFrame(root=JSONRPCMessage(jsonrpc_error), raw=None) + MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None) ) else: jsonrpc_response = JSONRPCResponse( @@ -307,7 +307,7 @@ async def _send_response( ), ) await self._write_stream.send( - MessageFrame(root=JSONRPCMessage(jsonrpc_response), raw=None) + MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None) ) async def _receive_loop(self) -> None: @@ -321,7 +321,7 @@ async def _receive_loop(self) -> None: await self._incoming_message_stream_writer.send(raw_message) continue - message = raw_message.root + message = raw_message.message if isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( message.root.model_dump( diff --git a/src/mcp/types.py b/src/mcp/types.py index 848764a67..38384dea8 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -184,15 +184,43 @@ class JSONRPCMessage( class MessageFrame(BaseModel, Generic[RawT]): - root: JSONRPCMessage + """ + A wrapper around the general message received that contains both the parsed message + and the raw message. + + This class serves as an encapsulation for JSON-RPC messages, providing access to + both the parsed structure (root) and the original raw data. This design is + particularly useful for Server-Sent Events (SSE) consumers who may need to access + additional metadata or headers associated with the message. + + The 'root' attribute contains the parsed JSONRPCMessage, which could be a request, + notification, response, or error. The 'raw' attribute preserves the original + message as received, allowing access to any additional context or metadata that + might be lost in parsing. + + This dual representation allows for flexible handling of messages, where consumers + can work with the structured data for standard operations, but still have the + option to examine or utilize the raw data when needed, such as for debugging, + logging, or accessing transport-specific information. + """ + + message: JSONRPCMessage raw: RawT | None = None model_config = ConfigDict(extra="allow") def model_dump(self, *args, **kwargs): - return self.root.model_dump(*args, **kwargs) + """ + Dumps the model to a dictionary, delegating to the root JSONRPCMessage. + This method allows for consistent serialization of the parsed message. + """ + return self.message.model_dump(*args, **kwargs) def model_dump_json(self, *args, **kwargs): - return self.root.model_dump_json(*args, **kwargs) + """ + Dumps the model to a JSON string, delegating to the root JSONRPCMessage. + This method provides a convenient way to serialize the parsed message to JSON. + """ + return self.message.model_dump_json(*args, **kwargs) class EmptyResult(Result): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 2ee7351d4..27f02abf7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -58,13 +58,13 @@ async def mock_server(): ) async with server_to_client_send: - assert isinstance(jsonrpc_request.root.root, JSONRPCRequest) + assert isinstance(jsonrpc_request.message.root, JSONRPCRequest) await server_to_client_send.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", - id=jsonrpc_request.root.root.id, + id=jsonrpc_request.message.root.id, result=result.model_dump( by_alias=True, mode="json", exclude_none=True ), @@ -74,9 +74,9 @@ async def mock_server(): ) ) jsonrpc_notification = await client_to_server_receive.receive() - assert isinstance(jsonrpc_notification.root, JSONRPCMessage) + assert isinstance(jsonrpc_notification.message, JSONRPCMessage) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.root.model_dump( + jsonrpc_notification.message.model_dump( by_alias=True, mode="json", exclude_none=True ) ) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index ac78dab4c..fd05c7737 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -66,7 +66,7 @@ async def run_server(): ) await client_writer.send( - MessageFrame(root=JSONRPCMessage(root=init_req), raw=None) + MessageFrame(message=JSONRPCMessage(root=init_req), raw=None) ) await server_reader.receive() # Get init response but don't need to check it @@ -77,7 +77,9 @@ async def run_server(): jsonrpc="2.0", ) await client_writer.send( - MessageFrame(root=JSONRPCMessage(root=initialized_notification), raw=None) + MessageFrame( + message=JSONRPCMessage(root=initialized_notification), raw=None + ) ) # Send ping request with custom ID @@ -86,7 +88,7 @@ async def run_server(): ) await client_writer.send( - MessageFrame(root=JSONRPCMessage(root=ping_request), raw=None) + MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None) ) # Read response @@ -94,7 +96,7 @@ async def run_server(): # Verify response ID matches request ID assert ( - response.root.root.id == custom_request_id + response.message.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 67677fae1..18d9a4c5b 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -84,7 +84,7 @@ async def run_server(): ) await send_stream1.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=1, @@ -100,7 +100,7 @@ async def run_server(): # Send initialized notification await send_stream1.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized", @@ -113,7 +113,7 @@ async def run_server(): # Call the tool to verify lifespan context await send_stream1.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=2, @@ -127,7 +127,7 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() - assert response.root.root.result["content"][0]["text"] == "true" + assert response.message.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() @@ -189,7 +189,7 @@ async def run_server(): ) await send_stream1.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=1, @@ -205,7 +205,7 @@ async def run_server(): # Send initialized notification await send_stream1.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized", @@ -218,7 +218,7 @@ async def run_server(): # Call the tool to verify lifespan context await send_stream1.send( MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=2, @@ -232,7 +232,7 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() - assert response.root.root.result["content"][0]["text"] == "true" + assert response.message.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index df4f165bd..c12c26373 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -35,25 +35,25 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert isinstance(received_messages[0].root, JSONRPCMessage) - assert isinstance(received_messages[0].root.root, JSONRPCRequest) - assert received_messages[0].root.root.id == 1 - assert received_messages[0].root.root.method == "ping" + assert isinstance(received_messages[0].message, JSONRPCMessage) + assert isinstance(received_messages[0].message.root, JSONRPCRequest) + assert received_messages[0].message.root.id == 1 + assert received_messages[0].message.root.method == "ping" - assert isinstance(received_messages[1].root, JSONRPCMessage) - assert isinstance(received_messages[1].root.root, JSONRPCResponse) - assert received_messages[1].root.root.id == 2 + assert isinstance(received_messages[1].message, JSONRPCMessage) + assert isinstance(received_messages[1].message.root, JSONRPCResponse) + assert received_messages[1].message.root.id == 2 # Test sending responses from the server responses = [ MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") ), raw=None, ), MessageFrame( - root=JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) ), raw=None,