@@ -199,20 +199,21 @@ def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
199199 with Client () as client :
200200 llm = client .llm .model (model_id )
201201 chat = Chat ()
202- chat .add_user_message ("What is the sum of 123 and 3210?" )
202+ # Ensure the first response is a combination of text and tool use requests
203+ chat .add_user_message ("First say 'Hi'. Then calculate 1 + 3 with the tool." )
203204 tools = [ADDITION_TOOL_SPEC ]
204205 round_starts : list [int ] = []
205206 round_ends : list [int ] = []
206207 first_tokens : list [int ] = []
207208 predictions : list [PredictionRoundResult ] = []
208209 fragments : list [LlmPredictionFragment ] = []
209- last_fragment_round_index = 0
210+ fragment_round_indices : set [ int ] = set ()
210211
211212 def _append_fragment (f : LlmPredictionFragment , round_index : int ) -> None :
212- nonlocal last_fragment_round_index
213+ last_fragment_round_index = max ( fragment_round_indices , default = - 1 )
213214 assert round_index >= last_fragment_round_index
214215 fragments .append (f )
215- last_fragment_round_index = round_index
216+ fragment_round_indices . add ( round_index )
216217
217218 # TODO: Also check on_prompt_processing_progress and handling invalid messages
218219 # (although it isn't clear how to provoke calls to the latter without mocking)
@@ -233,8 +234,9 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
233234 assert round_starts == sequential_round_indices
234235 assert round_ends == sequential_round_indices
235236 expected_token_indices = [p .round_index for p in predictions if p .content ]
237+ assert expected_token_indices == sequential_round_indices
236238 assert first_tokens == expected_token_indices
237- assert last_fragment_round_index == num_rounds - 1
239+ assert fragment_round_indices == set ( expected_token_indices )
238240 assert len (chat ._messages ) == 2 * num_rounds # No tool results in last round
239241
240242 cloned_chat = chat .copy ()
0 commit comments