Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 142 additions & 30 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,84 @@ def _normalize(node: Any) -> Any:

class MCPClient:
def __init__(self) -> None:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup

# Each connection runs in its own task so that anyio cancel scopes
# are always exited from the task that entered them, preventing
# RuntimeError: Attempted to exit cancel scope in a different task
self._connection_task: asyncio.Task | None = None
self._old_connection_tasks: list[asyncio.Task] = []

# Internal; managed exclusively by _run_connection.
self.exit_stack: AsyncExitStack | None = None

self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()

# Store connection config for reconnection
self._mcp_server_config: dict | None = None
self._server_name: str | None = None
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
self._reconnecting: bool = False

async def _run_connection(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
self,
mcp_server_config: dict,
name: str,
ready: asyncio.Future,
) -> None:
"""Own the full lifetime of one MCP connection.

This coroutine is always run inside a dedicated asyncio.Task
(_connection_task). Because *this task* is the one that enters every
anyio cancel scope (via sse_client / streamablehttp_client), anyio's
_host_task check is always satisfied when the stack is later closed —
either in the task's own finally block (normal path) or when the task
is cancelled from outside (cleanup / reconnect path).

This avoids the
RuntimeError: Attempted to exit cancel scope in a different task
that previously occurred when aclose() was called from a different task
or from the asyncio async-generator GC finalizer.
"""
# Capture the stack in a local variable so that if self.exit_stack is
# overwritten by a concurrent _run_connection (during reconnect), this
# task's finally block still closes only the resources it opened.
stack = self.exit_stack = AsyncExitStack()
try:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
try:
await self._do_connect(mcp_server_config, name)
except Exception as exc:
if not ready.done():
ready.set_exception(exc)
raise
else:
if not ready.done():
ready.set_result(None)
# Hold the connection open until cancelled.
await asyncio.Event().wait()
finally:
try:
await stack.aclose()
except Exception as e:
logger.debug(f"Error closing exit stack for {name}: {e}")
# Clear the instance reference only if it still points to this task's
# stack; a concurrent reconnect may have already replaced it.
if self.exit_stack is stack:
self.exit_stack = None
# Guard against the task exiting before ready was resolved.
if not ready.done():
ready.set_exception(RuntimeError("Connection task exited early"))

async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
"""Connect to MCP server
"""Connect to MCP server by spawning a dedicated owner task.

The owner task (_connection_task) holds the AsyncExitStack and all
anyio cancel scopes for the lifetime of this connection. To disconnect,
cancel _connection_task — the finally block in _run_connection will call
aclose() from within the correct task context.

If `url` parameter exists:
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
Expand All @@ -412,10 +471,47 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server

"""
# Store config for reconnection
self._mcp_server_config = mcp_server_config
self._server_name = name

ready: asyncio.Future = asyncio.get_running_loop().create_future()

# Defensively cancel any existing connection task that was not cleaned
# up before this call (e.g. if connect_to_server is called twice).
if self._connection_task and not self._connection_task.done():
self._cancel_connection_task(self._connection_task)
self._connection_task = None

self._connection_task = asyncio.create_task(
self._run_connection(mcp_server_config, name, ready),
name=f"mcp-conn:{name}",
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
)

try:
await ready
except asyncio.CancelledError:
# Caller was cancelled while waiting — tear down the connection task.
# cancel() is asynchronous; the task will not finish until the next
# event-loop iteration, so we track it in _old_connection_tasks so
# that cleanup() can await it later.
if self._connection_task and not self._connection_task.done():
self._cancel_connection_task(self._connection_task)
self._connection_task = None
raise
except Exception:
# _do_connect raised; the connection task's finally block may still
# be running (e.g. awaiting stack.aclose()). Track it so that
# cleanup() can await it, but do NOT cancel it — we want the
# finally block to finish cleaning up resources naturally.
if self._connection_task and not self._connection_task.done():
self._old_connection_tasks.append(self._connection_task)
self._connection_task = None
raise

async def _do_connect(self, mcp_server_config: dict, name: str) -> None:
"""Internal: perform the actual connection inside _run_connection's task."""
# exit_stack is always set by _run_connection before _do_connect is called.
assert self.exit_stack is not None
cfg = _prepare_config(mcp_server_config.copy())

