From a89eebea185c3e19bbf494bb8f4ae21c57b9b37a Mon Sep 17 00:00:00 2001 From: Tobi1chi <2826225051@qq.com> Date: Sat, 28 Mar 2026 22:18:22 +0000 Subject: [PATCH 1/3] fix: sanitize empty optional mcp arguments --- astrbot/core/agent/mcp_client.py | 47 ++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index af969a3fac..ddce6b1fee 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,7 +4,7 @@ import sys from contextlib import AsyncExitStack from datetime import timedelta -from typing import Generic +from typing import Any, Generic from tenacity import ( before_sleep_log, @@ -125,6 +125,38 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" +_EMPTY_MCP_ARGUMENT = object() + + +def _sanitize_mcp_arguments(value: Any) -> Any: + """Remove empty optional payload values before sending to MCP tools.""" + if value is None: + return _EMPTY_MCP_ARGUMENT + + if isinstance(value, str): + return value if value != "" else _EMPTY_MCP_ARGUMENT + + if isinstance(value, list): + cleaned_items = [] + for item in value: + cleaned_item = _sanitize_mcp_arguments(item) + if cleaned_item is _EMPTY_MCP_ARGUMENT: + continue + cleaned_items.append(cleaned_item) + return cleaned_items if cleaned_items else _EMPTY_MCP_ARGUMENT + + if isinstance(value, dict): + cleaned_dict = {} + for key, item in value.items(): + cleaned_item = _sanitize_mcp_arguments(item) + if cleaned_item is _EMPTY_MCP_ARGUMENT: + continue + cleaned_dict[key] = cleaned_item + return cleaned_dict + + return value + + class MCPClient: def __init__(self) -> None: # Initialize session and client objects @@ -347,6 +379,17 @@ async def call_tool_with_reconnect( anyio.ClosedResourceError: raised after reconnection failure """ + sanitized_arguments = _sanitize_mcp_arguments(arguments) + if sanitized_arguments is _EMPTY_MCP_ARGUMENT: + sanitized_arguments = {} + if sanitized_arguments != arguments: + logger.debug( + "Sanitized MCP tool %s arguments from %s to %s", + tool_name, + arguments, + sanitized_arguments, + ) + @retry( retry=retry_if_exception_type(anyio.ClosedResourceError), stop=stop_after_attempt(2), @@ -361,7 +404,7 @@ async def _call_with_retry(): try: return await self.session.call_tool( name=tool_name, - arguments=arguments, + arguments=sanitized_arguments, read_timeout_seconds=read_timeout_seconds, ) except anyio.ClosedResourceError: From 4200317e2bf58948fe4452165cbbba438d1bfcd5 Mon Sep 17 00:00:00 2001 From: Tobi1chi <2826225051@qq.com> Date: Sun, 29 Mar 2026 18:47:20 +0100 Subject: [PATCH 2/3] test: align nested empty MCP collection sanitization --- astrbot/core/agent/mcp_client.py | 2 +- tests/unit/test_mcp_client.py | 127 +++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_mcp_client.py diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index ddce6b1fee..2c860baeef 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -152,7 +152,7 @@ def _sanitize_mcp_arguments(value: Any) -> Any: if cleaned_item is _EMPTY_MCP_ARGUMENT: continue cleaned_dict[key] = cleaned_item - return cleaned_dict + return cleaned_dict if cleaned_dict else _EMPTY_MCP_ARGUMENT return value diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py new file mode 100644 index 0000000000..3941b1567a --- /dev/null +++ b/tests/unit/test_mcp_client.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import importlib.util +import logging +import sys +import types +from pathlib import Path +from typing import Generic, TypeVar +from unittest.mock import AsyncMock + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] +MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py" + + +def load_mcp_client_module(): + package_names = [ + "astrbot", + "astrbot.core", + "astrbot.core.agent", + "astrbot.core.utils", + ] + for name in package_names: + if name not in sys.modules: + module = types.ModuleType(name) + module.__path__ = [] + sys.modules[name] = module + + astrbot_module = sys.modules["astrbot"] + astrbot_module.logger = logging.getLogger("astrbot-test") + + log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe") + log_pipe_module.LogPipe = type("LogPipe", (), {}) + sys.modules[log_pipe_module.__name__] = log_pipe_module + + run_context_module = types.ModuleType("astrbot.core.agent.run_context") + run_context_module.TContext = TypeVar("TContext") + + class ContextWrapper(Generic[run_context_module.TContext]): + pass + + run_context_module.ContextWrapper = ContextWrapper + sys.modules[run_context_module.__name__] = run_context_module + + tool_module = types.ModuleType("astrbot.core.agent.tool") + tool_module.FunctionTool = type("FunctionTool", (), {}) + sys.modules[tool_module.__name__] = tool_module + + anyio_module = types.ModuleType("anyio") + anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {}) + sys.modules["anyio"] = anyio_module + + mcp_module = types.ModuleType("mcp") + mcp_module.Tool = type("Tool", (), {}) + mcp_module.ClientSession = type("ClientSession", (), {}) + mcp_module.ListToolsResult = type("ListToolsResult", (), {}) + mcp_module.StdioServerParameters = type("StdioServerParameters", (), {}) + mcp_module.stdio_client = lambda *args, **kwargs: None + mcp_module.types = types.SimpleNamespace( + LoggingMessageNotificationParams=type( + "LoggingMessageNotificationParams", (), {} + ), + CallToolResult=type("CallToolResult", (), {}), + ) + sys.modules["mcp"] = mcp_module + + mcp_client_module = types.ModuleType("mcp.client") + sys.modules[mcp_client_module.__name__] = mcp_client_module + + mcp_client_sse_module = types.ModuleType("mcp.client.sse") + mcp_client_sse_module.sse_client = lambda *args, **kwargs: None + sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module + + mcp_client_streamable_http_module = types.ModuleType( + "mcp.client.streamable_http" + ) + mcp_client_streamable_http_module.streamablehttp_client = ( + lambda *args, **kwargs: None + ) + sys.modules[mcp_client_streamable_http_module.__name__] = ( + mcp_client_streamable_http_module + ) + + spec = importlib.util.spec_from_file_location( + "astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH + ) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_sanitize_mcp_arguments_removes_nested_empty_collections(): + mcp_client_module = load_mcp_client_module() + + sanitized = mcp_client_module._sanitize_mcp_arguments( + { + "query": "hello", + "filters": {"tags": [], "scope": {}}, + "metadata": {"owner": "", "visibility": None}, + } + ) + + assert sanitized == {"query": "hello"} + + +@pytest.mark.asyncio +async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments(): + mcp_client_module = load_mcp_client_module() + + client = mcp_client_module.MCPClient() + client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok")) + + result = await client.call_tool_with_reconnect( + tool_name="search", + arguments={"filters": {}, "query": ""}, + read_timeout_seconds=mcp_client_module.timedelta(seconds=1), + ) + + assert result == "ok" + client.session.call_tool.assert_awaited_once_with( + name="search", + arguments={}, + read_timeout_seconds=mcp_client_module.timedelta(seconds=1), + ) From 97dc7c87ff5871017a8c2e1826be9c15b268882a Mon Sep 17 00:00:00 2001 From: Tobi1chi <2826225051@qq.com> Date: Mon, 30 Mar 2026 09:04:36 +0000 Subject: [PATCH 3/3] fix: preserve schema-valid empty mcp arguments --- astrbot/core/agent/mcp_client.py | 45 ++++++++++++---- tests/unit/test_mcp_client_sanitization.py | 61 ++++++++++++++++++++++ tests/unit/test_tool_google_schema.py | 22 +++++--- 3 files changed, 111 insertions(+), 17 deletions(-) create mode 100644 tests/unit/test_mcp_client_sanitization.py diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 2c860baeef..2e145abf84 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -128,31 +128,52 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: _EMPTY_MCP_ARGUMENT = object() -def _sanitize_mcp_arguments(value: Any) -> Any: +def _sanitize_mcp_arguments( + value: Any, + schema: dict[str, Any] | None = None, + *, + required: bool = False, +) -> Any: """Remove empty optional payload values before sending to MCP tools.""" if value is None: - return _EMPTY_MCP_ARGUMENT + return value if required else _EMPTY_MCP_ARGUMENT if isinstance(value, str): - return value if value != "" else _EMPTY_MCP_ARGUMENT + return value if value != "" or required else _EMPTY_MCP_ARGUMENT if isinstance(value, list): + if not value: + return value if required else _EMPTY_MCP_ARGUMENT cleaned_items = [] + item_schema = schema.get("items") if isinstance(schema, dict) else None for item in value: - cleaned_item = _sanitize_mcp_arguments(item) + cleaned_item = _sanitize_mcp_arguments(item, item_schema) + # Preserve list positions. If sanitizing an item would remove it, + # keep the original item instead of reindexing the payload. if cleaned_item is _EMPTY_MCP_ARGUMENT: - continue - cleaned_items.append(cleaned_item) - return cleaned_items if cleaned_items else _EMPTY_MCP_ARGUMENT + cleaned_items.append(item) + else: + cleaned_items.append(cleaned_item) + return cleaned_items if isinstance(value, dict): + if not value: + return value if required else _EMPTY_MCP_ARGUMENT + cleaned_dict = {} + properties = schema.get("properties", {}) if isinstance(schema, dict) else {} + required_keys = set(schema.get("required", [])) if isinstance(schema, dict) else set() for key, item in value.items(): - cleaned_item = _sanitize_mcp_arguments(item) + child_schema = properties.get(key) if isinstance(properties, dict) else None + cleaned_item = _sanitize_mcp_arguments( + item, + child_schema, + required=key in required_keys, + ) if cleaned_item is _EMPTY_MCP_ARGUMENT: continue cleaned_dict[key] = cleaned_item - return cleaned_dict if cleaned_dict else _EMPTY_MCP_ARGUMENT + return cleaned_dict if cleaned_dict or required else _EMPTY_MCP_ARGUMENT return value @@ -379,7 +400,11 @@ async def call_tool_with_reconnect( anyio.ClosedResourceError: raised after reconnection failure """ - sanitized_arguments = _sanitize_mcp_arguments(arguments) + tool_schema = next( + (tool.inputSchema for tool in self.tools if tool.name == tool_name), + None, + ) + sanitized_arguments = _sanitize_mcp_arguments(arguments, tool_schema) if sanitized_arguments is _EMPTY_MCP_ARGUMENT: sanitized_arguments = {} if sanitized_arguments != arguments: diff --git a/tests/unit/test_mcp_client_sanitization.py b/tests/unit/test_mcp_client_sanitization.py new file mode 100644 index 0000000000..a0b19a4a16 --- /dev/null +++ b/tests/unit/test_mcp_client_sanitization.py @@ -0,0 +1,61 @@ +from astrbot.core.agent.mcp_client import _sanitize_mcp_arguments + + +def test_sanitize_mcp_arguments_drops_empty_optional_object_fields(): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "area": {"type": "string"}, + "floor": {"type": "string"}, + "domain": {"type": "array", "items": {"type": "string"}}, + "device_class": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["name"], + } + + value = { + "name": "demo", + "area": "", + "floor": "", + "domain": ["light"], + "device_class": [], + } + + assert _sanitize_mcp_arguments(value, schema) == { + "name": "demo", + "domain": ["light"], + } + + +def test_sanitize_mcp_arguments_preserves_required_empty_values(): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "metadata": { + "type": "object", + "properties": { + "note": {"type": "string"}, + }, + }, + }, + "required": ["name", "tags", "metadata"], + } + + value = { + "name": "", + "tags": [], + "metadata": {}, + } + + assert _sanitize_mcp_arguments(value, schema) == value + + +def test_sanitize_mcp_arguments_preserves_list_positions(): + schema = {"type": "array", "items": {"type": "string"}} + + value = ["alpha", "", "omega"] + + assert _sanitize_mcp_arguments(value, schema) == ["alpha", "", "omega"] diff --git a/tests/unit/test_tool_google_schema.py b/tests/unit/test_tool_google_schema.py index f1046e6af3..e808c9a568 100644 --- a/tests/unit/test_tool_google_schema.py +++ b/tests/unit/test_tool_google_schema.py @@ -6,11 +6,13 @@ from pathlib import Path from typing import Generic, TypeVar +import pytest + REPO_ROOT = Path(__file__).resolve().parents[2] TOOL_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/tool.py" -def load_tool_module(): +def load_tool_module(monkeypatch: pytest.MonkeyPatch): package_names = [ "astrbot", "astrbot.core", @@ -21,13 +23,17 @@ def load_tool_module(): if name not in sys.modules: module = types.ModuleType(name) module.__path__ = [] - sys.modules[name] = module + monkeypatch.setitem(sys.modules, name, module) message_result_module = types.ModuleType( "astrbot.core.message.message_event_result" ) message_result_module.MessageEventResult = type("MessageEventResult", (), {}) - sys.modules[message_result_module.__name__] = message_result_module + monkeypatch.setitem( + sys.modules, + message_result_module.__name__, + message_result_module, + ) run_context_module = types.ModuleType("astrbot.core.agent.run_context") run_context_module.TContext = TypeVar("TContext") @@ -36,20 +42,22 @@ class ContextWrapper(Generic[run_context_module.TContext]): pass run_context_module.ContextWrapper = ContextWrapper - sys.modules[run_context_module.__name__] = run_context_module + monkeypatch.setitem(sys.modules, run_context_module.__name__, run_context_module) spec = importlib.util.spec_from_file_location( "astrbot.core.agent.tool", TOOL_MODULE_PATH ) assert spec and spec.loader module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module + monkeypatch.setitem(sys.modules, spec.name, module) spec.loader.exec_module(module) return module -def test_google_schema_fills_missing_array_items_with_string_schema(): - tool_module = load_tool_module() +def test_google_schema_fills_missing_array_items_with_string_schema( + monkeypatch: pytest.MonkeyPatch, +): + tool_module = load_tool_module(monkeypatch) FunctionTool = tool_module.FunctionTool ToolSet = tool_module.ToolSet