diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index df0608e614..3ec167d4f7 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -2164,15 +2164,11 @@ async def _get_response() -> ChatResponse: # Error threshold reached: force a final non-tool turn so # function_call_output items are submitted before exit. mutable_options["tool_choice"] = "none" - elif ( - max_function_calls is not None - and total_function_calls >= max_function_calls - ): + elif max_function_calls is not None and total_function_calls >= max_function_calls: # Best-effort limit: checked after each batch of parallel calls completes, # so the current batch always runs to completion even if it overshoots. logger.info( - "Maximum function calls reached (%d/%d). " - "Stopping further function calls for this request.", + "Maximum function calls reached (%d/%d). Stopping further function calls for this request.", total_function_calls, max_function_calls, ) @@ -2196,9 +2192,15 @@ async def _get_response() -> ChatResponse: prepped_messages.extend(response.messages) continue - if response is not None: - return response - + # Loop exhausted all iterations (or function invocation disabled). + # Make a final model call with tool_choice="none" so the model + # produces a plain text answer instead of leaving orphaned + # function_call items without matching results. + if response is not None and self.function_invocation_configuration["enabled"]: + logger.info( + "Maximum iterations reached (%d). Requesting final response without tools.", + self.function_invocation_configuration["max_iterations"], + ) mutable_options["tool_choice"] = "none" response = await super_get_response( messages=prepped_messages, @@ -2302,15 +2304,11 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: mutable_options["tool_choice"] = "none" elif result["action"] != "continue": return - elif ( - max_function_calls is not None - and total_function_calls >= max_function_calls - ): + elif max_function_calls is not None and total_function_calls >= max_function_calls: # Best-effort limit: checked after each batch of parallel calls completes, # so the current batch always runs to completion even if it overshoots. logger.info( - "Maximum function calls reached (%d/%d). " - "Stopping further function calls for this request.", + "Maximum function calls reached (%d/%d). Stopping further function calls for this request.", total_function_calls, max_function_calls, ) @@ -2333,9 +2331,15 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: prepped_messages.extend(response.messages) continue - if response is not None: - return - + # Loop exhausted all iterations (or function invocation disabled). + # Make a final model call with tool_choice="none" so the model + # produces a plain text answer instead of leaving orphaned + # function_call items without matching results. + if response is not None and self.function_invocation_configuration["enabled"]: + logger.info( + "Maximum iterations reached (%d). Requesting final response without tools.", + self.function_invocation_configuration["max_iterations"], + ) mutable_options["tool_choice"] = "none" inner_stream = await _ensure_response_stream( super_get_response( diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index b4213d6029..f278afaeac 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -831,8 +831,6 @@ def func_with_approval(arg1: str) -> str: assert "rejected" in rejection_result.result.lower() -@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API") -@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API") async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse): """Test that MAX_ITERATIONS in additional_properties limits function call loops.""" exec_counter = 0 @@ -880,6 +878,256 @@ def ai_func(arg1: str) -> str: assert response.messages[-1].text == "I broke out of the function invocation loop..." # Failsafe response +async def test_max_iterations_no_orphaned_function_calls(chat_client_base: SupportsChatGetResponse): + """When max_iterations is reached, verify the returned response has no orphaned + FunctionCallContent (i.e., every function_call has a matching function_result). + """ + exec_counter = 0 + + @tool(name="test_function", approval_mode="never_require") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + # Model keeps requesting tool calls on every iteration + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}') + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}') + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_3", name="test_function", arguments='{"arg1": "v3"}') + ], + ) + ), + ] + + chat_client_base.function_invocation_configuration["max_iterations"] = 2 + + response = await chat_client_base.get_response( + [Message(role="user", text="hello")], + options={"tool_choice": "auto", "tools": [ai_func]}, + ) + + # Collect all function_call and function_result call_ids from response + all_call_ids = set() + all_result_ids = set() + for msg in response.messages: + for content in msg.contents: + if content.type == "function_call": + all_call_ids.add(content.call_id) + elif content.type == "function_result": + all_result_ids.add(content.call_id) + + orphaned_calls = all_call_ids - all_result_ids + assert not orphaned_calls, ( + f"Response contains orphaned FunctionCallContent without matching " + f"FunctionResultContent: {orphaned_calls}." + ) + + +async def test_max_iterations_makes_final_toolchoice_none_call(chat_client_base: SupportsChatGetResponse): + """When max_iterations is reached, verify a final model call is made with + tool_choice='none' to produce a clean text response. + """ + exec_counter = 0 + + @tool(name="test_function", approval_mode="never_require") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}') + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}') + ], + ) + ), + # This response should be reached via failsafe (tool_choice="none") + ChatResponse(messages=Message(role="assistant", text="Final answer after giving up on tools.")), + ] + + chat_client_base.function_invocation_configuration["max_iterations"] = 1 + + response = await chat_client_base.get_response( + [Message(role="user", text="hello")], + options={"tool_choice": "auto", "tools": [ai_func]}, + ) + + assert exec_counter == 1, f"Expected 1 function execution, got {exec_counter}" + + # The response should end with a plain text message (from the failsafe call) + last_msg = response.messages[-1] + has_function_calls = any(c.type == "function_call" for c in last_msg.contents) + + assert not has_function_calls, ( + f"Last message in response still contains function_call items. " + f"Expected a clean text response after max_iterations failsafe. " + f"Got message with role={last_msg.role}, contents={[c.type for c in last_msg.contents]}" + ) + + # The mock client returns "I broke out of the function invocation loop..." + # when tool_choice="none" + assert last_msg.text == "I broke out of the function invocation loop...", ( + f"Expected failsafe text response, got: {last_msg.text!r}" + ) + + +async def test_max_iterations_preserves_all_fcc_messages(chat_client_base: SupportsChatGetResponse): + """When max_iterations is reached and a final response is produced, all + intermediate function call/result messages should be included. + """ + exec_counter = 0 + + @tool(name="test_function", approval_mode="never_require") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Result {exec_counter}" + + # Two iterations of function calls, then failsafe + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}') + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}') + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="Done")), + ] + + chat_client_base.function_invocation_configuration["max_iterations"] = 2 + + response = await chat_client_base.get_response( + [Message(role="user", text="hello")], + options={"tool_choice": "auto", "tools": [ai_func]}, + ) + + assert exec_counter == 2, f"Expected 2 function executions, got {exec_counter}" + + # All function calls from both iterations should be present in the response + all_call_ids = set() + all_result_ids = set() + for msg in response.messages: + for content in msg.contents: + if content.type == "function_call": + all_call_ids.add(content.call_id) + elif content.type == "function_result": + all_result_ids.add(content.call_id) + + assert "call_1" in all_call_ids, "First iteration's function call missing from response" + assert "call_2" in all_call_ids, "Second iteration's function call missing from response" + + assert all_call_ids == all_result_ids, ( + f"Mismatched function calls and results. Calls: {all_call_ids}, Results: {all_result_ids}" + ) + + +async def test_max_iterations_thread_integrity_with_agent(chat_client_base: SupportsChatGetResponse): + """Verify that agent.run() does not produce orphaned function calls after + max_iterations, which would corrupt the thread and cause API errors on the + next call. + """ + + @tool(name="browser_snapshot", approval_mode="never_require") + def browser_snapshot(url: str) -> str: + return f"Screenshot of {url}" + + # Model keeps requesting tool calls on every iteration. + # The failsafe call (with tool_choice="none") after the loop is handled + # automatically by the mock client, which returns a hardcoded text response + # when tool_choice="none" (see conftest.py ChatClientBase.get_response). + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_abc", name="browser_snapshot", arguments='{"url": "https://example.com"}' + ) + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_xyz", name="browser_snapshot", arguments='{"url": "https://test.com"}' + ) + ], + ) + ), + ] + + chat_client_base.function_invocation_configuration["max_iterations"] = 2 + + agent = Agent( + client=chat_client_base, + name="test-agent", + tools=[browser_snapshot], + ) + + response = await agent.run( + "Take screenshots", + options={"tool_choice": "auto"}, + ) + + # Check for orphaned function calls in the response messages + all_call_ids = set() + all_result_ids = set() + for msg in response.messages: + for content in msg.contents: + if content.type == "function_call": + all_call_ids.add(content.call_id) + elif content.type == "function_result": + all_result_ids.add(content.call_id) + + orphaned_calls = all_call_ids - all_result_ids + assert not orphaned_calls, ( + f"Response contains orphaned function calls {orphaned_calls}. " + f"This would cause API errors on the next call." + ) + + @pytest.mark.parametrize("max_iterations", [10]) async def test_max_function_calls_limits_parallel_invocations(chat_client_base: SupportsChatGetResponse): """Test that max_function_calls caps total function invocations across iterations with parallel calls.""" @@ -2248,7 +2496,6 @@ def func_with_approval(arg1: str) -> str: assert exec_counter == 0 # Function not executed yet due to approval requirement -@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API") async def test_streaming_max_iterations_limit(chat_client_base: SupportsChatGetResponse): """Test that MAX_ITERATIONS in streaming mode limits function call loops.""" exec_counter = 0