From bed3f238174219df363b2c1268bd66dd96593280 Mon Sep 17 00:00:00 2001 From: Utkarsh <49205471+utkarshkingh@users.noreply.github.com> Date: Thu, 11 Sep 2025 20:47:00 +0000 Subject: [PATCH 1/4] fix: Enhance error handling for closed and broken resources in retry logic --- src/google/adk/tools/mcp_tool/mcp_session_manager.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index fbe843a510..e840a50cb8 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -126,10 +126,10 @@ def retry_on_closed_resource(func): async def wrapper(self, *args, **kwargs): try: return await func(self, *args, **kwargs) - except anyio.ClosedResourceError: + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # Simply retry the function - create_session will handle # detecting and replacing disconnected sessions - logger.info('Retrying %s due to closed resource', func.__name__) + logger.info('Retrying %s due to closed/broken resource', func.__name__) return await func(self, *args, **kwargs) return wrapper @@ -369,7 +369,12 @@ async def create_session( except Exception: # If session creation fails, clean up the exit stack if exit_stack: - await exit_stack.aclose() + try: + await exit_stack.aclose() + except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: + # Resources might already be broken/closed due to connection issues + # Log the error but don't let it mask the original exception + logger.warning('Error during exit stack cleanup: %s', e) raise async def close(self): From adff656fd743ffd78b8f4aa988850ce72759eed1 Mon Sep 17 00:00:00 2001 From: Utkarsh <49205471+utkarshkingh@users.noreply.github.com> Date: Thu, 11 Sep 2025 21:22:28 +0000 Subject: [PATCH 2/4] refactor: Improve code formatting and enhance docstrings for clarity --- .../adk/tools/mcp_tool/mcp_session_manager.py | 615 +++++++++--------- 1 file changed, 293 insertions(+), 322 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index e840a50cb8..9876235a33 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -1,11 +1,11 @@ # Copyright 2025 Google LLC -# + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# + +# http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -32,369 +32,340 @@ from pydantic import BaseModel try: - from mcp import ClientSession - from mcp import StdioServerParameters - from mcp.client.sse import sse_client - from mcp.client.stdio import stdio_client - from mcp.client.streamable_http import streamablehttp_client + from mcp import ClientSession + from mcp import StdioServerParameters + from mcp.client.sse import sse_client + from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import streamablehttp_client except ImportError as e: - - if sys.version_info < (3, 10): - raise ImportError( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) from e - else: - raise e + if sys.version_info < (3, 10): + raise ImportError( + 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e logger = logging.getLogger('google_adk.' + __name__) - class StdioConnectionParams(BaseModel): - """Parameters for the MCP Stdio connection. + """Parameters for the MCP Stdio connection. - Attributes: - server_params: Parameters for the MCP Stdio server. - timeout: Timeout in seconds for establishing the connection to the MCP - stdio server. - """ - - server_params: StdioServerParameters - timeout: float = 5.0 + Attributes: + server_params: Parameters for the MCP Stdio server. + timeout: Timeout in seconds for establishing the connection to the MCP + stdio server. + """ + server_params: StdioServerParameters + timeout: float = 5.0 + class Config: + arbitrary_types_allowed = True class SseConnectionParams(BaseModel): - """Parameters for the MCP SSE connection. - - See MCP SSE Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py - - Attributes: - url: URL for the MCP SSE server. - headers: Headers for the MCP SSE connection. - timeout: Timeout in seconds for establishing the connection to the MCP SSE - server. - sse_read_timeout: Timeout in seconds for reading data from the MCP SSE - server. - """ - - url: str - headers: dict[str, Any] | None = None - timeout: float = 5.0 - sse_read_timeout: float = 60 * 5.0 - + """Parameters for the MCP SSE connection. + + See MCP SSE Client documentation for more details. + https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py + + Attributes: + url: URL for the MCP SSE server. + headers: Headers for the MCP SSE connection. + timeout: Timeout in seconds for establishing the connection to the MCP SSE + server. + sse_read_timeout: Timeout in seconds for reading data from the MCP SSE + server. + """ + url: str + headers: dict[str, Any] | None = None + timeout: float = 5.0 + sse_read_timeout: float = 60 * 5.0 class StreamableHTTPConnectionParams(BaseModel): - """Parameters for the MCP Streamable HTTP connection. - - See MCP Streamable HTTP Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py - - Attributes: - url: URL for the MCP Streamable HTTP server. - headers: Headers for the MCP Streamable HTTP connection. - timeout: Timeout in seconds for establishing the connection to the MCP - Streamable HTTP server. - sse_read_timeout: Timeout in seconds for reading data from the MCP - Streamable HTTP server. - terminate_on_close: Whether to terminate the MCP Streamable HTTP server - when the connection is closed. - """ + """Parameters for the MCP Streamable HTTP connection. + + See MCP Streamable HTTP Client documentation for more details. + https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py + + Attributes: + url: URL for the MCP Streamable HTTP server. + headers: Headers for the MCP Streamable HTTP connection. + timeout: Timeout in seconds for establishing the connection to the MCP + Streamable HTTP server. + sse_read_timeout: Timeout in seconds for reading data from the MCP + Streamable HTTP server. + terminate_on_close: Whether to terminate the MCP Streamable HTTP server + when the connection is closed. + """ + url: str + headers: dict[str, Any] | None = None + timeout: float = 5.0 + sse_read_timeout: float = 60 * 5.0 + terminate_on_close: bool = True - url: str - headers: dict[str, Any] | None = None - timeout: float = 5.0 - sse_read_timeout: float = 60 * 5.0 - terminate_on_close: bool = True +def retry_on_closed_resource(func): + """Decorator to automatically retry action when MCP session is closed. + CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations. + Do NOT use with tool calls that create, update, or delete resources as + retrying can cause duplicate operations or data corruption. + + Only use with read-only, idempotent operations like list_tools, + list_resources, or read_resource. + + Do NOT apply to generic tool execution methods like _run_async_impl. -def retry_on_closed_resource(func): - """Decorator to automatically retry action when MCP session is closed. + When MCP session was closed, the decorator will automatically retry the + action once. The create_session method will handle creating a new session + if the old one was disconnected. - When MCP session was closed, the decorator will automatically retry the - action once. The create_session method will handle creating a new session - if the old one was disconnected. + Args: + func: The function to decorate. Must be idempotent and safe to retry. - Args: - func: The function to decorate. + Returns: + The decorated function. + """ + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + logger.info('Retrying %s due to closed/broken resource', func.__name__) + return await func(self, *args, **kwargs) + return wrapper - Returns: - The decorated function. - """ +class MCPSessionManager: + """Manages MCP client sessions. - @functools.wraps(func) # Preserves original function metadata - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except (anyio.ClosedResourceError, anyio.BrokenResourceError): - # Simply retry the function - create_session will handle - # detecting and replacing disconnected sessions - logger.info('Retrying %s due to closed/broken resource', func.__name__) - return await func(self, *args, **kwargs) + This class provides methods for creating and initializing MCP client sessions, + handling different connection parameters (Stdio and SSE) and supporting + session pooling based on authentication headers. + """ - return wrapper + def __init__( + self, + connection_params: Union[ + StdioServerParameters, + StdioConnectionParams, + SseConnectionParams, + StreamableHTTPConnectionParams, + ], + errlog: TextIO = sys.stderr, + ): + """Initializes the MCP session manager. + + Args: + connection_params: Parameters for the MCP connection (Stdio, SSE or + Streamable HTTP). Stdio by default also has a 5s read timeout as other + parameters but it's not configurable for now. + errlog: (Optional) TextIO stream for error logging. Use only for + initializing a local stdio MCP session. + """ + if isinstance(connection_params, StdioServerParameters): + logger.warning( + 'StdioServerParameters is not recommended. Please use' + ' StdioConnectionParams.' + ) + self._connection_params = StdioConnectionParams( + server_params=connection_params, + timeout=5, + ) + else: + self._connection_params = connection_params + self._errlog = errlog + self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + self._session_lock = asyncio.Lock() -class MCPSessionManager: - """Manages MCP client sessions. - - This class provides methods for creating and initializing MCP client sessions, - handling different connection parameters (Stdio and SSE) and supporting - session pooling based on authentication headers. - """ - - def __init__( - self, - connection_params: Union[ - StdioServerParameters, - StdioConnectionParams, - SseConnectionParams, - StreamableHTTPConnectionParams, - ], - errlog: TextIO = sys.stderr, - ): - """Initializes the MCP session manager. + def _generate_session_key( + self, merged_headers: Optional[Dict[str, str]] = None + ) -> str: + """Generates a session key based on connection params and merged headers. - Args: - connection_params: Parameters for the MCP connection (Stdio, SSE or - Streamable HTTP). Stdio by default also has a 5s read timeout as other - parameters but it's not configurable for now. - errlog: (Optional) TextIO stream for error logging. Use only for - initializing a local stdio MCP session. - """ - if isinstance(connection_params, StdioServerParameters): - # So far timeout is not configurable. Given MCP is still evolving, we - # would expect stdio_client to evolve to accept timeout parameter like - # other client. - logger.warning( - 'StdioServerParameters is not recommended. Please use' - ' StdioConnectionParams.' - ) - self._connection_params = StdioConnectionParams( - server_params=connection_params, - timeout=5, - ) - else: - self._connection_params = connection_params - self._errlog = errlog + For StdioConnectionParams, returns a constant key since headers are not + supported. For SSE and StreamableHTTP connections, generates a key based + on the provided merged headers. - # Session pool: maps session keys to (session, exit_stack) tuples - self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + Args: + merged_headers: Already merged headers (base + additional). - # Lock to prevent race conditions in session creation - self._session_lock = asyncio.Lock() + Returns: + A unique session key string. + """ + if isinstance(self._connection_params, StdioConnectionParams): + return 'stdio_session' - def _generate_session_key( - self, merged_headers: Optional[Dict[str, str]] = None - ) -> str: - """Generates a session key based on connection params and merged headers. + if merged_headers: + headers_json = json.dumps(merged_headers, sort_keys=True) + headers_hash = hashlib.md5(headers_json.encode()).hexdigest() + return f'session_{headers_hash}' + else: + return 'session_no_headers' - For StdioConnectionParams, returns a constant key since headers are not - supported. For SSE and StreamableHTTP connections, generates a key based - on the provided merged headers. + def _merge_headers( + self, additional_headers: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """Merges base connection headers with additional headers. - Args: - merged_headers: Already merged headers (base + additional). + Args: + additional_headers: Optional headers to merge with connection headers. - Returns: - A unique session key string. - """ - if isinstance(self._connection_params, StdioConnectionParams): - # For stdio connections, headers are not supported, so use constant key - return 'stdio_session' - - # For SSE and StreamableHTTP connections, use merged headers - if merged_headers: - headers_json = json.dumps(merged_headers, sort_keys=True) - headers_hash = hashlib.md5(headers_json.encode()).hexdigest() - return f'session_{headers_hash}' - else: - return 'session_no_headers' + Returns: + Merged headers dictionary, or None if no headers are provided. + """ + if isinstance(self._connection_params, StdioConnectionParams) or isinstance( + self._connection_params, StdioServerParameters + ): + return None - def _merge_headers( - self, additional_headers: Optional[Dict[str, str]] = None - ) -> Optional[Dict[str, str]]: - """Merges base connection headers with additional headers. + base_headers = {} + if ( + hasattr(self._connection_params, 'headers') + and self._connection_params.headers + ): + base_headers = self._connection_params.headers.copy() - Args: - additional_headers: Optional headers to merge with connection headers. + if additional_headers: + base_headers.update(additional_headers) - Returns: - Merged headers dictionary, or None if no headers are provided. - """ - if isinstance(self._connection_params, StdioConnectionParams) or isinstance( - self._connection_params, StdioServerParameters - ): - # Stdio connections don't support headers - return None + return base_headers - base_headers = {} - if ( - hasattr(self._connection_params, 'headers') - and self._connection_params.headers - ): - base_headers = self._connection_params.headers.copy() + def _is_session_disconnected(self, session: ClientSession) -> bool: + """Checks if a session is disconnected or closed. - if additional_headers: - base_headers.update(additional_headers) + Args: + session: The ClientSession to check. - return base_headers + Returns: + True if the session is disconnected, False otherwise. + """ + return session._read_stream._closed or session._write_stream._closed - def _is_session_disconnected(self, session: ClientSession) -> bool: - """Checks if a session is disconnected or closed. + def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): + """Creates an MCP client based on the connection parameters. - Args: - session: The ClientSession to check. + Args: + merged_headers: Optional headers to include in the connection. + Only applicable for SSE and StreamableHTTP connections. - Returns: - True if the session is disconnected, False otherwise. - """ - return session._read_stream._closed or session._write_stream._closed + Returns: + The appropriate MCP client instance. - def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): - """Creates an MCP client based on the connection parameters. + Raises: + ValueError: If the connection parameters are not supported. + """ + if isinstance(self._connection_params, StdioConnectionParams): + client = stdio_client( + server=self._connection_params.server_params, + errlog=self._errlog, + ) + elif isinstance(self._connection_params, SseConnectionParams): + client = sse_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=self._connection_params.timeout, + sse_read_timeout=self._connection_params.sse_read_timeout, + ) + elif isinstance(self._connection_params, StreamableHTTPConnectionParams): + client = streamablehttp_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=timedelta(seconds=self._connection_params.timeout), + sse_read_timeout=timedelta( + seconds=self._connection_params.sse_read_timeout + ), + terminate_on_close=self._connection_params.terminate_on_close, + ) + else: + raise ValueError( + 'Unable to initialize connection. Connection should be' + ' StdioServerParameters or SseServerParams, but got' + f' {self._connection_params}' + ) - Args: - merged_headers: Optional headers to include in the connection. - Only applicable for SSE and StreamableHTTP connections. + return client - Returns: - The appropriate MCP client instance. + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: + """Creates and initializes an MCP client session. - Raises: - ValueError: If the connection parameters are not supported. - """ - if isinstance(self._connection_params, StdioConnectionParams): - client = stdio_client( - server=self._connection_params.server_params, - errlog=self._errlog, - ) - elif isinstance(self._connection_params, SseConnectionParams): - client = sse_client( - url=self._connection_params.url, - headers=merged_headers, - timeout=self._connection_params.timeout, - sse_read_timeout=self._connection_params.sse_read_timeout, - ) - elif isinstance(self._connection_params, StreamableHTTPConnectionParams): - client = streamablehttp_client( - url=self._connection_params.url, - headers=merged_headers, - timeout=timedelta(seconds=self._connection_params.timeout), - sse_read_timeout=timedelta( - seconds=self._connection_params.sse_read_timeout - ), - terminate_on_close=self._connection_params.terminate_on_close, - ) - else: - raise ValueError( - 'Unable to initialize connection. Connection should be' - ' StdioServerParameters or SseServerParams, but got' - f' {self._connection_params}' - ) - return client - - async def create_session( - self, headers: Optional[Dict[str, str]] = None - ) -> ClientSession: - """Creates and initializes an MCP client session. - - This method will check if an existing session for the given headers - is still connected. If it's disconnected, it will be cleaned up and - a new session will be created. + This method will check if an existing session for the given headers + is still connected. If it's disconnected, it will be cleaned up and + a new session will be created. - Args: - headers: Optional headers to include in the session. These will be + Args: + headers: Optional headers to include in the session. These will be merged with any existing connection headers. Only applicable for SSE and StreamableHTTP connections. - Returns: - ClientSession: The initialized MCP client session. - """ - # Merge headers once at the beginning - merged_headers = self._merge_headers(headers) - - # Generate session key using merged headers - session_key = self._generate_session_key(merged_headers) - - # Use async lock to prevent race conditions - async with self._session_lock: - # Check if we have an existing session - if session_key in self._sessions: - session, exit_stack = self._sessions[session_key] - - # Check if the existing session is still connected - if not self._is_session_disconnected(session): - # Session is still good, return it - return session - else: - # Session is disconnected, clean it up - logger.info('Cleaning up disconnected session: %s', session_key) - try: - await exit_stack.aclose() - except Exception as e: - logger.warning('Error during disconnected session cleanup: %s', e) - finally: - del self._sessions[session_key] - - # Create a new session (either first time or replacing disconnected one) - exit_stack = AsyncExitStack() - - try: - client = self._create_client(merged_headers) - - transports = await exit_stack.enter_async_context(client) - # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams - # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. - if isinstance(self._connection_params, StdioConnectionParams): - session = await exit_stack.enter_async_context( - ClientSession( - *transports[:2], - read_timeout_seconds=timedelta( - seconds=self._connection_params.timeout - ), - ) - ) - else: - session = await exit_stack.enter_async_context( - ClientSession(*transports[:2]) - ) - await session.initialize() - - # Store session and exit stack in the pool - self._sessions[session_key] = (session, exit_stack) - logger.debug('Created new session: %s', session_key) - return session - - except Exception: - # If session creation fails, clean up the exit stack - if exit_stack: - try: - await exit_stack.aclose() - except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: - # Resources might already be broken/closed due to connection issues - # Log the error but don't let it mask the original exception - logger.warning('Error during exit stack cleanup: %s', e) - raise - - async def close(self): - """Closes all sessions and cleans up resources.""" - async with self._session_lock: - for session_key in list(self._sessions.keys()): - _, exit_stack = self._sessions[session_key] - try: - await exit_stack.aclose() - except Exception as e: - # Log the error but don't re-raise to avoid blocking shutdown - print( - 'Warning: Error during MCP session cleanup for' - f' {session_key}: {e}', - file=self._errlog, - ) - finally: - del self._sessions[session_key] - + Returns: + ClientSession: The initialized MCP client session. + """ + merged_headers = self._merge_headers(headers) + session_key = self._generate_session_key(merged_headers) + + async with self._session_lock: + if session_key in self._sessions: + session, exit_stack = self._sessions[session_key] + if not self._is_session_disconnected(session): + return session + else: + logger.info('Cleaning up disconnected session: %s', session_key) + try: + await exit_stack.aclose() + except Exception as e: + logger.warning('Error during disconnected session cleanup: %s', e) + finally: + del self._sessions[session_key] + + exit_stack = AsyncExitStack() + try: + client = self._create_client(merged_headers) + transports = await exit_stack.enter_async_context(client) + + if isinstance(self._connection_params, StdioConnectionParams): + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta( + seconds=self._connection_params.timeout + ), + ) + ) + else: + session = await exit_stack.enter_async_context( + ClientSession(*transports[:2]) + ) + + await session.initialize() + self._sessions[session_key] = (session, exit_stack) + logger.debug('Created new session: %s', session_key) + return session + + except Exception: + if exit_stack: + try: + await exit_stack.aclose() + except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: + logger.warning('Error during exit stack cleanup: %s', e) + raise + + async def close(self): + """Closes all sessions and cleans up resources.""" + async with self._session_lock: + for session_key in list(self._sessions.keys()): + _, exit_stack = self._sessions[session_key] + try: + await exit_stack.aclose() + except Exception as e: + print( + 'Warning: Error during MCP session cleanup for' + f' {session_key}: {e}', + file=self._errlog, + ) + finally: + del self._sessions[session_key] SseServerParams = SseConnectionParams - StreamableHTTPServerParams = StreamableHTTPConnectionParams From 4b1528a7383e2b62b37e23dda4eac22ea2f4146b Mon Sep 17 00:00:00 2001 From: Utkarsh <49205471+utkarshkingh@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:00:36 +0000 Subject: [PATCH 3/4] keep header the same as before, keep all original comments - only additions no removals --- .../adk/tools/mcp_tool/mcp_session_manager.py | 615 +++++++++--------- 1 file changed, 319 insertions(+), 296 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 9876235a33..b919b76613 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -1,11 +1,11 @@ # Copyright 2025 Google LLC - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - +# +# http://www.apache.org/licenses/LICENSE-2.0 +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -32,340 +32,363 @@ from pydantic import BaseModel try: - from mcp import ClientSession - from mcp import StdioServerParameters - from mcp.client.sse import sse_client - from mcp.client.stdio import stdio_client - from mcp.client.streamable_http import streamablehttp_client + from mcp import ClientSession + from mcp import StdioServerParameters + from mcp.client.sse import sse_client + from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import streamablehttp_client except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) from e - else: - raise e + if sys.version_info < (3, 10): + raise ImportError( + 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' + ' version.' + ) from e + else: + raise e logger = logging.getLogger('google_adk.' + __name__) class StdioConnectionParams(BaseModel): - """Parameters for the MCP Stdio connection. + """Parameters for the MCP Stdio connection. - Attributes: - server_params: Parameters for the MCP Stdio server. - timeout: Timeout in seconds for establishing the connection to the MCP - stdio server. - """ - server_params: StdioServerParameters - timeout: float = 5.0 + Attributes: + server_params: Parameters for the MCP Stdio server. + timeout: Timeout in seconds for establishing the connection to the MCP + stdio server. + """ + server_params: StdioServerParameters + timeout: float = 5.0 - class Config: - arbitrary_types_allowed = True + class Config: + arbitrary_types_allowed = True class SseConnectionParams(BaseModel): - """Parameters for the MCP SSE connection. - - See MCP SSE Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py - - Attributes: - url: URL for the MCP SSE server. - headers: Headers for the MCP SSE connection. - timeout: Timeout in seconds for establishing the connection to the MCP SSE - server. - sse_read_timeout: Timeout in seconds for reading data from the MCP SSE - server. - """ - url: str - headers: dict[str, Any] | None = None - timeout: float = 5.0 - sse_read_timeout: float = 60 * 5.0 + """Parameters for the MCP SSE connection. + + See MCP SSE Client documentation for more details. + [https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py) + + Attributes: + url: URL for the MCP SSE server. + headers: Headers for the MCP SSE connection. + timeout: Timeout in seconds for establishing the connection to the MCP SSE + server. + sse_read_timeout: Timeout in seconds for reading data from the MCP SSE + server. + """ + url: str + headers: dict[str, Any] | None = None + timeout: float = 5.0 + sse_read_timeout: float = 60 * 5.0 class StreamableHTTPConnectionParams(BaseModel): - """Parameters for the MCP Streamable HTTP connection. - - See MCP Streamable HTTP Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py - - Attributes: - url: URL for the MCP Streamable HTTP server. - headers: Headers for the MCP Streamable HTTP connection. - timeout: Timeout in seconds for establishing the connection to the MCP - Streamable HTTP server. - sse_read_timeout: Timeout in seconds for reading data from the MCP - Streamable HTTP server. - terminate_on_close: Whether to terminate the MCP Streamable HTTP server - when the connection is closed. - """ - url: str - headers: dict[str, Any] | None = None - timeout: float = 5.0 - sse_read_timeout: float = 60 * 5.0 - terminate_on_close: bool = True + """Parameters for the MCP Streamable HTTP connection. + + See MCP Streamable HTTP Client documentation for more details. + [https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py) + + Attributes: + url: URL for the MCP Streamable HTTP server. + headers: Headers for the MCP Streamable HTTP connection. + timeout: Timeout in seconds for establishing the connection to the MCP + Streamable HTTP server. + sse_read_timeout: Timeout in seconds for reading data from the MCP + Streamable HTTP server. + terminate_on_close: Whether to terminate the MCP Streamable HTTP server + when the connection is closed. + """ + url: str + headers: dict[str, Any] | None = None + timeout: float = 5.0 + sse_read_timeout: float = 60 * 5.0 + terminate_on_close: bool = True def retry_on_closed_resource(func): - """Decorator to automatically retry action when MCP session is closed. - - CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations. - Do NOT use with tool calls that create, update, or delete resources as - retrying can cause duplicate operations or data corruption. - - Only use with read-only, idempotent operations like list_tools, - list_resources, or read_resource. - - Do NOT apply to generic tool execution methods like _run_async_impl. - - When MCP session was closed, the decorator will automatically retry the - action once. The create_session method will handle creating a new session - if the old one was disconnected. - - Args: - func: The function to decorate. Must be idempotent and safe to retry. - - Returns: - The decorated function. - """ - @functools.wraps(func) - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except (anyio.ClosedResourceError, anyio.BrokenResourceError): - logger.info('Retrying %s due to closed/broken resource', func.__name__) - return await func(self, *args, **kwargs) - return wrapper + """Decorator to automatically retry action when MCP session is closed. + + CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations. + Do NOT use with tool calls that create, update, or delete resources as + retrying can cause duplicate operations or data corruption. + + Only use with read-only, idempotent operations like list_tools, + list_resources, or read_resource. + + Do NOT apply to generic tool execution methods like _run_async_impl. + + When MCP session was closed, the decorator will automatically retry the + action once. The create_session method will handle creating a new session + if the old one was disconnected. + + Args: + func: The function to decorate. Must be idempotent and safe to retry. + + Returns: + The decorated function. + """ + @functools.wraps(func) # Preserves original function metadata + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + # Simply retry the function - create_session will handle + # detecting and replacing disconnected sessions + logger.info('Retrying %s due to closed/broken resource', func.__name__) + return await func(self, *args, **kwargs) + return wrapper class MCPSessionManager: - """Manages MCP client sessions. + """Manages MCP client sessions. + + This class provides methods for creating and initializing MCP client sessions, + handling different connection parameters (Stdio and SSE) and supporting + session pooling based on authentication headers. + """ + + def __init__( + self, + connection_params: Union[ + StdioServerParameters, + StdioConnectionParams, + SseConnectionParams, + StreamableHTTPConnectionParams, + ], + errlog: TextIO = sys.stderr, + ): + """Initializes the MCP session manager. - This class provides methods for creating and initializing MCP client sessions, - handling different connection parameters (Stdio and SSE) and supporting - session pooling based on authentication headers. + Args: + connection_params: Parameters for the MCP connection (Stdio, SSE or + Streamable HTTP). Stdio by default also has a 5s read timeout as other + parameters but it's not configurable for now. + errlog: (Optional) TextIO stream for error logging. Use only for + initializing a local stdio MCP session. """ + if isinstance(connection_params, StdioServerParameters): + # So far timeout is not configurable. Given MCP is still evolving, we + # would expect stdio_client to evolve to accept timeout parameter like + # other client. + logger.warning( + 'StdioServerParameters is not recommended. Please use' + ' StdioConnectionParams.' + ) + self._connection_params = StdioConnectionParams( + server_params=connection_params, + timeout=5, + ) + else: + self._connection_params = connection_params - def __init__( - self, - connection_params: Union[ - StdioServerParameters, - StdioConnectionParams, - SseConnectionParams, - StreamableHTTPConnectionParams, - ], - errlog: TextIO = sys.stderr, - ): - """Initializes the MCP session manager. - - Args: - connection_params: Parameters for the MCP connection (Stdio, SSE or - Streamable HTTP). Stdio by default also has a 5s read timeout as other - parameters but it's not configurable for now. - errlog: (Optional) TextIO stream for error logging. Use only for - initializing a local stdio MCP session. - """ - if isinstance(connection_params, StdioServerParameters): - logger.warning( - 'StdioServerParameters is not recommended. Please use' - ' StdioConnectionParams.' - ) - self._connection_params = StdioConnectionParams( - server_params=connection_params, - timeout=5, - ) - else: - self._connection_params = connection_params - - self._errlog = errlog - self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} - self._session_lock = asyncio.Lock() - - def _generate_session_key( - self, merged_headers: Optional[Dict[str, str]] = None - ) -> str: - """Generates a session key based on connection params and merged headers. + self._errlog = errlog + # Session pool: maps session keys to (session, exit_stack) tuples + self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} + # Lock to prevent race conditions in session creation + self._session_lock = asyncio.Lock() - For StdioConnectionParams, returns a constant key since headers are not - supported. For SSE and StreamableHTTP connections, generates a key based - on the provided merged headers. + def _generate_session_key( + self, merged_headers: Optional[Dict[str, str]] = None + ) -> str: + """Generates a session key based on connection params and merged headers. - Args: - merged_headers: Already merged headers (base + additional). + For StdioConnectionParams, returns a constant key since headers are not + supported. For SSE and StreamableHTTP connections, generates a key based + on the provided merged headers. - Returns: - A unique session key string. - """ - if isinstance(self._connection_params, StdioConnectionParams): - return 'stdio_session' + Args: + merged_headers: Already merged headers (base + additional). - if merged_headers: - headers_json = json.dumps(merged_headers, sort_keys=True) - headers_hash = hashlib.md5(headers_json.encode()).hexdigest() - return f'session_{headers_hash}' - else: - return 'session_no_headers' + Returns: + A unique session key string. + """ + if isinstance(self._connection_params, StdioConnectionParams): + # For stdio connections, headers are not supported, so use constant key + return 'stdio_session' + + # For SSE and StreamableHTTP connections, use merged headers + if merged_headers: + headers_json = json.dumps(merged_headers, sort_keys=True) + headers_hash = hashlib.md5(headers_json.encode()).hexdigest() + return f'session_{headers_hash}' + else: + return 'session_no_headers' - def _merge_headers( - self, additional_headers: Optional[Dict[str, str]] = None - ) -> Optional[Dict[str, str]]: - """Merges base connection headers with additional headers. + def _merge_headers( + self, additional_headers: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """Merges base connection headers with additional headers. - Args: - additional_headers: Optional headers to merge with connection headers. + Args: + additional_headers: Optional headers to merge with connection headers. - Returns: - Merged headers dictionary, or None if no headers are provided. - """ - if isinstance(self._connection_params, StdioConnectionParams) or isinstance( - self._connection_params, StdioServerParameters - ): - return None + Returns: + Merged headers dictionary, or None if no headers are provided. + """ + if isinstance(self._connection_params, StdioConnectionParams) or isinstance( + self._connection_params, StdioServerParameters + ): + # Stdio connections don't support headers + return None - base_headers = {} - if ( - hasattr(self._connection_params, 'headers') - and self._connection_params.headers - ): - base_headers = self._connection_params.headers.copy() + base_headers = {} + if ( + hasattr(self._connection_params, 'headers') + and self._connection_params.headers + ): + base_headers = self._connection_params.headers.copy() - if additional_headers: - base_headers.update(additional_headers) + if additional_headers: + base_headers.update(additional_headers) - return base_headers + return base_headers - def _is_session_disconnected(self, session: ClientSession) -> bool: - """Checks if a session is disconnected or closed. + def _is_session_disconnected(self, session: ClientSession) -> bool: + """Checks if a session is disconnected or closed. - Args: - session: The ClientSession to check. + Args: + session: The ClientSession to check. - Returns: - True if the session is disconnected, False otherwise. - """ - return session._read_stream._closed or session._write_stream._closed + Returns: + True if the session is disconnected, False otherwise. + """ + return session._read_stream._closed or session._write_stream._closed - def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): - """Creates an MCP client based on the connection parameters. + def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): + """Creates an MCP client based on the connection parameters. - Args: - merged_headers: Optional headers to include in the connection. - Only applicable for SSE and StreamableHTTP connections. + Args: + merged_headers: Optional headers to include in the connection. + Only applicable for SSE and StreamableHTTP connections. - Returns: - The appropriate MCP client instance. + Returns: + The appropriate MCP client instance. - Raises: - ValueError: If the connection parameters are not supported. - """ - if isinstance(self._connection_params, StdioConnectionParams): - client = stdio_client( - server=self._connection_params.server_params, - errlog=self._errlog, - ) - elif isinstance(self._connection_params, SseConnectionParams): - client = sse_client( - url=self._connection_params.url, - headers=merged_headers, - timeout=self._connection_params.timeout, - sse_read_timeout=self._connection_params.sse_read_timeout, - ) - elif isinstance(self._connection_params, StreamableHTTPConnectionParams): - client = streamablehttp_client( - url=self._connection_params.url, - headers=merged_headers, - timeout=timedelta(seconds=self._connection_params.timeout), - sse_read_timeout=timedelta( - seconds=self._connection_params.sse_read_timeout - ), - terminate_on_close=self._connection_params.terminate_on_close, - ) - else: - raise ValueError( - 'Unable to initialize connection. Connection should be' - ' StdioServerParameters or SseServerParams, but got' - f' {self._connection_params}' - ) + Raises: + ValueError: If the connection parameters are not supported. + """ + if isinstance(self._connection_params, StdioConnectionParams): + client = stdio_client( + server=self._connection_params.server_params, + errlog=self._errlog, + ) + elif isinstance(self._connection_params, SseConnectionParams): + client = sse_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=self._connection_params.timeout, + sse_read_timeout=self._connection_params.sse_read_timeout, + ) + elif isinstance(self._connection_params, StreamableHTTPConnectionParams): + client = streamablehttp_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=timedelta(seconds=self._connection_params.timeout), + sse_read_timeout=timedelta( + seconds=self._connection_params.sse_read_timeout + ), + terminate_on_close=self._connection_params.terminate_on_close, + ) + else: + raise ValueError( + 'Unable to initialize connection. Connection should be' + ' StdioServerParameters or SseServerParams, but got' + f' {self._connection_params}' + ) - return client + return client - async def create_session( - self, headers: Optional[Dict[str, str]] = None - ) -> ClientSession: - """Creates and initializes an MCP client session. + async def create_session( + self, headers: Optional[Dict[str, str]] = None + ) -> ClientSession: + """Creates and initializes an MCP client session. - This method will check if an existing session for the given headers - is still connected. If it's disconnected, it will be cleaned up and - a new session will be created. + This method will check if an existing session for the given headers + is still connected. If it's disconnected, it will be cleaned up and + a new session will be created. - Args: - headers: Optional headers to include in the session. These will be + Args: + headers: Optional headers to include in the session. These will be merged with any existing connection headers. Only applicable for SSE and StreamableHTTP connections. - Returns: - ClientSession: The initialized MCP client session. - """ - merged_headers = self._merge_headers(headers) - session_key = self._generate_session_key(merged_headers) - - async with self._session_lock: - if session_key in self._sessions: - session, exit_stack = self._sessions[session_key] - if not self._is_session_disconnected(session): - return session - else: - logger.info('Cleaning up disconnected session: %s', session_key) - try: - await exit_stack.aclose() - except Exception as e: - logger.warning('Error during disconnected session cleanup: %s', e) - finally: - del self._sessions[session_key] - - exit_stack = AsyncExitStack() - try: - client = self._create_client(merged_headers) - transports = await exit_stack.enter_async_context(client) - - if isinstance(self._connection_params, StdioConnectionParams): - session = await exit_stack.enter_async_context( - ClientSession( - *transports[:2], - read_timeout_seconds=timedelta( - seconds=self._connection_params.timeout - ), - ) - ) - else: - session = await exit_stack.enter_async_context( - ClientSession(*transports[:2]) - ) - - await session.initialize() - self._sessions[session_key] = (session, exit_stack) - logger.debug('Created new session: %s', session_key) - return session - - except Exception: - if exit_stack: - try: - await exit_stack.aclose() - except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: - logger.warning('Error during exit stack cleanup: %s', e) - raise - - async def close(self): - """Closes all sessions and cleans up resources.""" - async with self._session_lock: - for session_key in list(self._sessions.keys()): - _, exit_stack = self._sessions[session_key] - try: - await exit_stack.aclose() - except Exception as e: - print( - 'Warning: Error during MCP session cleanup for' - f' {session_key}: {e}', - file=self._errlog, - ) - finally: - del self._sessions[session_key] + Returns: + ClientSession: The initialized MCP client session. + """ + # Merge headers once at the beginning + merged_headers = self._merge_headers(headers) + # Generate session key using merged headers + session_key = self._generate_session_key(merged_headers) + + # Use async lock to prevent race conditions + async with self._session_lock: + # Check if we have an existing session + if session_key in self._sessions: + session, exit_stack = self._sessions[session_key] + # Check if the existing session is still connected + if not self._is_session_disconnected(session): + # Session is still good, return it + return session + else: + # Session is disconnected, clean it up + logger.info('Cleaning up disconnected session: %s', session_key) + try: + await exit_stack.aclose() + except Exception as e: + logger.warning('Error during disconnected session cleanup: %s', e) + finally: + del self._sessions[session_key] + + # Create a new session (either first time or replacing disconnected one) + exit_stack = AsyncExitStack() + try: + client = self._create_client(merged_headers) + transports = await exit_stack.enter_async_context(client) + + # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams + # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. + if isinstance(self._connection_params, StdioConnectionParams): + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta( + seconds=self._connection_params.timeout + ), + ) + ) + else: + session = await exit_stack.enter_async_context( + ClientSession(*transports[:2]) + ) + + await session.initialize() + # Store session and exit stack in the pool + self._sessions[session_key] = (session, exit_stack) + logger.debug('Created new session: %s', session_key) + return session + + except Exception: + # If session creation fails, clean up the exit stack + if exit_stack: + try: + await exit_stack.aclose() + except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: + logger.warning('Error during exit stack cleanup: %s', e) + raise + + async def close(self): + """Closes all sessions and cleans up resources.""" + async with self._session_lock: + for session_key in list(self._sessions.keys()): + _, exit_stack = self._sessions[session_key] + try: + await exit_stack.aclose() + except Exception as e: + # Log the error but don't re-raise to avoid blocking shutdown + print( + 'Warning: Error during MCP session cleanup for' + f' {session_key}: {e}', + file=self._errlog, + ) + finally: + del self._sessions[session_key] SseServerParams = SseConnectionParams StreamableHTTPServerParams = StreamableHTTPConnectionParams From 11bc6449666fe3bae847437c5f314d9e6cf37552 Mon Sep 17 00:00:00 2001 From: Utkarsh <49205471+utkarshkingh@users.noreply.github.com> Date: Tue, 4 Nov 2025 21:03:26 +0000 Subject: [PATCH 4/4] style: Clean up whitespace and formatting in mcp_session_manager.py --- .../adk/tools/mcp_tool/mcp_session_manager.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index b919b76613..2f908f5a0e 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -48,6 +48,7 @@ logger = logging.getLogger('google_adk.' + __name__) + class StdioConnectionParams(BaseModel): """Parameters for the MCP Stdio connection. @@ -56,12 +57,14 @@ class StdioConnectionParams(BaseModel): timeout: Timeout in seconds for establishing the connection to the MCP stdio server. """ + server_params: StdioServerParameters timeout: float = 5.0 class Config: arbitrary_types_allowed = True + class SseConnectionParams(BaseModel): """Parameters for the MCP SSE connection. @@ -76,11 +79,13 @@ class SseConnectionParams(BaseModel): sse_read_timeout: Timeout in seconds for reading data from the MCP SSE server. """ + url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 + class StreamableHTTPConnectionParams(BaseModel): """Parameters for the MCP Streamable HTTP connection. @@ -97,22 +102,24 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: Whether to terminate the MCP Streamable HTTP server when the connection is closed. """ + url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 terminate_on_close: bool = True + def retry_on_closed_resource(func): """Decorator to automatically retry action when MCP session is closed. CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations. - Do NOT use with tool calls that create, update, or delete resources as + Do NOT use with tool calls that create, update, or delete resources as retrying can cause duplicate operations or data corruption. - - Only use with read-only, idempotent operations like list_tools, + + Only use with read-only, idempotent operations like list_tools, list_resources, or read_resource. - + Do NOT apply to generic tool execution methods like _run_async_impl. When MCP session was closed, the decorator will automatically retry the @@ -125,6 +132,7 @@ def retry_on_closed_resource(func): Returns: The decorated function. """ + @functools.wraps(func) # Preserves original function metadata async def wrapper(self, *args, **kwargs): try: @@ -134,8 +142,10 @@ async def wrapper(self, *args, **kwargs): # detecting and replacing disconnected sessions logger.info('Retrying %s due to closed/broken resource', func.__name__) return await func(self, *args, **kwargs) + return wrapper + class MCPSessionManager: """Manages MCP client sessions. @@ -390,5 +400,6 @@ async def close(self): finally: del self._sessions[session_key] + SseServerParams = SseConnectionParams StreamableHTTPServerParams = StreamableHTTPConnectionParams