From 9733f32cf926718f24ab40cea7ddf320a4b925ed Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:13:42 -0600 Subject: [PATCH 1/4] Fix max length handling (#1510) --- pydantic_ai_slim/pydantic_ai/models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index b4e4af31a..addde7901 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -957,7 +957,7 @@ def transform(self, schema: JsonSchema) -> JsonSchema: # Remove incompatible keys, but note their impact in the description provided to the LLM description = schema.get('description') min_length = schema.pop('minLength', None) - max_length = schema.pop('minLength', None) + max_length = schema.pop('maxLength', None) if description is not None: notes = list[str]() if min_length is not None: # pragma: no cover From babdf8254985e8a1c47b26d957a9cef36f283657 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 16 Apr 2025 17:36:24 -0600 Subject: [PATCH 2/4] Do a better job of inferring openai strict mode (#1511) --- pydantic_ai_slim/pydantic_ai/models/openai.py | 86 ++++++++++++++++--- tests/models/test_openai.py | 64 +++++++++++--- 2 files changed, 124 insertions(+), 26 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index addde7901..dc1026f9f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import base64 +import re import warnings from collections.abc import AsyncIterable, AsyncIterator, Sequence from contextlib import asynccontextmanager @@ -932,6 +933,31 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R ) +_STRICT_INCOMPATIBLE_KEYS = [ + 'minLength', + 'maxLength', + 'pattern', + 'format', + 'minimum', + 'maximum', + 'multipleOf', + 'patternProperties', + 'unevaluatedProperties', + 'propertyNames', + 'minProperties', + 'maxProperties', + 'unevaluatedItems', + 'contains', + 'minContains', + 'maxContains', + 'minItems', + 'maxItems', + 'uniqueItems', +] + +_sentinel = object() + + @dataclass class _OpenAIJsonSchema(WalkJsonSchema): """Recursively handle the schema to make it compatible with OpenAI strict mode. @@ -946,28 +972,64 @@ def __init__(self, schema: JsonSchema, strict: bool | None): super().__init__(schema) self.strict = strict self.is_strict_compatible = True + self.root_ref = schema.get('$ref') + + def walk(self) -> JsonSchema: + # Note: OpenAI does not support anyOf at the root in strict mode + # However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema + # that the root schema either has type 'object' or is recursive. + result = super().walk() + + # For recursive models, we need to tweak the schema to make it compatible with strict mode. + # Because the following should never change the semantics of the schema we apply it unconditionally. + if self.root_ref is not None: + result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method + root_key = re.sub(r'^#/\$defs/', '', self.root_ref) + result.update(self.defs.get(root_key) or {}) + + return result - def transform(self, schema: JsonSchema) -> JsonSchema: + def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901 # Remove unnecessary keys schema.pop('title', None) schema.pop('default', None) schema.pop('$schema', None) schema.pop('discriminator', None) - # Remove incompatible keys, but note their impact in the description provided to the LLM + if schema_ref := schema.get('$ref'): + if schema_ref == self.root_ref: + schema['$ref'] = '#' + if len(schema) > 1: + # OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf". + # So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf": + schema['anyOf'] = [{'$ref': schema.pop('$ref')}] + + # Track strict-incompatible keys + incompatible_values: dict[str, Any] = {} + for key in _STRICT_INCOMPATIBLE_KEYS: + value = schema.get(key, _sentinel) + if value is not _sentinel: + incompatible_values[key] = value description = schema.get('description') - min_length = schema.pop('minLength', None) - max_length = schema.pop('maxLength', None) - if description is not None: - notes = list[str]() - if min_length is not None: # pragma: no cover - notes.append(f'min_length={min_length}') - if max_length is not None: # pragma: no cover - notes.append(f'max_length={max_length}') - if notes: # pragma: no cover - schema['description'] = f'{description} ({", ".join(notes)})' + if incompatible_values: + if self.strict is True: + notes: list[str] = [] + for key, value in incompatible_values.items(): + schema.pop(key) + notes.append(f'{key}={value}') + notes_string = ', '.join(notes) + schema['description'] = notes_string if not description else f'{description} ({notes_string})' + elif self.strict is None: + self.is_strict_compatible = False schema_type = schema.get('type') + if 'oneOf' in schema: + # OpenAI does not support oneOf in strict mode + if self.strict is True: + schema['anyOf'] = schema.pop('oneOf') + else: + self.is_strict_compatible = False + if schema_type == 'object': if self.strict is True: # additional properties are disallowed diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index c594f645c..b19c03b6c 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from datetime import datetime, timezone +from enum import Enum from functools import cached_property from typing import Annotated, Any, Callable, Literal, Union, cast @@ -730,9 +731,15 @@ class MyDefaultDc: x: int = 1 +class MyEnum(Enum): + a = 'a' + b = 'b' + + @dataclass class MyRecursiveDc: field: MyRecursiveDc | None + my_enum: MyEnum = Field(description='my enum') @dataclass @@ -826,9 +833,13 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str: }, 'type': 'object', }, + 'MyEnum': {'enum': ['a', 'b'], 'type': 'string'}, 'MyRecursiveDc': { - 'properties': {'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]}}, - 'required': ['field'], + 'properties': { + 'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]}, + 'my_enum': {'description': 'my enum', 'anyOf': [{'$ref': '#/$defs/MyEnum'}]}, + }, + 'required': ['field', 'my_enum'], 'type': 'object', }, }, @@ -857,11 +868,15 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str: 'additionalProperties': False, 'required': ['field'], }, + 'MyEnum': {'enum': ['a', 'b'], 'type': 'string'}, 'MyRecursiveDc': { - 'properties': {'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]}}, + 'properties': { + 'field': {'anyOf': [{'$ref': '#/$defs/MyRecursiveDc'}, {'type': 'null'}]}, + 'my_enum': {'description': 'my enum', 'anyOf': [{'$ref': '#/$defs/MyEnum'}]}, + }, 'type': 'object', 'additionalProperties': False, - 'required': ['field'], + 'required': ['field', 'my_enum'], }, }, 'additionalProperties': False, @@ -998,7 +1013,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str: } }, 'additionalProperties': False, - 'properties': {'x': {'oneOf': [{'type': 'integer'}, {'$ref': '#/$defs/MyDefaultDc'}]}}, + 'properties': {'x': {'anyOf': [{'type': 'integer'}, {'$ref': '#/$defs/MyDefaultDc'}]}}, 'required': ['x'], 'type': 'object', } @@ -1079,12 +1094,15 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str: { 'additionalProperties': False, 'properties': { - 'x': {'maxItems': 1, 'minItems': 1, 'prefixItems': [{'type': 'integer'}], 'type': 'array'}, + 'x': { + 'prefixItems': [{'type': 'integer'}], + 'type': 'array', + 'description': 'minItems=1, maxItems=1', + }, 'y': { - 'maxItems': 1, - 'minItems': 1, 'prefixItems': [{'type': 'string'}], 'type': 'array', + 'description': 'minItems=1, maxItems=1', }, }, 'required': ['x', 'y'], @@ -1160,28 +1178,46 @@ class MyModel(BaseModel): 'MyModel': { 'additionalProperties': False, 'properties': { - 'my_discriminated_union': {'oneOf': [{'$ref': '#/$defs/Apple'}, {'$ref': '#/$defs/Banana'}]}, + 'my_discriminated_union': {'anyOf': [{'$ref': '#/$defs/Apple'}, {'$ref': '#/$defs/Banana'}]}, 'my_list': {'items': {'type': 'number'}, 'type': 'array'}, 'my_patterns': { 'additionalProperties': False, - 'patternProperties': {'^my-pattern$': {'type': 'string'}}, + 'description': "patternProperties={'^my-pattern$': {'type': 'string'}}", 'type': 'object', 'properties': {}, 'required': [], }, - 'my_recursive': {'anyOf': [{'$ref': '#/$defs/MyModel'}, {'type': 'null'}]}, + 'my_recursive': {'anyOf': [{'$ref': '#'}, {'type': 'null'}]}, 'my_tuple': { - 'maxItems': 1, - 'minItems': 1, 'prefixItems': [{'type': 'integer'}], 'type': 'array', + 'description': 'minItems=1, maxItems=1', }, }, 'required': ['my_recursive', 'my_patterns', 'my_tuple', 'my_list', 'my_discriminated_union'], 'type': 'object', }, }, - '$ref': '#/$defs/MyModel', + 'properties': { + 'my_recursive': {'anyOf': [{'$ref': '#'}, {'type': 'null'}]}, + 'my_patterns': { + 'type': 'object', + 'description': "patternProperties={'^my-pattern$': {'type': 'string'}}", + 'additionalProperties': False, + 'properties': {}, + 'required': [], + }, + 'my_tuple': { + 'prefixItems': [{'type': 'integer'}], + 'type': 'array', + 'description': 'minItems=1, maxItems=1', + }, + 'my_list': {'items': {'type': 'number'}, 'type': 'array'}, + 'my_discriminated_union': {'anyOf': [{'$ref': '#/$defs/Apple'}, {'$ref': '#/$defs/Banana'}]}, + }, + 'required': ['my_recursive', 'my_patterns', 'my_tuple', 'my_list', 'my_discriminated_union'], + 'type': 'object', + 'additionalProperties': False, } ) From 1a33c02fc6a24a4c774c181fd5ccd33a82a14fd9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 17 Apr 2025 01:58:43 -0600 Subject: [PATCH 3/4] Properly validate serialized messages with BinaryContent by decoding base64 (#1513) Co-authored-by: Marcelo Trylesinski --- pydantic_ai_slim/pydantic_ai/messages.py | 2 +- tests/test_agent.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 4786dbc77..cd6839c02 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -589,7 +589,7 @@ def new_event_body(): """Any message sent to or returned by a model.""" ModelMessagesTypeAdapter = pydantic.TypeAdapter( - list[ModelMessage], config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64') + list[ModelMessage], config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64', val_json_bytes='base64') ) """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index c20cb7348..fd823195c 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -15,6 +15,7 @@ from pydantic_ai.messages import ( BinaryContent, ModelMessage, + ModelMessagesTypeAdapter, ModelRequest, ModelResponse, ModelResponsePart, @@ -1675,8 +1676,11 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: # pragma: no cover def test_binary_content_all_messages_json(): agent = Agent('test') - result = agent.run_sync(['Hello', BinaryContent(data=b'Hello', media_type='text/plain')]) - assert json.loads(result.all_messages_json()) == snapshot( + content = BinaryContent(data=b'Hello', media_type='text/plain') + result = agent.run_sync(['Hello', content]) + + serialized = result.all_messages_json() + assert json.loads(serialized) == snapshot( [ { 'parts': [ @@ -1698,6 +1702,10 @@ def test_binary_content_all_messages_json(): ] ) + # We also need to be able to round trip the serialized messages. + messages = ModelMessagesTypeAdapter.validate_json(serialized) + assert messages == result.all_messages() + def test_instructions_raise_error_when_system_prompt_is_set(): agent = Agent('test', instructions='An instructions!') From a45b8e14118257eebfe6d00081b8e890d73714f2 Mon Sep 17 00:00:00 2001 From: Tim Esler Date: Thu, 17 Apr 2025 01:05:14 -0700 Subject: [PATCH 4/4] Expose the StdioServerParameters.cwd param (#1514) --- pydantic_ai_slim/pydantic_ai/mcp.py | 6 +++++- pydantic_ai_slim/pyproject.toml | 2 +- tests/test_mcp.py | 10 ++++++++++ uv.lock | 8 ++++---- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 93e6e018d..35dcb7593 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -4,6 +4,7 @@ from collections.abc import AsyncIterator, Sequence from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass +from pathlib import Path from types import TracebackType from typing import Any @@ -150,13 +151,16 @@ async def main(): If you want to inherit the environment variables from the parent process, use `env=os.environ`. """ + cwd: str | Path | None = None + """The working directory to use when spawning the process.""" + @asynccontextmanager async def client_streams( self, ) -> AsyncIterator[ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]] ]: - server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env) + server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env, cwd=self.cwd) async with stdio_client(server=server) as (read_stream, write_stream): yield read_stream, write_stream diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 342540603..8f0471aa9 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,7 +69,7 @@ tavily = ["tavily-python>=0.5.0"] # CLI cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"] # MCP -mcp = ["mcp>=1.4.1; python_version >= '3.10'"] +mcp = ["mcp>=1.5.0; python_version >= '3.10'"] # Evals evals = ["pydantic-evals=={{ version }}"] diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 0d77a9f1b..4c7d2dd59 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,5 +1,7 @@ """Tests for the MCP (Model Context Protocol) server implementation.""" +from pathlib import Path + import pytest from dirty_equals import IsInstance from inline_snapshot import snapshot @@ -38,6 +40,14 @@ async def test_stdio_server(): assert result.content == snapshot([TextContent(type='text', text='32.0')]) +async def test_stdio_server_with_cwd(): + test_dir = Path(__file__).parent + server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) + async with server: + tools = await server.list_tools() + assert len(tools) == 1 + + def test_sse_server(): sse_server = MCPServerHTTP(url='http://localhost:8000/sse') assert sse_server.url == 'http://localhost:8000/sse' diff --git a/uv.lock b/uv.lock index 774e0d1d1..a8c9c0b8e 100644 --- a/uv.lock +++ b/uv.lock @@ -1670,7 +1670,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.4.1" +version = "1.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1682,9 +1682,9 @@ dependencies = [ { name = "starlette", marker = "python_full_version >= '3.10'" }, { name = "uvicorn", marker = "python_full_version >= '3.10'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/cc/5c5bb19f1a0f8f89a95e25cb608b0b07009e81fd4b031e519335404e1422/mcp-1.4.1.tar.gz", hash = "sha256:b9655d2de6313f9d55a7d1df62b3c3fe27a530100cc85bf23729145b0dba4c7a", size = 154942 } +sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 } wheels = [ - { url = "https://files.pythonhosted.org/packages/e8/0e/885f156ade60108e67bf044fada5269da68e29d758a10b0c513f4d85dd76/mcp-1.4.1-py3-none-any.whl", hash = "sha256:a7716b1ec1c054e76f49806f7d96113b99fc1166fc9244c2c6f19867cb75b593", size = 72448 }, + { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 }, ] [package.optional-dependencies] @@ -2979,7 +2979,7 @@ requires-dist = [ { name = "groq", marker = "extra == 'groq'", specifier = ">=0.15.0" }, { name = "httpx", specifier = ">=0.27" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, - { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.4.1" }, + { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.5.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.74.0" }, { name = "opentelemetry-api", specifier = ">=1.28.0" },