Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import base64
import logging
import re
Expand Down Expand Up @@ -370,6 +371,38 @@ def functions(self) -> list[AIFunction[Any, Any]]:
return self._functions
return [func for func in self._functions if func.name in self.allowed_tools]

async def _safe_close_exit_stack(self) -> None:
"""Safely close the exit stack, handling cross-task boundary errors.

anyio's cancel scopes are bound to the task they were created in.
If aclose() is called from a different task (e.g., during streaming reconnection),
anyio will raise a RuntimeError or CancelledError. In this case, we log a warning
and allow garbage collection to clean up the resources.

Known error variants:
- "Attempted to exit cancel scope in a different task than it was entered in"
- "Attempted to exit a cancel scope that isn't the current task's current cancel scope"
- CancelledError from anyio cancel scope cleanup
"""
try:
await self._exit_stack.aclose()
except RuntimeError as e:
error_msg = str(e).lower()
# Check for anyio cancel scope errors (multiple variants exist)
if "cancel scope" in error_msg:
logger.warning(
"Could not cleanly close MCP exit stack due to cancel scope error. "
"Old resources will be garbage collected. Error: %s",
e,
)
else:
raise
except asyncio.CancelledError:
# CancelledError can occur during cleanup when cancel scopes are involved
logger.warning(
"Could not cleanly close MCP exit stack due to cancellation. Old resources will be garbage collected."
)

async def connect(self, *, reset: bool = False) -> None:
"""Connect to the MCP server.

Expand All @@ -383,15 +416,15 @@ async def connect(self, *, reset: bool = False) -> None:
ToolException: If connection or session initialization fails.
"""
if reset:
await self._exit_stack.aclose()
await self._safe_close_exit_stack()
self.session = None
self.is_connected = False
self._exit_stack = AsyncExitStack()
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
except Exception as ex:
await self._exit_stack.aclose()
await self._safe_close_exit_stack()
command = getattr(self, "command", None)
if command:
error_msg = f"Failed to start MCP server '{command}': {ex}"
Expand All @@ -412,15 +445,15 @@ async def connect(self, *, reset: bool = False) -> None:
)
)
except Exception as ex:
await self._exit_stack.aclose()
await self._safe_close_exit_stack()
raise ToolException(
message="Failed to create MCP session. Please check your configuration.",
inner_exception=ex,
) from ex
try:
await session.initialize()
except Exception as ex:
await self._exit_stack.aclose()
await self._safe_close_exit_stack()
# Provide context about initialization failure
command = getattr(self, "command", None)
if command:
Expand Down Expand Up @@ -664,7 +697,7 @@ async def close(self) -> None:

Closes the connection and cleans up resources.
"""
await self._exit_stack.aclose()
await self._safe_close_exit_stack()
self.session = None
self.is_connected = False

Expand Down Expand Up @@ -836,7 +869,7 @@ async def __aenter__(self) -> Self:
except ToolException:
raise
except Exception as ex:
await self._exit_stack.aclose()
await self._safe_close_exit_stack()
raise ToolExecutionException("Failed to enter context manager.", inner_exception=ex) from ex

async def __aexit__(
Expand Down
151 changes: 151 additions & 0 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,3 +2363,154 @@ async def always_fail(*args, **kwargs):
assert mock_connect.call_count >= 1
# Verify error message indicates reconnection failure
assert "failed to reconnect" in str(exc_info.value).lower()


async def test_mcp_tool_reconnection_handles_cross_task_cancel_scope_error():
"""Test that reconnection gracefully handles anyio cancel scope errors.

This tests the fix for the bug where calling connect(reset=True) from a
different task than where the connection was originally established would
cause: RuntimeError: Attempted to exit cancel scope in a different task
than it was entered in

This happens when using multiple MCP tools with AG-UI streaming - the first
tool call succeeds, but when the connection closes, the second tool call
triggers a reconnection from within the streaming loop (a different task).
"""
from contextlib import AsyncExitStack

from agent_framework._mcp import MCPStdioTool

# Use load_tools=False and load_prompts=False to avoid triggering them during connect()
tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=False,
load_prompts=False,
)

# Mock the exit stack to raise the cross-task cancel scope error
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(
side_effect=RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
)
tool._exit_stack = mock_exit_stack
tool.session = Mock()
tool.is_connected = True

# Mock get_mcp_client to return a mock transport
mock_transport = (Mock(), Mock())
mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(return_value=mock_transport)
mock_context.__aexit__ = AsyncMock()

with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession") as mock_session_class,
):
mock_session = Mock()
mock_session._request_id = 1
mock_session.initialize = AsyncMock()
mock_session.set_logging_level = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock()
mock_session_class.return_value = mock_session_context

# This should NOT raise even though aclose() raised the cancel scope error
# The _safe_close_exit_stack method should catch and log the error
await tool.connect(reset=True)

# Verify a new exit stack was created (the old mock was replaced)
assert tool._exit_stack is not mock_exit_stack
assert tool.session is not None
assert tool.is_connected is True


async def test_mcp_tool_safe_close_reraises_other_runtime_errors():
"""Test that _safe_close_exit_stack re-raises RuntimeErrors that aren't cancel scope related."""
from contextlib import AsyncExitStack

from agent_framework._mcp import MCPStdioTool

tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=True,
)

# Mock the exit stack to raise a different RuntimeError
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(side_effect=RuntimeError("Some other runtime error"))
tool._exit_stack = mock_exit_stack

# This should re-raise the RuntimeError since it's not about cancel scopes
with pytest.raises(RuntimeError) as exc_info:
await tool._safe_close_exit_stack()

assert "Some other runtime error" in str(exc_info.value)


async def test_mcp_tool_safe_close_handles_alternate_cancel_scope_error():
"""Test that _safe_close_exit_stack handles the alternate cancel scope error message.

anyio has multiple variants of cancel scope errors:
- "Attempted to exit cancel scope in a different task than it was entered in"
- "Attempted to exit a cancel scope that isn't the current task's current cancel scope"
"""
from contextlib import AsyncExitStack

from agent_framework._mcp import MCPStdioTool

tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=False,
load_prompts=False,
)

# Mock the exit stack to raise the alternate cancel scope error
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(
side_effect=RuntimeError("Attempted to exit a cancel scope that isn't the current task's current cancel scope")
)
tool._exit_stack = mock_exit_stack

# This should NOT raise - the error should be caught and logged
await tool._safe_close_exit_stack()

# Verify aclose was called
mock_exit_stack.aclose.assert_called_once()


async def test_mcp_tool_safe_close_handles_cancelled_error():
"""Test that _safe_close_exit_stack handles asyncio.CancelledError.

CancelledError can occur during cleanup when anyio cancel scopes are involved.
"""
import asyncio
from contextlib import AsyncExitStack

from agent_framework._mcp import MCPStdioTool

tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=False,
load_prompts=False,
)

# Mock the exit stack to raise CancelledError
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(side_effect=asyncio.CancelledError())
tool._exit_stack = mock_exit_stack

# This should NOT raise - the CancelledError should be caught and logged
await tool._safe_close_exit_stack()

# Verify aclose was called
mock_exit_stack.aclose.assert_called_once()
Loading