Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions astrbot/core/agent/handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(
def default_parameters(self) -> dict:
return {
"type": "object",
"required": ["input"],
"additionalProperties": False,
Comment on lines +42 to +43
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Making input required and disallowing additional properties conflicts with the background_task argument handling.

With input required and additionalProperties=False, any background_task key in tool_args is now invalid JSON Schema, even though FunctionToolExecutor.execute still reads it and _build_handoff_error_result shows it in the example. Either add background_task (with the correct type) to properties or relax additionalProperties to allow it; if it’s no longer supported, remove it from both the error example and the tool_args handling to keep schema and behavior aligned.

"properties": {
"input": {
"type": "string",
Expand Down
132 changes: 127 additions & 5 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,113 @@


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
def _build_handoff_error_result(
cls,
*,
tool_name: str,
error_type: str,
fix_hint: str,
action_hint: str,
) -> mcp.types.CallToolResult:
guidance = (
"[handoff CALL FAILED - IMMEDIATE RETRY REQUIRED]\n"
f"error_type: {error_type}\n"
f"fix: {fix_hint}\n"
f"action: {action_hint}\n"
"example:\n"
"{\n"
' "input": "Summarize the user request, constraints, and expected output.",\n'
' "background_task": false\n'
"}"
)
return mcp.types.CallToolResult(
content=[
mcp.types.TextContent(
type="text",
text=f"error: {tool_name} rejected invalid handoff request.\n{guidance}",
)
]
)

@classmethod
def _parse_background_task_arg(
cls,
tool_name: str,
value: T.Any,
) -> tuple[bool, mcp.types.CallToolResult | None]:
if value is None:
return False, None
if isinstance(value, bool):
return value, None
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes", "on"}:
return True, None
if normalized in {"false", "0", "no", "off", ""}:
return False, None

return False, cls._build_handoff_error_result(
tool_name=tool_name,
error_type="invalid_background_task",
fix_hint=(
"`background_task` must be a boolean (`true` or `false`) or a string "
'equivalent such as `"1"`/`"0"`, `"yes"`/`"no"`, or '
'`"on"`/`"off"`.'
),
action_hint=(
"Retry the same handoff with `background_task` set to a boolean or one "
'of the supported string equivalents (`"true"`, `"false"`, '
'`"1"`, `"0"`, `"yes"`, `"no"`, `"on"`, `"off"`).'
),
)

@classmethod
def _normalize_handoff_input(
cls,
tool_name: str,
input_value: T.Any,
) -> tuple[str | None, mcp.types.CallToolResult | None]:
if not isinstance(input_value, str) or not input_value.strip():
return None, cls._build_handoff_error_result(
tool_name=tool_name,
error_type="missing_or_empty_input",
fix_hint=(
"Provide a non-empty `input` string that clearly describes the delegated task."
),
action_hint=("Retry now with a concise task statement in `input`."),
)
return input_value.strip(), None

@classmethod
async def _resolve_handoff_provider_id(
cls,
tool: HandoffTool,
*,
ctx: T.Any,
umo: str,
) -> str:
configured_provider_id = str(getattr(tool, "provider_id", "") or "").strip()
if not configured_provider_id:
return await ctx.get_current_chat_provider_id(umo)

provider_mgr = getattr(ctx, "provider_manager", None)
if provider_mgr is None or not hasattr(provider_mgr, "get_provider_by_id"):
return configured_provider_id

provider_inst = await provider_mgr.get_provider_by_id(configured_provider_id)
if provider_inst is not None:
return configured_provider_id

fallback_provider_id = await ctx.get_current_chat_provider_id(umo)
logger.warning(
"Subagent %s configured provider `%s` not found, fallback to `%s`.",
tool.name,
configured_provider_id,
fallback_provider_id,
)
return fallback_provider_id

@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
if image_urls_raw is None:
Expand Down Expand Up @@ -130,7 +237,13 @@ async def execute(cls, tool, run_context, **tool_args):

"""
if isinstance(tool, HandoffTool):
is_bg = tool_args.pop("background_task", False)
is_bg, bg_error = cls._parse_background_task_arg(
tool.name,
tool_args.pop("background_task", None),
)
if bg_error is not None:
yield bg_error
return
if is_bg:
async for r in cls._execute_handoff_background(
tool, run_context, **tool_args
Expand Down Expand Up @@ -245,7 +358,14 @@ async def _execute_handoff(
**tool_args: T.Any,
):
tool_args = dict(tool_args)
input_ = tool_args.get("input")
input_, input_error = cls._normalize_handoff_input(
tool.name,
tool_args.get("input"),
)
if input_error is not None:
yield input_error
return
tool_args["input"] = input_
if image_urls_prepared:
prepared_image_urls = tool_args.get("image_urls")
if isinstance(prepared_image_urls, list):
Expand All @@ -272,9 +392,11 @@ async def _execute_handoff(

# Use per-subagent provider override if configured; otherwise fall back
# to the current/default provider resolution.
prov_id = getattr(
tool, "provider_id", None
) or await ctx.get_current_chat_provider_id(umo)
prov_id = await cls._resolve_handoff_provider_id(
tool,
ctx=ctx,
umo=umo,
)

# prepare begin dialogs
contexts = None
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,96 @@ async def _fake_convert_to_file_path(self):
)

assert image_urls == []


@pytest.mark.asyncio
async def test_execute_handoff_rejects_empty_input():
async def _fake_get_current_chat_provider_id(_umo):
return "provider-id"

async def _fake_tool_loop_agent(**_kwargs):
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
Comment on lines +348 to +357
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen empty-input test by asserting the handoff agent is never invoked

To fully validate the early-rejection behavior, also assert that _fake_tool_loop_agent is never invoked (e.g., via a flag or counter). This will confirm the error is raised before any downstream handoff execution occurs for empty input.

Suggested implementation:

@pytest.mark.asyncio
async def test_execute_handoff_rejects_empty_input():
    tool_loop_called = False

    async def _fake_get_current_chat_provider_id(_umo):
        return "provider-id"

    async def _fake_tool_loop_agent(**_kwargs):
        nonlocal tool_loop_called
        tool_loop_called = True
        return SimpleNamespace(completion_text="ok")

    context = SimpleNamespace(
        get_current_chat_provider_id=_fake_get_current_chat_provider_id,
        tool_loop_agent=_fake_tool_loop_agent,

To fully implement the strengthened test, you should also:

  1. After the code that triggers the empty-input rejection (likely the call to the handoff execution function and the corresponding pytest.raises block), add:
    assert tool_loop_called is False
  2. Ensure this assertion is placed at the end of test_execute_handoff_rejects_empty_input, after all other assertions that verify the rejection behavior.

tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {"provider_settings": {}},
)
event = _DummyEvent([])
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id=None,
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input=" ",
image_urls=[],
):
results.append(result)

assert len(results) == 1
assert isinstance(results[0], mcp.types.CallToolResult)
text_content = results[0].content[0]
assert isinstance(text_content, mcp.types.TextContent)
assert "missing_or_empty_input" in text_content.text


@pytest.mark.asyncio
async def test_execute_handoff_falls_back_to_current_provider_when_configured_missing():
captured: dict = {}

class _DummyProviderManager:
async def get_provider_by_id(self, _provider_id: str):
return None

async def _fake_get_current_chat_provider_id(_umo):
return "fallback-provider"

async def _fake_tool_loop_agent(**kwargs):
captured.update(kwargs)
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
provider_manager=_DummyProviderManager(),
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {"provider_settings": {}},
)
event = _DummyEvent([])
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id="missing-provider-id",
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input="hello",
image_urls=[],
):
results.append(result)

assert len(results) == 1
assert captured["chat_provider_id"] == "fallback-provider"
Loading