Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix mcp prefix #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""

def __init__(self, cache_tools_list: bool):
def __init__(self, cache_tools_list: bool, append_prefix: str | None):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
Expand All @@ -68,6 +68,7 @@ def __init__(self, cache_tools_list: bool):
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
self.append_prefix = append_prefix

# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
Expand Down Expand Up @@ -122,14 +123,25 @@ async def list_tools(self) -> list[MCPTool]:
self._cache_dirty = False

# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list
tools = (await self.session.list_tools()).tools
# Add prefix to tool names if provided
if self.append_prefix:
tools = [
MCPTool(name=f"{self.append_prefix}{tool.name}", description=tool.description)
for tool in tools
]
self._tools_list = tools
return tools

async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
"""Invoke a tool on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

# Remove prefix from tool name before calling the tool
if self.append_prefix:
tool_name = tool_name[len(self.append_prefix) :]

return await self.session.call_tool(tool_name, arguments)

async def cleanup(self):
Expand Down Expand Up @@ -170,6 +182,11 @@ class MCPServerStdioParams(TypedDict):
explanations of possible values.
"""

append_prefix: NotRequired[str]
"""If provided, this string will be prepended to tool names when calling them. Used to avoid
collisions between tool names.
"""


class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server implementation that uses the stdio transport. See the [spec]
Expand Down Expand Up @@ -199,7 +216,7 @@ def __init__(
name: A readable name for the server. If not provided, we'll create one from the
command.
"""
super().__init__(cache_tools_list)
super().__init__(cache_tools_list, params.get("append_prefix", None))

self.params = StdioServerParameters(
command=params["command"],
Expand Down Expand Up @@ -244,6 +261,11 @@ class MCPServerSseParams(TypedDict):
sse_read_timeout: NotRequired[float]
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""

append_prefix: NotRequired[str]
"""If provided, this string will be prepended to tool names when calling them. Used to avoid
collisions between tool names.
"""


class MCPServerSse(_MCPServerWithClientSession):
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
Expand Down Expand Up @@ -274,7 +296,7 @@ def __init__(
name: A readable name for the server. If not provided, we'll create one from the
URL.
"""
super().__init__(cache_tools_list)
super().__init__(cache_tools_list, params.get("append_prefix", None))

self.params = params
self._name = name or f"sse: {self.params['url']}"
Expand Down
39 changes: 39 additions & 0 deletions tests/mcp/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,42 @@ async def test_server_caching_works(
# Without invalidating the cache, calling list_tools() again should return the cached value
tools = await server.list_tools()
assert tools == tools


@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("mcp.client.session.ClientSession.list_tools")
async def test_server_caching_with_prefix(
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
):
"""Test that caching works and tool names are prefixed when append_prefix is set."""
server = MCPServerStdio(
params={
"command": tee,
"append_prefix": "myprefix_",
},
cache_tools_list=True,
)

tools = [
MCPTool(name="tool1", inputSchema={}),
MCPTool(name="tool2", inputSchema={}),
]

mock_list_tools.return_value = ListToolsResult(tools=tools)

async with server:
tools1 = await server.list_tools()
assert [t.name for t in tools1] == ["myprefix_tool1", "myprefix_tool2"]

# Should be cached
tools2 = await server.list_tools()
assert tools1 is tools2
assert mock_list_tools.call_count == 1

# Invalidate cache and check again
server.invalidate_tools_cache()
tools3 = await server.list_tools()
assert [t.name for t in tools3] == ["myprefix_tool1", "myprefix_tool2"]
assert mock_list_tools.call_count == 2
36 changes: 36 additions & 0 deletions tests/mcp/test_mcp_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,39 @@ async def test_mcp_tracing():
}
]
)


@pytest.mark.asyncio
async def test_mcp_tracing_with_prefix():
model = FakeModel()
server = FakeMCPServer()
server.append_prefix = "prefix_"
server.add_tool("toolA", {})
agent = Agent(
name="test",
model=model,
mcp_servers=[server],
tools=[get_function_tool("non_mcp_tool", "tool_result")],
)

model.add_multiple_turn_outputs(
[
[get_text_message("msg"), get_function_tool_call("prefix_toolA", "")],
[get_text_message("done")],
]
)

x = Runner.run_streamed(agent, input="first_test")
async for _ in x.stream_events():
pass

assert x.final_output == "done"
spans = fetch_normalized_spans()
# The tool name in the trace should be prefixed
assert any(
c.get("data", {}).get("name") == "prefix_toolA"
for span in spans
for agent_span in span.get("children", [])
for c in agent_span.get("children", [])
if c.get("type") == "function"
)
11 changes: 11 additions & 0 deletions tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,14 @@ async def test_util_adds_properties():
assert tool.params_json_schema == snapshot(
{"type": "object", "description": "Test tool", "properties": {}}
)


@pytest.mark.asyncio
async def test_get_all_function_tools_with_prefix():
"""Test that get_all_function_tools returns prefixed tool names if append_prefix is set."""
server = FakeMCPServer()
server.append_prefix = "utilprefix_"
server.add_tool("foo", {})
server.add_tool("bar", {})
tools = await MCPUtil.get_all_function_tools([server], convert_schemas_to_strict=False)
assert [tool.name for tool in tools] == ["utilprefix_foo", "utilprefix_bar"]
29 changes: 29 additions & 0 deletions tests/mcp/test_runner_calls_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,32 @@ async def test_runner_calls_mcp_tool_with_args(streaming: bool):
assert server.tool_results == [f"result_test_tool_2_{json_args}"]

await server.cleanup()


@pytest.mark.asyncio
@pytest.mark.parametrize("streaming", [False, True])
async def test_runner_calls_mcp_tool_with_prefix(streaming: bool):
"""Test that the runner calls the correct tool when append_prefix is set."""
server = FakeMCPServer()
server.append_prefix = "myprefix_"
server.add_tool("tool1", {})
model = FakeModel()
agent = Agent(
name="test",
model=model,
mcp_servers=[server],
)
model.add_multiple_turn_outputs(
[
[get_text_message("msg"), get_function_tool_call("myprefix_tool1", "")],
[get_text_message("done")],
]
)
if streaming:
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass
else:
await Runner.run(agent, input="user_message")
# Should call the tool with the unprefixed name
assert server.tool_calls == ["tool1"]
Loading