Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/agents/realtime/_tool_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Iterable
from typing import Any

from ..exceptions import UserError
from ..handoffs import Handoff
from ..tool import FunctionTool, Tool


def validate_realtime_tool_names(
tools: Iterable[Tool],
handoffs: Iterable[Handoff[Any, Any]],
) -> None:
"""Ensure all model-visible Realtime tool names are unambiguous."""
sources_by_name: dict[str, list[str]] = {}

for tool in tools:
if isinstance(tool, FunctionTool):
sources_by_name.setdefault(tool.name, []).append("function tool")

for handoff in handoffs:
sources_by_name.setdefault(handoff.tool_name, []).append("handoff")

duplicate_descriptions = [
f"{name!r} ({_format_sources(sources)})"
for name, sources in sorted(sources_by_name.items())
if len(sources) > 1
]
if not duplicate_descriptions:
return

plural = "name" if len(duplicate_descriptions) == 1 else "names"
raise UserError(
f"Duplicate Realtime tool {plural} found: {', '.join(duplicate_descriptions)}. "
"Realtime function tool and handoff names must be unique. Rename one of them "
"before starting the session."
)


def _format_sources(sources: list[str]) -> str:
parts = [_format_source_count(source, count) for source, count in Counter(sources).items()]
if len(parts) == 1:
return parts[0]
if len(parts) == 2:
return f"{parts[0]} and {parts[1]}"
return f"{', '.join(parts[:-1])}, and {parts[-1]}"


def _format_source_count(source: str, count: int) -> str:
if count == 1:
return source
return f"{count} {source}s"
3 changes: 3 additions & 0 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from ..logger import logger
from ..run_context import RunContextWrapper, TContext
from ..version import __version__
from ._tool_validation import validate_realtime_tool_names
from .agent import RealtimeAgent
from .config import (
RealtimeModelTracingConfig,
Expand Down Expand Up @@ -1509,6 +1510,8 @@ def _tools_to_session_tools(
)
)

validate_realtime_tool_names(tools, handoffs)

