From c568177bb723d086f68f7227ee4b2aed5cad03c1 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 28 Mar 2026 23:04:53 +0800 Subject: [PATCH] feat(agent-runner): add tool_choice parameter to fix empty tool calls response in "skills-like" tool call mode fixes: #7049 --- .../agent/runners/tool_loop_agent_runner.py | 118 +++++++++++++----- astrbot/core/provider/provider.py | 6 +- .../core/provider/sources/anthropic_source.py | 17 +++ .../core/provider/sources/gemini_source.py | 24 +++- .../core/provider/sources/openai_source.py | 8 ++ 5 files changed, 142 insertions(+), 31 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index cb410ecb02..cb857dc091 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -100,6 +100,32 @@ def _get_persona_custom_error_message(self) -> str | None: event = getattr(self.run_context.context, "event", None) return extract_persona_custom_error_message_from_event(event) + async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None: + """Finalize the current step as a plain assistant response with no tool calls.""" + self.final_llm_resp = llm_resp + self._transition_state(AgentState.DONE) + self.stats.end_time = time.time() + + parts = [] + if llm_resp.reasoning_content or llm_resp.reasoning_signature: + parts.append( + ThinkPart( + think=llm_resp.reasoning_content, + encrypted=llm_resp.reasoning_signature, + ) + ) + if llm_resp.completion_text: + parts.append(TextPart(text=llm_resp.completion_text)) + if len(parts) == 0: + logger.warning("LLM returned empty assistant message with no tool calls.") + self.run_context.messages.append(Message(role="assistant", content=parts)) + + try: + await self.agent_hooks.on_agent_done(self.run_context, llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + self._resolve_unconsumed_follow_ups() + @override async def reset( self, @@ -463,34 +489,7 @@ async def step(self): return if not llm_resp.tools_call_name: - # 如果没有工具调用,转换到完成状态 - self.final_llm_resp = llm_resp - self._transition_state(AgentState.DONE) - self.stats.end_time = time.time() - - # record the final assistant message - parts = [] - if llm_resp.reasoning_content or llm_resp.reasoning_signature: - parts.append( - ThinkPart( - think=llm_resp.reasoning_content, - encrypted=llm_resp.reasoning_signature, - ) - ) - if llm_resp.completion_text: - parts.append(TextPart(text=llm_resp.completion_text)) - if len(parts) == 0: - logger.warning( - "LLM returned empty assistant message with no tool calls." - ) - self.run_context.messages.append(Message(role="assistant", content=parts)) - - # call the on_agent_done hook - try: - await self.agent_hooks.on_agent_done(self.run_context, llm_resp) - except Exception as e: - logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) - self._resolve_unconsumed_follow_ups() + await self._complete_with_assistant_response(llm_resp) # 返回 LLM 结果 if llm_resp.result_chain: @@ -510,6 +509,24 @@ async def step(self): if llm_resp.tools_call_name: if self.tool_schema_mode == "skills_like": llm_resp, _ = await self._resolve_tool_exec(llm_resp) + if not llm_resp.tools_call_name: + logger.warning( + "skills_like tool re-query returned no tool calls; fallback to assistant response." + ) + if llm_resp.result_chain: + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=llm_resp.result_chain), + ) + elif llm_resp.completion_text: + yield AgentResponse( + type="llm_result", + data=AgentResponseData( + chain=MessageChain().message(llm_resp.completion_text), + ), + ) + await self._complete_with_assistant_response(llm_resp) + return tool_call_result_blocks = [] cached_images = [] # Collect cached images for LLM visibility @@ -873,7 +890,9 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: ) def _build_tool_requery_context( - self, tool_names: list[str] + self, + tool_names: list[str], + extra_instruction: str | None = None, ) -> list[dict[str, T.Any]]: """Build contexts for re-querying LLM with param-only tool schemas.""" contexts: list[dict[str, T.Any]] = [] @@ -888,6 +907,8 @@ def _build_tool_requery_context( + ". Now call the tool(s) with required arguments using the tool schema, " "and follow the existing tool-use rules." ) + if extra_instruction: + instruction = f"{instruction}\n{extra_instruction}" if contexts and contexts[0].get("role") == "system": content = contexts[0].get("content") or "" contexts[0]["content"] = f"{content}\n{instruction}" @@ -895,6 +916,11 @@ def _build_tool_requery_context( contexts.insert(0, {"role": "system", "content": instruction}) return contexts + @staticmethod + def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool: + text = (llm_resp.completion_text or "").strip() + return bool(text) + def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet: """Build a subset of tools from the given tool set based on tool names.""" subset = ToolSet() @@ -932,11 +958,45 @@ async def _resolve_tool_exec( model=self.req.model, session_id=self.req.session_id, extra_user_content_parts=self.req.extra_user_content_parts, + tool_choice="required", abort_signal=self._abort_signal, ) if requery_resp: llm_resp = requery_resp + # If the re-query still returns no tool calls, and also does not have a meaningful assistant reply, + # we consider it as a failure of the LLM to follow the tool-use instruction, + # and we will retry once with a stronger instruction that explicitly requires the LLM to either call the tool or give an explanation. + if ( + not llm_resp.tools_call_name + and not self._has_meaningful_assistant_reply(llm_resp) + ): + logger.warning( + "skills_like tool re-query returned no tool calls and no explanation; retrying with stronger instruction." + ) + repair_contexts = self._build_tool_requery_context( + tool_names, + extra_instruction=( + "This is the second-stage tool execution step. " + "You must do exactly one of the following: " + "1. Call one of the selected tools using the provided tool schema. " + "2. If calling a tool is no longer possible or appropriate, reply to the user with a brief explanation of why. " + "Do not return an empty response. " + "Do not ignore the selected tools without explanation." + ), + ) + repair_resp = await self.provider.text_chat( + contexts=repair_contexts, + func_tool=param_subset, + model=self.req.model, + session_id=self.req.session_id, + extra_user_content_parts=self.req.extra_user_content_parts, + tool_choice="required", + abort_signal=self._abort_signal, + ) + if repair_resp: + llm_resp = repair_resp + return llm_resp, subset def done(self) -> bool: diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 345ad7b743..fab3ce6104 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,7 @@ import asyncio import os from collections.abc import AsyncGenerator -from typing import TypeAlias, Union +from typing import Literal, TypeAlias, Union from astrbot.core.agent.message import ContentPart, Message from astrbot.core.agent.tool import ToolSet @@ -104,6 +104,7 @@ async def text_chat( tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, extra_user_content_parts: list[ContentPart] | None = None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -113,6 +114,7 @@ async def text_chat( session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: tool set + tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具 contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等) @@ -135,6 +137,7 @@ async def text_chat_stream( system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 @@ -144,6 +147,7 @@ async def text_chat_stream( session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: tool set + tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具 contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 203d0610ff..56600e2126 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,6 +1,7 @@ import base64 import json from collections.abc import AsyncGenerator +from typing import Literal import anthropic import httpx @@ -258,6 +259,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list + payloads["tool_choice"] = { + "type": "any" + if payloads.get("tool_choice") == "required" + else "auto" + } extra_body = self.provider_config.get("custom_extra_body", {}) @@ -334,6 +340,11 @@ async def _query_stream( if tools: if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list + payloads["tool_choice"] = { + "type": "any" + if payloads.get("tool_choice") == "required" + else "auto" + } # 用于累积工具调用信息 tool_use_buffer = {} @@ -483,6 +494,7 @@ async def text_chat( tool_calls_result=None, model=None, extra_user_content_parts=None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> LLMResponse: if contexts is None: @@ -516,6 +528,8 @@ async def text_chat( model = model or self.get_model() payloads = {"messages": new_messages, "model": model} + if func_tool and not func_tool.empty(): + payloads["tool_choice"] = tool_choice # Anthropic has a different way of handling system prompts if system_prompt: @@ -540,6 +554,7 @@ async def text_chat_stream( tool_calls_result=None, model=None, extra_user_content_parts=None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ): if contexts is None: @@ -572,6 +587,8 @@ async def text_chat_stream( model = model or self.get_model() payloads = {"messages": new_messages, "model": model} + if func_tool and not func_tool.empty(): + payloads["tool_choice"] = tool_choice # Anthropic has a different way of handling system prompts if system_prompt: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9557f3dbcd..19936a3c34 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,7 +4,7 @@ import logging import random from collections.abc import AsyncGenerator -from typing import cast +from typing import Literal, cast from google import genai from google.genai import types @@ -131,6 +131,7 @@ async def _prepare_query_config( self, payloads: dict, tools: ToolSet | None = None, + tool_choice: Literal["auto", "required"] = "auto", system_instruction: str | None = None, modalities: list[str] | None = None, temperature: float = 0.7, @@ -207,6 +208,18 @@ async def _prepare_query_config( types.Tool(function_declarations=func_desc["function_declarations"]), ] + tool_config = None + if tools and tool_list: + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=( + types.FunctionCallingConfigMode.ANY + if tool_choice == "required" + else types.FunctionCallingConfigMode.AUTO + ) + ) + ) + # oper thinking config thinking_config = None if model_name in [ @@ -272,6 +285,7 @@ async def _prepare_query_config( seed=payloads.get("seed"), response_modalities=modalities, tools=cast(types.ToolListUnion | None, tool_list), + tool_config=tool_config, safety_settings=self.safety_settings if self.safety_settings else None, thinking_config=thinking_config, automatic_function_calling=types.AutomaticFunctionCallingConfig( @@ -535,6 +549,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: config = await self._prepare_query_config( payloads, tools, + payloads.get("tool_choice", "auto"), system_instruction, modalities, temperature, @@ -616,6 +631,7 @@ async def _query_stream( config = await self._prepare_query_config( payloads, tools, + payloads.get("tool_choice", "auto"), system_instruction, ) result = await self.client.models.generate_content_stream( @@ -728,6 +744,7 @@ async def text_chat( tool_calls_result=None, model=None, extra_user_content_parts=None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> LLMResponse: if contexts is None: @@ -758,6 +775,8 @@ async def text_chat( model = model or self.get_model() payloads = {"messages": context_query, "model": model} + if func_tool and not func_tool.empty(): + payloads["tool_choice"] = tool_choice retry = 10 keys = self.api_keys.copy() @@ -783,6 +802,7 @@ async def text_chat_stream( tool_calls_result=None, model=None, extra_user_content_parts=None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> AsyncGenerator[LLMResponse, None]: if contexts is None: @@ -813,6 +833,8 @@ async def text_chat_stream( model = model or self.get_model() payloads = {"messages": context_query, "model": model} + if func_tool and not func_tool.empty(): + payloads["tool_choice"] = tool_choice retry = 10 keys = self.api_keys.copy() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index cdad66a22f..aa908bcc25 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -436,6 +436,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: ) if tool_list: payloads["tools"] = tool_list + payloads["tool_choice"] = payloads.get("tool_choice", "auto") # 不在默认参数中的参数放在 extra_body 中 extra_body = {} @@ -486,6 +487,7 @@ async def _query_stream( ) if tool_list: payloads["tools"] = tool_list + payloads["tool_choice"] = payloads.get("tool_choice", "auto") # 不在默认参数中的参数放在 extra_body 中 extra_body = {} @@ -965,6 +967,7 @@ async def text_chat( tool_calls_result=None, model=None, extra_user_content_parts=None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> LLMResponse: payloads, context_query = await self._prepare_chat_payload( @@ -977,6 +980,8 @@ async def text_chat( extra_user_content_parts=extra_user_content_parts, **kwargs, ) + if func_tool and not func_tool.empty(): + payloads["tool_choice"] = tool_choice llm_response = None max_retries = 10 @@ -1032,6 +1037,7 @@ async def text_chat_stream( system_prompt=None, tool_calls_result=None, model=None, + tool_choice: Literal["auto", "required"] = "auto", **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" @@ -1044,6 +1050,8 @@ async def text_chat_stream( model=model, **kwargs, ) + if func_tool and not func_tool.empty(): + payloads["tool_choice"] = tool_choice max_retries = 10 available_api_keys = self.api_keys.copy()