From e1b373ec93a31cf8d1fb34d2b3908bbd3eb7df35 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Wed, 25 Mar 2026 00:53:39 -0500 Subject: [PATCH 1/3] fix(agent): add configurable repeated-reply convergence guard --- astrbot/core/astr_agent_run_util.py | 58 +++++++ astrbot/core/config/default.py | 12 ++ .../method/agent_sub_stages/internal.py | 14 ++ tests/unit/test_astr_agent_run_util.py | 148 ++++++++++++++++++ 4 files changed, 232 insertions(+) create mode 100644 tests/unit/test_astr_agent_run_util.py diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index eca24699ae..d239aed1ab 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -87,6 +87,21 @@ def _build_tool_result_status_message( return status_msg +def _normalize_repeat_reply_guard_threshold(value: int) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + return 0 + return max(0, parsed) + + +def _build_chain_signature(msg_chain: MessageChain) -> str: + signature = msg_chain.get_plain_text(with_other_comps_mark=True).strip() + if not signature: + return "" + return re.sub(r"\s+", " ", signature) + + async def run_agent( agent_runner: AgentRunner, max_step: int = 30, @@ -94,10 +109,16 @@ async def run_agent( show_tool_call_result: bool = False, stream_to_general: bool = False, show_reasoning: bool = False, + repeat_reply_guard_threshold: int = 3, ) -> AsyncGenerator[MessageChain | None, None]: step_idx = 0 astr_event = agent_runner.run_context.context.event tool_name_by_call_id: dict[str, str] = {} + guard_threshold = _normalize_repeat_reply_guard_threshold( + repeat_reply_guard_threshold + ) + guard_last_signature = "" + guard_repeat_count = 0 while step_idx < max_step + 1: step_idx += 1 @@ -193,6 +214,38 @@ async def run_agent( await astr_event.send(chain) continue + if resp.type == "llm_result" and guard_threshold > 0: + chain_signature = _build_chain_signature(resp.data["chain"]) + if chain_signature: + if chain_signature == guard_last_signature: + guard_repeat_count += 1 + else: + guard_last_signature = chain_signature + guard_repeat_count = 1 + + if guard_repeat_count >= guard_threshold: + logger.warning( + "Agent repeated identical llm_result %d times; forcing convergence. threshold=%d", + guard_repeat_count, + guard_threshold, + ) + if not agent_runner.done(): + if agent_runner.req: + agent_runner.req.func_tool = None + agent_runner.run_context.messages.append( + Message( + role="user", + content=( + "检测到你连续多次输出相同回复。" + "请停止重复,基于已有信息给出最终答复," + "不要再次调用工具。" + ), + ) + ) + # Jump to the same convergence path as max-step limit. + step_idx = max_step + continue + if stream_to_general and resp.type == "streaming_delta": continue @@ -288,6 +341,7 @@ async def run_live_agent( show_tool_use: bool = True, show_tool_call_result: bool = False, show_reasoning: bool = False, + repeat_reply_guard_threshold: int = 3, ) -> AsyncGenerator[MessageChain | None, None]: """Live Mode 的 Agent 运行器,支持流式 TTS @@ -311,6 +365,7 @@ async def run_live_agent( show_tool_call_result=show_tool_call_result, stream_to_general=False, show_reasoning=show_reasoning, + repeat_reply_guard_threshold=repeat_reply_guard_threshold, ): yield chain return @@ -343,6 +398,7 @@ async def run_live_agent( show_tool_use, show_tool_call_result, show_reasoning, + repeat_reply_guard_threshold, ) ) @@ -430,6 +486,7 @@ async def _run_agent_feeder( show_tool_use: bool, show_tool_call_result: bool, show_reasoning: bool, + repeat_reply_guard_threshold: int, ) -> None: """运行 Agent 并将文本输出分句放入队列""" buffer = "" @@ -441,6 +498,7 @@ async def _run_agent_feeder( show_tool_call_result=show_tool_call_result, stream_to_general=False, show_reasoning=show_reasoning, + repeat_reply_guard_threshold=repeat_reply_guard_threshold, ): if chain is None: continue diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 0f43dbd06d..ac5a5d64a6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -149,6 +149,7 @@ "unsupported_streaming_strategy": "realtime_segmenting", "reachability_check": False, "max_agent_step": 30, + "repeat_reply_guard_threshold": 3, "tool_call_timeout": 120, "tool_schema_mode": "full", "llm_safety_mode": True, @@ -2685,6 +2686,9 @@ class ChatProviderTemplate(TypedDict): "max_agent_step": { "type": "int", }, + "repeat_reply_guard_threshold": { + "type": "int", + }, "tool_call_timeout": { "type": "int", }, @@ -3430,6 +3434,14 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.repeat_reply_guard_threshold": { + "description": "连续相同回复拦截阈值", + "type": "int", + "hint": "同一轮 Agent 运行中连续出现相同回复达到该次数时,将触发防循环收敛。设置为 0 可关闭。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.tool_call_timeout": { "description": "工具调用超时时间(秒)", "type": "int", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..2476b2a936 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -64,6 +64,17 @@ async def initialize(self, ctx: PipelineContext) -> None: self.tool_schema_mode = "full" if isinstance(self.max_step, bool): # workaround: #2622 self.max_step = 30 + self.repeat_reply_guard_threshold: int = settings.get( + "repeat_reply_guard_threshold", 3 + ) + if isinstance(self.repeat_reply_guard_threshold, bool): + self.repeat_reply_guard_threshold = 3 + try: + self.repeat_reply_guard_threshold = int(self.repeat_reply_guard_threshold) + except (TypeError, ValueError): + self.repeat_reply_guard_threshold = 3 + if self.repeat_reply_guard_threshold < 0: + self.repeat_reply_guard_threshold = 0 self.show_tool_use: bool = settings.get("show_tool_use_status", True) self.show_tool_call_result: bool = settings.get("show_tool_call_result", False) self.show_reasoning = settings.get("display_reasoning_text", False) @@ -274,6 +285,7 @@ async def process( self.show_tool_use, self.show_tool_call_result, show_reasoning=self.show_reasoning, + repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, ), ), ) @@ -304,6 +316,7 @@ async def process( self.show_tool_use, self.show_tool_call_result, show_reasoning=self.show_reasoning, + repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, ), ), ) @@ -334,6 +347,7 @@ async def process( self.show_tool_call_result, stream_to_general, show_reasoning=self.show_reasoning, + repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, ): yield diff --git a/tests/unit/test_astr_agent_run_util.py b/tests/unit/test_astr_agent_run_util.py new file mode 100644 index 0000000000..e3dd485014 --- /dev/null +++ b/tests/unit/test_astr_agent_run_util.py @@ -0,0 +1,148 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.core.astr_agent_run_util import run_agent +from astrbot.core.message.message_event_result import MessageChain + + +def _llm_result_response(text: str): + return SimpleNamespace( + type="llm_result", + data={"chain": MessageChain().message(text)}, + ) + + +class _DummyTrace: + def record(self, *args, **kwargs) -> None: + return None + + +class _DummyEvent: + def __init__(self) -> None: + self._extras: dict = {} + self._stopped = False + self.result_texts: list[str] = [] + self.trace = _DummyTrace() + + def is_stopped(self) -> bool: + return self._stopped + + def get_extra(self, key: str, default=None): + return self._extras.get(key, default) + + def set_extra(self, key: str, value) -> None: + self._extras[key] = value + + def set_result(self, result) -> None: + self.result_texts.append(result.get_plain_text(with_other_comps_mark=True)) + + def clear_result(self) -> None: + return None + + def get_platform_name(self) -> str: + return "slack" + + def get_platform_id(self) -> str: + return "slack" + + async def send(self, _msg_chain) -> None: + return None + + +class _FakeRunner: + def __init__(self, steps: list[list[SimpleNamespace]]) -> None: + self._steps = steps + self._step_idx = 0 + self._done = False + self.streaming = False + self.req = SimpleNamespace(func_tool=object()) + self.run_context = SimpleNamespace( + context=SimpleNamespace(event=_DummyEvent()), + messages=[], + ) + self.stats = SimpleNamespace(to_dict=lambda: {}) + + def done(self) -> bool: + return self._done + + def request_stop(self) -> None: + self.run_context.context.event.set_extra("agent_stop_requested", True) + + def was_aborted(self) -> bool: + return False + + async def step(self): + if self._step_idx >= len(self._steps): + self._done = True + return + + current = self._steps[self._step_idx] + self._step_idx += 1 + for resp in current: + yield resp + + if self._step_idx >= len(self._steps): + self._done = True + + +@pytest.mark.asyncio +async def test_repeat_reply_guard_forces_convergence(): + runner = _FakeRunner( + [ + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("最终答案")], + ] + ) + + async for _ in run_agent( + runner, + max_step=8, + show_tool_use=False, + show_tool_call_result=False, + repeat_reply_guard_threshold=3, + ): + pass + + assert runner.run_context.context.event.result_texts == [ + "重复输出", + "重复输出", + "最终答案", + ] + assert runner.req.func_tool is None + assert any( + msg.role == "user" and "检测到你连续多次输出相同回复" in str(msg.content) + for msg in runner.run_context.messages + ) + + +@pytest.mark.asyncio +async def test_repeat_reply_guard_can_be_disabled_with_zero_threshold(): + runner = _FakeRunner( + [ + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("重复输出")], + [_llm_result_response("最终答案")], + ] + ) + original_func_tool = runner.req.func_tool + + async for _ in run_agent( + runner, + max_step=8, + show_tool_use=False, + show_tool_call_result=False, + repeat_reply_guard_threshold=0, + ): + pass + + assert runner.run_context.context.event.result_texts == [ + "重复输出", + "重复输出", + "重复输出", + "最终答案", + ] + assert runner.req.func_tool is original_func_tool From 9955f5ed795b6c550225bd8cf253f0ddee409b8e Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Wed, 25 Mar 2026 06:50:38 -0500 Subject: [PATCH 2/3] refactor(agent): centralize repeat-reply guard settings --- astrbot/core/astr_agent_run_util.py | 18 +++----- astrbot/core/config/default.py | 3 +- .../method/agent_sub_stages/internal.py | 20 ++++---- astrbot/core/repeat_reply_guard.py | 18 ++++++++ tests/unit/test_repeat_reply_guard.py | 46 +++++++++++++++++++ 5 files changed, 84 insertions(+), 21 deletions(-) create mode 100644 astrbot/core/repeat_reply_guard.py create mode 100644 tests/unit/test_repeat_reply_guard.py diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index d239aed1ab..48da08c73b 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -19,6 +19,10 @@ ) from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.provider import TTSProvider +from astrbot.core.repeat_reply_guard import ( + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + normalize_repeat_reply_guard_threshold, +) AgentRunner = ToolLoopAgentRunner[AstrAgentContext] @@ -87,14 +91,6 @@ def _build_tool_result_status_message( return status_msg -def _normalize_repeat_reply_guard_threshold(value: int) -> int: - try: - parsed = int(value) - except (TypeError, ValueError): - return 0 - return max(0, parsed) - - def _build_chain_signature(msg_chain: MessageChain) -> str: signature = msg_chain.get_plain_text(with_other_comps_mark=True).strip() if not signature: @@ -109,12 +105,12 @@ async def run_agent( show_tool_call_result: bool = False, stream_to_general: bool = False, show_reasoning: bool = False, - repeat_reply_guard_threshold: int = 3, + repeat_reply_guard_threshold: int = DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, ) -> AsyncGenerator[MessageChain | None, None]: step_idx = 0 astr_event = agent_runner.run_context.context.event tool_name_by_call_id: dict[str, str] = {} - guard_threshold = _normalize_repeat_reply_guard_threshold( + guard_threshold = normalize_repeat_reply_guard_threshold( repeat_reply_guard_threshold ) guard_last_signature = "" @@ -341,7 +337,7 @@ async def run_live_agent( show_tool_use: bool = True, show_tool_call_result: bool = False, show_reasoning: bool = False, - repeat_reply_guard_threshold: int = 3, + repeat_reply_guard_threshold: int = DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, ) -> AsyncGenerator[MessageChain | None, None]: """Live Mode 的 Agent 运行器,支持流式 TTS diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ac5a5d64a6..2b0a6b4fbb 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -3,6 +3,7 @@ import os from typing import Any, TypedDict +from astrbot.core.repeat_reply_guard import DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD from astrbot.core.utils.astrbot_path import get_astrbot_data_path VERSION = "4.22.0" @@ -149,7 +150,7 @@ "unsupported_streaming_strategy": "realtime_segmenting", "reachability_check": False, "max_agent_step": 30, - "repeat_reply_guard_threshold": 3, + "repeat_reply_guard_threshold": DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, "tool_call_timeout": 120, "tool_schema_mode": "full", "llm_safety_mode": True, diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 2476b2a936..9d9e58c6b9 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -28,6 +28,10 @@ LLMResponse, ProviderRequest, ) +from astrbot.core.repeat_reply_guard import ( + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + normalize_config_repeat_reply_guard_threshold, +) from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager @@ -65,16 +69,14 @@ async def initialize(self, ctx: PipelineContext) -> None: if isinstance(self.max_step, bool): # workaround: #2622 self.max_step = 30 self.repeat_reply_guard_threshold: int = settings.get( - "repeat_reply_guard_threshold", 3 + "repeat_reply_guard_threshold", + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + ) + self.repeat_reply_guard_threshold = ( + normalize_config_repeat_reply_guard_threshold( + self.repeat_reply_guard_threshold + ) ) - if isinstance(self.repeat_reply_guard_threshold, bool): - self.repeat_reply_guard_threshold = 3 - try: - self.repeat_reply_guard_threshold = int(self.repeat_reply_guard_threshold) - except (TypeError, ValueError): - self.repeat_reply_guard_threshold = 3 - if self.repeat_reply_guard_threshold < 0: - self.repeat_reply_guard_threshold = 0 self.show_tool_use: bool = settings.get("show_tool_use_status", True) self.show_tool_call_result: bool = settings.get("show_tool_call_result", False) self.show_reasoning = settings.get("display_reasoning_text", False) diff --git a/astrbot/core/repeat_reply_guard.py b/astrbot/core/repeat_reply_guard.py new file mode 100644 index 0000000000..6b5d4f7f65 --- /dev/null +++ b/astrbot/core/repeat_reply_guard.py @@ -0,0 +1,18 @@ +DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD = 3 + + +def normalize_repeat_reply_guard_threshold(value, *, invalid_fallback: int = 0) -> int: + if isinstance(value, bool): + return invalid_fallback + try: + parsed = int(value) + except (TypeError, ValueError): + return invalid_fallback + return max(0, parsed) + + +def normalize_config_repeat_reply_guard_threshold(value) -> int: + return normalize_repeat_reply_guard_threshold( + value, + invalid_fallback=DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + ) diff --git a/tests/unit/test_repeat_reply_guard.py b/tests/unit/test_repeat_reply_guard.py new file mode 100644 index 0000000000..371dc79bd2 --- /dev/null +++ b/tests/unit/test_repeat_reply_guard.py @@ -0,0 +1,46 @@ +import inspect + +from astrbot.core.astr_agent_run_util import run_agent, run_live_agent +from astrbot.core.config.default import DEFAULT_CONFIG +from astrbot.core.repeat_reply_guard import ( + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + normalize_config_repeat_reply_guard_threshold, + normalize_repeat_reply_guard_threshold, +) + + +def test_runtime_repeat_reply_guard_threshold_normalization(): + assert normalize_repeat_reply_guard_threshold("2") == 2 + assert normalize_repeat_reply_guard_threshold(-1) == 0 + assert normalize_repeat_reply_guard_threshold(None) == 0 + assert normalize_repeat_reply_guard_threshold(True) == 0 + + +def test_config_repeat_reply_guard_threshold_normalization(): + assert normalize_config_repeat_reply_guard_threshold("4") == 4 + assert normalize_config_repeat_reply_guard_threshold(-1) == 0 + assert ( + normalize_config_repeat_reply_guard_threshold(None) + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + assert ( + normalize_config_repeat_reply_guard_threshold(True) + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + + +def test_repeat_reply_guard_default_is_shared(): + assert ( + DEFAULT_CONFIG["provider_settings"]["repeat_reply_guard_threshold"] + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + assert ( + inspect.signature(run_agent).parameters["repeat_reply_guard_threshold"].default + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) + assert ( + inspect.signature(run_live_agent) + .parameters["repeat_reply_guard_threshold"] + .default + == DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD + ) From 69b8fea31eafebcb9f4fd100b0e49b217e1e9744 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Sun, 29 Mar 2026 10:03:23 -0500 Subject: [PATCH 3/3] refactor(agent): inline repeat-reply guard helpers --- astrbot/core/astr_agent_run_util.py | 29 ++++++++++++++----- astrbot/core/config/default.py | 2 +- .../method/agent_sub_stages/internal.py | 8 ++--- astrbot/core/repeat_reply_guard.py | 18 ------------ tests/unit/test_astr_agent_run_util.py | 3 +- tests/unit/test_repeat_reply_guard.py | 7 +++-- 6 files changed, 33 insertions(+), 34 deletions(-) delete mode 100644 astrbot/core/repeat_reply_guard.py diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 48da08c73b..d21b8f0028 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -8,6 +8,7 @@ from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.config.default import DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD from astrbot.core.message.components import BaseMessageComponent, Json, Plain from astrbot.core.message.message_event_result import ( MessageChain, @@ -19,14 +20,27 @@ ) from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.provider import TTSProvider -from astrbot.core.repeat_reply_guard import ( - DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, - normalize_repeat_reply_guard_threshold, -) AgentRunner = ToolLoopAgentRunner[AstrAgentContext] +def normalize_repeat_reply_guard_threshold(value, *, invalid_fallback: int = 0) -> int: + if isinstance(value, bool): + return invalid_fallback + try: + parsed = int(value) + except (TypeError, ValueError): + return invalid_fallback + return max(0, parsed) + + +def normalize_config_repeat_reply_guard_threshold(value) -> int: + return normalize_repeat_reply_guard_threshold( + value, + invalid_fallback=DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + ) + + def _should_stop_agent(astr_event) -> bool: return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested")) @@ -232,9 +246,10 @@ async def run_agent( Message( role="user", content=( - "检测到你连续多次输出相同回复。" - "请停止重复,基于已有信息给出最终答复," - "不要再次调用工具。" + "You have repeated the same reply multiple times. " + "Stop repeating yourself, provide a final answer " + "based on the information you already have, and do " + "not call tools again." ), ) ) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 2b0a6b4fbb..86371dc7f9 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -3,11 +3,11 @@ import os from typing import Any, TypedDict -from astrbot.core.repeat_reply_guard import DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD from astrbot.core.utils.astrbot_path import get_astrbot_data_path VERSION = "4.22.0" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") +DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD = 3 PERSONAL_WECHAT_CONFIG_METADATA = { "weixin_oc_base_url": { "description": "Base URL", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 9d9e58c6b9..8ef76b7ad7 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -8,6 +8,10 @@ from astrbot.core import logger from astrbot.core.agent.message import Message from astrbot.core.agent.response import AgentStats +from astrbot.core.astr_agent_run_util import ( + DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, + normalize_config_repeat_reply_guard_threshold, +) from astrbot.core.astr_main_agent import ( MainAgentBuildConfig, MainAgentBuildResult, @@ -28,10 +32,6 @@ LLMResponse, ProviderRequest, ) -from astrbot.core.repeat_reply_guard import ( - DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, - normalize_config_repeat_reply_guard_threshold, -) from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager diff --git a/astrbot/core/repeat_reply_guard.py b/astrbot/core/repeat_reply_guard.py deleted file mode 100644 index 6b5d4f7f65..0000000000 --- a/astrbot/core/repeat_reply_guard.py +++ /dev/null @@ -1,18 +0,0 @@ -DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD = 3 - - -def normalize_repeat_reply_guard_threshold(value, *, invalid_fallback: int = 0) -> int: - if isinstance(value, bool): - return invalid_fallback - try: - parsed = int(value) - except (TypeError, ValueError): - return invalid_fallback - return max(0, parsed) - - -def normalize_config_repeat_reply_guard_threshold(value) -> int: - return normalize_repeat_reply_guard_threshold( - value, - invalid_fallback=DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, - ) diff --git a/tests/unit/test_astr_agent_run_util.py b/tests/unit/test_astr_agent_run_util.py index e3dd485014..0ab0e776d5 100644 --- a/tests/unit/test_astr_agent_run_util.py +++ b/tests/unit/test_astr_agent_run_util.py @@ -113,7 +113,8 @@ async def test_repeat_reply_guard_forces_convergence(): ] assert runner.req.func_tool is None assert any( - msg.role == "user" and "检测到你连续多次输出相同回复" in str(msg.content) + msg.role == "user" + and "You have repeated the same reply multiple times." in str(msg.content) for msg in runner.run_context.messages ) diff --git a/tests/unit/test_repeat_reply_guard.py b/tests/unit/test_repeat_reply_guard.py index 371dc79bd2..91ade0651f 100644 --- a/tests/unit/test_repeat_reply_guard.py +++ b/tests/unit/test_repeat_reply_guard.py @@ -1,12 +1,13 @@ import inspect -from astrbot.core.astr_agent_run_util import run_agent, run_live_agent -from astrbot.core.config.default import DEFAULT_CONFIG -from astrbot.core.repeat_reply_guard import ( +from astrbot.core.astr_agent_run_util import ( DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD, normalize_config_repeat_reply_guard_threshold, normalize_repeat_reply_guard_threshold, + run_agent, + run_live_agent, ) +from astrbot.core.config.default import DEFAULT_CONFIG def test_runtime_repeat_reply_guard_threshold_normalization():