diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a3f4570b6e..391a4de757 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -81,6 +81,16 @@ ) +def _get_tool_name(tool: Any) -> str | None: + """Extract a tool's name from either an object with a .name attribute or a dict tool definition.""" + if isinstance(tool, dict): + func = tool.get("function") + if isinstance(func, dict): + return func.get("name") + return None + return getattr(tool, "name", None) + + def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: """Merge two options dicts, with override values taking precedence. @@ -97,8 +107,8 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, continue if key == "tools" and result.get("tools"): # Combine tool lists, avoiding duplicates by name - existing_names = {getattr(t, "name", None) for t in result["tools"]} - unique_new = [t for t in value if getattr(t, "name", None) not in existing_names] + existing_names = {_get_tool_name(t) for t in result["tools"]} - {None} + unique_new = [t for t in value if _get_tool_name(t) not in existing_names] result["tools"] = list(result["tools"]) + unique_new elif key == "logit_bias" and result.get("logit_bias"): # Merge logit_bias dicts diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 627987a1f2..5067a68c87 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -25,7 +25,7 @@ SupportsChatGetResponse, tool, ) -from agent_framework._agents import _merge_options, _sanitize_agent_name +from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool @@ -880,6 +880,152 @@ def __init__(self, name): assert "tool2" in tool_names +def test_merge_options_dict_tools_combined(): + """Test _merge_options combines dict-defined tool lists without duplicates.""" + base = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + ] + } + override = { + "tools": [ + {"type": "function", "function": {"name": "tool_b"}}, + ] + } + + result = _merge_options(base, override) + + assert len(result["tools"]) == 2 + names = [_get_tool_name(t) for t in result["tools"]] + assert "tool_a" in names + assert "tool_b" in names + + +def test_merge_options_dict_tools_deduplicates(): + """Test _merge_options deduplicates dict-defined tools by function name.""" + base = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + ] + } + override = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + {"type": "function", "function": {"name": "tool_b"}}, + ] + } + + result = _merge_options(base, override) + + assert len(result["tools"]) == 2 + names = [_get_tool_name(t) for t in result["tools"]] + assert names.count("tool_a") == 1 + assert "tool_b" in names + + +def test_merge_options_mixed_tools_combined(): + """Test _merge_options combines object and dict-defined tools.""" + + class MockTool: + def __init__(self, name): + self.name = name + + base = {"tools": [MockTool("tool_a")]} + override = { + "tools": [ + {"type": "function", "function": {"name": "tool_b"}}, + ] + } + + result = _merge_options(base, override) + + assert len(result["tools"]) == 2 + names = [_get_tool_name(t) for t in result["tools"]] + assert "tool_a" in names + assert "tool_b" in names + + +def test_merge_options_mixed_tools_deduplicates(): + """Test _merge_options deduplicates when a dict tool and object tool share the same name.""" + + class MockTool: + def __init__(self, name): + self.name = name + + base = {"tools": [MockTool("tool_a")]} + override = { + "tools": [ + {"type": "function", "function": {"name": "tool_a"}}, + ] + } + + result = _merge_options(base, override) + + assert len(result["tools"]) == 1 + assert _get_tool_name(result["tools"][0]) == "tool_a" + + +def test_merge_options_nameless_tools_not_deduplicated(): + """Test that tools with no extractable name (None) are not falsely deduplicated.""" + base = { + "tools": [ + {"type": "function"}, # no 'function.name' -> _get_tool_name returns None + ] + } + override = { + "tools": [ + {"type": "function"}, # also returns None + ] + } + + result = _merge_options(base, override) + + # Both nameless tools should be kept (None is excluded from dedup set) + assert len(result["tools"]) == 2 + + +def test_get_tool_name_dict_no_function_key(): + """_get_tool_name returns None for a dict without a 'function' key.""" + assert _get_tool_name({"type": "function"}) is None + + +def test_get_tool_name_dict_function_not_dict(): + """_get_tool_name returns None when 'function' value is not a dict.""" + assert _get_tool_name({"function": "not_a_dict"}) is None + + +def test_get_tool_name_dict_function_no_name(): + """_get_tool_name returns None when 'function' dict has no 'name' key.""" + assert _get_tool_name({"function": {"description": "does stuff"}}) is None + + +def test_get_tool_name_object_no_name_attr(): + """_get_tool_name returns None for an object without a 'name' attribute.""" + assert _get_tool_name(object()) is None + + +def test_get_tool_name_non_dict_non_object(): + """_get_tool_name returns None for non-dict inputs like int or string.""" + assert _get_tool_name(42) is None + assert _get_tool_name("tool_name") is None + + +def test_get_tool_name_valid_dict(): + """_get_tool_name extracts name from a well-formed dict tool.""" + tool_dict = {"type": "function", "function": {"name": "my_tool"}} + assert _get_tool_name(tool_dict) == "my_tool" + + +def test_get_tool_name_valid_object(): + """_get_tool_name extracts name from an object with a name attribute.""" + + class MockTool: + def __init__(self, name): + self.name = name + + assert _get_tool_name(MockTool("my_tool")) == "my_tool" + + def test_merge_options_logit_bias_merged(): """Test _merge_options merges logit_bias dicts.""" base = {"logit_bias": {"token1": 1.0}}