diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index cb857dc091..3fb487cbe6 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -16,11 +16,18 @@ TextContent, TextResourceContents, ) +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from astrbot import logger from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool_image_cache import tool_image_cache +from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.message.components import Json from astrbot.core.message.message_event_result import ( MessageChain, @@ -95,6 +102,10 @@ class _ToolExecutionInterrupted(Exception): class ToolLoopAgentRunner(BaseAgentRunner[TContext]): + EMPTY_OUTPUT_RETRY_ATTEMPTS = 3 + EMPTY_OUTPUT_RETRY_WAIT_MIN_S = 1 + EMPTY_OUTPUT_RETRY_WAIT_MAX_S = 4 + def _get_persona_custom_error_message(self) -> str | None: """Read persona-level custom error message from event extras when available.""" event = getattr(self.run_context.context, "event", None) @@ -279,31 +290,61 @@ async def _iter_llm_responses_with_fallback( candidate_id, ) self.provider = candidate - has_stream_output = False try: - async for resp in self._iter_llm_responses(include_model=idx == 0): - if resp.is_chunk: - has_stream_output = True - yield resp - continue - - if ( - resp.role == "err" - and not has_stream_output - and (not is_last_candidate) - ): - last_err_response = resp - logger.warning( - "Chat Model %s returns error response, trying fallback to next provider.", - candidate_id, - ) - break - - yield resp - return + retrying = AsyncRetrying( + retry=retry_if_exception_type(EmptyModelOutputError), + stop=stop_after_attempt(self.EMPTY_OUTPUT_RETRY_ATTEMPTS), + wait=wait_exponential( + multiplier=1, + min=self.EMPTY_OUTPUT_RETRY_WAIT_MIN_S, + max=self.EMPTY_OUTPUT_RETRY_WAIT_MAX_S, + ), + reraise=True, + ) - if has_stream_output: - return + async for attempt in retrying: + has_stream_output = False + with attempt: + try: + async for resp in self._iter_llm_responses( + include_model=idx == 0 + ): + if resp.is_chunk: + has_stream_output = True + yield resp + continue + + if ( + resp.role == "err" + and not has_stream_output + and (not is_last_candidate) + ): + last_err_response = resp + logger.warning( + "Chat Model %s returns error response, trying fallback to next provider.", + candidate_id, + ) + break + + yield resp + return + + if has_stream_output: + return + except EmptyModelOutputError: + if has_stream_output: + logger.warning( + "Chat Model %s returned empty output after streaming started; skipping empty-output retry.", + candidate_id, + ) + else: + logger.warning( + "Chat Model %s returned empty output on attempt %s/%s.", + candidate_id, + attempt.retry_state.attempt_number, + self.EMPTY_OUTPUT_RETRY_ATTEMPTS, + ) + raise except Exception as exc: # noqa: BLE001 last_exception = exc logger.warning( diff --git a/astrbot/core/exceptions.py b/astrbot/core/exceptions.py index e637d4930f..f10af57ea8 100644 --- a/astrbot/core/exceptions.py +++ b/astrbot/core/exceptions.py @@ -7,3 +7,7 @@ class AstrBotError(Exception): class ProviderNotFoundError(AstrBotError): """Raised when a specified provider is not found.""" + + +class EmptyModelOutputError(AstrBotError): + """Raised when the model response contains no usable assistant output.""" diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 56600e2126..89c33021ae 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -13,6 +13,7 @@ from astrbot import logger from astrbot.api.provider import Provider from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart +from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url @@ -29,6 +30,23 @@ "Anthropic Claude API 提供商适配器", ) class ProviderAnthropic(Provider): + @staticmethod + def _ensure_usable_response( + llm_response: LLMResponse, + *, + completion_id: str | None = None, + stop_reason: str | None = None, + ) -> None: + has_text_output = bool((llm_response.completion_text or "").strip()) + has_reasoning_output = bool(llm_response.reasoning_content.strip()) + has_tool_output = bool(llm_response.tools_call_args) + if has_text_output or has_reasoning_output or has_tool_output: + return + raise EmptyModelOutputError( + "Anthropic completion has no usable output. " + f"completion_id={completion_id}, stop_reason={stop_reason}" + ) + @staticmethod def _normalize_custom_headers(provider_config: dict) -> dict[str, str] | None: custom_headers = provider_config.get("custom_headers", {}) @@ -289,7 +307,9 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: logger.debug(f"completion: {completion}") if len(completion.content) == 0: - raise Exception("API 返回的 completion 为空。") + raise EmptyModelOutputError( + f"Anthropic completion is empty. completion_id={completion.id}" + ) llm_response = LLMResponse(role="assistant") @@ -317,10 +337,9 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if not llm_response.completion_text and not llm_response.tools_call_args: # Guard clause: raise early if no valid content at all if not llm_response.reasoning_content: - raise ValueError( - f"Anthropic API returned unparsable completion: " - f"no text, tool_use, or thinking content found. " - f"Completion: {completion}" + raise EmptyModelOutputError( + "Anthropic completion has no usable output. " + f"completion_id={completion.id}, stop_reason={completion.stop_reason}" ) # We have reasoning content (ThinkingBlock) - this is valid @@ -330,6 +349,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: ) llm_response.completion_text = "" # Ensure empty string, not None + self._ensure_usable_response( + llm_response, + completion_id=completion.id, + stop_reason=completion.stop_reason, + ) return llm_response async def _query_stream( @@ -481,6 +505,11 @@ async def _query_stream( final_response.tools_call_name = [call["name"] for call in final_tool_calls] final_response.tools_call_ids = [call["id"] for call in final_tool_calls] + self._ensure_usable_response( + final_response, + completion_id=id, + stop_reason=None, + ) yield final_response async def text_chat( diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 19936a3c34..388ab5c91f 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -14,6 +14,7 @@ from astrbot import logger from astrbot.api.provider import Provider from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart +from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.func_tool_manager import ToolSet @@ -444,6 +445,23 @@ def _extract_usage( output=usage_metadata.candidates_token_count or 0, ) + @staticmethod + def _ensure_usable_response( + llm_response: LLMResponse, + *, + response_id: str | None = None, + finish_reason: str | None = None, + ) -> None: + has_text_output = bool((llm_response.completion_text or "").strip()) + has_reasoning_output = bool(llm_response.reasoning_content.strip()) + has_tool_output = bool(llm_response.tools_call_args) + if has_text_output or has_reasoning_output or has_tool_output: + return + raise EmptyModelOutputError( + "Gemini completion has no usable output. " + f"response_id={response_id}, finish_reason={finish_reason}" + ) + def _process_content_parts( self, candidate: types.Candidate, @@ -452,7 +470,10 @@ def _process_content_parts( """处理内容部分并构建消息链""" if not candidate.content: logger.warning(f"收到的 candidate.content 为空: {candidate}") - raise Exception("API 返回的 candidate.content 为空。") + raise EmptyModelOutputError( + "Gemini candidate content is empty. " + f"finish_reason={candidate.finish_reason}" + ) finish_reason = candidate.finish_reason result_parts: list[types.Part] | None = candidate.content.parts @@ -474,7 +495,10 @@ def _process_content_parts( if not result_parts: logger.warning(f"收到的 candidate.content.parts 为空: {candidate}") - raise Exception("API 返回的 candidate.content.parts 为空。") + raise EmptyModelOutputError( + "Gemini candidate content parts are empty. " + f"finish_reason={candidate.finish_reason}" + ) # 提取 reasoning content reasoning = self._extract_reasoning_content(candidate) @@ -525,7 +549,14 @@ def _process_content_parts( if ts := part.thought_signature: # only keep the last thinking signature llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8") - return MessageChain(chain=chain) + chain_result = MessageChain(chain=chain) + llm_response.result_chain = chain_result + self._ensure_usable_response( + llm_response, + response_id=None, + finish_reason=str(finish_reason) if finish_reason is not None else None, + ) + return chain_result async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: """非流式请求 Gemini API""" @@ -727,9 +758,12 @@ async def _query_stream( final_response.result_chain = MessageChain( chain=[Comp.Plain(accumulated_text)], ) - elif not final_response.result_chain: - # If no text was accumulated and no final response was set, provide empty space - final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")]) + + self._ensure_usable_response( + final_response, + response_id=getattr(final_response, "id", None), + finish_reason=None, + ) yield final_response diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index aa908bcc25..39bfc69dbf 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -26,6 +26,7 @@ from astrbot.api.provider import Provider from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart from astrbot.core.agent.tool import ToolSet +from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult from astrbot.core.utils.io import download_image_by_url @@ -696,7 +697,9 @@ async def _parse_openai_completion( llm_response = LLMResponse("assistant") if not completion.choices: - raise Exception("API 返回的 completion 为空。") + raise EmptyModelOutputError( + f"OpenAI completion has no choices. response_id={completion.id}" + ) choice = completion.choices[0] # parse the text completion @@ -714,6 +717,10 @@ async def _parse_openai_completion( # Also clean up orphan tags that may leak from some models completion_text = re.sub(r"\s*$", "", completion_text).strip() llm_response.result_chain = MessageChain().message(completion_text) + elif refusal := getattr(choice.message, "refusal", None): + refusal_text = self._normalize_content(refusal) + if refusal_text: + llm_response.result_chain = MessageChain().message(refusal_text) # parse the reasoning content if any # the priority is higher than the tag extraction @@ -761,9 +768,18 @@ async def _parse_openai_completion( raise Exception( "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。", ) - if llm_response.completion_text is None and not llm_response.tools_call_args: - logger.error(f"API 返回的 completion 无法解析:{completion}。") - raise Exception(f"API 返回的 completion 无法解析:{completion}。") + has_text_output = bool((llm_response.completion_text or "").strip()) + has_reasoning_output = bool(llm_response.reasoning_content.strip()) + if ( + not has_text_output + and not has_reasoning_output + and not llm_response.tools_call_args + ): + logger.error(f"OpenAI completion has no usable output: {completion}.") + raise EmptyModelOutputError( + "OpenAI completion has no usable output. " + f"response_id={completion.id}, finish_reason={choice.finish_reason}" + ) llm_response.raw_completion = completion llm_response.id = completion.id diff --git a/tests/test_anthropic_kimi_code_provider.py b/tests/test_anthropic_kimi_code_provider.py index ea22e376c3..a46953f22c 100644 --- a/tests/test_anthropic_kimi_code_provider.py +++ b/tests/test_anthropic_kimi_code_provider.py @@ -1,7 +1,10 @@ import httpx +import pytest import astrbot.core.provider.sources.anthropic_source as anthropic_source import astrbot.core.provider.sources.kimi_code_source as kimi_code_source +from astrbot.core.exceptions import EmptyModelOutputError +from astrbot.core.provider.entities import LLMResponse class _FakeAsyncAnthropic: @@ -79,3 +82,14 @@ def test_kimi_code_provider_restores_required_user_agent_when_blank(monkeypatch) assert provider.custom_headers == { "User-Agent": kimi_code_source.KIMI_CODE_USER_AGENT, } + + +def test_anthropic_empty_output_raises_empty_model_output_error(): + llm_response = LLMResponse(role="assistant") + + with pytest.raises(EmptyModelOutputError): + anthropic_source.ProviderAnthropic._ensure_usable_response( + llm_response, + completion_id="msg_empty", + stop_reason="end_turn", + ) diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py new file mode 100644 index 0000000000..4db8e92bfe --- /dev/null +++ b/tests/test_gemini_source.py @@ -0,0 +1,29 @@ +import pytest + +from astrbot.core.exceptions import EmptyModelOutputError +from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI + + +def test_gemini_empty_output_raises_empty_model_output_error(): + llm_response = LLMResponse(role="assistant") + + with pytest.raises(EmptyModelOutputError): + ProviderGoogleGenAI._ensure_usable_response( + llm_response, + response_id="resp_empty", + finish_reason="STOP", + ) + + +def test_gemini_reasoning_only_output_is_allowed(): + llm_response = LLMResponse( + role="assistant", + reasoning_content="chain of thought placeholder", + ) + + ProviderGoogleGenAI._ensure_usable_response( + llm_response, + response_id="resp_reasoning", + finish_reason="STOP", + ) diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 2454c09c1d..39bb6d3810 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -4,6 +4,7 @@ from openai.types.chat.chat_completion import ChatCompletion from PIL import Image as PILImage +from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.sources.groq_source import ProviderGroq from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial @@ -1136,3 +1137,39 @@ async def fake_create(**kwargs): assert extra_body["temperature"] == 0.1 finally: await provider.terminate() + + +@pytest.mark.asyncio +async def test_parse_openai_completion_raises_empty_model_output_error(): + provider = _make_provider() + try: + completion = ChatCompletion.model_validate( + { + "id": "chatcmpl-empty", + "object": "chat.completion", + "created": 0, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "refusal": None, + "tool_calls": None, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 0, + "total_tokens": 1, + }, + } + ) + + with pytest.raises(EmptyModelOutputError): + await provider._parse_openai_completion(completion, tools=None) + finally: + await provider.terminate() diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 2a27be4c1d..13ea6dac7b 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -17,6 +17,7 @@ from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor +from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage from astrbot.core.provider.provider import Provider @@ -134,6 +135,22 @@ async def text_chat(self, **kwargs) -> LLMResponse: ) +class MockEmptyOutputThenSuccessProvider(MockProvider): + def __init__(self, failures_before_success: int = 1): + super().__init__() + self.failures_before_success = failures_before_success + + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + if self.call_count <= self.failures_before_success: + raise EmptyModelOutputError("model returned no usable output") + return LLMResponse( + role="assistant", + completion_text="这是重试后的最终回答", + usage=TokenUsage(input_other=10, output=5), + ) + + class MockAbortableStreamProvider(MockProvider): async def text_chat_stream(self, **kwargs): abort_signal = kwargs.get("abort_signal") @@ -579,6 +596,67 @@ async def test_fallback_provider_used_when_primary_returns_err( assert fallback_provider.call_count == 1 +@pytest.mark.asyncio +async def test_empty_output_is_retried_before_succeeding( + runner, provider_request, mock_tool_executor, mock_hooks, monkeypatch +): + monkeypatch.setattr(runner, "EMPTY_OUTPUT_RETRY_WAIT_MIN_S", 0) + monkeypatch.setattr(runner, "EMPTY_OUTPUT_RETRY_WAIT_MAX_S", 0) + + provider = MockEmptyOutputThenSuccessProvider(failures_before_success=1) + await runner.reset( + provider=provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + async for _ in runner.step_until_done(5): + pass + + final_resp = runner.get_final_llm_resp() + assert final_resp is not None + assert final_resp.role == "assistant" + assert final_resp.completion_text == "这是重试后的最终回答" + assert provider.call_count == 2 + + +@pytest.mark.asyncio +async def test_empty_output_retries_exhausted_then_uses_fallback_provider( + runner, provider_request, mock_tool_executor, mock_hooks, monkeypatch +): + monkeypatch.setattr(runner, "EMPTY_OUTPUT_RETRY_WAIT_MIN_S", 0) + monkeypatch.setattr(runner, "EMPTY_OUTPUT_RETRY_WAIT_MAX_S", 0) + + primary_provider = MockEmptyOutputThenSuccessProvider( + failures_before_success=runner.EMPTY_OUTPUT_RETRY_ATTEMPTS + ) + fallback_provider = MockProvider() + fallback_provider.should_call_tools = False + + await runner.reset( + provider=primary_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + fallback_providers=[fallback_provider], + ) + + async for _ in runner.step_until_done(5): + pass + + final_resp = runner.get_final_llm_resp() + assert final_resp is not None + assert final_resp.role == "assistant" + assert final_resp.completion_text == "这是我的最终回答" + assert primary_provider.call_count == runner.EMPTY_OUTPUT_RETRY_ATTEMPTS + assert fallback_provider.call_count == 1 + + @pytest.mark.asyncio async def test_stop_signal_returns_aborted_and_persists_partial_message( runner, provider_request, mock_tool_executor, mock_hooks