Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,15 +2164,11 @@ async def _get_response() -> ChatResponse:
# Error threshold reached: force a final non-tool turn so
# function_call_output items are submitted before exit.
mutable_options["tool_choice"] = "none"
elif (
max_function_calls is not None
and total_function_calls >= max_function_calls
):
elif max_function_calls is not None and total_function_calls >= max_function_calls:
# Best-effort limit: checked after each batch of parallel calls completes,
# so the current batch always runs to completion even if it overshoots.
logger.info(
"Maximum function calls reached (%d/%d). "
"Stopping further function calls for this request.",
"Maximum function calls reached (%d/%d). Stopping further function calls for this request.",
total_function_calls,
max_function_calls,
)
Expand All @@ -2196,9 +2192,15 @@ async def _get_response() -> ChatResponse:
prepped_messages.extend(response.messages)
continue

if response is not None:
return response

# Loop exhausted all iterations (or function invocation disabled).
# Make a final model call with tool_choice="none" so the model
# produces a plain text answer instead of leaving orphaned
# function_call items without matching results.
if response is not None and self.function_invocation_configuration["enabled"]:
logger.info(
"Maximum iterations reached (%d). Requesting final response without tools.",
self.function_invocation_configuration["max_iterations"],
)
mutable_options["tool_choice"] = "none"
response = await super_get_response(
messages=prepped_messages,
Expand Down Expand Up @@ -2302,15 +2304,11 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
mutable_options["tool_choice"] = "none"
elif result["action"] != "continue":
return
elif (
max_function_calls is not None
and total_function_calls >= max_function_calls
):
elif max_function_calls is not None and total_function_calls >= max_function_calls:
# Best-effort limit: checked after each batch of parallel calls completes,
# so the current batch always runs to completion even if it overshoots.
logger.info(
"Maximum function calls reached (%d/%d). "
"Stopping further function calls for this request.",
"Maximum function calls reached (%d/%d). Stopping further function calls for this request.",
total_function_calls,
max_function_calls,
)
Expand All @@ -2333,9 +2331,15 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
prepped_messages.extend(response.messages)
continue

if response is not None:
return

# Loop exhausted all iterations (or function invocation disabled).
# Make a final model call with tool_choice="none" so the model
# produces a plain text answer instead of leaving orphaned
# function_call items without matching results.
if response is not None and self.function_invocation_configuration["enabled"]:
logger.info(
"Maximum iterations reached (%d). Requesting final response without tools.",
self.function_invocation_configuration["max_iterations"],
)
mutable_options["tool_choice"] = "none"
inner_stream = await _ensure_response_stream(
super_get_response(
Expand Down
253 changes: 250 additions & 3 deletions python/packages/core/tests/core/test_function_invocation_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,6 @@ def func_with_approval(arg1: str) -> str:
assert "rejected" in rejection_result.result.lower()


@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API")
@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API")
async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
"""Test that MAX_ITERATIONS in additional_properties limits function call loops."""
exec_counter = 0
Expand Down Expand Up @@ -880,6 +878,256 @@ def ai_func(arg1: str) -> str:
assert response.messages[-1].text == "I broke out of the function invocation loop..." # Failsafe response


async def test_max_iterations_no_orphaned_function_calls(chat_client_base: SupportsChatGetResponse):
"""When max_iterations is reached, verify the returned response has no orphaned
FunctionCallContent (i.e., every function_call has a matching function_result).
"""
exec_counter = 0

@tool(name="test_function", approval_mode="never_require")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}"

# Model keeps requesting tool calls on every iteration
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}')
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}')
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_3", name="test_function", arguments='{"arg1": "v3"}')
],
)
),
]

chat_client_base.function_invocation_configuration["max_iterations"] = 2

response = await chat_client_base.get_response(
[Message(role="user", text="hello")],
options={"tool_choice": "auto", "tools": [ai_func]},
)

# Collect all function_call and function_result call_ids from response
all_call_ids = set()
all_result_ids = set()
for msg in response.messages:
for content in msg.contents:
if content.type == "function_call":
all_call_ids.add(content.call_id)
elif content.type == "function_result":
all_result_ids.add(content.call_id)

orphaned_calls = all_call_ids - all_result_ids
assert not orphaned_calls, (
f"Response contains orphaned FunctionCallContent without matching "
f"FunctionResultContent: {orphaned_calls}."
)


async def test_max_iterations_makes_final_toolchoice_none_call(chat_client_base: SupportsChatGetResponse):
"""When max_iterations is reached, verify a final model call is made with
tool_choice='none' to produce a clean text response.
"""
exec_counter = 0

