Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 89 additions & 29 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,32 @@ def _get_persona_custom_error_message(self) -> str | None:
event = getattr(self.run_context.context, "event", None)
return extract_persona_custom_error_message_from_event(event)

async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None:
"""Finalize the current step as a plain assistant response with no tool calls."""
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()

parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if len(parts) == 0:
logger.warning("LLM returned empty assistant message with no tool calls.")
self.run_context.messages.append(Message(role="assistant", content=parts))

try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
self._resolve_unconsumed_follow_ups()

@override
async def reset(
self,
Expand Down Expand Up @@ -463,34 +489,7 @@ async def step(self):
return

if not llm_resp.tools_call_name:
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()

# record the final assistant message
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if len(parts) == 0:
logger.warning(
"LLM returned empty assistant message with no tool calls."
)
self.run_context.messages.append(Message(role="assistant", content=parts))

# call the on_agent_done hook
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
self._resolve_unconsumed_follow_ups()
await self._complete_with_assistant_response(llm_resp)

# 返回 LLM 结果
if llm_resp.result_chain:
Expand All @@ -510,6 +509,24 @@ async def step(self):
if llm_resp.tools_call_name:
if self.tool_schema_mode == "skills_like":
llm_resp, _ = await self._resolve_tool_exec(llm_resp)
if not llm_resp.tools_call_name:
logger.warning(
"skills_like tool re-query returned no tool calls; fallback to assistant response."
)
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text),
),
)
await self._complete_with_assistant_response(llm_resp)
return

tool_call_result_blocks = []
cached_images = [] # Collect cached images for LLM visibility
Expand Down Expand Up @@ -873,7 +890,9 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
)

def _build_tool_requery_context(
self, tool_names: list[str]
self,
tool_names: list[str],
extra_instruction: str | None = None,
) -> list[dict[str, T.Any]]:
"""Build contexts for re-querying LLM with param-only tool schemas."""
contexts: list[dict[str, T.Any]] = []
Expand All @@ -888,13 +907,20 @@ def _build_tool_requery_context(
+ ". Now call the tool(s) with required arguments using the tool schema, "
"and follow the existing tool-use rules."
)
if extra_instruction:
instruction = f"{instruction}\n{extra_instruction}"
if contexts and contexts[0].get("role") == "system":
content = contexts[0].get("content") or ""
contexts[0]["content"] = f"{content}\n{instruction}"
else:
contexts.insert(0, {"role": "system", "content": instruction})
return contexts

@staticmethod
def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool:
text = (llm_resp.completion_text or "").strip()
return bool(text)
Comment on lines +919 to +922
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: _has_meaningful_assistant_reply only looks at completion_text, which can misclassify valid replies.

The current implementation only inspects completion_text:

@staticmethod
def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool:
    text = (llm_resp.completion_text or "").strip()
    return bool(text)

If a provider returns the assistant reply in another field (e.g. result_chain or message parts) and leaves completion_text empty, this will incorrectly treat a valid reply as missing and trigger an unnecessary second repair attempt. If LLMResponse guarantees that any assistant NL reply is always mirrored to completion_text, this is fine; otherwise, this helper should also consider other fields that can contain assistant text before deciding there was “no explanation.”

Suggested implementation:

    @staticmethod
    def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool:
        """
        Return True if the LLMResponse appears to contain a non‑empty natural‑language
        assistant reply in any of the known fields, not just `completion_text`.

        This is intentionally conservative: it only returns False when we are reasonably
        sure there is no assistant text at all, to avoid triggering unnecessary repair
        attempts when providers populate alternative fields.
        """
        # Primary / normalized path: most providers should mirror assistant text here.
        text = (getattr(llm_resp, "completion_text", None) or "").strip()
        if text:
            return True

        # Fallback 1: some providers may populate `result_chain` instead of `completion_text`.
        result_chain = getattr(llm_resp, "result_chain", None)
        if result_chain:
            # If it's already a string, check it directly.
            if isinstance(result_chain, str):
                if result_chain.strip():
                    return True
            # If it's a list/sequence of "messages" or nodes, look for any non‑empty content.
            elif isinstance(result_chain, (list, tuple)):
                for node in result_chain:
                    content = ""

                    # Common patterns: dicts with "content" or "parts".
                    if isinstance(node, dict):
                        if "content" in node:
                            content = node.get("content") or ""
                        elif "parts" in node and isinstance(node["parts"], list):
                            # Join parts that are simple strings.
                            str_parts = [
                                p for p in node["parts"] if isinstance(p, str)
                            ]
                            content = " ".join(str_parts)
                    # Or objects with a `content` attribute.
                    elif hasattr(node, "content"):
                        content = getattr(node, "content") or ""

                    if isinstance(content, str) and content.strip():
                        return True

        # Fallback 2: generic `message` / `messages` fields if present.
        message = getattr(llm_resp, "message", None)
        if isinstance(message, str) and message.strip():
            return True

        messages = getattr(llm_resp, "messages", None)
        if isinstance(messages, (list, tuple)):
            for msg in messages:
                content = ""
                if isinstance(msg, dict):
                    content = msg.get("content") or ""
                elif hasattr(msg, "content"):
                    content = getattr(msg, "content") or ""
                if isinstance(content, str) and content.strip():
                    return True

        # If none of the known fields contain text, treat it as no meaningful reply.
        return False

