diff --git a/.github/set_docs_main_preview_url.py b/.github/set_docs_main_preview_url.py new file mode 100644 index 000000000..4e423703d --- /dev/null +++ b/.github/set_docs_main_preview_url.py @@ -0,0 +1,57 @@ +import json +import os +import re +import typing + +import httpx + +DEPLOY_OUTPUT = os.environ['DEPLOY_OUTPUT'] +GITHUB_TOKEN = os.environ['GITHUB_TOKEN'] +REPOSITORY = os.environ['REPOSITORY'] +REF = os.environ['REF'] +ENVIRONMENT = 'deploy-docs-preview' + +m = re.search(r'https://(\S+)\.workers\.dev', DEPLOY_OUTPUT) +assert m, f'Could not find worker URL in {DEPLOY_OUTPUT!r}' + +worker_name = m.group(1) +m = re.search(r'Current Version ID: ([^-]+)', DEPLOY_OUTPUT) +assert m, f'Could not find version ID in {DEPLOY_OUTPUT!r}' + +version_id = m.group(1) +preview_url = f'https://{version_id}-{worker_name}.workers.dev' +print('CloudFlare worker preview URL:', preview_url, flush=True) + +gh_headers = { + 'Accept': 'application/vnd.github+json', + 'Authorization': f'Bearer {GITHUB_TOKEN}', + 'X-GitHub-Api-Version': '2022-11-28', +} + +deployment_url = f'https://api.github.com/repos/{REPOSITORY}/deployments' +deployment_data: dict[str, typing.Any] = { + 'ref': REF, + 'task': 'docs preview', + 'environment': ENVIRONMENT, + 'auto_merge': False, + 'required_contexts': [], + 'payload': json.dumps({ + 'preview_url': preview_url, + 'worker_name': worker_name, + 'version_id': version_id, + }) +} +r = httpx.post(deployment_url, headers=gh_headers, json=deployment_data) +print(f'POST {deployment_url} {r.status_code} {r.text}', flush=True) +r.raise_for_status() +deployment_id = r.json()['id'] + +status_url = f'https://api.github.com/repos/{REPOSITORY}/deployments/{deployment_id}/statuses' +status_data = { + 'environment': ENVIRONMENT, + 'environment_url': preview_url, + 'state': 'success', +} +r = httpx.post(status_url, headers=gh_headers, json=status_data) +print(f'POST {status_url} {r.status_code} {r.text}', flush=True) +r.raise_for_status() diff --git a/.github/set_docs_preview_url.py b/.github/set_docs_pr_preview_url.py similarity index 82% rename from .github/set_docs_preview_url.py rename to .github/set_docs_pr_preview_url.py index 8c46fbf12..f92f29cbe 100644 --- a/.github/set_docs_preview_url.py +++ b/.github/set_docs_pr_preview_url.py @@ -18,7 +18,7 @@ version_id = m.group(1) preview_url = f'https://{version_id}-{worker_name}.workers.dev' -print('Docs preview URL:', preview_url) +print('Docs preview URL:', preview_url, flush=True) gh_headers = { 'Accept': 'application/vnd.github+json', @@ -28,14 +28,14 @@ # now create or update a comment on the PR with the preview URL if not PULL_REQUEST_NUMBER: - print('Pull request number not set') + print('Pull request number not set', flush=True) exit(1) comments_url = f'https://api.github.com/repos/{REPOSITORY}/issues/{PULL_REQUEST_NUMBER}/comments' r = httpx.get(comments_url, headers=gh_headers) -print(f'{r.request.method} {r.request.url} {r.status_code}') +print(f'{r.request.method} {r.request.url} {r.status_code}', flush=True) if r.status_code != 200: - print(f'Failed to get comments, status {r.status_code}, response:\n{r.text}') + print(f'Failed to get comments, status {r.status_code}, response:\n{r.text}', flush=True) exit(1) comment_update_url = None @@ -62,11 +62,11 @@ comment_data = {'body': body} if comment_update_url: - print('Updating existing comment...') + print('Updating existing comment...', flush=True) r = httpx.patch(comment_update_url, headers=gh_headers, json=comment_data) else: - print('Creating new comment...') + print('Creating new comment...', flush=True) r = httpx.post(comments_url, headers=gh_headers, json=comment_data) -print(f'{r.request.method} {r.request.url} {r.status_code}') +print(f'{r.request.method} {r.request.url} {r.status_code}', flush=True) r.raise_for_status() diff --git a/.github/workflows/after-ci.yml b/.github/workflows/after-ci.yml index 712b99a33..1cf5a5e16 100644 --- a/.github/workflows/after-ci.yml +++ b/.github/workflows/after-ci.yml @@ -53,7 +53,7 @@ jobs: runs-on: ubuntu-latest if: github.event.workflow_run.event == 'pull_request' environment: - name: deploy-preview + name: deploy-docs-preview steps: - uses: actions/checkout@v4 @@ -92,7 +92,7 @@ jobs: GITHUB_EVENT_JSON: ${{ toJSON(github.event) }} - name: Set preview URL - run: uv run --with httpx .github/set_docs_preview_url.py + run: uv run --no-project --with httpx .github/set_docs_pr_preview_url.py env: DEPLOY_OUTPUT: ${{ steps.deploy.outputs.command-output }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c230935fb..c1ecc3d9f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -311,6 +311,52 @@ jobs: env: ALGOLIA_WRITE_API_KEY: ${{ secrets.ALGOLIA_WRITE_API_KEY }} + deploy-docs-preview: + needs: [check] + if: success() && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + environment: + name: deploy-docs-preview + + permissions: + deployments: write + statuses: write + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-node@v4 + - run: npm install + working-directory: docs-site + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - uses: actions/download-artifact@v4 + with: + name: site + path: site + + - uses: cloudflare/wrangler-action@v3 + id: deploy + with: + apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} + environment: previews + workingDirectory: docs-site + command: > + deploy + --var GIT_COMMIT_SHA:${{ github.sha }} + --var GIT_BRANCH:main + + - name: Set preview URL + run: uv run --no-project --with httpx .github/set_docs_main_preview_url.py + env: + DEPLOY_OUTPUT: ${{ steps.deploy.outputs.command-output }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + REPOSITORY: ${{ github.repository }} + REF: ${{ github.sha }} + # TODO(Marcelo): We need to split this into two jobs: `build` and `release`. release: needs: [check] diff --git a/docs-site/src/index.ts b/docs-site/src/index.ts index 261cf86f0..223410b3e 100644 --- a/docs-site/src/index.ts +++ b/docs-site/src/index.ts @@ -66,20 +66,24 @@ async function getChangelog(kv: KVNamespace, commitSha: string): Promise interface Release { name: string + tag_name: string body: string html_url: string } +const githubIcon = `` + function prepRelease(release: Release): string { const body = release.body .replace(/(#+)/g, (m) => `##${m}`) .replace(/https:\/\/github.com\/pydantic\/pydantic-ai\/pull\/(\d+)/g, (url, id) => `[#${id}](${url})`) - .replace(/\*\*Full Changelog\*\*: (\S+)/, (_, url) => `[Compare diff](${url})`) + .replace(/(\s)@([\w\-]+)/g, (_, s, u) => `${s}[@${u}](https://github.com/${u})`) + .replace(/\*\*Full Changelog\*\*: (\S+)/, (_, url) => `[${githubIcon} Compare diff](${url}).`) return ` ### ${release.name} ${body} -[View on GitHub](${release.html_url}) +[${githubIcon} View ${release.tag_name} release](${release.html_url}). ` } diff --git a/docs/changelog.md b/docs/changelog.md index eb088e265..50378d872 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -3,9 +3,9 @@ PydanticAI is still pre-version 1, so breaking changes will occur, however: - We try to minimize them as much as possible. -- We will use minor version bumps to signify breaking changes. -- Wherever possible we'll deprecate old features so code continues to work with deprecation warnings when changing the public API. -- We intend to release V1 in summer 2025, and then follow strict semantic versioning, e.g. no intentional breaking changes except in major versions. +- We use minor version bumps to signify breaking changes. +- Wherever possible we deprecate old features so code continues to work with deprecation warnings when changing the public API. +- We intend to release V1 in summer 2025, and then follow strict semantic versioning, e.g. no intentional breaking changes except in minor or patch versions. ## Breaking Changes @@ -27,11 +27,11 @@ See [#1484](https://github.com/pydantic/pydantic-ai/pull/1484) — `format_as_xm diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 1e4fe52fc..61d32f0ed 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,11 +3,11 @@ import asyncio import dataclasses import json -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field -from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast from opentelemetry.trace import Span, Tracer from typing_extensions import TypeGuard, TypeVar, assert_never @@ -87,6 +87,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): usage_limits: _usage.UsageLimits max_result_retries: int end_strategy: EndStrategy + get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] output_schema: _output.OutputSchema[OutputDataT] | None output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] @@ -141,7 +142,9 @@ async def _get_first_message( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> _messages.ModelRequest: run_context = build_run_context(ctx) - history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) + history, next_message = await self._prepare_messages( + self.user_prompt, ctx.state.message_history, ctx.deps.get_instructions, run_context + ) ctx.state.message_history = history run_context.messages = history @@ -155,6 +158,7 @@ async def _prepare_messages( self, user_prompt: str | Sequence[_messages.UserContent] | None, message_history: list[_messages.ModelMessage] | None, + get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]], run_context: RunContext[DepsT], ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]: try: @@ -169,7 +173,7 @@ async def _prepare_messages( ctx_messages.used = True parts: list[_messages.ModelRequestPart] = [] - instructions = await self._instructions(run_context) + instructions = await get_instructions(run_context) if message_history: # Shallow copy messages messages.extend(message_history) @@ -210,15 +214,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod messages.append(_messages.SystemPromptPart(prompt)) return messages - async def _instructions(self, run_context: RunContext[DepsT]) -> str | None: - if self.instructions is None and not self.instructions_functions: - return None - - instructions = self.instructions or '' - for instructions_runner in self.instructions_functions: - instructions += await instructions_runner.run(run_context) - return instructions - async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], @@ -479,7 +474,11 @@ async def _handle_tool_calls( else: if tool_responses: parts.extend(tool_responses) - self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) + run_context = build_run_context(ctx) + instructions = await ctx.deps.get_instructions(run_context) + self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( + _messages.ModelRequest(parts=parts, instructions=instructions) + ) def _handle_final_result( self, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index cc6d68dfc..d9eb1000e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -620,6 +620,15 @@ async def main(): }, ) + async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: + if self._instructions is None and not self._instructions_functions: + return None + + instructions = self._instructions or '' + for instructions_runner in self._instructions_functions: + instructions += await instructions_runner.run(run_context) + return instructions + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, @@ -635,6 +644,7 @@ async def main(): mcp_servers=self._mcp_servers, run_span=run_span, tracer=tracer, + get_instructions=get_instructions, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, diff --git a/pydantic_ai_slim/pydantic_ai/models/_json_schema.py b/pydantic_ai_slim/pydantic_ai/models/_json_schema.py index d35154fdc..501313bd3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/_json_schema.py +++ b/pydantic_ai_slim/pydantic_ai/models/_json_schema.py @@ -20,11 +20,11 @@ class WalkJsonSchema(ABC): def __init__( self, schema: JsonSchema, *, prefer_inlined_defs: bool = False, simplify_nullable_unions: bool = False ): - self.schema = deepcopy(schema) + self.schema = schema self.prefer_inlined_defs = prefer_inlined_defs self.simplify_nullable_unions = simplify_nullable_unions - self.defs: dict[str, JsonSchema] = self.schema.pop('$defs', {}) + self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {}) self.refs_stack = tuple[str, ...]() self.recursive_refs = set[str]() @@ -34,7 +34,11 @@ def transform(self, schema: JsonSchema) -> JsonSchema: return schema def walk(self) -> JsonSchema: - handled = self._handle(deepcopy(self.schema)) + schema = deepcopy(self.schema) + + # First, handle everything but $defs: + schema.pop('$defs', None) + handled = self._handle(schema) if not self.prefer_inlined_defs and self.defs: handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()} diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 073b99698..9f4ad2c91 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import base64 +import warnings from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace @@ -776,6 +777,22 @@ def __init__(self, schema: JsonSchema): super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True) def transform(self, schema: JsonSchema) -> JsonSchema: + # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini + additional_properties = schema.pop( + 'additionalProperties', None + ) # don't pop yet so it's included in the warning + if additional_properties: # pragma: no cover + original_schema = {**schema, 'additionalProperties': additional_properties} + warnings.warn( + '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.' + f' Full schema: {self.schema}\n\n' + f'Source of additionalProperties within the full schema: {original_schema}\n\n' + 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n' + "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub" + ' and we will fix this behavior.', + UserWarning, + ) + schema.pop('title', None) schema.pop('default', None) schema.pop('$schema', None) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 326065571..8905bd867 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -333,7 +333,7 @@ async def _run( ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: try: if isinstance(message.args, str): - args_dict = self._validator.validate_json(message.args) + args_dict = self._validator.validate_json(message.args or '{}') else: args_dict = self._validator.validate_python(message.args) except ValidationError as e: diff --git a/tests/models/cassettes/test_gemini/test_gemini_additional_properties_is_false.yaml b/tests/models/cassettes/test_gemini/test_gemini_additional_properties_is_false.yaml new file mode 100644 index 000000000..5720c6b03 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_additional_properties_is_false.yaml @@ -0,0 +1,76 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '296' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the temperature in Tokyo? + role: user + tools: + function_declarations: + - description: null + name: get_temperature + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '748' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=523 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.12538558465463143 + content: + parts: + - text: | + The available tools lack the ability to access real-time information, including current temperature. Therefore, I cannot answer your question. + role: model + finishReason: STOP + modelVersion: gemini-1.5-flash + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 14 + promptTokensDetails: + - modality: TEXT + tokenCount: 14 + totalTokenCount: 41 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_additional_properties_is_true.yaml b/tests/models/cassettes/test_gemini/test_gemini_additional_properties_is_true.yaml new file mode 100644 index 000000000..2c08a300d --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_additional_properties_is_true.yaml @@ -0,0 +1,73 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '264' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the temperature in Tokyo? + role: user + tools: + function_declarations: + - description: '' + name: get_temperature + parameters: + properties: + location: + type: object + required: + - location + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '741' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=534 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.15060695580073766 + content: + parts: + - text: | + I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information. + role: model + finishReason: STOP + modelVersion: gemini-1.5-flash + usageMetadata: + candidatesTokenCount: 28 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 28 + promptTokenCount: 12 + promptTokensDetails: + - modality: TEXT + tokenCount: 12 + totalTokenCount: 40 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_instructions_with_tool_calls_keep_instructions.yaml b/tests/models/cassettes/test_openai/test_openai_instructions_with_tool_calls_keep_instructions.yaml new file mode 100644 index 000000000..c3f2b2d86 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_instructions_with_tool_calls_keep_instructions.yaml @@ -0,0 +1,207 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '419' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: You are a helpful assistant. + role: system + - content: What is the temperature in Tokyo? + role: user + model: gpt-4.1-mini + n: 1 + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_temperature + parameters: + additionalProperties: false + properties: + city: + type: string + required: + - city + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1089' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '490' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"city":"Tokyo"}' + name: get_temperature + id: call_bhZkmIKKItNGJ41whHUHB7p9 + type: function + created: 1744810634 + id: chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq + model: gpt-4.1-mini-2025-04-14 + object: chat.completion + service_tier: default + system_fingerprint: fp_38647f5e19 + usage: + completion_tokens: 15 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 50 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 65 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '665' + content-type: + - application/json + cookie: + - __cf_bm=x.H2GlMeh.t_Q.gVlCXrh3.ggn9lKjhmUeG_ToNThLs-1744810635-1.0.1.1-tiHwqGvBw3eEy_y9_q5nx7B.7YCbLb9cXdDj6DklLmtFllOFe708mKwYvGd8fY2y5bO2NOagULipA7MxfwW9P0hlnRSiJZbZBO9tjrUweFc; + _cfuvid=VlHcJdsIsxGEt2lddKu_5Am_lfyYndl9JB2Ezy.aygo-1744810635187-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: You are a helpful assistant. + role: system + - content: What is the temperature in Tokyo? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{"city":"Tokyo"}' + name: get_temperature + id: call_bhZkmIKKItNGJ41whHUHB7p9 + type: function + - content: '20.0' + role: tool + tool_call_id: call_bhZkmIKKItNGJ41whHUHB7p9 + model: gpt-4.1-mini + n: 1 + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_temperature + parameters: + additionalProperties: false + properties: + city: + type: string + required: + - city + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '867' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '949' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: The temperature in Tokyo is currently 20.0 degrees Celsius. + refusal: null + role: assistant + created: 1744810635 + id: chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP + model: gpt-4.1-mini-2025-04-14 + object: chat.completion + service_tier: default + system_fingerprint: fp_38647f5e19 + usage: + completion_tokens: 15 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 75 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 90 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 00c7043c1..ddf9a91bf 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1029,3 +1029,40 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ ), ] ) + + +class CurrentLocation(BaseModel, extra='forbid'): + city: str + country: str + + +@pytest.mark.vcr() +async def test_gemini_additional_properties_is_false(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + agent = Agent(m) + + @agent.tool_plain + async def get_temperature(location: CurrentLocation) -> float: # pragma: no cover + return 20.0 + + result = await agent.run('What is the temperature in Tokyo?') + assert result.output == snapshot( + 'The available tools lack the ability to access real-time information, including current temperature. Therefore, I cannot answer your question.\n' + ) + + +@pytest.mark.vcr() +async def test_gemini_additional_properties_is_true(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + agent = Agent(m) + + with pytest.warns(UserWarning, match='.*additionalProperties.*'): + + @agent.tool_plain + async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pragma: no cover + return 20.0 + + result = await agent.run('What is the temperature in Tokyo?') + assert result.output == snapshot( + 'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n' + ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 00cff284c..c594f645c 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1215,3 +1215,41 @@ async def test_openai_model_without_system_prompt(allow_model_requests: None, op assert result.output == snapshot( "That's right—I am a potato! A spud of many talents, here to help you out. How can this humble potato be of service today?" ) + + +@pytest.mark.vcr() +async def test_openai_instructions_with_tool_calls_keep_instructions(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4.1-mini', provider=OpenAIProvider(api_key=openai_api_key)) + agent = Agent(m, instructions='You are a helpful assistant.') + + @agent.tool_plain + async def get_temperature(city: str) -> float: + return 20.0 + + result = await agent.run('What is the temperature in Tokyo?') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is the temperature in Tokyo?', timestamp=IsDatetime())], + instructions='You are a helpful assistant.', + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_temperature', args='{"city":"Tokyo"}', tool_call_id=IsStr())], + model_name='gpt-4.1-mini-2025-04-14', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_temperature', content=20.0, tool_call_id=IsStr(), timestamp=IsDatetime() + ) + ], + instructions='You are a helpful assistant.', + ), + ModelResponse( + parts=[TextPart(content='The temperature in Tokyo is currently 20.0 degrees Celsius.')], + model_name='gpt-4.1-mini-2025-04-14', + timestamp=IsDatetime(), + ), + ] + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index 3d499eae4..7677121a1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -829,6 +829,7 @@ async def call_tools_first(messages: list[ModelMessage], info: AgentInfo) -> Mod ToolCallPart(tool_name='my_tool', args={'a': 13, 'b': 4}), ToolCallPart(tool_name='my_tool_plain', args={'b': 17}), ToolCallPart(tool_name='my_tool_plain', args={'a': 4, 'b': 17}), + ToolCallPart(tool_name='no_args_tool', args=''), ] ) else: @@ -836,6 +837,10 @@ async def call_tools_first(messages: list[ModelMessage], info: AgentInfo) -> Mod agent = Agent(FunctionModel(call_tools_first)) + @agent.tool_plain + def no_args_tool() -> None: + return None + @agent.tool def my_tool(ctx: RunContext[None], a: int, b: int = 2) -> int: return a + b @@ -858,9 +863,10 @@ def my_tool_plain(*, a: int = 3, b: int) -> int: {'a': 13, 'b': 4}, {'b': 17}, {'a': 4, 'b': 17}, + '', ] ) - assert tool_returns == snapshot([15, 17, 51, 68]) + assert tool_returns == snapshot([15, 17, 51, 68, None]) def test_schema_generator():