Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
DEERFLOW_PROVIDER_TYPE,
DEERFLOW_THREAD_ID_KEY,
)
from astrbot.core.persona_utils import (
is_persona_none_marker,
normalize_persona_id,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.message_type import MessageType
from astrbot.core.utils.active_event_registry import active_event_registry
Expand Down Expand Up @@ -37,7 +41,7 @@ async def _get_current_persona_id(self, session_id):
)
if not conv:
return None
return conv.persona_id
return normalize_persona_id(conv.persona_id)

async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
Expand Down Expand Up @@ -225,7 +229,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
platform_name=platform_name,
provider_settings=provider_settings,
)
if persona_id == "[%None]":
if is_persona_none_marker(persona_id):
persona_name = "无"
elif persona_id:
persona_name = persona_id
Expand Down
8 changes: 6 additions & 2 deletions astrbot/builtin_stars/builtin_commands/commands/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.persona_utils import (
PERSONA_NONE_MARKER,
is_persona_none_marker,
)

if TYPE_CHECKING:
from astrbot.core.db.po import Persona
Expand Down Expand Up @@ -92,7 +96,7 @@ async def persona(self, message: AstrMessageEvent) -> None:
provider_settings=provider_settings,
)

if persona_id == "[%None]":
if is_persona_none_marker(persona_id):
curr_persona_name = "无"
elif persona_id:
curr_persona_name = persona_id
Expand Down Expand Up @@ -174,7 +178,7 @@ async def persona(self, message: AstrMessageEvent) -> None:
return
await self.context.conversation_manager.update_conversation_persona_id(
message.unified_msg_origin,
"[%None]",
PERSONA_NONE_MARKER,
)
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
Expand Down
7 changes: 7 additions & 0 deletions astrbot/core/conversation_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
from astrbot.core.persona_utils import normalize_persona_id
from astrbot.core.utils.datetime_utils import to_utc_timestamp


