diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a3f4570b6e..c8ccf3c9e5 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -81,6 +81,14 @@ ) +def _get_tool_name(tool: Any) -> str | None: + """Extract the name from a tool, supporting both objects and dicts.""" + if isinstance(tool, dict): + name: str | None = tool.get("function", {}).get("name") + return name + return getattr(tool, "name", None) + + def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: """Merge two options dicts, with override values taking precedence. @@ -97,8 +105,9 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, continue if key == "tools" and result.get("tools"): # Combine tool lists, avoiding duplicates by name - existing_names = {getattr(t, "name", None) for t in result["tools"]} - unique_new = [t for t in value if getattr(t, "name", None) not in existing_names] + existing_names = {_get_tool_name(t) for t in result["tools"]} + existing_names.discard(None) + unique_new = [t for t in value if _get_tool_name(t) not in existing_names] result["tools"] = list(result["tools"]) + unique_new elif key == "logit_bias" and result.get("logit_bias"): # Merge logit_bias dicts diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 627987a1f2..d21fd6e7bb 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -25,7 +25,7 @@ SupportsChatGetResponse, tool, ) -from agent_framework._agents import _merge_options, _sanitize_agent_name +from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool @@ -880,6 +880,71 @@ def __init__(self, name): assert "tool2" in tool_names +def test_merge_options_dict_tools_combined(): + """Test _merge_options combines dict-defined tool lists without dropping tools.""" + base = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + ] + } + override = { + "tools": [ + {"type": "function", "function": {"name": "tool_b"}}, + ] + } + + result = _merge_options(base, override) + + assert len(result["tools"]) == 2 + names = [_get_tool_name(t) for t in result["tools"]] + assert "tool_a" in names + assert "tool_b" in names + + +def test_merge_options_dict_tools_deduplicates(): + """Test _merge_options deduplicates dict-defined tools by function name.""" + base = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + ] + } + override = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + {"type": "function", "function": {"name": "tool_b"}}, + ] + } + + result = _merge_options(base, override) + + assert len(result["tools"]) == 2 + names = [_get_tool_name(t) for t in result["tools"]] + assert names.count("tool_a") == 1 + assert "tool_b" in names + + +def test_get_tool_name_dict(): + """Test _get_tool_name extracts name from dict-defined tools.""" + tool = {"type": "function", "function": {"name": "my_tool"}} + assert _get_tool_name(tool) == "my_tool" + + +def test_get_tool_name_object(): + """Test _get_tool_name extracts name from object-defined tools.""" + + class MockTool: + def __init__(self, name): + self.name = name + + assert _get_tool_name(MockTool("my_tool")) == "my_tool" + + +def test_get_tool_name_no_name(): + """Test _get_tool_name returns None when no name is available.""" + assert _get_tool_name({}) is None + assert _get_tool_name(42) is None + + def test_merge_options_logit_bias_merged(): """Test _merge_options merges logit_bias dicts.""" base = {"logit_bias": {"token1": 1.0}}