diff --git a/src/agents/realtime/_tool_validation.py b/src/agents/realtime/_tool_validation.py new file mode 100644 index 0000000000..fc4c4daa2b --- /dev/null +++ b/src/agents/realtime/_tool_validation.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections import Counter +from collections.abc import Iterable +from typing import Any + +from ..exceptions import UserError +from ..handoffs import Handoff +from ..tool import FunctionTool, Tool + + +def validate_realtime_tool_names( + tools: Iterable[Tool], + handoffs: Iterable[Handoff[Any, Any]], +) -> None: + """Ensure all model-visible Realtime tool names are unambiguous.""" + sources_by_name: dict[str, list[str]] = {} + + for tool in tools: + if isinstance(tool, FunctionTool): + sources_by_name.setdefault(tool.name, []).append("function tool") + + for handoff in handoffs: + sources_by_name.setdefault(handoff.tool_name, []).append("handoff") + + duplicate_descriptions = [ + f"{name!r} ({_format_sources(sources)})" + for name, sources in sorted(sources_by_name.items()) + if len(sources) > 1 + ] + if not duplicate_descriptions: + return + + plural = "name" if len(duplicate_descriptions) == 1 else "names" + raise UserError( + f"Duplicate Realtime tool {plural} found: {', '.join(duplicate_descriptions)}. " + "Realtime function tool and handoff names must be unique. Rename one of them " + "before starting the session." + ) + + +def _format_sources(sources: list[str]) -> str: + parts = [_format_source_count(source, count) for source, count in Counter(sources).items()] + if len(parts) == 1: + return parts[0] + if len(parts) == 2: + return f"{parts[0]} and {parts[1]}" + return f"{', '.join(parts[:-1])}, and {parts[-1]}" + + +def _format_source_count(source: str, count: int) -> str: + if count == 1: + return source + return f"{count} {source}s" diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 6b986c6edc..75287d50b8 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -100,6 +100,7 @@ from ..logger import logger from ..run_context import RunContextWrapper, TContext from ..version import __version__ +from ._tool_validation import validate_realtime_tool_names from .agent import RealtimeAgent from .config import ( RealtimeModelTracingConfig, @@ -1509,6 +1510,8 @@ def _tools_to_session_tools( ) ) + validate_realtime_tool_names(tools, handoffs) + for handoff in handoffs: converted_tools.append( OpenAISessionFunction( diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index ca809dd9c4..cf0e12d6c5 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -25,6 +25,7 @@ from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool from ..tool_context import ToolContext from ..util._approvals import evaluate_needs_approval_setting +from ._tool_validation import validate_realtime_tool_names from .agent import RealtimeAgent from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput from .events import ( @@ -741,6 +742,7 @@ async def _handle_tool_call( agent.get_all_tools(self._context_wrapper), self._get_handoffs(agent, self._context_wrapper), ) + validate_realtime_tool_names(tools, handoffs) function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} handoff_map = {handoff.tool_name: handoff for handoff in handoffs} @@ -1277,6 +1279,11 @@ async def _get_updated_model_settings_from_agent( if starting_settings: updated_settings.update(starting_settings) + validate_realtime_tool_names( + updated_settings.get("tools", []), + updated_settings.get("handoffs", []), + ) + disable_tracing = self._run_config.get("tracing_disabled", False) if disable_tracing: updated_settings["tracing"] = None diff --git a/tests/realtime/test_openai_realtime_conversions.py b/tests/realtime/test_openai_realtime_conversions.py index bc98fe3c4e..b5b891e08a 100644 --- a/tests/realtime/test_openai_realtime_conversions.py +++ b/tests/realtime/test_openai_realtime_conversions.py @@ -128,6 +128,45 @@ def test_tools_to_session_tools_includes_handoffs(): assert out[0].name is not None and out[0].name.startswith("transfer_to_") +def test_tools_to_session_tools_rejects_duplicate_function_tool_names(): + tool_one = function_tool(lambda: "one", name_override="lookup_account") + tool_two = function_tool(lambda: "two", name_override="lookup_account") + m = OpenAIRealtimeWebSocketModel() + + with pytest.raises( + UserError, + match=("Duplicate Realtime tool name found: 'lookup_account' \\(2 function tools\\)"), + ): + m._tools_to_session_tools([tool_one, tool_two], []) + + +def test_tools_to_session_tools_rejects_function_handoff_name_conflict(): + tool = function_tool(lambda: "ok", name_override="transfer_to_billing") + h = handoff(Agent(name="billing"), tool_name_override="transfer_to_billing") + m = OpenAIRealtimeWebSocketModel() + + with pytest.raises( + UserError, + match=( + "Duplicate Realtime tool name found: " + "'transfer_to_billing' \\(function tool and handoff\\)" + ), + ): + m._tools_to_session_tools([tool], [h]) + + +def test_tools_to_session_tools_rejects_duplicate_handoff_names(): + handoff_one = handoff(Agent(name="billing"), tool_name_override="transfer_to_support") + handoff_two = handoff(Agent(name="technical"), tool_name_override="transfer_to_support") + m = OpenAIRealtimeWebSocketModel() + + with pytest.raises( + UserError, + match=("Duplicate Realtime tool name found: 'transfer_to_support' \\(2 handoffs\\)"), + ): + m._tools_to_session_tools([], [handoff_one, handoff_two]) + + def test_tools_to_session_tools_rejects_namespaced_function_tools(): tool = tool_namespace( name="crm", diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index e289bc3c9e..30cf376dbb 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -62,7 +62,7 @@ ) from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output from agents.run_context import RunContextWrapper -from agents.tool import FunctionTool, tool_namespace +from agents.tool import FunctionTool, function_tool, tool_namespace from agents.tool_context import ToolContext @@ -221,6 +221,33 @@ async def test_handle_tool_call_handoff_invalid_result_raises(): ) +@pytest.mark.asyncio +async def test_handle_tool_call_rejects_ambiguous_function_handoff_name(): + model = _DummyModel() + target = RealtimeAgent(name="billing") + tool = function_tool(lambda: "ok", name_override="transfer_to_billing") + handoff = Handoff( + tool_name="transfer_to_billing", + tool_description="Transfer to billing", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=target), + input_filter=None, + agent_name=target.name, + is_enabled=True, + ) + agent = RealtimeAgent(name="agent", tools=[tool], handoffs=[handoff]) + session = RealtimeSession(model, agent, None) + + with pytest.raises(UserError, match="function tool and handoff"): + await session._handle_tool_call( + RealtimeModelToolCallEvent( + name="transfer_to_billing", + call_id="c1", + arguments="{}", + ) + ) + + @pytest.mark.asyncio async def test_on_guardrail_task_done_emits_error_event(): model = _DummyModel() @@ -275,6 +302,52 @@ async def is_enabled(ctx, agent): assert len(enabled) == 2 +@pytest.mark.asyncio +async def test_updated_model_settings_ignores_disabled_handoff_name_conflict(): + tool = function_tool(lambda: "ok", name_override="transfer_to_billing") + disabled_handoff = Handoff( + tool_name="transfer_to_billing", + tool_description="Transfer to billing", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=RealtimeAgent(name="billing")), + input_filter=None, + agent_name="billing", + is_enabled=False, + ) + agent = RealtimeAgent(name="agent", tools=[tool], handoffs=[disabled_handoff]) + session = RealtimeSession(_DummyModel(), agent, None) + + settings = await session._get_updated_model_settings_from_agent(None, agent) + + assert settings["tools"] == [tool] + assert settings["handoffs"] == [] + + +@pytest.mark.asyncio +async def test_updated_model_settings_validates_final_tool_names_after_overrides(): + agent_tool = function_tool(lambda: "ok", name_override="transfer_to_billing") + agent_handoff = Handoff( + tool_name="transfer_to_billing", + tool_description="Transfer to billing", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=RealtimeAgent(name="billing")), + input_filter=None, + agent_name="billing", + is_enabled=True, + ) + override_tool = function_tool(lambda: "ok", name_override="lookup_account") + agent = RealtimeAgent(name="agent", tools=[agent_tool], handoffs=[agent_handoff]) + session = RealtimeSession(_DummyModel(), agent, None) + + settings = await session._get_updated_model_settings_from_agent( + {"tools": [override_tool], "handoffs": []}, + agent, + ) + + assert settings["tools"] == [override_tool] + assert settings["handoffs"] == [] + + class MockRealtimeModel(RealtimeModel): def __init__(self): super().__init__()