Skip to content
134 changes: 103 additions & 31 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,23 @@ def __init__(
description: A description of the function.
approval_mode: Whether or not approval is required to run this tool.
Default is that approval is NOT required (``"never_require"``).
max_invocations: The maximum number of times this function can be invoked.
If None, there is no limit. Should be at least 1.
max_invocations: The maximum number of times this function can be invoked
across the **lifetime of this tool instance**. If None (default),
there is no limit. Should be at least 1. If the tool is called multiple
times in one iteration, those will execute, after that it will stop working. For example,
if max_invocations is 3 and the tool is called 5 times in a single iteration,
these will complete, but any subsequent calls to the tool (in the same or future iterations)
will raise a ToolException.

.. note::
This counter lives on the tool instance and is never automatically
reset. For module-level or singleton tools in long-running
applications, the counter accumulates across all requests. Use
:attr:`invocation_count` to inspect or reset the counter manually,
or consider using
``FunctionInvocationConfiguration["max_function_calls"]``
for per-request limits instead.

max_invocation_exceptions: The maximum number of exceptions allowed during invocations.
If None, there is no limit. Should be at least 1.
additional_properties: Additional properties to set on the function.
Expand Down Expand Up @@ -1130,8 +1145,10 @@ def tool(
function's signature. Defaults to ``None`` (infer from signature).
approval_mode: Whether or not approval is required to run this tool.
Default is that approval is NOT required (``"never_require"``).
max_invocations: The maximum number of times this function can be invoked.
If None, there is no limit, should be at least 1.
max_invocations: The maximum number of times this function can be invoked
across the **lifetime of this tool instance**. If None (default), there is
no limit. Should be at least 1. For per-request limits, use
``FunctionInvocationConfiguration["max_function_calls"]`` instead.
max_invocation_exceptions: The maximum number of exceptions allowed during invocations.
If None, there is no limit, should be at least 1.
additional_properties: Additional properties to set on the function.
Expand Down Expand Up @@ -1247,43 +1264,54 @@ def wrapper(f: Callable[..., Any]) -> FunctionTool:
class FunctionInvocationConfiguration(TypedDict, total=False):
"""Configuration for function invocation in chat clients.

The configuration controls the tool execution loop that runs when the model
requests function calls. Key settings:

- ``enabled``: Master switch for the function invocation loop.
- ``max_iterations``: Limits the number of **LLM roundtrips** (iterations).
Each iteration may execute one or more function calls in parallel, so
this does *not* directly limit the total number of function executions.
- ``max_function_calls``: Limits the **total number of individual function
invocations** across all iterations within a single request. This is the
primary knob for controlling cost and preventing runaway tool usage. When
the limit is reached, the loop stops invoking tools and forces the model
to produce a text response. Default is ``None`` (unlimited).

This is a **best-effort** limit: it is checked *after* each batch of
parallel tool calls completes, not before. If the model requests 20
parallel calls in a single iteration and the limit is 10, all 20 will
execute before the loop stops.
- ``max_consecutive_errors_per_request``: How many consecutive errors
before abandoning the tool loop for this request.
- ``terminate_on_unknown_calls``: Whether to raise an error when the model
requests a function that is not in the tool map.
- ``additional_tools``: Extra tools available during execution but not
advertised to the model in the tool list.
- ``include_detailed_errors``: Whether to include exception details in the
function result returned to the model.

Note:
``max_iterations`` and ``max_function_calls`` serve complementary purposes.
``max_iterations`` caps the number of model round-trips regardless of how
many tools are called per trip. ``max_function_calls`` caps the cumulative
number of individual tool executions regardless of how they are distributed
across iterations.

Example:
.. code-block:: python

from agent_framework.openai import OpenAIChatClient

# Create an OpenAI chat client
client = OpenAIChatClient(api_key="your_api_key")

# Disable function invocation
client.function_invocation_configuration["enabled"] = False

# Set maximum iterations to 10
client.function_invocation_configuration["max_iterations"] = 10

# Enable termination on unknown function calls
client.function_invocation_configuration["terminate_on_unknown_calls"] = True

# Add additional tools for function execution
client.function_invocation_configuration["additional_tools"] = [my_custom_tool]

# Enable detailed error information in function results
client.function_invocation_configuration["include_detailed_errors"] = True

# You can also create a new configuration dict if needed
new_config: FunctionInvocationConfiguration = {
"enabled": True,
"max_iterations": 20,
"terminate_on_unknown_calls": False,
"additional_tools": [another_tool],
"include_detailed_errors": False,
}

# and then assign it to the client
client.function_invocation_configuration = new_config
# Limit to 5 LLM roundtrips and 20 total function executions
client.function_invocation_configuration["max_iterations"] = 5
client.function_invocation_configuration["max_function_calls"] = 20
"""

enabled: bool
max_iterations: int
max_function_calls: int | None
max_consecutive_errors_per_request: int
terminate_on_unknown_calls: bool
additional_tools: Sequence[FunctionTool]
Expand All @@ -1296,6 +1324,7 @@ def normalize_function_invocation_configuration(
normalized: FunctionInvocationConfiguration = {
"enabled": True,
"max_iterations": DEFAULT_MAX_ITERATIONS,
"max_function_calls": None,
"max_consecutive_errors_per_request": DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST,
"terminate_on_unknown_calls": False,
"additional_tools": [],
Expand All @@ -1305,6 +1334,8 @@ def normalize_function_invocation_configuration(
normalized.update(config)
if normalized["max_iterations"] < 1:
raise ValueError("max_iterations must be at least 1.")
if normalized["max_function_calls"] is not None and normalized["max_function_calls"] < 1:
raise ValueError("max_function_calls must be at least 1 or None.")
if normalized["max_consecutive_errors_per_request"] < 0:
raise ValueError("max_consecutive_errors_per_request must be 0 or more.")
if normalized["additional_tools"] is None:
Expand Down Expand Up @@ -1816,13 +1847,15 @@ class FunctionRequestResult(TypedDict, total=False):
result_message: The message containing function call results, if any.
update_role: The role to update for the next message, if any.
function_call_results: The list of function call results, if any.
function_call_count: The number of function calls executed in this processing step.
"""

action: Literal["return", "continue", "stop"]
errors_in_a_row: int
result_message: Message | None
update_role: Literal["assistant", "tool"] | None
function_call_results: list[Content] | None
function_call_count: int


def _handle_function_call_results(
Expand Down Expand Up @@ -1913,6 +1946,7 @@ async def _process_function_requests(
max_errors,
)
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
executed_count = sum(1 for r in approved_function_results if r.type == "function_result")
# Continue to call chat client with updated messages (containing function results)
# so it can generate the final response
return {
Expand All @@ -1921,6 +1955,7 @@ async def _process_function_requests(
"result_message": None,
"update_role": None,
"function_call_results": None,
"function_call_count": executed_count,
}

if response is None or fcc_messages is None:
Expand All @@ -1930,6 +1965,7 @@ async def _process_function_requests(
"result_message": None,
"update_role": None,
"function_call_results": None,
"function_call_count": 0,
}

tools = _extract_tools(tool_options)
Expand All @@ -1942,6 +1978,7 @@ async def _process_function_requests(
"result_message": None,
"update_role": None,
"function_call_results": None,
"function_call_count": 0,
}

function_call_results, should_terminate, had_errors = await execute_function_calls(
Expand All @@ -1958,6 +1995,7 @@ async def _process_function_requests(
max_errors=max_errors,
)
result["function_call_results"] = list(function_call_results)
result["function_call_count"] = sum(1 for r in function_call_results if r.type == "function_result")
# If middleware requested termination, change action to return
if should_terminate:
result["action"] = "return"
Expand Down Expand Up @@ -2071,6 +2109,8 @@ async def _get_response() -> ChatResponse:
nonlocal mutable_options
nonlocal filtered_kwargs
errors_in_a_row: int = 0
total_function_calls: int = 0
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse | None = None
Expand All @@ -2094,6 +2134,7 @@ async def _get_response() -> ChatResponse:
response = ChatResponse(messages=prepped_messages)
break
errors_in_a_row = approval_result["errors_in_a_row"]
total_function_calls += approval_result.get("function_call_count", 0)

response = await super_get_response(
messages=prepped_messages,
Expand All @@ -2118,10 +2159,24 @@ async def _get_response() -> ChatResponse:
)
if result["action"] == "return":
return response
total_function_calls += result.get("function_call_count", 0)
if result["action"] == "stop":
# Error threshold reached: force a final non-tool turn so
# function_call_output items are submitted before exit.
mutable_options["tool_choice"] = "none"
elif (
max_function_calls is not None
and total_function_calls >= max_function_calls
):
# Best-effort limit: checked after each batch of parallel calls completes,
# so the current batch always runs to completion even if it overshoots.
logger.info(
"Maximum function calls reached (%d/%d). "
"Stopping further function calls for this request.",
total_function_calls,
max_function_calls,
)
mutable_options["tool_choice"] = "none"
errors_in_a_row = result["errors_in_a_row"]

# When tool_choice is 'required', reset tool_choice after one iteration to avoid infinite loops
Expand Down Expand Up @@ -2167,6 +2222,8 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
nonlocal mutable_options
nonlocal stream_result_hooks
errors_in_a_row: int = 0
total_function_calls: int = 0
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse | None = None
Expand All @@ -2187,6 +2244,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
execute_function_calls=execute_function_calls,
)
errors_in_a_row = approval_result["errors_in_a_row"]
total_function_calls += approval_result.get("function_call_count", 0)
if approval_result["action"] == "stop":
mutable_options["tool_choice"] = "none"
return
Expand Down Expand Up @@ -2232,6 +2290,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
execute_function_calls=execute_function_calls,
)
errors_in_a_row = result["errors_in_a_row"]
total_function_calls += result.get("function_call_count", 0)
if role := result["update_role"]:
yield ChatResponseUpdate(
contents=result["function_call_results"] or [],
Expand All @@ -2243,6 +2302,19 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
mutable_options["tool_choice"] = "none"
elif result["action"] != "continue":
return
elif (
max_function_calls is not None
and total_function_calls >= max_function_calls
):
# Best-effort limit: checked after each batch of parallel calls completes,
# so the current batch always runs to completion even if it overshoots.
logger.info(
"Maximum function calls reached (%d/%d). "
"Stopping further function calls for this request.",
total_function_calls,
max_function_calls,
)
mutable_options["tool_choice"] = "none"

# When tool_choice is 'required', reset the tool_choice after one iteration to avoid infinite loops
if mutable_options.get("tool_choice") == "required" or (
Expand Down
Loading