This implementation assumes LLMResponse may expose assistant content via result_chain, message, or messages fields, and that these follow typical "message with content" conventions. If your actual LLMResponse type uses different field names or shapes (e.g., output_messages, assistant_message, etc.), you should:

  1. Update _has_meaningful_assistant_reply to explicitly inspect those real fields.
  2. Optionally, centralize the “extract assistant text” logic in a helper on LLMResponse (e.g., a get_assistant_text() method) and have _has_meaningful_assistant_reply call that instead, to keep this runner decoupled from provider‑specific details.


def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet:
"""Build a subset of tools from the given tool set based on tool names."""
subset = ToolSet()
Expand Down Expand Up @@ -932,11 +958,45 @@ async def _resolve_tool_exec(
model=self.req.model,
session_id=self.req.session_id,
extra_user_content_parts=self.req.extra_user_content_parts,
tool_choice="required",
abort_signal=self._abort_signal,
)
if requery_resp:
llm_resp = requery_resp

# If the re-query still returns no tool calls, and also does not have a meaningful assistant reply,
# we consider it as a failure of the LLM to follow the tool-use instruction,
# and we will retry once with a stronger instruction that explicitly requires the LLM to either call the tool or give an explanation.
if (
not llm_resp.tools_call_name
and not self._has_meaningful_assistant_reply(llm_resp)
):
logger.warning(
"skills_like tool re-query returned no tool calls and no explanation; retrying with stronger instruction."
)
repair_contexts = self._build_tool_requery_context(
tool_names,
extra_instruction=(
"This is the second-stage tool execution step. "
"You must do exactly one of the following: "
"1. Call one of the selected tools using the provided tool schema. "
"2. If calling a tool is no longer possible or appropriate, reply to the user with a brief explanation of why. "
"Do not return an empty response. "
"Do not ignore the selected tools without explanation."
),
)
repair_resp = await self.provider.text_chat(
contexts=repair_contexts,
func_tool=param_subset,
model=self.req.model,
session_id=self.req.session_id,
extra_user_content_parts=self.req.extra_user_content_parts,
tool_choice="required",
abort_signal=self._abort_signal,
)
Comment on lines +988 to +996
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Second-stage repair still enforces tool_choice="required", which conflicts with allowing a no-tool explanation.

In the second-stage repair call you instruct the model that it may either call a tool or return an explanation, but you still pass tool_choice="required":

repair_resp = await self.provider.text_chat(
    contexts=repair_contexts,
    func_tool=param_subset,
    model=self.req.model,
    session_id=self.req.session_id,
    extra_user_content_parts=self.req.extra_user_content_parts,
    tool_choice="required",
    abort_signal=self._abort_signal,
)

For Gemini/Anthropic this maps to ANY/any and for OpenAI to required, all of which bias strongly toward always calling a tool. This conflicts with the intent to allow a no-tool explanation and may weaken the fallback behavior. Consider using tool_choice="auto" (or omitting it) here so the API constraint matches the instruction that a pure explanation is allowed.

Suggested change
repair_resp = await self.provider.text_chat(
contexts=repair_contexts,
func_tool=param_subset,
model=self.req.model,
session_id=self.req.session_id,
extra_user_content_parts=self.req.extra_user_content_parts,
tool_choice="required",
abort_signal=self._abort_signal,
)
repair_resp = await self.provider.text_chat(
contexts=repair_contexts,
func_tool=param_subset,
model=self.req.model,
session_id=self.req.session_id,
extra_user_content_parts=self.req.extra_user_content_parts,
tool_choice="auto",
abort_signal=self._abort_signal,
)

if repair_resp:
llm_resp = repair_resp

return llm_resp, subset

def done(self) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from typing import Literal, TypeAlias, Union

