diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index b75999ea65..cb68814983 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -383,10 +383,16 @@ def _normalize(node: Any) -> Any: class MCPClient: def __init__(self) -> None: - # Initialize session and client objects self.session: mcp.ClientSession | None = None - self.exit_stack = AsyncExitStack() - self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup + + # Each connection runs in its own task so that anyio cancel scopes + # are always exited from the task that entered them, preventing + # RuntimeError: Attempted to exit cancel scope in a different task + self._connection_task: asyncio.Task | None = None + self._old_connection_tasks: list[asyncio.Task] = [] + + # Internal; managed exclusively by _run_connection. + self.exit_stack: AsyncExitStack | None = None self.name: str | None = None self.active: bool = True @@ -394,14 +400,67 @@ def __init__(self) -> None: self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() - # Store connection config for reconnection self._mcp_server_config: dict | None = None self._server_name: str | None = None self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection - self._reconnecting: bool = False # For logging and debugging + self._reconnecting: bool = False + + async def _run_connection( + self, + mcp_server_config: dict, + name: str, + ready: asyncio.Future, + ) -> None: + """Own the full lifetime of one MCP connection. + + This coroutine is always run inside a dedicated asyncio.Task + (_connection_task). Because *this task* is the one that enters every + anyio cancel scope (via sse_client / streamablehttp_client), anyio's + _host_task check is always satisfied when the stack is later closed — + either in the task's own finally block (normal path) or when the task + is cancelled from outside (cleanup / reconnect path). + + This avoids the + RuntimeError: Attempted to exit cancel scope in a different task + that previously occurred when aclose() was called from a different task + or from the asyncio async-generator GC finalizer. + """ + # Capture the stack in a local variable so that if self.exit_stack is + # overwritten by a concurrent _run_connection (during reconnect), this + # task's finally block still closes only the resources it opened. + stack = self.exit_stack = AsyncExitStack() + try: + try: + await self._do_connect(mcp_server_config, name) + except Exception as exc: + if not ready.done(): + ready.set_exception(exc) + raise + else: + if not ready.done(): + ready.set_result(None) + # Hold the connection open until cancelled. + await asyncio.Event().wait() + finally: + try: + await stack.aclose() + except Exception as e: + logger.debug(f"Error closing exit stack for {name}: {e}") + # Clear the instance reference only if it still points to this task's + # stack; a concurrent reconnect may have already replaced it. + if self.exit_stack is stack: + self.exit_stack = None + # Guard against the task exiting before ready was resolved. + if not ready.done(): + ready.set_exception(RuntimeError("Connection task exited early")) async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: - """Connect to MCP server + """Connect to MCP server by spawning a dedicated owner task. + + The owner task (_connection_task) holds the AsyncExitStack and all + anyio cancel scopes for the lifetime of this connection. To disconnect, + cancel _connection_task — the finally block in _run_connection will call + aclose() from within the correct task context. If `url` parameter exists: 1. When transport is specified as `streamable_http`, use Streamable HTTP connection. @@ -412,10 +471,47 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ - # Store config for reconnection self._mcp_server_config = mcp_server_config self._server_name = name + ready: asyncio.Future = asyncio.get_running_loop().create_future() + + # Defensively cancel any existing connection task that was not cleaned + # up before this call (e.g. if connect_to_server is called twice). + if self._connection_task and not self._connection_task.done(): + self._cancel_connection_task(self._connection_task) + self._connection_task = None + + self._connection_task = asyncio.create_task( + self._run_connection(mcp_server_config, name, ready), + name=f"mcp-conn:{name}", + ) + + try: + await ready + except asyncio.CancelledError: + # Caller was cancelled while waiting — tear down the connection task. + # cancel() is asynchronous; the task will not finish until the next + # event-loop iteration, so we track it in _old_connection_tasks so + # that cleanup() can await it later. + if self._connection_task and not self._connection_task.done(): + self._cancel_connection_task(self._connection_task) + self._connection_task = None + raise + except Exception: + # _do_connect raised; the connection task's finally block may still + # be running (e.g. awaiting stack.aclose()). Track it so that + # cleanup() can await it, but do NOT cancel it — we want the + # finally block to finish cleaning up resources naturally. + if self._connection_task and not self._connection_task.done(): + self._old_connection_tasks.append(self._connection_task) + self._connection_task = None + raise + + async def _do_connect(self, mcp_server_config: dict, name: str) -> None: + """Internal: perform the actual connection inside _run_connection's task.""" + # exit_stack is always set by _run_connection before _do_connect is called. + assert self.exit_stack is not None cfg = _prepare_config(mcp_server_config.copy()) def logging_callback( @@ -533,16 +629,32 @@ async def list_tools_and_save(self) -> mcp.ListToolsResult: self.tools = response.tools return response + def _cancel_connection_task(self, task: asyncio.Task) -> None: + """Cancel a connection owner task and track it until it finishes.""" + # Prune already-finished tasks to avoid accumulating references over + # many reconnections in a long-running process. + self._old_connection_tasks = [ + t for t in self._old_connection_tasks if not t.done() + ] + if task.done(): + return + task.cancel() + self._old_connection_tasks.append(task) + async def _reconnect(self) -> None: """Reconnect to the MCP server using the stored configuration. + Cancels the current _connection_task (which owns the exit_stack and all + anyio cancel scopes) and starts a fresh one. Because each connection + task enters and exits its own anyio cancel scope, there is no + cross-task cancel-scope violation and no GC finalizer surprise. + Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments. Raises: Exception: raised when reconnection fails """ async with self._reconnect_lock: - # Check if already reconnecting (useful for logging) if self._reconnecting: logger.debug( f"MCP Client {self._server_name} is already reconnecting, skipping" @@ -558,17 +670,16 @@ async def _reconnect(self) -> None: f"Attempting to reconnect to MCP server {self._server_name}..." ) - # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues) - if self.exit_stack: - self._old_exit_stacks.append(self.exit_stack) - - # Mark old session as invalid + # Cancel the old connection task. Its finally block will call + # exit_stack.aclose() from within the correct task context, so + # anyio cancel scopes are exited cleanly without triggering the + # GC-finalizer busy-spin bug. + if self._connection_task and not self._connection_task.done(): + self._cancel_connection_task(self._connection_task) + self._connection_task = None self.session = None - # Create new exit stack for new connection - self.exit_stack = AsyncExitStack() - - # Reconnect using stored config + # Reconnect — this creates a new _connection_task. await self.connect_to_server(self._mcp_server_config, self._server_name) await self.list_tools_and_save() @@ -633,19 +744,20 @@ async def _call_with_retry(): return await _call_with_retry() async def cleanup(self) -> None: - """Clean up resources including old exit stacks from reconnections""" - # Close current exit stack - try: - await self.exit_stack.aclose() - except Exception as e: - logger.debug(f"Error closing current exit stack: {e}") - - # Don't close old exit stacks as they may be in different task contexts - # They will be garbage collected naturally - # Just clear the list to release references - self._old_exit_stacks.clear() - - # Set running_event first to unblock any waiting tasks + """Clean up resources by cancelling the connection owner task.""" + # Cancel current and any old connection tasks via the shared helper so + # all cancellation + tracking behaviour goes through one code path. + if self._connection_task: + self._cancel_connection_task(self._connection_task) + self._connection_task = None + + if self._old_connection_tasks: + pending = [t for t in self._old_connection_tasks if not t.done()] + if pending: + await asyncio.gather(*pending, return_exceptions=True) + self._old_connection_tasks.clear() + + # Set running_event to unblock any waiting tasks self.running_event.set()