diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index 5190a363ee..6eb82c25ca 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -6,6 +6,10 @@ DEERFLOW_PROVIDER_TYPE, DEERFLOW_THREAD_ID_KEY, ) +from astrbot.core.persona_utils import ( + is_persona_none_marker, + normalize_persona_id, +) from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.message_type import MessageType from astrbot.core.utils.active_event_registry import active_event_registry @@ -37,7 +41,7 @@ async def _get_current_persona_id(self, session_id): ) if not conv: return None - return conv.persona_id + return normalize_persona_id(conv.persona_id) async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" @@ -225,7 +229,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: platform_name=platform_name, provider_settings=provider_settings, ) - if persona_id == "[%None]": + if is_persona_none_marker(persona_id): persona_name = "无" elif persona_id: persona_name = persona_id diff --git a/astrbot/builtin_stars/builtin_commands/commands/persona.py b/astrbot/builtin_stars/builtin_commands/commands/persona.py index 7a7416bbaf..3e8311a600 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/persona.py +++ b/astrbot/builtin_stars/builtin_commands/commands/persona.py @@ -3,6 +3,10 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.persona_utils import ( + PERSONA_NONE_MARKER, + is_persona_none_marker, +) if TYPE_CHECKING: from astrbot.core.db.po import Persona @@ -92,7 +96,7 @@ async def persona(self, message: AstrMessageEvent) -> None: provider_settings=provider_settings, ) - if persona_id == "[%None]": + if is_persona_none_marker(persona_id): curr_persona_name = "无" elif persona_id: curr_persona_name = persona_id @@ -174,7 +178,7 @@ async def persona(self, message: AstrMessageEvent) -> None: return await self.context.conversation_manager.update_conversation_persona_id( message.unified_msg_origin, - "[%None]", + PERSONA_NONE_MARKER, ) message.set_result(MessageEventResult().message("取消人格成功。")) else: diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 2c282867f9..757df8a69c 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -11,6 +11,7 @@ from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 +from astrbot.core.persona_utils import normalize_persona_id from astrbot.core.utils.datetime_utils import to_utc_timestamp @@ -98,6 +99,12 @@ async def new_conversation( platform_id = parts[0] if not platform_id: platform_id = "unknown" + if persona_id is None: + curr_cid = await self.get_curr_conversation_id(unified_msg_origin) + if curr_cid: + curr_conv = await self.db.get_conversation_by_id(cid=curr_cid) + if curr_conv: + persona_id = normalize_persona_id(curr_conv.persona_id) conv = await self.db.create_conversation( user_id=unified_msg_origin, platform_id=platform_id, diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index 6320ac3bbc..001bdd7675 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -3,6 +3,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Persona, PersonaFolder, Personality +from astrbot.core.persona_utils import is_persona_none_marker from astrbot.core.platform.message_session import MessageSession from astrbot.core.sentinels import NOT_GIVEN @@ -104,7 +105,7 @@ async def resolve_selected_persona( if not persona_id: persona_id = conversation_persona_id - if persona_id == "[%None]": + if is_persona_none_marker(persona_id): pass elif persona_id is None: persona_id = (provider_settings or {}).get("default_personality") @@ -115,7 +116,11 @@ async def resolve_selected_persona( ) use_webchat_special_default = False - if not persona and platform_name == "webchat" and persona_id != "[%None]": + if ( + not persona + and platform_name == "webchat" + and not is_persona_none_marker(persona_id) + ): persona_id = "_chatui_default_" use_webchat_special_default = True diff --git a/astrbot/core/persona_utils.py b/astrbot/core/persona_utils.py new file mode 100644 index 0000000000..146090fd8b --- /dev/null +++ b/astrbot/core/persona_utils.py @@ -0,0 +1,15 @@ +"""Helpers for persona marker handling.""" + +PERSONA_NONE_MARKER = "[%None]" + + +def is_persona_none_marker(persona_id: str | None) -> bool: + """Return whether the persona id is the explicit no-persona marker.""" + return persona_id == PERSONA_NONE_MARKER + + +def normalize_persona_id(persona_id: str | None) -> str | None: + """Normalize the explicit no-persona marker to None.""" + if is_persona_none_marker(persona_id): + return None + return persona_id diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 9a42abd733..30696c7f1a 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -10,6 +10,7 @@ from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Plain, Reply +from astrbot.core.persona_utils import PERSONA_NONE_MARKER from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.platform_metadata import PlatformMetadata from astrbot.core.provider import Provider @@ -504,14 +505,14 @@ async def test_ensure_persona_from_conversation(self, mock_event, mock_context): @pytest.mark.asyncio async def test_ensure_persona_none_explicit(self, mock_event, mock_context): - """Test that [%None] persona is explicitly set to no persona.""" + """Test that the explicit no-persona marker is treated as no persona.""" module = ama mock_context.persona_manager.personas_v3 = [] mock_context.persona_manager.resolve_selected_persona = AsyncMock( - return_value=("[%None]", None, None, False) + return_value=(PERSONA_NONE_MARKER, None, None, False) ) req = ProviderRequest() - req.conversation = MagicMock(persona_id="[%None]") + req.conversation = MagicMock(persona_id=PERSONA_NONE_MARKER) await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) @@ -565,9 +566,10 @@ async def test_subagent_dedupe_uses_default_persona_tools( tmgr = mock_context.get_llm_tool_manager.return_value tmgr.func_list = [tool_a, tool_b] tmgr.get_full_tool_set.return_value = ToolSet([tool_a, tool_b]) - tmgr.get_func.side_effect = lambda name: {"tool_a": tool_a, "tool_b": tool_b}.get( - name - ) + tmgr.get_func.side_effect = lambda name: { + "tool_a": tool_a, + "tool_b": tool_b, + }.get(name) handoff = MagicMock() handoff.name = "transfer_to_planner" diff --git a/tests/unit/test_conversation_mgr.py b/tests/unit/test_conversation_mgr.py new file mode 100644 index 0000000000..4b09b0d897 --- /dev/null +++ b/tests/unit/test_conversation_mgr.py @@ -0,0 +1,99 @@ +"""Tests for conversation persona inheritance behavior.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.builtin_stars.builtin_commands.commands.conversation import ( + ConversationCommands, +) +from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.persona_utils import PERSONA_NONE_MARKER + + +@pytest.mark.asyncio +async def test_new_conversation_inherits_current_persona_when_not_provided(): + db = MagicMock() + db.get_conversation_by_id = AsyncMock( + return_value=SimpleNamespace(persona_id="psychologist") + ) + db.create_conversation = AsyncMock( + return_value=SimpleNamespace(conversation_id="new-cid") + ) + + manager = ConversationManager(db) + manager.session_conversations["test:private:u1"] = "old-cid" + + with patch( + "astrbot.core.conversation_mgr.sp.session_put", + new=AsyncMock(return_value=None), + ): + await manager.new_conversation("test:private:u1", platform_id="test") + + assert db.create_conversation.await_args.kwargs["persona_id"] == "psychologist" + + +@pytest.mark.asyncio +async def test_new_conversation_does_not_inherit_persona_none_marker(): + db = MagicMock() + db.get_conversation_by_id = AsyncMock( + return_value=SimpleNamespace(persona_id=PERSONA_NONE_MARKER) + ) + db.create_conversation = AsyncMock( + return_value=SimpleNamespace(conversation_id="new-cid") + ) + + manager = ConversationManager(db) + manager.session_conversations["test:private:u1"] = "old-cid" + + with patch( + "astrbot.core.conversation_mgr.sp.session_put", + new=AsyncMock(return_value=None), + ): + await manager.new_conversation("test:private:u1", platform_id="test") + + assert db.create_conversation.await_args.kwargs["persona_id"] is None + + +@pytest.mark.asyncio +async def test_new_conversation_keeps_explicit_persona_id(): + db = MagicMock() + db.get_conversation_by_id = AsyncMock( + return_value=SimpleNamespace(persona_id="psychologist") + ) + db.create_conversation = AsyncMock( + return_value=SimpleNamespace(conversation_id="new-cid") + ) + + manager = ConversationManager(db) + manager.session_conversations["test:private:u1"] = "old-cid" + + with patch( + "astrbot.core.conversation_mgr.sp.session_put", + new=AsyncMock(return_value=None), + ): + await manager.new_conversation( + "test:private:u1", + platform_id="test", + persona_id="teacher", + ) + + assert db.create_conversation.await_args.kwargs["persona_id"] == "teacher" + + +@pytest.mark.asyncio +async def test_get_current_persona_id_returns_none_for_none_marker(): + context = MagicMock() + context.conversation_manager.get_curr_conversation_id = AsyncMock( + return_value="old-cid" + ) + context.conversation_manager.get_conversation = AsyncMock( + return_value=MagicMock(persona_id=PERSONA_NONE_MARKER) + ) + + command = ConversationCommands(context) + + result = await command._get_current_persona_id("test:private:u1") + + assert result is None