diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index d95d48f282..847bc9f263 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -38,7 +38,6 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client except ImportError as e: - if sys.version_info < (3, 10): raise ImportError( 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' @@ -62,12 +61,15 @@ class StdioConnectionParams(BaseModel): server_params: StdioServerParameters timeout: float = 5.0 + class Config: + arbitrary_types_allowed = True + class SseConnectionParams(BaseModel): """Parameters for the MCP SSE connection. See MCP SSE Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py + [https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py) Attributes: url: URL for the MCP SSE server. @@ -88,7 +90,7 @@ class StreamableHTTPConnectionParams(BaseModel): """Parameters for the MCP Streamable HTTP connection. See MCP Streamable HTTP Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py + [https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py) Attributes: url: URL for the MCP Streamable HTTP server. @@ -111,12 +113,21 @@ class StreamableHTTPConnectionParams(BaseModel): def retry_on_closed_resource(func): """Decorator to automatically retry action when MCP session is closed. + CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations. + Do NOT use with tool calls that create, update, or delete resources as + retrying can cause duplicate operations or data corruption. + + Only use with read-only, idempotent operations like list_tools, + list_resources, or read_resource. + + Do NOT apply to generic tool execution methods like _run_async_impl. + When MCP session was closed, the decorator will automatically retry the action once. The create_session method will handle creating a new session if the old one was disconnected. Args: - func: The function to decorate. + func: The function to decorate. Must be idempotent and safe to retry. Returns: The decorated function. @@ -177,11 +188,10 @@ def __init__( ) else: self._connection_params = connection_params - self._errlog = errlog + self._errlog = errlog # Session pool: maps session keys to (session, exit_stack) tuples self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} - # Lock to prevent race conditions in session creation self._session_lock = asyncio.Lock() @@ -293,6 +303,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): ' StdioServerParameters or SseServerParams, but got' f' {self._connection_params}' ) + return client async def create_session( @@ -314,7 +325,6 @@ async def create_session( """ # Merge headers once at the beginning merged_headers = self._merge_headers(headers) - # Generate session key using merged headers session_key = self._generate_session_key(merged_headers) @@ -323,7 +333,6 @@ async def create_session( # Check if we have an existing session if session_key in self._sessions: session, exit_stack = self._sessions[session_key] - # Check if the existing session is still connected if not self._is_session_disconnected(session): # Session is still good, return it @@ -369,6 +378,7 @@ async def create_session( ) await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds) + await session.initialize() # Store session and exit stack in the pool self._sessions[session_key] = (session, exit_stack) logger.debug('Created new session: %s', session_key) @@ -404,5 +414,4 @@ async def close(self): SseServerParams = SseConnectionParams - StreamableHTTPServerParams = StreamableHTTPConnectionParams