Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7670a80

Browse files
committed
Allow arbitrary kwargs in model
1 parent 8dfd6ff commit 7670a80

File tree

6 files changed

+267
-0
lines changed

6 files changed

+267
-0
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ async def _fetch_response(
284284
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
285285
extra_kwargs.update(model_settings.extra_body)
286286

287+
# Add kwargs from model_settings.kwargs, filtering out None values
288+
if model_settings.kwargs:
289+
extra_kwargs.update(model_settings.kwargs)
290+
287291
ret = await litellm.acompletion(
288292
model=self.model,
289293
messages=converted_messages,

src/agents/model_settings.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class ModelSettings:
7373
"""Additional headers to provide with the request.
7474
Defaults to None if not provided."""
7575

76+
kwargs: dict[str, Any] | None = None
77+
"""Arbitrary keyword arguments to pass to the model API call.
78+
These will be passed directly to the underlying model provider's API.
79+
Use with caution as not all models support all parameters."""
80+
7681
def resolve(self, override: ModelSettings | None) -> ModelSettings:
7782
"""Produce a new ModelSettings by overlaying any non-None values from the
7883
override on top of this instance."""
@@ -84,6 +89,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings:
8489
for field in fields(self)
8590
if getattr(override, field.name) is not None
8691
}
92+
93+
# Handle kwargs merging specially - merge dictionaries instead of replacing
94+
if self.kwargs is not None or override.kwargs is not None:
95+
merged_kwargs = {}
96+
if self.kwargs:
97+
merged_kwargs.update(self.kwargs)
98+
if override.kwargs:
99+
merged_kwargs.update(override.kwargs)
100+
changes["kwargs"] = merged_kwargs if merged_kwargs else None
101+
87102
return replace(self, **changes)
88103

89104
def to_json_dict(self) -> dict[str, Any]:

src/agents/models/openai_chatcompletions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ async def _fetch_response(
281281
extra_query=model_settings.extra_query,
282282
extra_body=model_settings.extra_body,
283283
metadata=self._non_null_or_not_given(model_settings.metadata),
284+
**(model_settings.kwargs or {}),
284285
)
285286

286287
if isinstance(ret, ChatCompletion):

src/agents/models/openai_responses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ async def _fetch_response(
266266
store=self._non_null_or_not_given(model_settings.store),
267267
reasoning=self._non_null_or_not_given(model_settings.reasoning),
268268
metadata=self._non_null_or_not_given(model_settings.metadata),
269+
**(model_settings.kwargs or {}),
269270
)
270271

271272
def _get_client(self) -> AsyncOpenAI:

tests/model_settings/test_serialization.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_all_fields_serialization() -> None:
4747
extra_query={"foo": "bar"},
4848
extra_body={"foo": "bar"},
4949
extra_headers={"foo": "bar"},
50+
kwargs={"custom_param": "value", "another_param": 42},
5051
)
5152

5253
# Verify that every single field is set to a non-None value
@@ -57,3 +58,76 @@ def test_all_fields_serialization() -> None:
5758