from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
Expand Down Expand Up @@ -104,6 +104,7 @@ async def text_chat(
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
tool_choice: Literal["auto", "required"] = "auto",
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Expand All @@ -113,6 +114,7 @@ async def text_chat(
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: tool set
tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
Expand All @@ -135,6 +137,7 @@ async def text_chat_stream(
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
tool_choice: Literal["auto", "required"] = "auto",
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Expand All @@ -144,6 +147,7 @@ async def text_chat_stream(
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: tool set
tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Expand Down
17 changes: 17 additions & 0 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import json
from collections.abc import AsyncGenerator
from typing import Literal

import anthropic
import httpx
Expand Down Expand Up @@ -258,6 +259,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
payloads["tool_choice"] = {
"type": "any"
if payloads.get("tool_choice") == "required"
else "auto"
}

extra_body = self.provider_config.get("custom_extra_body", {})

Expand Down Expand Up @@ -334,6 +340,11 @@ async def _query_stream(
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
payloads["tool_choice"] = {
"type": "any"
if payloads.get("tool_choice") == "required"
else "auto"
}

# 用于累积工具调用信息
tool_use_buffer = {}
Expand Down Expand Up @@ -483,6 +494,7 @@ async def text_chat(
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
**kwargs,
) -> LLMResponse:
if contexts is None:
Expand Down Expand Up @@ -516,6 +528,8 @@ async def text_chat(
model = model or self.get_model()

payloads = {"messages": new_messages, "model": model}
if func_tool and not func_tool.empty():
payloads["tool_choice"] = tool_choice

# Anthropic has a different way of handling system prompts
if system_prompt:
Expand All @@ -540,6 +554,7 @@ async def text_chat_stream(
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
**kwargs,
):
if contexts is None:
Expand Down Expand Up @@ -572,6 +587,8 @@ async def text_chat_stream(
model = model or self.get_model()

payloads = {"messages": new_messages, "model": model}
if func_tool and not func_tool.empty():
payloads["tool_choice"] = tool_choice

# Anthropic has a different way of handling system prompts
if system_prompt:
Expand Down
24 changes: 23 additions & 1 deletion astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import random
from collections.abc import AsyncGenerator
from typing import cast
from typing import Literal, cast

from google import genai
from google.genai import types
Expand Down Expand Up @@ -131,6 +131,7 @@ async def _prepare_query_config(
self,
payloads: dict,
tools: ToolSet | None = None,
tool_choice: Literal["auto", "required"] = "auto",
system_instruction: str | None = None,
modalities: list[str] | None = None,
temperature: float = 0.7,
Expand Down Expand Up @@ -207,6 +208,18 @@ async def _prepare_query_config(
types.Tool(function_declarations=func_desc["function_declarations"]),
]

tool_config = None
if tools and tool_list:
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=(
types.FunctionCallingConfigMode.ANY
if tool_choice == "required"
else types.FunctionCallingConfigMode.AUTO
)
)
)

# oper thinking config
thinking_config = None
if model_name in [
Expand Down Expand Up @@ -272,6 +285,7 @@ async def _prepare_query_config(
seed=payloads.get("seed"),
response_modalities=modalities,
tools=cast(types.ToolListUnion | None, tool_list),
tool_config=tool_config,
safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=thinking_config,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
Expand Down Expand Up @@ -535,6 +549,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
config = await self._prepare_query_config(
payloads,
tools,
payloads.get("tool_choice", "auto"),
system_instruction,
modalities,
temperature,
Expand Down Expand Up @@ -616,6 +631,7 @@ async def _query_stream(
config = await self._prepare_query_config(
payloads,
tools,
payloads.get("tool_choice", "auto"),
system_instruction,
)
result = await self.client.models.generate_content_stream(
Expand Down Expand Up @@ -728,6 +744,7 @@ async def text_chat(
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
**kwargs,
) -> LLMResponse:
if contexts is None:
Expand Down Expand Up @@ -758,6 +775,8 @@ async def text_chat(
model = model or self.get_model()

payloads = {"messages": context_query, "model": model}
if func_tool and not func_tool.empty():
payloads["tool_choice"] = tool_choice

retry = 10
keys = self.api_keys.copy()
Expand All @@ -783,6 +802,7 @@ async def text_chat_stream(
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
Expand Down Expand Up @@ -813,6 +833,8 @@ async def text_chat_stream(
model = model or self.get_model()

payloads = {"messages": context_query, "model": model}
if func_tool and not func_tool.empty():
payloads["tool_choice"] = tool_choice

retry = 10
keys = self.api_keys.copy()
Expand Down
Loading
Loading