Expand Down Expand Up @@ -98,6 +99,12 @@ async def new_conversation(
platform_id = parts[0]
if not platform_id:
platform_id = "unknown"
if persona_id is None:
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid:
curr_conv = await self.db.get_conversation_by_id(cid=curr_cid)
if curr_conv:
persona_id = normalize_persona_id(curr_conv.persona_id)
Comment on lines +102 to +107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这段代码正确地实现了在创建新对话时继承当前人格的逻辑。

然而,我注意到这可能会导致一些调用路径(例如 conversation.py 中的 new_conv)出现效率问题。new_conv 会先调用 _get_current_persona_id 来解析当前的人格,然后将结果传给 new_conversation。如果当前人格是 [%None]_get_current_persona_id 会返回 None,然后 new_conversation 会因为 persona_id is None 而再次执行几乎相同的逻辑来解析当前人格,这导致了重复的数据库查询。

虽然这不影响功能的正确性,但为了提高性能和代码的简洁性,建议在未来的重构中,让调用方(如 new_conv)直接调用 new_conversation() 而不传递 persona_id,完全依赖此处的继承逻辑。

由于相关调用方的代码不在此次变更范围内,这可以作为一个后续的改进点。

conv = await self.db.create_conversation(
user_id=unified_msg_origin,
platform_id=platform_id,
Expand Down
9 changes: 7 additions & 2 deletions astrbot/core/persona_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, PersonaFolder, Personality
from astrbot.core.persona_utils import is_persona_none_marker
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.sentinels import NOT_GIVEN

Expand Down Expand Up @@ -104,7 +105,7 @@ async def resolve_selected_persona(

if not persona_id:
persona_id = conversation_persona_id
if persona_id == "[%None]":
if is_persona_none_marker(persona_id):
pass
elif persona_id is None:
persona_id = (provider_settings or {}).get("default_personality")
Expand All @@ -115,7 +116,11 @@ async def resolve_selected_persona(
)

use_webchat_special_default = False
if not persona and platform_name == "webchat" and persona_id != "[%None]":
if (
not persona
and platform_name == "webchat"
and not is_persona_none_marker(persona_id)
):
persona_id = "_chatui_default_"
use_webchat_special_default = True

Expand Down
15 changes: 15 additions & 0 deletions astrbot/core/persona_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Helpers for persona marker handling."""

PERSONA_NONE_MARKER = "[%None]"


def is_persona_none_marker(persona_id: str | None) -> bool:
"""Return whether the persona id is the explicit no-persona marker."""
return persona_id == PERSONA_NONE_MARKER


def normalize_persona_id(persona_id: str | None) -> str | None:
"""Normalize the explicit no-persona marker to None."""
if is_persona_none_marker(persona_id):
return None
return persona_id
14 changes: 8 additions & 6 deletions tests/unit/test_astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Plain, Reply
from astrbot.core.persona_utils import PERSONA_NONE_MARKER
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.provider import Provider
Expand Down Expand Up @@ -504,14 +505,14 @@ async def test_ensure_persona_from_conversation(self, mock_event, mock_context):

@pytest.mark.asyncio
async def test_ensure_persona_none_explicit(self, mock_event, mock_context):
"""Test that [%None] persona is explicitly set to no persona."""
"""Test that the explicit no-persona marker is treated as no persona."""
module = ama
mock_context.persona_manager.personas_v3 = []
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("[%None]", None, None, False)
return_value=(PERSONA_NONE_MARKER, None, None, False)
)
req = ProviderRequest()
req.conversation = MagicMock(persona_id="[%None]")
req.conversation = MagicMock(persona_id=PERSONA_NONE_MARKER)

await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)

Expand Down Expand Up @@ -565,9 +566,10 @@ async def test_subagent_dedupe_uses_default_persona_tools(
tmgr = mock_context.get_llm_tool_manager.return_value
tmgr.func_list = [tool_a, tool_b]
tmgr.get_full_tool_set.return_value = ToolSet([tool_a, tool_b])
tmgr.get_func.side_effect = lambda name: {"tool_a": tool_a, "tool_b": tool_b}.get(
name
)
tmgr.get_func.side_effect = lambda name: {
"tool_a": tool_a,
"tool_b": tool_b,
}.get(name)

handoff = MagicMock()
handoff.name = "transfer_to_planner"
Expand Down
99 changes: 99 additions & 0 deletions tests/unit/test_conversation_mgr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Tests for conversation persona inheritance behavior."""

from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from astrbot.builtin_stars.builtin_commands.commands.conversation import (
ConversationCommands,
)
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.persona_utils import PERSONA_NONE_MARKER


@pytest.mark.asyncio
async def test_new_conversation_inherits_current_persona_when_not_provided():
db = MagicMock()
db.get_conversation_by_id = AsyncMock(
return_value=SimpleNamespace(persona_id="psychologist")
)
db.create_conversation = AsyncMock(
return_value=SimpleNamespace(conversation_id="new-cid")
)

manager = ConversationManager(db)
manager.session_conversations["test:private:u1"] = "old-cid"

with patch(
"astrbot.core.conversation_mgr.sp.session_put",
new=AsyncMock(return_value=None),
):
await manager.new_conversation("test:private:u1", platform_id="test")

assert db.create_conversation.await_args.kwargs["persona_id"] == "psychologist"


@pytest.mark.asyncio
async def test_new_conversation_does_not_inherit_persona_none_marker():
db = MagicMock()
db.get_conversation_by_id = AsyncMock(
return_value=SimpleNamespace(persona_id=PERSONA_NONE_MARKER)
)
db.create_conversation = AsyncMock(
return_value=SimpleNamespace(conversation_id="new-cid")
)

manager = ConversationManager(db)
manager.session_conversations["test:private:u1"] = "old-cid"

with patch(
"astrbot.core.conversation_mgr.sp.session_put",
new=AsyncMock(return_value=None),
):
await manager.new_conversation("test:private:u1", platform_id="test")

assert db.create_conversation.await_args.kwargs["persona_id"] is None


@pytest.mark.asyncio
async def test_new_conversation_keeps_explicit_persona_id():
db = MagicMock()
db.get_conversation_by_id = AsyncMock(
return_value=SimpleNamespace(persona_id="psychologist")
)
db.create_conversation = AsyncMock(
return_value=SimpleNamespace(conversation_id="new-cid")
)

manager = ConversationManager(db)
manager.session_conversations["test:private:u1"] = "old-cid"

with patch(
"astrbot.core.conversation_mgr.sp.session_put",
new=AsyncMock(return_value=None),
):
await manager.new_conversation(
"test:private:u1",
platform_id="test",
persona_id="teacher",
)

assert db.create_conversation.await_args.kwargs["persona_id"] == "teacher"


@pytest.mark.asyncio
async def test_get_current_persona_id_returns_none_for_none_marker():
context = MagicMock()
context.conversation_manager.get_curr_conversation_id = AsyncMock(
return_value="old-cid"
)
context.conversation_manager.get_conversation = AsyncMock(
return_value=MagicMock(persona_id=PERSONA_NONE_MARKER)
)

command = ConversationCommands(context)

result = await command._get_current_persona_id("test:private:u1")

assert result is None
Loading