5859
# Now, lets serialize the ModelSettings instance to a JSON string
5960
verify_serialization(model_settings)
61+
62+
63+
def test_kwargs_serialization() -> None:
64+
"""Test that kwargs are properly serialized."""
65+
model_settings = ModelSettings(
66+
temperature=0.5,
67+
kwargs={"custom_param": "value", "another_param": 42, "nested": {"key": "value"}},
68+
)
69+
70+
json_dict = model_settings.to_json_dict()
71+
assert json_dict["kwargs"] == {
72+
"custom_param": "value",
73+
"another_param": 42,
74+
"nested": {"key": "value"},
75+
}
76+
77+
# Verify serialization works
78+
verify_serialization(model_settings)
79+
80+
81+
def test_kwargs_resolve() -> None:
82+
"""Test that kwargs are properly merged in the resolve method."""
83+
base_settings = ModelSettings(
84+
temperature=0.5, kwargs={"param1": "base_value", "param2": "base_only"}
85+
)
86+
87+
override_settings = ModelSettings(
88+
top_p=0.9, kwargs={"param1": "override_value", "param3": "override_only"}
89+
)
90+
91+
resolved = base_settings.resolve(override_settings)
92+
93+
# Check that regular fields are properly resolved
94+
assert resolved.temperature == 0.5 # from base
95+
assert resolved.top_p == 0.9 # from override
96+
97+
# Check that kwargs are properly merged
98+
expected_kwargs = {
99+
"param1": "override_value", # override wins
100+
"param2": "base_only", # from base
101+
"param3": "override_only", # from override
102+
}
103+
assert resolved.kwargs == expected_kwargs
104+
105+
106+
def test_kwargs_resolve_with_none() -> None:
107+
"""Test that resolve works properly when one side has None kwargs."""
108+
# Base with kwargs, override with None
109+
base_settings = ModelSettings(kwargs={"param1": "value1"})
110+
override_settings = ModelSettings(temperature=0.8)
111+
112+
resolved = base_settings.resolve(override_settings)
113+
assert resolved.kwargs == {"param1": "value1"}
114+
assert resolved.temperature == 0.8
115+
116+
# Base with None, override with kwargs
117+
base_settings = ModelSettings(temperature=0.5)
118+
override_settings = ModelSettings(kwargs={"param2": "value2"})
119+
120+
resolved = base_settings.resolve(override_settings)
121+
assert resolved.kwargs == {"param2": "value2"}
122+
assert resolved.temperature == 0.5
123+
124+
125+
def test_kwargs_resolve_both_none() -> None:
126+
"""Test that resolve works when both sides have None kwargs."""
127+
base_settings = ModelSettings(temperature=0.5)
128+
override_settings = ModelSettings(top_p=0.9)
129+
130+
resolved = base_settings.resolve(override_settings)
131+
assert resolved.kwargs is None
132+
assert resolved.temperature == 0.5
133+
assert resolved.top_p == 0.9
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import litellm
2+
import pytest
3+
from litellm.types.utils import Choices, Message, ModelResponse, Usage
4+
from openai.types.chat.chat_completion import ChatCompletion, Choice
5+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
6+
from openai.types.completion_usage import CompletionUsage
7+
8+
from agents.extensions.models.litellm_model import LitellmModel
9+
from agents.model_settings import ModelSettings
10+
from agents.models.interface import ModelTracing
11+
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
12+
13+
14+
@pytest.mark.allow_call_model_methods
15+
@pytest.mark.asyncio
16+
async def test_litellm_kwargs_forwarded(monkeypatch):
17+
"""
18+
Test that kwargs from ModelSettings are forwarded to litellm.acompletion.
19+
"""
20+
captured: dict[str, object] = {}
21+
22+
async def fake_acompletion(model, messages=None, **kwargs):
23+
captured.update(kwargs)
24+
msg = Message(role="assistant", content="test response")
25+
choice = Choices(index=0, message=msg)
26+
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
27+
28+
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
29+
30+
settings = ModelSettings(
31+
temperature=0.5,
32+
kwargs={
33+
"custom_param": "custom_value",
34+
"seed": 42,
35+
"stop": ["END"],
36+
"logit_bias": {123: -100},
37+
},
38+
)
39+
model = LitellmModel(model="test-model")
40+
41+
await model.get_response(
42+
system_instructions=None,
43+
input="test input",
44+
model_settings=settings,
45+
tools=[],
46+
output_schema=None,
47+
handoffs=[],
48+
tracing=ModelTracing.DISABLED,
49+
previous_response_id=None,
50+
)
51+
52+
# Verify that all kwargs were passed through
53+
assert captured["custom_param"] == "custom_value"
54+
assert captured["seed"] == 42
55+
assert captured["stop"] == ["END"]
56+
assert captured["logit_bias"] == {123: -100}
57+
58+
# Verify regular parameters are still passed
59+
assert captured["temperature"] == 0.5
60+
61+
62+
@pytest.mark.allow_call_model_methods
63+
@pytest.mark.asyncio
64+
async def test_openai_chatcompletions_kwargs_forwarded(monkeypatch):
65+
"""
66+
Test that kwargs from ModelSettings are forwarded to OpenAI chat completions API.
67+
"""
68+
captured: dict[str, object] = {}
69+
70+
class MockChatCompletions:
71+
async def create(self, **kwargs):
72+
captured.update(kwargs)
73+
msg = ChatCompletionMessage(role="assistant", content="test response")
74+
choice = Choice(index=0, message=msg, finish_reason="stop")
75+
return ChatCompletion(
76+
id="test-id",
77+
created=0,
78+
model="gpt-4",
79+
object="chat.completion",
80+
choices=[choice],
81+
usage=CompletionUsage(completion_tokens=5, prompt_tokens=10, total_tokens=15),
82+
)
83+
84+
class MockChat:
85+
def __init__(self):
86+
self.completions = MockChatCompletions()
87+
88+
class MockClient:
89+
def __init__(self):
90+
self.chat = MockChat()
91+
self.base_url = "https://api.openai.com/v1"
92+
93+
settings = ModelSettings(
94+
temperature=0.7,
95+
kwargs={"seed": 123, "logit_bias": {456: 10}, "stop": ["STOP", "END"], "user": "test-user"},
96+
)
97+
98+
mock_client = MockClient()
99+
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=mock_client) # type: ignore
100+
101+
await model.get_response(
102+
system_instructions="Test system",
103+
input="test input",
104+
model_settings=settings,
105+
tools=[],
106+
output_schema=None,
107+
handoffs=[],
108+
tracing=ModelTracing.DISABLED,
109+
previous_response_id=None,
110+
)
111+
112+
# Verify that all kwargs were passed through
113+
assert captured["seed"] == 123
114+
assert captured["logit_bias"] == {456: 10}
115+
assert captured["stop"] == ["STOP", "END"]
116+
assert captured["user"] == "test-user"
117+
118+
# Verify regular parameters are still passed
119+
assert captured["temperature"] == 0.7
120+
121+
122+
@pytest.mark.allow_call_model_methods
123+
@pytest.mark.asyncio
124+
async def test_empty_kwargs_handling(monkeypatch):
125+
"""
126+
Test that empty or None kwargs are handled gracefully.
127+
"""
128+
captured: dict[str, object] = {}
129+
130+
async def fake_acompletion(model, messages=None, **kwargs):
131+
captured.update(kwargs)
132+
msg = Message(role="assistant", content="test response")
133+
choice = Choices(index=0, message=msg)
134+
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
135+
136+
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
137+
138+
# Test with None kwargs
139+
settings_none = ModelSettings(temperature=0.5, kwargs=None)
140+
model = LitellmModel(model="test-model")
141+
142+
await model.get_response(
143+
system_instructions=None,
144+
input="test input",
145+
model_settings=settings_none,
146+
tools=[],
147+
output_schema=None,
148+
handoffs=[],
149+
tracing=ModelTracing.DISABLED,
150+
previous_response_id=None,
151+
)
152+
153+
# Should work without error and include regular parameters
154+
assert captured["temperature"] == 0.5
155+
156+
# Test with empty dict
157+
captured.clear()
158+
settings_empty = ModelSettings(temperature=0.3, kwargs={})
159+
160+
await model.get_response(
161+
system_instructions=None,
162+
input="test input",
163+
model_settings=settings_empty,
164+
tools=[],
165+
output_schema=None,
166+
handoffs=[],
167+
tracing=ModelTracing.DISABLED,
168+
previous_response_id=None,
169+
)
170+
171+
# Should work without error and include regular parameters
172+
assert captured["temperature"] == 0.3

0 commit comments

Comments
 (0)