Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from typing import Any, Generic

from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -125,6 +125,38 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"


_EMPTY_MCP_ARGUMENT = object()


def _sanitize_mcp_arguments(value: Any) -> Any:
"""Remove empty optional payload values before sending to MCP tools."""
if value is None:
return _EMPTY_MCP_ARGUMENT

if isinstance(value, str):
return value if value != "" else _EMPTY_MCP_ARGUMENT

if isinstance(value, list):
cleaned_items = []
for item in value:
cleaned_item = _sanitize_mcp_arguments(item)
if cleaned_item is _EMPTY_MCP_ARGUMENT:
continue
cleaned_items.append(cleaned_item)
return cleaned_items if cleaned_items else _EMPTY_MCP_ARGUMENT

if isinstance(value, dict):
cleaned_dict = {}
for key, item in value.items():
cleaned_item = _sanitize_mcp_arguments(item)
if cleaned_item is _EMPTY_MCP_ARGUMENT:
continue
cleaned_dict[key] = cleaned_item
return cleaned_dict if cleaned_dict else _EMPTY_MCP_ARGUMENT

return value


class MCPClient:
def __init__(self) -> None:
# Initialize session and client objects
Expand Down Expand Up @@ -347,6 +379,17 @@ async def call_tool_with_reconnect(
anyio.ClosedResourceError: raised after reconnection failure
"""

sanitized_arguments = _sanitize_mcp_arguments(arguments)
if sanitized_arguments is _EMPTY_MCP_ARGUMENT:
sanitized_arguments = {}
if sanitized_arguments != arguments:
logger.debug(
"Sanitized MCP tool %s arguments from %s to %s",
tool_name,
arguments,
sanitized_arguments,
)

@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
Expand All @@ -361,7 +404,7 @@ async def _call_with_retry():
try:
return await self.session.call_tool(
name=tool_name,
arguments=arguments,
arguments=sanitized_arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
Expand Down
127 changes: 127 additions & 0 deletions tests/unit/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import importlib.util
import logging
import sys
import types
from pathlib import Path
from typing import Generic, TypeVar
from unittest.mock import AsyncMock

import pytest

REPO_ROOT = Path(__file__).resolve().parents[2]
MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py"


def load_mcp_client_module():
package_names = [
"astrbot",
"astrbot.core",
"astrbot.core.agent",
"astrbot.core.utils",
]
for name in package_names:
if name not in sys.modules:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module

astrbot_module = sys.modules["astrbot"]
astrbot_module.logger = logging.getLogger("astrbot-test")

log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe")
log_pipe_module.LogPipe = type("LogPipe", (), {})
sys.modules[log_pipe_module.__name__] = log_pipe_module

run_context_module = types.ModuleType("astrbot.core.agent.run_context")
run_context_module.TContext = TypeVar("TContext")

class ContextWrapper(Generic[run_context_module.TContext]):
pass

run_context_module.ContextWrapper = ContextWrapper
sys.modules[run_context_module.__name__] = run_context_module

tool_module = types.ModuleType("astrbot.core.agent.tool")
tool_module.FunctionTool = type("FunctionTool", (), {})
sys.modules[tool_module.__name__] = tool_module

anyio_module = types.ModuleType("anyio")
anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {})
sys.modules["anyio"] = anyio_module

mcp_module = types.ModuleType("mcp")
mcp_module.Tool = type("Tool", (), {})
mcp_module.ClientSession = type("ClientSession", (), {})
mcp_module.ListToolsResult = type("ListToolsResult", (), {})
mcp_module.StdioServerParameters = type("StdioServerParameters", (), {})
mcp_module.stdio_client = lambda *args, **kwargs: None
mcp_module.types = types.SimpleNamespace(
LoggingMessageNotificationParams=type(
"LoggingMessageNotificationParams", (), {}
),
CallToolResult=type("CallToolResult", (), {}),
)
sys.modules["mcp"] = mcp_module

mcp_client_module = types.ModuleType("mcp.client")
sys.modules[mcp_client_module.__name__] = mcp_client_module

mcp_client_sse_module = types.ModuleType("mcp.client.sse")
mcp_client_sse_module.sse_client = lambda *args, **kwargs: None
sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module

mcp_client_streamable_http_module = types.ModuleType(
"mcp.client.streamable_http"
)
mcp_client_streamable_http_module.streamablehttp_client = (
lambda *args, **kwargs: None
)
sys.modules[mcp_client_streamable_http_module.__name__] = (
mcp_client_streamable_http_module
)

spec = importlib.util.spec_from_file_location(
"astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module


def test_sanitize_mcp_arguments_removes_nested_empty_collections():
mcp_client_module = load_mcp_client_module()

sanitized = mcp_client_module._sanitize_mcp_arguments(
{
"query": "hello",
"filters": {"tags": [], "scope": {}},
"metadata": {"owner": "", "visibility": None},
}
)

assert sanitized == {"query": "hello"}


@pytest.mark.asyncio
async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments():
mcp_client_module = load_mcp_client_module()

client = mcp_client_module.MCPClient()
client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok"))

result = await client.call_tool_with_reconnect(
tool_name="search",
arguments={"filters": {}, "query": ""},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)

assert result == "ok"
client.session.call_tool.assert_awaited_once_with(
name="search",
arguments={},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)