From c055ecf99200b0dd9192e3afd5cbec92b3517579 Mon Sep 17 00:00:00 2001 From: Calvin Giles Date: Thu, 12 Jun 2025 12:03:33 +1200 Subject: [PATCH] feat: Support adding a name prefix to MCPTool and MCPToolset --- src/google/adk/tools/mcp_tool/mcp_tool.py | 9 +- src/google/adk/tools/mcp_tool/mcp_toolset.py | 18 ++-- tests/unittests/tools/mcp_tool/__init__.py | 0 .../unittests/tools/mcp_tool/test_mcp_tool.py | 82 +++++++++++++++++++ 4 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/tools/mcp_tool/__init__.py create mode 100644 tests/unittests/tools/mcp_tool/test_mcp_tool.py diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 463202b18f..1367548128 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -62,6 +62,7 @@ def __init__( mcp_session_manager: MCPSessionManager, auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, + tool_name_prefix: str = "", ): """Initializes a MCPTool. @@ -73,6 +74,8 @@ def __init__( mcp_session_manager: The MCP session manager to use for communication. auth_scheme: The authentication scheme to use. auth_credential: The authentication credential to use. + tool_name_prefix: string to add to the start of the tool name. For example, + `prefix="ns_"` would name `my_tool` as `ns_my_tool`. Raises: ValueError: If mcp_tool or mcp_session_manager is None. @@ -81,8 +84,10 @@ def __init__( raise ValueError("mcp_tool cannot be None") if mcp_session_manager is None: raise ValueError("mcp_session_manager cannot be None") + raw_name = mcp_tool.name + name = tool_name_prefix + raw_name super().__init__( - name=mcp_tool.name, + name=name, description=mcp_tool.description if mcp_tool.description else "", ) self._mcp_tool = mcp_tool @@ -90,6 +95,8 @@ def __init__( # TODO(cheliu): Support passing auth to MCP Server. self._auth_scheme = auth_scheme self._auth_credential = auth_credential + self._tool_name_prefix = tool_name_prefix + self._raw_name = raw_name @override def _get_declaration(self) -> FunctionDeclaration: diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 8076752b48..69c95b5e74 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -36,8 +36,8 @@ # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. try: - from mcp import StdioServerParameters from mcp.types import ListToolsResult + from mcp import StdioServerParameters except ImportError as e: import sys @@ -68,7 +68,8 @@ class MCPToolset(BaseToolset): command='npx', args=["-y", "@modelcontextprotocol/server-filesystem"], ), - tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools + tool_filter=['read_file', 'list_directory'], # Optional: filter specific tools + tool_name_prefix="sfs_", # Optional: add_name_prefix ) # Use in an agent @@ -96,6 +97,7 @@ def __init__( ], tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, errlog: TextIO = sys.stderr, + tool_name_prefix: str = "", ): """Initializes the MCPToolset. @@ -108,10 +110,15 @@ def __init__( mcp server (e.g. using `npx` or `python3` ), but it does not support timeout, and we recommend to use `StdioConnectionParams` instead when timeout is needed. - tool_filter: Optional filter to select specific tools. Can be either: - A - list of tool names to include - A ToolPredicate function for custom - filtering logic + tool_filter: Optional filter to select specific tools. Can be either: + - A list of tool names to include + - A ToolPredicate function for custom filtering logic + In both cases, the tool name WILL include the `tool_name_prefix` when + matching. errlog: TextIO stream for error logging. + tool_name_prefix: string to add to the start of the name of all return tools. + For example, `prefix="ns_"` would change a returned tool name from + `my_tool` to `ns_my_tool`. """ super().__init__(tool_filter=tool_filter) @@ -120,6 +127,7 @@ def __init__( self._connection_params = connection_params self._errlog = errlog + self._tool_name_prefix = tool_name_prefix # Create the session manager that will handle the MCP connection self._mcp_session_manager = MCPSessionManager( diff --git a/tests/unittests/tools/mcp_tool/__init__.py b/tests/unittests/tools/mcp_tool/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py new file mode 100644 index 0000000000..663a7d98bf --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -0,0 +1,82 @@ +from google.adk import Agent +from google.adk.tools import FunctionTool +from google.adk.tools.mcp_tool import MCPTool +from tests.unittests import testing_utils +from google.genai.types import Part +from mcp.types import Tool as McpBaseTool + + +class TestMCPSession(object): + + def __init__(self, function_tool: FunctionTool): + self._function_tool = function_tool + + async def call_tool(self, name, arguments): + return self._function_tool.func(**arguments) + + +class TestMCPSessionManager(object): + + def __init__(self, function_tool: FunctionTool): + self._function_tool = function_tool + + async def create_session(self): + return TestMCPSession(self._function_tool) + + async def close(self): + pass + + +def mcp_tool(function_tool: FunctionTool, prefix=''): + return MCPTool( + mcp_tool=McpBaseTool( + name=function_tool.name, + description=function_tool.description, + inputSchema=function_tool._get_declaration().parameters.json_schema.model_dump( + exclude_none=True + ), + ), + mcp_session_manager=TestMCPSessionManager(function_tool), + tool_name_prefix=prefix, + ) + + +def test_mcp_tool(): + @FunctionTool + def add(a: int, b: int): + """Add a and b and retuirn the result""" + return a + b + + mcp_add = mcp_tool(add, 'mcp_') + + add_call = Part.from_function_call(name='add', args={'a': 1, 'b': 2}) + add_response = Part.from_function_response(name='add', response={'result': 3}) + + mcp_add_call = Part.from_function_call(name='mcp_add', args={'a': 5, 'b': 10}) + mcp_add_response = Part.from_function_response( + name='mcp_add', response={'result': 15} + ) + + mock_model = testing_utils.MockModel.create( + responses=[ + add_call, + mcp_add_call, + 'response1', + ] + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[add, mcp_add], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + + assert testing_utils.simplify_events(runner.run('test1')) == [ + ('root_agent', add_call), + ('root_agent', add_response), + ('root_agent', mcp_add_call), + ('root_agent', mcp_add_response), + ('root_agent', 'response1'), + ]