diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index a54fd8823..471870533 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,7 +11,6 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Any import anyio import httpx @@ -52,14 +51,10 @@ class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" - pass - class ResumptionError(StreamableHTTPError): """Raised when resumption request is invalid.""" - pass - @dataclass class RequestContext: @@ -71,7 +66,7 @@ class RequestContext: session_message: SessionMessage metadata: ClientMessageMetadata | None read_stream_writer: StreamWriter - sse_read_timeout: timedelta + sse_read_timeout: float class StreamableHTTPTransport: @@ -80,9 +75,9 @@ class StreamableHTTPTransport: def __init__( self, url: str, - headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, auth: httpx.Auth | None = None, ) -> None: """Initialize the StreamableHTTP transport. @@ -96,10 +91,12 @@ def __init__( """ self.url = url self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.sse_read_timeout = ( + sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout + ) self.auth = auth - self.session_id: str | None = None + self.session_id = None self.request_headers = { ACCEPT: f"{JSON}, {SSE}", CONTENT_TYPE: JSON, @@ -160,7 +157,7 @@ async def _handle_sse_event( return isinstance(message.root, JSONRPCResponse | JSONRPCError) except Exception as exc: - logger.error(f"Error parsing SSE message: {exc}") + logger.exception("Error parsing SSE message") await read_stream_writer.send(exc) return False else: @@ -184,10 +181,7 @@ async def handle_get_stream( "GET", self.url, headers=headers, - timeout=httpx.Timeout( - self.timeout.total_seconds(), - read=self.sse_read_timeout.total_seconds(), - ), + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") @@ -216,10 +210,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: "GET", self.url, headers=headers, - timeout=httpx.Timeout( - self.timeout.total_seconds(), - read=ctx.sse_read_timeout.total_seconds(), - ), + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") @@ -412,9 +403,9 @@ def get_session_id(self) -> str | None: @asynccontextmanager async def streamablehttp_client( url: str, - headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, terminate_on_close: bool = True, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, @@ -449,10 +440,7 @@ async def streamablehttp_client( async with httpx_client_factory( headers=transport.request_headers, - timeout=httpx.Timeout( - transport.timeout.total_seconds(), - read=transport.sse_read_timeout.total_seconds(), - ), + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), auth=transport.auth, ) as client: # Define callbacks that need access to tg