@tool(name="test_function", approval_mode="never_require")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}"

chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}')
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}')
],
)
),
# This response should be reached via failsafe (tool_choice="none")
ChatResponse(messages=Message(role="assistant", text="Final answer after giving up on tools.")),
]

chat_client_base.function_invocation_configuration["max_iterations"] = 1

response = await chat_client_base.get_response(
[Message(role="user", text="hello")],
options={"tool_choice": "auto", "tools": [ai_func]},
)

assert exec_counter == 1, f"Expected 1 function execution, got {exec_counter}"

# The response should end with a plain text message (from the failsafe call)
last_msg = response.messages[-1]
has_function_calls = any(c.type == "function_call" for c in last_msg.contents)

assert not has_function_calls, (
f"Last message in response still contains function_call items. "
f"Expected a clean text response after max_iterations failsafe. "
f"Got message with role={last_msg.role}, contents={[c.type for c in last_msg.contents]}"
)

# The mock client returns "I broke out of the function invocation loop..."
# when tool_choice="none"
assert last_msg.text == "I broke out of the function invocation loop...", (
f"Expected failsafe text response, got: {last_msg.text!r}"
)


async def test_max_iterations_preserves_all_fcc_messages(chat_client_base: SupportsChatGetResponse):
"""When max_iterations is reached and a final response is produced, all
intermediate function call/result messages should be included.
"""
exec_counter = 0

@tool(name="test_function", approval_mode="never_require")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Result {exec_counter}"

# Two iterations of function calls, then failsafe
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}')
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}')
],
)
),
ChatResponse(messages=Message(role="assistant", text="Done")),
]

chat_client_base.function_invocation_configuration["max_iterations"] = 2

response = await chat_client_base.get_response(
[Message(role="user", text="hello")],
options={"tool_choice": "auto", "tools": [ai_func]},
)

assert exec_counter == 2, f"Expected 2 function executions, got {exec_counter}"

# All function calls from both iterations should be present in the response
all_call_ids = set()
all_result_ids = set()
for msg in response.messages:
for content in msg.contents:
if content.type == "function_call":
all_call_ids.add(content.call_id)
elif content.type == "function_result":
all_result_ids.add(content.call_id)

assert "call_1" in all_call_ids, "First iteration's function call missing from response"
assert "call_2" in all_call_ids, "Second iteration's function call missing from response"

assert all_call_ids == all_result_ids, (
f"Mismatched function calls and results. Calls: {all_call_ids}, Results: {all_result_ids}"
)


async def test_max_iterations_thread_integrity_with_agent(chat_client_base: SupportsChatGetResponse):
"""Verify that agent.run() does not produce orphaned function calls after
max_iterations, which would corrupt the thread and cause API errors on the
next call.
"""

@tool(name="browser_snapshot", approval_mode="never_require")
def browser_snapshot(url: str) -> str:
return f"Screenshot of {url}"

# Model keeps requesting tool calls on every iteration.
# The failsafe call (with tool_choice="none") after the loop is handled
# automatically by the mock client, which returns a hardcoded text response
# when tool_choice="none" (see conftest.py ChatClientBase.get_response).
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_abc", name="browser_snapshot", arguments='{"url": "https://example.com"}'
)
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_xyz", name="browser_snapshot", arguments='{"url": "https://test.com"}'
)
],
)
),
]

chat_client_base.function_invocation_configuration["max_iterations"] = 2

agent = Agent(
client=chat_client_base,
name="test-agent",
tools=[browser_snapshot],
)

response = await agent.run(
"Take screenshots",
options={"tool_choice": "auto"},
)

# Check for orphaned function calls in the response messages
all_call_ids = set()
all_result_ids = set()
for msg in response.messages:
for content in msg.contents:
if content.type == "function_call":
all_call_ids.add(content.call_id)
elif content.type == "function_result":
all_result_ids.add(content.call_id)

orphaned_calls = all_call_ids - all_result_ids
assert not orphaned_calls, (
f"Response contains orphaned function calls {orphaned_calls}. "
f"This would cause API errors on the next call."
)


@pytest.mark.parametrize("max_iterations", [10])
async def test_max_function_calls_limits_parallel_invocations(chat_client_base: SupportsChatGetResponse):
"""Test that max_function_calls caps total function invocations across iterations with parallel calls."""
Expand Down Expand Up @@ -2248,7 +2496,6 @@ def func_with_approval(arg1: str) -> str:
assert exec_counter == 0 # Function not executed yet due to approval requirement


@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API")
async def test_streaming_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
"""Test that MAX_ITERATIONS in streaming mode limits function call loops."""
exec_counter = 0
Expand Down