Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit da0cf22

Browse files
authored
Wrap JSONRPC messages with SessionMessage for metadata support (#590)
1 parent 3978c6e commit da0cf22

22 files changed

+286
-173
lines changed

src/mcp/client/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from mcp.client.session import ClientSession
1212
from mcp.client.sse import sse_client
1313
from mcp.client.stdio import StdioServerParameters, stdio_client
14+
from mcp.shared.message import SessionMessage
1415
from mcp.shared.session import RequestResponder
15-
from mcp.types import JSONRPCMessage
1616

1717
if not sys.warnoptions:
1818
import warnings
@@ -36,8 +36,8 @@ async def message_handler(
3636

3737

3838
async def run_session(
39-
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
40-
write_stream: MemoryObjectSendStream[JSONRPCMessage],
39+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
40+
write_stream: MemoryObjectSendStream[SessionMessage],
4141
client_info: types.Implementation | None = None,
4242
):
4343
async with ClientSession(

src/mcp/client/session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import mcp.types as types
99
from mcp.shared.context import RequestContext
10+
from mcp.shared.message import SessionMessage
1011
from mcp.shared.session import BaseSession, RequestResponder
1112
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1213

@@ -92,8 +93,8 @@ class ClientSession(
9293
):
9394
def __init__(
9495
self,
95-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
96-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
96+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
97+
write_stream: MemoryObjectSendStream[SessionMessage],
9798
read_timeout_seconds: timedelta | None = None,
9899
sampling_callback: SamplingFnT | None = None,
99100
list_roots_callback: ListRootsFnT | None = None,

src/mcp/client/sse.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from httpx_sse import aconnect_sse
1111

1212
import mcp.types as types
13+
from mcp.shared.message import SessionMessage
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -31,11 +32,11 @@ async def sse_client(
3132
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3233
event before disconnecting. All other HTTP operations are controlled by `timeout`.
3334
"""
34-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
35-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
35+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
36+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
3637

37-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
38-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
38+
write_stream: MemoryObjectSendStream[SessionMessage]
39+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
3940

4041
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4142
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -97,7 +98,8 @@ async def sse_reader(
9798
await read_stream_writer.send(exc)
9899
continue
99100

100-
await read_stream_writer.send(message)
101+
session_message = SessionMessage(message)
102+
await read_stream_writer.send(session_message)
101103
case _:
102104
logger.warning(
103105
f"Unknown SSE event: {sse.event}"
@@ -111,11 +113,13 @@ async def sse_reader(
111113
async def post_writer(endpoint_url: str):
112114
try:
113115
async with write_stream_reader:
114-
async for message in write_stream_reader:
115-
logger.debug(f"Sending client message: {message}")
116+
async for session_message in write_stream_reader:
117+
logger.debug(
118+
f"Sending client message: {session_message}"
119+
)
116120
response = await client.post(
117121
endpoint_url,
118-
json=message.model_dump(
122+
json=session_message.message.model_dump(
119123
by_alias=True,
120124
mode="json",
121125
exclude_none=True,

src/mcp/client/stdio/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic import BaseModel, Field
1212

1313
import mcp.types as types
14+
from mcp.shared.message import SessionMessage
1415

1516
from .win32 import (
1617
create_windows_process,
@@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
9899
Client transport for stdio: this will connect to a server by spawning a
99100
process and communicating with it over stdin/stdout.
100101
"""
101-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
102-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
102+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
103+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
103104

104-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
105-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
105+
write_stream: MemoryObjectSendStream[SessionMessage]
106+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
106107

107108
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
108109
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -143,7 +144,8 @@ async def stdout_reader():
143144
await read_stream_writer.send(exc)
144145
continue
145146

146-
await read_stream_writer.send(message)
147+
session_message = SessionMessage(message)
148+
await read_stream_writer.send(session_message)
147149
except anyio.ClosedResourceError:
148150
await anyio.lowlevel.checkpoint()
149151

@@ -152,8 +154,10 @@ async def stdin_writer():
152154

153155
try:
154156
async with write_stream_reader:
155-
async for message in write_stream_reader:
156-
json = message.model_dump_json(by_alias=True, exclude_none=True)
157+
async for session_message in write_stream_reader:
158+
json = session_message.message.model_dump_json(
159+
by_alias=True, exclude_none=True
160+
)
157161
await process.stdin.send(
158162
(json + "\n").encode(
159163
encoding=server.encoding,

src/mcp/client/streamable_http.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import httpx
1616
from httpx_sse import EventSource, aconnect_sse
1717

18+
from mcp.shared.message import SessionMessage
1819
from mcp.types import (
1920
ErrorData,
2021
JSONRPCError,
@@ -52,10 +53,10 @@ async def streamablehttp_client(
5253
"""
5354

5455
read_stream_writer, read_stream = anyio.create_memory_object_stream[
55-
JSONRPCMessage | Exception
56+
SessionMessage | Exception
5657
](0)
5758
write_stream, write_stream_reader = anyio.create_memory_object_stream[
58-
JSONRPCMessage
59+
SessionMessage
5960
](0)
6061

6162
async def get_stream():
@@ -86,7 +87,8 @@ async def get_stream():
8687
try:
8788
message = JSONRPCMessage.model_validate_json(sse.data)
8889
logger.debug(f"GET message: {message}")
89-
await read_stream_writer.send(message)
90+
session_message = SessionMessage(message)
91+
await read_stream_writer.send(session_message)
9092
except Exception as exc:
9193
logger.error(f"Error parsing GET message: {exc}")
9294
await read_stream_writer.send(exc)
@@ -100,7 +102,8 @@ async def post_writer(client: httpx.AsyncClient):
100102
nonlocal session_id
101103
try:
102104
async with write_stream_reader:
103-
async for message in write_stream_reader:
105+
async for session_message in write_stream_reader:
106+
message = session_message.message
104107
# Add session ID to headers if we have one
105108
post_headers = request_headers.copy()
106109
if session_id:
@@ -141,9 +144,10 @@ async def post_writer(client: httpx.AsyncClient):
141144
message="Session terminated",
142145
),
143146
)
144-
await read_stream_writer.send(
147+
session_message = SessionMessage(
145148
JSONRPCMessage(jsonrpc_error)
146149
)
150+
await read_stream_writer.send(session_message)
147151
continue
148152
response.raise_for_status()
149153

@@ -163,7 +167,8 @@ async def post_writer(client: httpx.AsyncClient):
163167
json_message = JSONRPCMessage.model_validate_json(
164168
content
165169
)
166-
await read_stream_writer.send(json_message)
170+
session_message = SessionMessage(json_message)
171+
await read_stream_writer.send(session_message)
167172
except Exception as exc:
168173
logger.error(f"Error parsing JSON response: {exc}")
169174
await read_stream_writer.send(exc)
@@ -175,11 +180,15 @@ async def post_writer(client: httpx.AsyncClient):
175180
async for sse in event_source.aiter_sse():
176181
if sse.event == "message":
177182
try:
178-
await read_stream_writer.send(
183+
message = (
179184
JSONRPCMessage.model_validate_json(
180185
sse.data
181186
)
182187
)
188+
session_message = SessionMessage(message)
189+
await read_stream_writer.send(
190+
session_message
191+
)
183192
except Exception as exc:
184193
logger.exception("Error parsing message")
185194
await read_stream_writer.send(exc)

src/mcp/client/websocket.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from websockets.typing import Subprotocol
1111

1212
import mcp.types as types
13+
from mcp.shared.message import SessionMessage
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -19,8 +20,8 @@ async def websocket_client(
1920
url: str,
2021
) -> AsyncGenerator[
2122
tuple[
22-
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
23-
MemoryObjectSendStream[types.JSONRPCMessage],
23+
MemoryObjectReceiveStream[SessionMessage | Exception],
24+
MemoryObjectSendStream[SessionMessage],
2425
],
2526
None,
2627
]:
@@ -39,10 +40,10 @@ async def websocket_client(
3940
# Create two in-memory streams:
4041
# - One for incoming messages (read_stream, written by ws_reader)
4142
# - One for outgoing messages (write_stream, read by ws_writer)
42-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
43-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
44-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
45-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
43+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
44+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
45+
write_stream: MemoryObjectSendStream[SessionMessage]
46+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
4647

4748
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4849
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -59,7 +60,8 @@ async def ws_reader():
5960
async for raw_text in ws:
6061
try:
6162
message = types.JSONRPCMessage.model_validate_json(raw_text)
62-
await read_stream_writer.send(message)
63+
session_message = SessionMessage(message)
64+
await read_stream_writer.send(session_message)
6365
except ValidationError as exc:
6466
# If JSON parse or model validation fails, send the exception
6567
await read_stream_writer.send(exc)
@@ -70,9 +72,9 @@ async def ws_writer():
7072
sends them to the server.
7173
"""
7274
async with write_stream_reader:
73-
async for message in write_stream_reader:
75+
async for session_message in write_stream_reader:
7476
# Convert to a dict, then to JSON
75-
msg_dict = message.model_dump(
77+
msg_dict = session_message.message.model_dump(
7678
by_alias=True, mode="json", exclude_none=True
7779
)
7880
await ws.send(json.dumps(msg_dict))

src/mcp/server/lowlevel/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ async def main():
8484
from mcp.server.stdio import stdio_server as stdio_server
8585
from mcp.shared.context import RequestContext
8686
from mcp.shared.exceptions import McpError
87+
from mcp.shared.message import SessionMessage
8788
from mcp.shared.session import RequestResponder
8889

8990
logger = logging.getLogger(__name__)
@@ -471,8 +472,8 @@ async def handler(req: types.CompleteRequest):
471472

472473
async def run(
473474
self,
474-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
475-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
475+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
476+
write_stream: MemoryObjectSendStream[SessionMessage],
476477
initialization_options: InitializationOptions,
477478
# When False, exceptions are returned as messages to the client.
478479
# When True, exceptions are raised, which will cause the server to shut down

src/mcp/server/session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50+
from mcp.shared.message import SessionMessage
5051
from mcp.shared.session import (
5152
BaseSession,
5253
RequestResponder,
@@ -82,8 +83,8 @@ class ServerSession(
8283

8384
def __init__(
8485
self,
85-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
86-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
86+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
87+
write_stream: MemoryObjectSendStream[SessionMessage],
8788
init_options: InitializationOptions,
8889
stateless: bool = False,
8990
) -> None:

src/mcp/server/sse.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def handle_sse(request):
4646
from starlette.types import Receive, Scope, Send
4747

4848
import mcp.types as types
49+
from mcp.shared.message import SessionMessage
4950

5051
logger = logging.getLogger(__name__)
5152

@@ -63,9 +64,7 @@ class SseServerTransport:
6364
"""
6465

6566
_endpoint: str
66-
_read_stream_writers: dict[
67-
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
68-
]
67+
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
6968

7069
def __init__(self, endpoint: str) -> None:
7170
"""
@@ -85,11 +84,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8584
raise ValueError("connect_sse can only handle HTTP requests")
8685

8786
logger.debug("Setting up SSE connection")
88-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
89-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
87+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
88+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
9089

91-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
92-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
90+
write_stream: MemoryObjectSendStream[SessionMessage]
91+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
9392

9493
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
9594
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -109,12 +108,12 @@ async def sse_writer():
109108
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
110109
logger.debug(f"Sent endpoint event: {session_uri}")
111110

112-
async for message in write_stream_reader:
113-
logger.debug(f"Sending message via SSE: {message}")
111+
async for session_message in write_stream_reader:
112+
logger.debug(f"Sending message via SSE: {session_message}")
114113
await sse_stream_writer.send(
115114
{
116115
"event": "message",
117-
"data": message.model_dump_json(
116+
"data": session_message.message.model_dump_json(
118117
by_alias=True, exclude_none=True
119118
),
120119
}
@@ -169,7 +168,8 @@ async def handle_post_message(
169168
await writer.send(err)
170169
return
171170

172-
logger.debug(f"Sending message to writer: {message}")
171+
session_message = SessionMessage(message)
172+
logger.debug(f"Sending session message to writer: {session_message}")
173173
response = Response("Accepted", status_code=202)
174174
await response(scope, receive, send)
175-
await writer.send(message)
175+
await writer.send(session_message)

0 commit comments

Comments
 (0)