Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0eee6b8

Browse files
authored
Allow arbitrary kwargs in model (openai#842)
Sometimes users want to provide parameters specific to a model provider. This is an escape hatch.
1 parent 8dfd6ff commit 0eee6b8

File tree

6 files changed

+272
-0
lines changed

6 files changed

+272
-0
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ async def _fetch_response(
284284
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
285285
extra_kwargs.update(model_settings.extra_body)
286286

287+
# Add kwargs from model_settings.extra_args, filtering out None values
288+
if model_settings.extra_args:
289+
extra_kwargs.update(model_settings.extra_args)
290+
287291
ret = await litellm.acompletion(
288292
model=self.model,
289293
messages=converted_messages,

src/agents/model_settings.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class ModelSettings:
7373
"""Additional headers to provide with the request.
7474
Defaults to None if not provided."""
7575

76+
extra_args: dict[str, Any] | None = None
77+
"""Arbitrary keyword arguments to pass to the model API call.
78+
These will be passed directly to the underlying model provider's API.
79+
Use with caution as not all models support all parameters."""
80+
7681
def resolve(self, override: ModelSettings | None) -> ModelSettings:
7782
"""Produce a new ModelSettings by overlaying any non-None values from the
7883
override on top of this instance."""
@@ -84,6 +89,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings:
8489
for field in fields(self)
8590
if getattr(override, field.name) is not None
8691
}
92+
93+
# Handle extra_args merging specially - merge dictionaries instead of replacing
94+
if self.extra_args is not None or override.extra_args is not None:
95+
merged_args = {}
96+
if self.extra_args:
97+
merged_args.update(self.extra_args)
98+
if override.extra_args:
99+
merged_args.update(override.extra_args)
100+
changes["extra_args"] = merged_args if merged_args else None
101+
87102
return replace(self, **changes)
88103

89104
def to_json_dict(self) -> dict[str, Any]:

src/agents/models/openai_chatcompletions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ async def _fetch_response(
281281
extra_query=model_settings.extra_query,
282282
extra_body=model_settings.extra_body,
283283
metadata=self._non_null_or_not_given(model_settings.metadata),
284+
**(model_settings.extra_args or {}),
284285
)
285286

286287
if isinstance(ret, ChatCompletion):

src/agents/models/openai_responses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ async def _fetch_response(
266266
store=self._non_null_or_not_given(model_settings.store),
267267
reasoning=self._non_null_or_not_given(model_settings.reasoning),
268268
metadata=self._non_null_or_not_given(model_settings.metadata),
269+
**(model_settings.extra_args or {}),
269270
)
270271

271272
def _get_client(self) -> AsyncOpenAI:

tests/model_settings/test_serialization.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_all_fields_serialization() -> None:
4747
extra_query={"foo": "bar"},
4848
extra_body={"foo": "bar"},
4949
extra_headers={"foo": "bar"},
50+
extra_args={"custom_param": "value", "another_param": 42},
5051
)
5152

5253
# Verify that every single field is set to a non-None value
@@ -57,3 +58,76 @@ def test_all_fields_serialization() -> None:
5758

5859
# Now, lets serialize the ModelSettings instance to a JSON string
5960
verify_serialization(model_settings)
61+
62+
63+
def test_extra_args_serialization() -> None:
64+
"""Test that extra_args are properly serialized."""
65+
model_settings = ModelSettings(
66+
temperature=0.5,
67+
extra_args={"custom_param": "value", "another_param": 42, "nested": {"key": "value"}},
68+
)
69+
70+
json_dict = model_settings.to_json_dict()
71+
assert json_dict["extra_args"] == {
72+
"custom_param": "value",
73+
"another_param": 42,
74+
"nested": {"key": "value"},
75+
}
76+
77+
# Verify serialization works
78+
verify_serialization(model_settings)
79+
80+
81+
def test_extra_args_resolve() -> None:
82+
"""Test that extra_args are properly merged in the resolve method."""
83+
base_settings = ModelSettings(
84+
temperature=0.5, extra_args={"param1": "base_value", "param2": "base_only"}
85+
)
86+
87+
override_settings = ModelSettings(
88+
top_p=0.9, extra_args={"param1": "override_value", "param3": "override_only"}
89+
)
90+
91+
resolved = base_settings.resolve(override_settings)
92+
93+
# Check that regular fields are properly resolved
94+
assert resolved.temperature == 0.5 # from base
95+
assert resolved.top_p == 0.9 # from override
96+
97+
# Check that extra_args are properly merged
98+
expected_extra_args = {
99+
"param1": "override_value", # override wins
100+
"param2": "base_only", # from base
101+
"param3": "override_only", # from override
102+
}
103+
assert resolved.extra_args == expected_extra_args
104+
105+
106+
def test_extra_args_resolve_with_none() -> None:
107+
"""Test that resolve works properly when one side has None extra_args."""
108+
# Base with extra_args, override with None
109+
base_settings = ModelSettings(extra_args={"param1": "value1"})
110+
override_settings = ModelSettings(temperature=0.8)
111+
112+
resolved = base_settings.resolve(override_settings)
113+
assert resolved.extra_args == {"param1": "value1"}
114+
assert resolved.temperature == 0.8
115+
116+
# Base with None, override with extra_args
117+
base_settings = ModelSettings(temperature=0.5)
118+
override_settings = ModelSettings(extra_args={"param2": "value2"})
119+
120+
resolved = base_settings.resolve(override_settings)
121+
assert resolved.extra_args == {"param2": "value2"}
122+
assert resolved.temperature == 0.5
123+
124+
125+
def test_extra_args_resolve_both_none() -> None:
126+
"""Test that resolve works when both sides have None extra_args."""
127+
base_settings = ModelSettings(temperature=0.5)
128+
override_settings = ModelSettings(top_p=0.9)
129+
130+
resolved = base_settings.resolve(override_settings)
131+
assert resolved.extra_args is None
132+
assert resolved.temperature == 0.5
133+
assert resolved.top_p == 0.9
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import litellm
2+
import pytest
3+
from litellm.types.utils import Choices, Message, ModelResponse, Usage
4+
from openai.types.chat.chat_completion import ChatCompletion, Choice
5+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
6+
from openai.types.completion_usage import CompletionUsage
7+
8+
from agents.extensions.models.litellm_model import LitellmModel
9+
from agents.model_settings import ModelSettings
10+
from agents.models.interface import ModelTracing
11+
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
12+
13+
14+
@pytest.mark.allow_call_model_methods
15+
@pytest.mark.asyncio
16+
async def test_litellm_kwargs_forwarded(monkeypatch):
17+
"""
18+
Test that kwargs from ModelSettings are forwarded to litellm.acompletion.
19+
"""
20+
captured: dict[str, object] = {}
21+
22+
async def fake_acompletion(model, messages=None, **kwargs):
23+
captured.update(kwargs)
24+
msg = Message(role="assistant", content="test response")
25+
choice = Choices(index=0, message=msg)
26+
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
27+
28+
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
29+
30+
settings = ModelSettings(
31+
temperature=0.5,
32+
extra_args={
33+
"custom_param": "custom_value",
34+
"seed": 42,
35+
"stop": ["END"],
36+
"logit_bias": {123: -100},
37+
},
38+
)
39+
model = LitellmModel(model="test-model")
40+
41+
await model.get_response(
42+
system_instructions=None,
43+
input="test input",
44+
model_settings=settings,
45+
tools=[],
46+
output_schema=None,
47+
handoffs=[],
48+
tracing=ModelTracing.DISABLED,
49+
previous_response_id=None,
50+
)
51+
52+
# Verify that all kwargs were passed through
53+
assert captured["custom_param"] == "custom_value"
54+
assert captured["seed"] == 42
55+
assert captured["stop"] == ["END"]
56+
assert captured["logit_bias"] == {123: -100}
57+
58+
# Verify regular parameters are still passed
59+
assert captured["temperature"] == 0.5
60+
61+
62+
@pytest.mark.allow_call_model_methods
63+
@pytest.mark.asyncio
64+
async def test_openai_chatcompletions_kwargs_forwarded(monkeypatch):
65+
"""
66+
Test that kwargs from ModelSettings are forwarded to OpenAI chat completions API.
67+
"""
68+
captured: dict[str, object] = {}
69+
70+
class MockChatCompletions:
71+
async def create(self, **kwargs):
72+
captured.update(kwargs)
73+
msg = ChatCompletionMessage(role="assistant", content="test response")
74+
choice = Choice(index=0, message=msg, finish_reason="stop")
75+
return ChatCompletion(
76+
id="test-id",
77+
created=0,
78+
model="gpt-4",
79+
object="chat.completion",
80+
choices=[choice],
81+
usage=CompletionUsage(completion_tokens=5, prompt_tokens=10, total_tokens=15),
82+
)
83+
84+
class MockChat:
85+
def __init__(self):
86+
self.completions = MockChatCompletions()
87+
88+
class MockClient:
89+
def __init__(self):
90+
self.chat = MockChat()
91+
self.base_url = "https://api.openai.com/v1"
92+
93+
settings = ModelSettings(
94+
temperature=0.7,
95+
extra_args={
96+
"seed": 123,
97+
"logit_bias": {456: 10},
98+
"stop": ["STOP", "END"],
99+
"user": "test-user",
100+
},
101+
)
102+
103+
mock_client = MockClient()
104+
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=mock_client) # type: ignore
105+
106+
await model.get_response(
107+
system_instructions="Test system",
108+
input="test input",
109+
model_settings=settings,
110+
tools=[],
111+
output_schema=None,
112+
handoffs=[],
113+
tracing=ModelTracing.DISABLED,
114+
previous_response_id=None,
115+
)
116+
117+
# Verify that all kwargs were passed through
118+
assert captured["seed"] == 123
119+
assert captured["logit_bias"] == {456: 10}
120+
assert captured["stop"] == ["STOP", "END"]
121+
assert captured["user"] == "test-user"
122+
123+
# Verify regular parameters are still passed
124+
assert captured["temperature"] == 0.7
125+
126+
127+
@pytest.mark.allow_call_model_methods
128+
@pytest.mark.asyncio
129+
async def test_empty_kwargs_handling(monkeypatch):
130+
"""
131+
Test that empty or None kwargs are handled gracefully.
132+
"""
133+
captured: dict[str, object] = {}
134+
135+
async def fake_acompletion(model, messages=None, **kwargs):
136+
captured.update(kwargs)
137+
msg = Message(role="assistant", content="test response")
138+
choice = Choices(index=0, message=msg)
139+
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
140+
141+
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
142+
143+
# Test with None kwargs
144+
settings_none = ModelSettings(temperature=0.5, extra_args=None)
145+
model = LitellmModel(model="test-model")
146+
147+
await model.get_response(
148+
system_instructions=None,
149+
input="test input",
150+
model_settings=settings_none,
151+
tools=[],
152+
output_schema=None,
153+
handoffs=[],
154+
tracing=ModelTracing.DISABLED,
155+
previous_response_id=None,
156+
)
157+
158+
# Should work without error and include regular parameters
159+
assert captured["temperature"] == 0.5
160+
161+
# Test with empty dict
162+
captured.clear()
163+
settings_empty = ModelSettings(temperature=0.3, extra_args={})
164+
165+
await model.get_response(
166+
system_instructions=None,
167+
input="test input",
168+
model_settings=settings_empty,
169+
tools=[],
170+
output_schema=None,
171+
handoffs=[],
172+
tracing=ModelTracing.DISABLED,
173+
previous_response_id=None,
174+
)
175+
176+
# Should work without error and include regular parameters
177+
assert captured["temperature"] == 0.3

0 commit comments

Comments
 (0)