diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index af969a3fac..2e145abf84 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,7 +4,7 @@ import sys from contextlib import AsyncExitStack from datetime import timedelta -from typing import Generic +from typing import Any, Generic from tenacity import ( before_sleep_log, @@ -125,6 +125,59 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" +_EMPTY_MCP_ARGUMENT = object() + + +def _sanitize_mcp_arguments( + value: Any, + schema: dict[str, Any] | None = None, + *, + required: bool = False, +) -> Any: + """Remove empty optional payload values before sending to MCP tools.""" + if value is None: + return value if required else _EMPTY_MCP_ARGUMENT + + if isinstance(value, str): + return value if value != "" or required else _EMPTY_MCP_ARGUMENT + + if isinstance(value, list): + if not value: + return value if required else _EMPTY_MCP_ARGUMENT + cleaned_items = [] + item_schema = schema.get("items") if isinstance(schema, dict) else None + for item in value: + cleaned_item = _sanitize_mcp_arguments(item, item_schema) + # Preserve list positions. If sanitizing an item would remove it, + # keep the original item instead of reindexing the payload. + if cleaned_item is _EMPTY_MCP_ARGUMENT: + cleaned_items.append(item) + else: + cleaned_items.append(cleaned_item) + return cleaned_items + + if isinstance(value, dict): + if not value: + return value if required else _EMPTY_MCP_ARGUMENT + + cleaned_dict = {} + properties = schema.get("properties", {}) if isinstance(schema, dict) else {} + required_keys = set(schema.get("required", [])) if isinstance(schema, dict) else set() + for key, item in value.items(): + child_schema = properties.get(key) if isinstance(properties, dict) else None + cleaned_item = _sanitize_mcp_arguments( + item, + child_schema, + required=key in required_keys, + ) + if cleaned_item is _EMPTY_MCP_ARGUMENT: + continue + cleaned_dict[key] = cleaned_item + return cleaned_dict if cleaned_dict or required else _EMPTY_MCP_ARGUMENT + + return value + + class MCPClient: def __init__(self) -> None: # Initialize session and client objects @@ -347,6 +400,21 @@ async def call_tool_with_reconnect( anyio.ClosedResourceError: raised after reconnection failure """ + tool_schema = next( + (tool.inputSchema for tool in self.tools if tool.name == tool_name), + None, + ) + sanitized_arguments = _sanitize_mcp_arguments(arguments, tool_schema) + if sanitized_arguments is _EMPTY_MCP_ARGUMENT: + sanitized_arguments = {} + if sanitized_arguments != arguments: + logger.debug( + "Sanitized MCP tool %s arguments from %s to %s", + tool_name, + arguments, + sanitized_arguments, + ) + @retry( retry=retry_if_exception_type(anyio.ClosedResourceError), stop=stop_after_attempt(2), @@ -361,7 +429,7 @@ async def _call_with_retry(): try: return await self.session.call_tool( name=tool_name, - arguments=arguments, + arguments=sanitized_arguments, read_timeout_seconds=read_timeout_seconds, ) except anyio.ClosedResourceError: diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py new file mode 100644 index 0000000000..3941b1567a --- /dev/null +++ b/tests/unit/test_mcp_client.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import importlib.util +import logging +import sys +import types +from pathlib import Path +from typing import Generic, TypeVar +from unittest.mock import AsyncMock + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] +MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py" + + +def load_mcp_client_module(): + package_names = [ + "astrbot", + "astrbot.core", + "astrbot.core.agent", + "astrbot.core.utils", + ] + for name in package_names: + if name not in sys.modules: + module = types.ModuleType(name) + module.__path__ = [] + sys.modules[name] = module + + astrbot_module = sys.modules["astrbot"] + astrbot_module.logger = logging.getLogger("astrbot-test") + + log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe") + log_pipe_module.LogPipe = type("LogPipe", (), {}) + sys.modules[log_pipe_module.__name__] = log_pipe_module + + run_context_module = types.ModuleType("astrbot.core.agent.run_context") + run_context_module.TContext = TypeVar("TContext") + + class ContextWrapper(Generic[run_context_module.TContext]): + pass + + run_context_module.ContextWrapper = ContextWrapper + sys.modules[run_context_module.__name__] = run_context_module + + tool_module = types.ModuleType("astrbot.core.agent.tool") + tool_module.FunctionTool = type("FunctionTool", (), {}) + sys.modules[tool_module.__name__] = tool_module + + anyio_module = types.ModuleType("anyio") + anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {}) + sys.modules["anyio"] = anyio_module + + mcp_module = types.ModuleType("mcp") + mcp_module.Tool = type("Tool", (), {}) + mcp_module.ClientSession = type("ClientSession", (), {}) + mcp_module.ListToolsResult = type("ListToolsResult", (), {}) + mcp_module.StdioServerParameters = type("StdioServerParameters", (), {}) + mcp_module.stdio_client = lambda *args, **kwargs: None + mcp_module.types = types.SimpleNamespace( + LoggingMessageNotificationParams=type( + "LoggingMessageNotificationParams", (), {} + ), + CallToolResult=type("CallToolResult", (), {}), + ) + sys.modules["mcp"] = mcp_module + + mcp_client_module = types.ModuleType("mcp.client") + sys.modules[mcp_client_module.__name__] = mcp_client_module + + mcp_client_sse_module = types.ModuleType("mcp.client.sse") + mcp_client_sse_module.sse_client = lambda *args, **kwargs: None + sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module + + mcp_client_streamable_http_module = types.ModuleType( + "mcp.client.streamable_http" + ) + mcp_client_streamable_http_module.streamablehttp_client = ( + lambda *args, **kwargs: None + ) + sys.modules[mcp_client_streamable_http_module.__name__] = ( + mcp_client_streamable_http_module + ) + + spec = importlib.util.spec_from_file_location( + "astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH + ) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_sanitize_mcp_arguments_removes_nested_empty_collections(): + mcp_client_module = load_mcp_client_module() + + sanitized = mcp_client_module._sanitize_mcp_arguments( + { + "query": "hello", + "filters": {"tags": [], "scope": {}}, + "metadata": {"owner": "", "visibility": None}, + } + ) + + assert sanitized == {"query": "hello"} + + +@pytest.mark.asyncio +async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments(): + mcp_client_module = load_mcp_client_module() + + client = mcp_client_module.MCPClient() + client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok")) + + result = await client.call_tool_with_reconnect( + tool_name="search", + arguments={"filters": {}, "query": ""}, + read_timeout_seconds=mcp_client_module.timedelta(seconds=1), + ) + + assert result == "ok" + client.session.call_tool.assert_awaited_once_with( + name="search", + arguments={}, + read_timeout_seconds=mcp_client_module.timedelta(seconds=1), + ) diff --git a/tests/unit/test_mcp_client_sanitization.py b/tests/unit/test_mcp_client_sanitization.py new file mode 100644 index 0000000000..a0b19a4a16 --- /dev/null +++ b/tests/unit/test_mcp_client_sanitization.py @@ -0,0 +1,61 @@ +from astrbot.core.agent.mcp_client import _sanitize_mcp_arguments + + +def test_sanitize_mcp_arguments_drops_empty_optional_object_fields(): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "area": {"type": "string"}, + "floor": {"type": "string"}, + "domain": {"type": "array", "items": {"type": "string"}}, + "device_class": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["name"], + } + + value = { + "name": "demo", + "area": "", + "floor": "", + "domain": ["light"], + "device_class": [], + } + + assert _sanitize_mcp_arguments(value, schema) == { + "name": "demo", + "domain": ["light"], + } + + +def test_sanitize_mcp_arguments_preserves_required_empty_values(): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "metadata": { + "type": "object", + "properties": { + "note": {"type": "string"}, + }, + }, + }, + "required": ["name", "tags", "metadata"], + } + + value = { + "name": "", + "tags": [], + "metadata": {}, + } + + assert _sanitize_mcp_arguments(value, schema) == value + + +def test_sanitize_mcp_arguments_preserves_list_positions(): + schema = {"type": "array", "items": {"type": "string"}} + + value = ["alpha", "", "omega"] + + assert _sanitize_mcp_arguments(value, schema) == ["alpha", "", "omega"] diff --git a/tests/unit/test_tool_google_schema.py b/tests/unit/test_tool_google_schema.py index f1046e6af3..e808c9a568 100644 --- a/tests/unit/test_tool_google_schema.py +++ b/tests/unit/test_tool_google_schema.py @@ -6,11 +6,13 @@ from pathlib import Path from typing import Generic, TypeVar +import pytest + REPO_ROOT = Path(__file__).resolve().parents[2] TOOL_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/tool.py" -def load_tool_module(): +def load_tool_module(monkeypatch: pytest.MonkeyPatch): package_names = [ "astrbot", "astrbot.core", @@ -21,13 +23,17 @@ def load_tool_module(): if name not in sys.modules: module = types.ModuleType(name) module.__path__ = [] - sys.modules[name] = module + monkeypatch.setitem(sys.modules, name, module) message_result_module = types.ModuleType( "astrbot.core.message.message_event_result" ) message_result_module.MessageEventResult = type("MessageEventResult", (), {}) - sys.modules[message_result_module.__name__] = message_result_module + monkeypatch.setitem( + sys.modules, + message_result_module.__name__, + message_result_module, + ) run_context_module = types.ModuleType("astrbot.core.agent.run_context") run_context_module.TContext = TypeVar("TContext") @@ -36,20 +42,22 @@ class ContextWrapper(Generic[run_context_module.TContext]): pass run_context_module.ContextWrapper = ContextWrapper - sys.modules[run_context_module.__name__] = run_context_module + monkeypatch.setitem(sys.modules, run_context_module.__name__, run_context_module) spec = importlib.util.spec_from_file_location( "astrbot.core.agent.tool", TOOL_MODULE_PATH ) assert spec and spec.loader module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module + monkeypatch.setitem(sys.modules, spec.name, module) spec.loader.exec_module(module) return module -def test_google_schema_fills_missing_array_items_with_string_schema(): - tool_module = load_tool_module() +def test_google_schema_fills_missing_array_items_with_string_schema( + monkeypatch: pytest.MonkeyPatch, +): + tool_module = load_tool_module(monkeypatch) FunctionTool = tool_module.FunctionTool ToolSet = tool_module.ToolSet