diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 310fc48f11..54e5dcd15f 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -67,6 +67,7 @@ def __init__( mcp_session_manager: MCPSessionManager, auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, + tool_name_prefix: str = "", ): """Initializes an MCPTool. @@ -78,12 +79,20 @@ def __init__( mcp_session_manager: The MCP session manager to use for communication. auth_scheme: The authentication scheme to use. auth_credential: The authentication credential to use. + tool_name_prefix: string to add to the start of the tool name. For example, + `prefix="ns_"` would name `my_tool` as `ns_my_tool`. Raises: ValueError: If mcp_tool or mcp_session_manager is None. """ + if mcp_tool is None: + raise ValueError("mcp_tool cannot be None") + if mcp_session_manager is None: + raise ValueError("mcp_session_manager cannot be None") + raw_name = mcp_tool.name + name = tool_name_prefix + raw_name super().__init__( - name=mcp_tool.name, + name=name, description=mcp_tool.description if mcp_tool.description else "", auth_config=AuthConfig( auth_scheme=auth_scheme, raw_auth_credential=auth_credential @@ -93,6 +102,8 @@ def __init__( ) self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager + self._tool_name_prefix = tool_name_prefix + self._raw_name = raw_name @override def _get_declaration(self) -> FunctionDeclaration: @@ -128,7 +139,7 @@ async def _run_async_impl( # Get the session from the session manager session = await self._mcp_session_manager.create_session(headers=headers) - response = await session.call_tool(self.name, arguments=args) + response = await session.call_tool(self._mcp_tool.name, arguments=args) return response async def _get_headers( diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index c01b0cec28..edbdcb8cd6 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -36,8 +36,8 @@ # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. try: - from mcp import StdioServerParameters from mcp.types import ListToolsResult + from mcp import StdioServerParameters except ImportError as e: import sys @@ -68,7 +68,8 @@ class MCPToolset(BaseToolset): command='npx', args=["-y", "@modelcontextprotocol/server-filesystem"], ), - tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools + tool_filter=['read_file', 'list_directory'], # Optional: filter specific tools + tool_name_prefix="sfs_", # Optional: add_name_prefix ) # Use in an agent @@ -98,6 +99,7 @@ def __init__( errlog: TextIO = sys.stderr, auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, + tool_name_prefix: str = "", ): """Initializes the MCPToolset. @@ -110,12 +112,17 @@ def __init__( mcp server (e.g. using `npx` or `python3` ), but it does not support timeout, and we recommend to use `StdioConnectionParams` instead when timeout is needed. - tool_filter: Optional filter to select specific tools. Can be either: - A - list of tool names to include - A ToolPredicate function for custom - filtering logic + tool_filter: Optional filter to select specific tools. Can be either: + - A list of tool names to include + - A ToolPredicate function for custom filtering logic + In both cases, the tool name WILL include the `tool_name_prefix` when + matching. errlog: TextIO stream for error logging. auth_scheme: The auth scheme of the tool for tool calling auth_credential: The auth credential of the tool for tool calling + tool_name_prefix: string to add to the start of the name of all return tools. + For example, `prefix="ns_"` would change a returned tool name from + `my_tool` to `ns_my_tool`. """ super().__init__(tool_filter=tool_filter) @@ -124,6 +131,7 @@ def __init__( self._connection_params = connection_params self._errlog = errlog + self._tool_name_prefix = tool_name_prefix # Create the session manager that will handle the MCP connection self._mcp_session_manager = MCPSessionManager( @@ -161,6 +169,7 @@ async def get_tools( mcp_session_manager=self._mcp_session_manager, auth_scheme=self._auth_scheme, auth_credential=self._auth_credential, + tool_name_prefix=self._tool_name_prefix, ) if self._is_tool_selected(mcp_tool, readonly_context): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 82e3f2234a..b475218aad 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -13,12 +13,12 @@ # limitations under the License. import sys -from typing import Any -from typing import Dict +from google.genai.types import Part from unittest.mock import AsyncMock from unittest.mock import Mock -from unittest.mock import patch +from google.adk import Agent +from google.adk.tools import FunctionTool from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import HttpAuth @@ -26,6 +26,8 @@ from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount import pytest +from tests.unittests import testing_utils + # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( @@ -38,6 +40,7 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.tool_context import ToolContext from google.genai.types import FunctionDeclaration + from mcp.types import Tool as McpBaseTool except ImportError as e: if sys.version_info < (3, 10): # Create dummy classes to prevent NameError during test collection @@ -49,6 +52,7 @@ class DummyClass: MCPTool = DummyClass ToolContext = DummyClass FunctionDeclaration = DummyClass + McpBaseTool = DummyClass else: raise e @@ -358,3 +362,79 @@ def test_init_validation(self): with pytest.raises(TypeError): MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager + + +class TestMCPSession(object): + + def __init__(self, function_tool: FunctionTool): + self._function_tool = function_tool + + async def call_tool(self, name, arguments): + return self._function_tool.func(**arguments) + + +class TestMCPSessionManager(object): + + def __init__(self, function_tool: FunctionTool): + self._function_tool = function_tool + + async def create_session(self, headers=None): + return TestMCPSession(self._function_tool) + + async def close(self): + pass + + +def mcp_tool(function_tool: FunctionTool, prefix=""): + return MCPTool( + mcp_tool=McpBaseTool( + name=function_tool.name, + description=function_tool.description, + inputSchema=function_tool._get_declaration().parameters.json_schema.model_dump( + exclude_none=True + ), + ), + mcp_session_manager=TestMCPSessionManager(function_tool), + tool_name_prefix=prefix, + ) + + +def test_mcp_tool(): + @FunctionTool + def add(a: int, b: int): + """Add a and b and retuirn the result""" + return a + b + + mcp_add = mcp_tool(add, "mcp_") + + add_call = Part.from_function_call(name="add", args={"a": 1, "b": 2}) + add_response = Part.from_function_response(name="add", response={"result": 3}) + + mcp_add_call = Part.from_function_call(name="mcp_add", args={"a": 5, "b": 10}) + mcp_add_response = Part.from_function_response( + name="mcp_add", response={"result": 15} + ) + + mock_model = testing_utils.MockModel.create( + responses=[ + add_call, + mcp_add_call, + "response1", + ] + ) + + root_agent = Agent( + name="root_agent", + model=mock_model, + tools=[add, mcp_add], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + + assert testing_utils.simplify_events(runner.run("test1")) == [ + ("root_agent", add_call), + ("root_agent", add_response), + ("root_agent", mcp_add_call), + ("root_agent", mcp_add_response), + ("root_agent", "response1"), + ]