diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 0363e2d55d..2139a2c8f0 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -39,6 +39,8 @@ def __init__( def default_parameters(self) -> dict: return { "type": "object", + "required": ["input"], + "additionalProperties": False, "properties": { "input": { "type": "string", diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 1fb4b03368..325627c0f1 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -44,6 +44,113 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + @classmethod + def _build_handoff_error_result( + cls, + *, + tool_name: str, + error_type: str, + fix_hint: str, + action_hint: str, + ) -> mcp.types.CallToolResult: + guidance = ( + "[handoff CALL FAILED - IMMEDIATE RETRY REQUIRED]\n" + f"error_type: {error_type}\n" + f"fix: {fix_hint}\n" + f"action: {action_hint}\n" + "example:\n" + "{\n" + ' "input": "Summarize the user request, constraints, and expected output.",\n' + ' "background_task": false\n' + "}" + ) + return mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=f"error: {tool_name} rejected invalid handoff request.\n{guidance}", + ) + ] + ) + + @classmethod + def _parse_background_task_arg( + cls, + tool_name: str, + value: T.Any, + ) -> tuple[bool, mcp.types.CallToolResult | None]: + if value is None: + return False, None + if isinstance(value, bool): + return value, None + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True, None + if normalized in {"false", "0", "no", "off", ""}: + return False, None + + return False, cls._build_handoff_error_result( + tool_name=tool_name, + error_type="invalid_background_task", + fix_hint=( + "`background_task` must be a boolean (`true` or `false`) or a string " + 'equivalent such as `"1"`/`"0"`, `"yes"`/`"no"`, or ' + '`"on"`/`"off"`.' + ), + action_hint=( + "Retry the same handoff with `background_task` set to a boolean or one " + 'of the supported string equivalents (`"true"`, `"false"`, ' + '`"1"`, `"0"`, `"yes"`, `"no"`, `"on"`, `"off"`).' + ), + ) + + @classmethod + def _normalize_handoff_input( + cls, + tool_name: str, + input_value: T.Any, + ) -> tuple[str | None, mcp.types.CallToolResult | None]: + if not isinstance(input_value, str) or not input_value.strip(): + return None, cls._build_handoff_error_result( + tool_name=tool_name, + error_type="missing_or_empty_input", + fix_hint=( + "Provide a non-empty `input` string that clearly describes the delegated task." + ), + action_hint=("Retry now with a concise task statement in `input`."), + ) + return input_value.strip(), None + + @classmethod + async def _resolve_handoff_provider_id( + cls, + tool: HandoffTool, + *, + ctx: T.Any, + umo: str, + ) -> str: + configured_provider_id = str(getattr(tool, "provider_id", "") or "").strip() + if not configured_provider_id: + return await ctx.get_current_chat_provider_id(umo) + + provider_mgr = getattr(ctx, "provider_manager", None) + if provider_mgr is None or not hasattr(provider_mgr, "get_provider_by_id"): + return configured_provider_id + + provider_inst = await provider_mgr.get_provider_by_id(configured_provider_id) + if provider_inst is not None: + return configured_provider_id + + fallback_provider_id = await ctx.get_current_chat_provider_id(umo) + logger.warning( + "Subagent %s configured provider `%s` not found, fallback to `%s`.", + tool.name, + configured_provider_id, + fallback_provider_id, + ) + return fallback_provider_id + @classmethod def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: if image_urls_raw is None: @@ -130,7 +237,13 @@ async def execute(cls, tool, run_context, **tool_args): """ if isinstance(tool, HandoffTool): - is_bg = tool_args.pop("background_task", False) + is_bg, bg_error = cls._parse_background_task_arg( + tool.name, + tool_args.pop("background_task", None), + ) + if bg_error is not None: + yield bg_error + return if is_bg: async for r in cls._execute_handoff_background( tool, run_context, **tool_args @@ -245,7 +358,14 @@ async def _execute_handoff( **tool_args: T.Any, ): tool_args = dict(tool_args) - input_ = tool_args.get("input") + input_, input_error = cls._normalize_handoff_input( + tool.name, + tool_args.get("input"), + ) + if input_error is not None: + yield input_error + return + tool_args["input"] = input_ if image_urls_prepared: prepared_image_urls = tool_args.get("image_urls") if isinstance(prepared_image_urls, list): @@ -272,9 +392,11 @@ async def _execute_handoff( # Use per-subagent provider override if configured; otherwise fall back # to the current/default provider resolution. - prov_id = getattr( - tool, "provider_id", None - ) or await ctx.get_current_chat_provider_id(umo) + prov_id = await cls._resolve_handoff_provider_id( + tool, + ctx=ctx, + umo=umo, + ) # prepare begin dialogs contexts = None diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 5fab9fe0a2..c146874645 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -343,3 +343,96 @@ async def _fake_convert_to_file_path(self): ) assert image_urls == [] + + +@pytest.mark.asyncio +async def test_execute_handoff_rejects_empty_input(): + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_tool_loop_agent(**_kwargs): + return SimpleNamespace(completion_text="ok") + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = SimpleNamespace( + name="transfer_to_subagent", + provider_id=None, + agent=SimpleNamespace( + name="subagent", + tools=[], + instructions="subagent-instructions", + begin_dialogs=[], + run_hooks=None, + ), + ) + + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input=" ", + image_urls=[], + ): + results.append(result) + + assert len(results) == 1 + assert isinstance(results[0], mcp.types.CallToolResult) + text_content = results[0].content[0] + assert isinstance(text_content, mcp.types.TextContent) + assert "missing_or_empty_input" in text_content.text + + +@pytest.mark.asyncio +async def test_execute_handoff_falls_back_to_current_provider_when_configured_missing(): + captured: dict = {} + + class _DummyProviderManager: + async def get_provider_by_id(self, _provider_id: str): + return None + + async def _fake_get_current_chat_provider_id(_umo): + return "fallback-provider" + + async def _fake_tool_loop_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(completion_text="ok") + + context = SimpleNamespace( + provider_manager=_DummyProviderManager(), + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = SimpleNamespace( + name="transfer_to_subagent", + provider_id="missing-provider-id", + agent=SimpleNamespace( + name="subagent", + tools=[], + instructions="subagent-instructions", + begin_dialogs=[], + run_hooks=None, + ), + ) + + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input="hello", + image_urls=[], + ): + results.append(result) + + assert len(results) == 1 + assert captured["chat_provider_id"] == "fallback-provider" diff --git a/tests/unit/test_handoff_background_task_arg.py b/tests/unit/test_handoff_background_task_arg.py new file mode 100644 index 0000000000..8866c3e3bc --- /dev/null +++ b/tests/unit/test_handoff_background_task_arg.py @@ -0,0 +1,258 @@ +from types import SimpleNamespace + +import mcp +import pytest + +from astrbot.core.agent.agent import Agent +from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + + +def _build_tool_and_run_context(): + tool = HandoffTool(agent=Agent(name="subagent")) + event = SimpleNamespace( + unified_msg_origin="webchat:FriendMessage:webchat!user!session", + message_obj=SimpleNamespace(message=[]), + ) + run_context = ContextWrapper( + context=SimpleNamespace(event=event, context=SimpleNamespace()) + ) + return tool, run_context + + +@pytest.mark.parametrize( + ("value", "expected_bool", "expect_error"), + [ + (True, True, False), + ("true", True, False), + ("1", True, False), + ("yes", True, False), + ("on", True, False), + (" TRUE ", True, False), + (False, False, False), + ("false", False, False), + ("0", False, False), + ("no", False, False), + ("off", False, False), + ("", False, False), + (" FALSE ", False, False), + (None, False, False), + ("not-a-bool", False, True), + ("y", False, True), + ("t", False, True), + (123, False, True), + ({}, False, True), + ], +) +def test_parse_background_task_arg(value, expected_bool, expect_error): + is_bg, error = FunctionToolExecutor._parse_background_task_arg( + "transfer_to_subagent", + value, + ) + + assert is_bg is expected_bool + if expect_error: + assert error is not None + assert isinstance(error, mcp.types.CallToolResult) + text_content = error.content[0] + assert isinstance(text_content, mcp.types.TextContent) + assert "invalid_background_task" in text_content.text + else: + assert error is None + + +def test_handoff_default_parameters_align_with_background_task_handling(): + params = HandoffTool(agent=Agent(name="subagent")).parameters + assert params["required"] == ["input"] + assert params["additionalProperties"] is False + assert set(params["properties"]) == {"input", "image_urls", "background_task"} + assert params["properties"]["background_task"]["type"] == "boolean" + + +@pytest.mark.asyncio +async def test_execute_invalid_background_task_early_error(monkeypatch): + call_count = {"handoff": 0, "handoff_bg": 0} + + async def _fake_execute_handoff(cls, tool, run_context, **tool_args): + call_count["handoff"] += 1 + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="unexpected")] + ) + + async def _fake_execute_handoff_bg(cls, tool, run_context, **tool_args): + call_count["handoff_bg"] += 1 + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="unexpected")] + ) + + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff", + classmethod(_fake_execute_handoff), + ) + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff_background", + classmethod(_fake_execute_handoff_bg), + ) + + tool, run_context = _build_tool_and_run_context() + + results = [] + async for result in FunctionToolExecutor.execute( + tool, + run_context, + input="hello", + background_task="not-a-bool", + ): + results.append(result) + + assert len(results) == 1 + assert isinstance(results[0], mcp.types.CallToolResult) + text_content = results[0].content[0] + assert isinstance(text_content, mcp.types.TextContent) + assert "invalid_background_task" in text_content.text + assert call_count["handoff"] == 0 + assert call_count["handoff_bg"] == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("background_value", [True, "true"]) +async def test_execute_truthy_background_task_routes_to_background_handoff( + monkeypatch, + background_value, +): + call_count = {"handoff": 0, "handoff_bg": 0} + + async def _fake_execute_handoff(cls, tool, run_context, **tool_args): + call_count["handoff"] += 1 + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="foreground")] + ) + + async def _fake_execute_handoff_bg(cls, tool, run_context, **tool_args): + call_count["handoff_bg"] += 1 + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="background")] + ) + + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff", + classmethod(_fake_execute_handoff), + ) + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff_background", + classmethod(_fake_execute_handoff_bg), + ) + + tool, run_context = _build_tool_and_run_context() + + results = [] + async for result in FunctionToolExecutor.execute( + tool, + run_context, + input="hello", + background_task=background_value, + ): + results.append(result) + + assert len(results) == 1 + assert call_count["handoff"] == 0 + assert call_count["handoff_bg"] == 1 + text_content = results[0].content[0] + assert isinstance(text_content, mcp.types.TextContent) + assert "invalid_background_task" not in text_content.text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("background_value", [False, "false", "0"]) +async def test_execute_falsey_background_task_routes_to_foreground_handoff( + monkeypatch, + background_value, +): + call_count = {"handoff": 0, "handoff_bg": 0} + + async def _fake_execute_handoff(cls, tool, run_context, **tool_args): + call_count["handoff"] += 1 + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="foreground")] + ) + + async def _fake_execute_handoff_bg(cls, tool, run_context, **tool_args): + call_count["handoff_bg"] += 1 + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="background")] + ) + + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff", + classmethod(_fake_execute_handoff), + ) + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff_background", + classmethod(_fake_execute_handoff_bg), + ) + + tool, run_context = _build_tool_and_run_context() + + results = [] + async for result in FunctionToolExecutor.execute( + tool, + run_context, + input="hello", + background_task=background_value, + ): + results.append(result) + + assert len(results) == 1 + assert call_count["handoff"] == 1 + assert call_count["handoff_bg"] == 0 + text_content = results[0].content[0] + assert isinstance(text_content, mcp.types.TextContent) + assert "invalid_background_task" not in text_content.text + + +@pytest.mark.asyncio +async def test_execute_handoff_rejects_empty_input_without_downstream_tool_loop(): + call_count = {"tool_loop": 0} + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_tool_loop_agent(**_kwargs): + call_count["tool_loop"] += 1 + return SimpleNamespace(completion_text="ok") + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + event = SimpleNamespace( + unified_msg_origin="webchat:FriendMessage:webchat!user!session", + message_obj=SimpleNamespace(message=[]), + ) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = HandoffTool(agent=Agent(name="subagent")) + + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input=" ", + image_urls=[], + ): + results.append(result) + + assert len(results) == 1 + assert isinstance(results[0], mcp.types.CallToolResult) + text_content = results[0].content[0] + assert isinstance(text_content, mcp.types.TextContent) + assert "missing_or_empty_input" in text_content.text + assert call_count["tool_loop"] == 0