def logging_callback(
Expand Down Expand Up @@ -533,16 +629,32 @@ async def list_tools_and_save(self) -> mcp.ListToolsResult:
self.tools = response.tools
return response

def _cancel_connection_task(self, task: asyncio.Task) -> None:
"""Cancel a connection owner task and track it until it finishes."""
# Prune already-finished tasks to avoid accumulating references over
# many reconnections in a long-running process.
self._old_connection_tasks = [
t for t in self._old_connection_tasks if not t.done()
]
if task.done():
return
task.cancel()
self._old_connection_tasks.append(task)
Comment thread
lingyun14beta marked this conversation as resolved.

async def _reconnect(self) -> None:
"""Reconnect to the MCP server using the stored configuration.

Cancels the current _connection_task (which owns the exit_stack and all
anyio cancel scopes) and starts a fresh one. Because each connection
task enters and exits its own anyio cancel scope, there is no
cross-task cancel-scope violation and no GC finalizer surprise.

Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.

Raises:
Exception: raised when reconnection fails
"""
async with self._reconnect_lock:
# Check if already reconnecting (useful for logging)
if self._reconnecting:
logger.debug(
f"MCP Client {self._server_name} is already reconnecting, skipping"
Expand All @@ -558,17 +670,16 @@ async def _reconnect(self) -> None:
f"Attempting to reconnect to MCP server {self._server_name}..."
)

# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
if self.exit_stack:
self._old_exit_stacks.append(self.exit_stack)

# Mark old session as invalid
# Cancel the old connection task. Its finally block will call
# exit_stack.aclose() from within the correct task context, so
# anyio cancel scopes are exited cleanly without triggering the
# GC-finalizer busy-spin bug.
if self._connection_task and not self._connection_task.done():
self._cancel_connection_task(self._connection_task)
self._connection_task = None
self.session = None

# Create new exit stack for new connection
self.exit_stack = AsyncExitStack()

# Reconnect using stored config
# Reconnect — this creates a new _connection_task.
await self.connect_to_server(self._mcp_server_config, self._server_name)
await self.list_tools_and_save()

Expand Down Expand Up @@ -633,19 +744,20 @@ async def _call_with_retry():
return await _call_with_retry()

async def cleanup(self) -> None:
"""Clean up resources including old exit stacks from reconnections"""
# Close current exit stack
try:
await self.exit_stack.aclose()
except Exception as e:
logger.debug(f"Error closing current exit stack: {e}")

# Don't close old exit stacks as they may be in different task contexts
# They will be garbage collected naturally
# Just clear the list to release references
self._old_exit_stacks.clear()

# Set running_event first to unblock any waiting tasks
"""Clean up resources by cancelling the connection owner task."""
# Cancel current and any old connection tasks via the shared helper so
# all cancellation + tracking behaviour goes through one code path.
if self._connection_task:
self._cancel_connection_task(self._connection_task)
self._connection_task = None

if self._old_connection_tasks:
pending = [t for t in self._old_connection_tasks if not t.done()]
if pending:
await asyncio.gather(*pending, return_exceptions=True)
self._old_connection_tasks.clear()
Comment on lines +754 to +758
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Be careful about calling cleanup from within a connection task, as it will await on the task it just cancelled

Since cleanup cancels self._connection_task and then awaits _old_connection_tasks, if cleanup is ever invoked from inside the currently running _connection_task you’ll end up awaiting on the current task, which raises RuntimeError. If that call path is possible, consider avoiding adding the current task to _old_connection_tasks when cleanup is called from it, or skip awaiting any task equal to asyncio.current_task().


# Set running_event to unblock any waiting tasks
self.running_event.set()


Expand Down
Loading