for handoff in handoffs:
converted_tools.append(
OpenAISessionFunction(
Expand Down
7 changes: 7 additions & 0 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool
from ..tool_context import ToolContext
from ..util._approvals import evaluate_needs_approval_setting
from ._tool_validation import validate_realtime_tool_names
from .agent import RealtimeAgent
from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput
from .events import (
Expand Down Expand Up @@ -741,6 +742,7 @@ async def _handle_tool_call(
agent.get_all_tools(self._context_wrapper),
self._get_handoffs(agent, self._context_wrapper),
)
validate_realtime_tool_names(tools, handoffs)
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}

Expand Down Expand Up @@ -1277,6 +1279,11 @@ async def _get_updated_model_settings_from_agent(
if starting_settings:
updated_settings.update(starting_settings)

validate_realtime_tool_names(
updated_settings.get("tools", []),
updated_settings.get("handoffs", []),
)

disable_tracing = self._run_config.get("tracing_disabled", False)
if disable_tracing:
updated_settings["tracing"] = None
Expand Down
39 changes: 39 additions & 0 deletions tests/realtime/test_openai_realtime_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,45 @@ def test_tools_to_session_tools_includes_handoffs():
assert out[0].name is not None and out[0].name.startswith("transfer_to_")


def test_tools_to_session_tools_rejects_duplicate_function_tool_names():
tool_one = function_tool(lambda: "one", name_override="lookup_account")
tool_two = function_tool(lambda: "two", name_override="lookup_account")
m = OpenAIRealtimeWebSocketModel()

with pytest.raises(
UserError,
match=("Duplicate Realtime tool name found: 'lookup_account' \\(2 function tools\\)"),
):
m._tools_to_session_tools([tool_one, tool_two], [])


def test_tools_to_session_tools_rejects_function_handoff_name_conflict():
tool = function_tool(lambda: "ok", name_override="transfer_to_billing")
h = handoff(Agent(name="billing"), tool_name_override="transfer_to_billing")
m = OpenAIRealtimeWebSocketModel()

with pytest.raises(
UserError,
match=(
"Duplicate Realtime tool name found: "
"'transfer_to_billing' \\(function tool and handoff\\)"
),
):
m._tools_to_session_tools([tool], [h])


def test_tools_to_session_tools_rejects_duplicate_handoff_names():
handoff_one = handoff(Agent(name="billing"), tool_name_override="transfer_to_support")
handoff_two = handoff(Agent(name="technical"), tool_name_override="transfer_to_support")
m = OpenAIRealtimeWebSocketModel()

with pytest.raises(
UserError,
match=("Duplicate Realtime tool name found: 'transfer_to_support' \\(2 handoffs\\)"),
):
m._tools_to_session_tools([], [handoff_one, handoff_two])


def test_tools_to_session_tools_rejects_namespaced_function_tools():
tool = tool_namespace(
name="crm",
Expand Down
75 changes: 74 additions & 1 deletion tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output
from agents.run_context import RunContextWrapper
from agents.tool import FunctionTool, tool_namespace
from agents.tool import FunctionTool, function_tool, tool_namespace
from agents.tool_context import ToolContext


Expand Down Expand Up @@ -221,6 +221,33 @@ async def test_handle_tool_call_handoff_invalid_result_raises():
)


@pytest.mark.asyncio
async def test_handle_tool_call_rejects_ambiguous_function_handoff_name():
model = _DummyModel()
target = RealtimeAgent(name="billing")
tool = function_tool(lambda: "ok", name_override="transfer_to_billing")
handoff = Handoff(
tool_name="transfer_to_billing",
tool_description="Transfer to billing",
input_json_schema={},
on_invoke_handoff=AsyncMock(return_value=target),
input_filter=None,
agent_name=target.name,
is_enabled=True,
)
agent = RealtimeAgent(name="agent", tools=[tool], handoffs=[handoff])
session = RealtimeSession(model, agent, None)

with pytest.raises(UserError, match="function tool and handoff"):
await session._handle_tool_call(
RealtimeModelToolCallEvent(
name="transfer_to_billing",
call_id="c1",
arguments="{}",
)
)


@pytest.mark.asyncio
async def test_on_guardrail_task_done_emits_error_event():
model = _DummyModel()
Expand Down Expand Up @@ -275,6 +302,52 @@ async def is_enabled(ctx, agent):
assert len(enabled) == 2


@pytest.mark.asyncio
async def test_updated_model_settings_ignores_disabled_handoff_name_conflict():
tool = function_tool(lambda: "ok", name_override="transfer_to_billing")
disabled_handoff = Handoff(
tool_name="transfer_to_billing",
tool_description="Transfer to billing",
input_json_schema={},
on_invoke_handoff=AsyncMock(return_value=RealtimeAgent(name="billing")),
input_filter=None,
agent_name="billing",
is_enabled=False,
)
agent = RealtimeAgent(name="agent", tools=[tool], handoffs=[disabled_handoff])
session = RealtimeSession(_DummyModel(), agent, None)

settings = await session._get_updated_model_settings_from_agent(None, agent)

assert settings["tools"] == [tool]
assert settings["handoffs"] == []


@pytest.mark.asyncio
async def test_updated_model_settings_validates_final_tool_names_after_overrides():
agent_tool = function_tool(lambda: "ok", name_override="transfer_to_billing")
agent_handoff = Handoff(
tool_name="transfer_to_billing",
tool_description="Transfer to billing",
input_json_schema={},
on_invoke_handoff=AsyncMock(return_value=RealtimeAgent(name="billing")),
input_filter=None,
agent_name="billing",
is_enabled=True,
)
override_tool = function_tool(lambda: "ok", name_override="lookup_account")
agent = RealtimeAgent(name="agent", tools=[agent_tool], handoffs=[agent_handoff])
session = RealtimeSession(_DummyModel(), agent, None)

settings = await session._get_updated_model_settings_from_agent(
{"tools": [override_tool], "handoffs": []},
agent,
)

assert settings["tools"] == [override_tool]
assert settings["handoffs"] == []


class MockRealtimeModel(RealtimeModel):
def __init__(self):
super().__init__()
Expand Down