diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index eca24699ae..d21b8f0028 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -8,6 +8,7 @@ from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.config.default import DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD from astrbot.core.message.components import BaseMessageComponent, Json, Plain from astrbot.core.message.message_event_result import ( MessageChain, @@ -23,6 +24,23 @@ AgentRunner = ToolLoopAgentRunner[AstrAgentContext] +def normalize_repeat_reply_guard_threshold(value, *, invalid_fallback: int = 0) -> int: + if isinstance(value, bool): + return invalid_fallback + try: + parsed = int(value) + except (TypeError, ValueError): + return invalid_fallback + return max(0, parsed) + + +def normalize_config_repeat_reply_guard_threshold(value) -> int: + return normalize_repeat_reply_guard_threshold( + value, + invalid_fallback=DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + ) + + def _should_stop_agent(astr_event) -> bool: return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested")) @@ -87,6 +105,13 @@ def _build_tool_result_status_message( return status_msg +def _build_chain_signature(msg_chain: MessageChain) -> str: + signature = msg_chain.get_plain_text(with_other_comps_mark=True).strip() + if not signature: + return "" + return re.sub(r"\s+", " ", signature) + + async def run_agent( agent_runner: AgentRunner, max_step: int = 30, @@ -94,10 +119,16 @@ async def run_agent( show_tool_call_result: bool = False, stream_to_general: bool = False, show_reasoning: bool = False, + repeat_reply_guard_threshold: int = DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, ) -> AsyncGenerator[MessageChain | None, None]: step_idx = 0 astr_event = agent_runner.run_context.context.event tool_name_by_call_id: dict[str, str] = {} + guard_threshold = normalize_repeat_reply_guard_threshold( + repeat_reply_guard_threshold + ) + guard_last_signature = "" + guard_repeat_count = 0 while step_idx < max_step + 1: step_idx += 1 @@ -193,6 +224,39 @@ async def run_agent( await astr_event.send(chain) continue + if resp.type == "llm_result" and guard_threshold > 0: + chain_signature = _build_chain_signature(resp.data["chain"]) + if chain_signature: + if chain_signature == guard_last_signature: + guard_repeat_count += 1 + else: + guard_last_signature = chain_signature + guard_repeat_count = 1 + + if guard_repeat_count >= guard_threshold: + logger.warning( + "Agent repeated identical llm_result %d times; forcing convergence. threshold=%d", + guard_repeat_count, + guard_threshold, + ) + if not agent_runner.done(): + if agent_runner.req: + agent_runner.req.func_tool = None + agent_runner.run_context.messages.append( + Message( + role="user", + content=( + "You have repeated the same reply multiple times. " + "Stop repeating yourself, provide a final answer " + "based on the information you already have, and do " + "not call tools again." + ), + ) + ) + # Jump to the same convergence path as max-step limit. + step_idx = max_step + continue + if stream_to_general and resp.type == "streaming_delta": continue @@ -288,6 +352,7 @@ async def run_live_agent( show_tool_use: bool = True, show_tool_call_result: bool = False, show_reasoning: bool = False, + repeat_reply_guard_threshold: int = DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, ) -> AsyncGenerator[MessageChain | None, None]: """Live Mode 的 Agent 运行器,支持流式 TTS @@ -311,6 +376,7 @@ async def run_live_agent( show_tool_call_result=show_tool_call_result, stream_to_general=False, show_reasoning=show_reasoning, + repeat_reply_guard_threshold=repeat_reply_guard_threshold, ): yield chain return @@ -343,6 +409,7 @@ async def run_live_agent( show_tool_use, show_tool_call_result, show_reasoning, + repeat_reply_guard_threshold, ) ) @@ -430,6 +497,7 @@ async def _run_agent_feeder( show_tool_use: bool, show_tool_call_result: bool, show_reasoning: bool, + repeat_reply_guard_threshold: int, ) -> None: """运行 Agent 并将文本输出分句放入队列""" buffer = "" @@ -441,6 +509,7 @@ async def _run_agent_feeder( show_tool_call_result=show_tool_call_result, stream_to_general=False, show_reasoning=show_reasoning, + repeat_reply_guard_threshold=repeat_reply_guard_threshold, ): if chain is None: continue diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 0f43dbd06d..86371dc7f9 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -7,6 +7,7 @@ VERSION = "4.22.0" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") +DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD = 3 PERSONAL_WECHAT_CONFIG_METADATA = { "weixin_oc_base_url": { "description": "Base URL", @@ -149,6 +150,7 @@ "unsupported_streaming_strategy": "realtime_segmenting", "reachability_check": False, "max_agent_step": 30, + "repeat_reply_guard_threshold": DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, "tool_call_timeout": 120, "tool_schema_mode": "full", "llm_safety_mode": True, @@ -2685,6 +2687,9 @@ class ChatProviderTemplate(TypedDict): "max_agent_step": { "type": "int", }, + "repeat_reply_guard_threshold": { + "type": "int", + }, "tool_call_timeout": { "type": "int", }, @@ -3430,6 +3435,14 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.repeat_reply_guard_threshold": { + "description": "连续相同回复拦截阈值", + "type": "int", + "hint": "同一轮 Agent 运行中连续出现相同回复达到该次数时,将触发防循环收敛。设置为 0 可关闭。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.tool_call_timeout": { "description": "工具调用超时时间(秒)", "type": "int", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..8ef76b7ad7 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -8,6 +8,10 @@ from astrbot.core import logger from astrbot.core.agent.message import Message from astrbot.core.agent.response import AgentStats +from astrbot.core.astr_agent_run_util import ( + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + normalize_config_repeat_reply_guard_threshold, +) from astrbot.core.astr_main_agent import ( MainAgentBuildConfig, MainAgentBuildResult, @@ -64,6 +68,15 @@ async def initialize(self, ctx: PipelineContext) -> None: self.tool_schema_mode = "full" if isinstance(self.max_step, bool): # workaround: #2622 self.max_step = 30 + self.repeat_reply_guard_threshold: int = settings.get( + "repeat_reply_guard_threshold", + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + ) + self.repeat_reply_guard_threshold = ( + normalize_config_repeat_reply_guard_threshold( + self.repeat_reply_guard_threshold + ) + ) self.show_tool_use: bool = settings.get("show_tool_use_status", True) self.show_tool_call_result: bool = settings.get("show_tool_call_result", False) self.show_reasoning = settings.get("display_reasoning_text", False) @@ -274,6 +287,7 @@ async def process( self.show_tool_use, self.show_tool_call_result, show_reasoning=self.show_reasoning, + repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, ), ), ) @@ -304,6 +318,7 @@ async def process( self.show_tool_use, self.show_tool_call_result, show_reasoning=self.show_reasoning, + repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, ), ), ) @@ -334,6 +349,7 @@ async def process( self.show_tool_call_result, stream_to_general, show_reasoning=self.show_reasoning, + repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, ): yield diff --git a/tests/unit/test_astr_agent_run_util.py b/tests/unit/test_astr_agent_run_util.py new file mode 100644 index 0000000000..0ab0e776d5 --- /dev/null +++ b/tests/unit/test_astr_agent_run_util.py @@ -0,0 +1,149 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.core.astr_agent_run_util import run_agent +from astrbot.core.message.message_event_result import MessageChain + + +def _llm_result_response(text: str): + return SimpleNamespace( + type="llm_result", + data={"chain": MessageChain().message(text)}, + ) + + +class _DummyTrace: + def record(self, *args, **kwargs) -> None: + return None + + +class _DummyEvent: + def __init__(self) -> None: + self._extras: dict = {} + self._stopped = False + self.result_texts: list[str] = [] + self.trace = _DummyTrace() + + def is_stopped(self) -> bool: + return self._stopped + + def get_extra(self, key: str, default=None): + return self._extras.get(key, default) + + def set_extra(self, key: str, value) -> None: + self._extras[key] = value + + def set_result(self, result) -> None: + self.result_texts.append(result.get_plain_text(with_other_comps_mark=True)) + + def clear_result(self) -> None: + return None + + def get_platform_name(self) -> str: + return "slack" + + def get_platform_id(self) -> str: + return "slack" + + async def send(self, _msg_chain) -> None: + return None + + +class _FakeRunner: + def __init__(self, steps: list[list[SimpleNamespace]]) -> None: + self._steps = steps + self._step_idx = 0 + self._done = False + self.streaming = False + self.req = SimpleNamespace(func_tool=object()) + self.run_context = SimpleNamespace( + context=SimpleNamespace(event=_DummyEvent()), + messages=[], + ) + self.stats = SimpleNamespace(to_dict=lambda: {}) + + def done(self) -> bool: + return self._done + + def request_stop(self) -> None: + self.run_context.context.event.set_extra("agent_stop_requested", True) + + def was_aborted(self) -> bool: + return False + + async def step(self): + if self._step_idx >= len(self._steps): + self._done = True + return + + current = self._steps[self._step_idx] + self._step_idx += 1 + for resp in current: + yield resp + + if self._step_idx >= len(self._steps): + self._done = True + + +@pytest.mark.asyncio +async def test_repeat_reply_guard_forces_convergence(): + runner = _FakeRunner( + [ + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("最终答案")], + ] + ) + + async for _ in run_agent( + runner, + max_step=8, + show_tool_use=False, + show_tool_call_result=False, + repeat_reply_guard_threshold=3, + ): + pass + + assert runner.run_context.context.event.result_texts == [ + "重复输出", + "重复输出", + "最终答案", + ] + assert runner.req.func_tool is None + assert any( + msg.role == "user" + and "You have repeated the same reply multiple times." in str(msg.content) + for msg in runner.run_context.messages + ) + + +@pytest.mark.asyncio +async def test_repeat_reply_guard_can_be_disabled_with_zero_threshold(): + runner = _FakeRunner( + [ + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("最终答案")], + ] + ) + original_func_tool = runner.req.func_tool + + async for _ in run_agent( + runner, + max_step=8, + show_tool_use=False, + show_tool_call_result=False, + repeat_reply_guard_threshold=0, + ): + pass + + assert runner.run_context.context.event.result_texts == [ + "重复输出", + "重复输出", + "重复输出", + "最终答案", + ] + assert runner.req.func_tool is original_func_tool diff --git a/tests/unit/test_repeat_reply_guard.py b/tests/unit/test_repeat_reply_guard.py new file mode 100644 index 0000000000..91ade0651f --- /dev/null +++ b/tests/unit/test_repeat_reply_guard.py @@ -0,0 +1,47 @@ +import inspect + +from astrbot.core.astr_agent_run_util import ( + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + normalize_config_repeat_reply_guard_threshold, + normalize_repeat_reply_guard_threshold, + run_agent, + run_live_agent, +) +from astrbot.core.config.default import DEFAULT_CONFIG + + +def test_runtime_repeat_reply_guard_threshold_normalization(): + assert normalize_repeat_reply_guard_threshold("2") == 2 + assert normalize_repeat_reply_guard_threshold(-1) == 0 + assert normalize_repeat_reply_guard_threshold(None) == 0 + assert normalize_repeat_reply_guard_threshold(True) == 0 + + +def test_config_repeat_reply_guard_threshold_normalization(): + assert normalize_config_repeat_reply_guard_threshold("4") == 4 + assert normalize_config_repeat_reply_guard_threshold(-1) == 0 + assert ( + normalize_config_repeat_reply_guard_threshold(None) + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + assert ( + normalize_config_repeat_reply_guard_threshold(True) + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + + +def test_repeat_reply_guard_default_is_shared(): + assert ( + DEFAULT_CONFIG["provider_settings"]["repeat_reply_guard_threshold"] + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + assert ( + inspect.signature(run_agent).parameters["repeat_reply_guard_threshold"].default + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + assert ( + inspect.signature(run_live_agent) + .parameters["repeat_reply_guard_threshold"] + .default + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + )