Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions astrbot/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PlatformSession,
PlatformStat,
Preference,
ProviderStat,
SessionProjectRelation,
Stats,
)
Expand Down Expand Up @@ -105,6 +106,21 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...

@abc.abstractmethod
async def insert_provider_stat(
self,
*,
umo: str,
provider_id: str,
provider_model: str | None = None,
conversation_id: str | None = None,
status: str = "completed",
stats: dict | None = None,
agent_type: str = "internal",
) -> ProviderStat:
"""Insert a per-response provider stat record."""
...

@abc.abstractmethod
async def get_conversations(
self,
Expand Down
24 changes: 24 additions & 0 deletions astrbot/core/db/po.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,30 @@ class PlatformStat(SQLModel, table=True):
)


class ProviderStat(TimestampMixin, SQLModel, table=True):
"""Per-response provider stats for internal agent runs."""

__tablename__: str = "provider_stats"

id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
agent_type: str = Field(default="internal", nullable=False, index=True)
status: str = Field(default="completed", nullable=False, index=True)
umo: str = Field(nullable=False, index=True)
conversation_id: str | None = Field(default=None, index=True)
provider_id: str = Field(nullable=False, index=True)
provider_model: str | None = Field(default=None, index=True)
token_input_other: int = Field(default=0, nullable=False)
token_input_cached: int = Field(default=0, nullable=False)
token_output: int = Field(default=0, nullable=False)
start_time: float = Field(default=0.0, nullable=False)
end_time: float = Field(default=0.0, nullable=False)
time_to_first_token: float = Field(default=0.0, nullable=False)


class ConversationV2(TimestampMixin, SQLModel, table=True):
__tablename__: str = "conversations"

Expand Down
46 changes: 46 additions & 0 deletions astrbot/core/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PlatformSession,
PlatformStat,
Preference,
ProviderStat,
SessionProjectRelation,
SQLModel,
)
Expand Down Expand Up @@ -169,6 +170,51 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat
)
return list(result.scalars().all())

async def insert_provider_stat(
self,
*,
umo: str,
provider_id: str,
provider_model: str | None = None,
conversation_id: str | None = None,
status: str = "completed",
stats: dict | None = None,
agent_type: str = "internal",
) -> ProviderStat:
"""Insert a provider stat record for a single agent response."""
stats = stats or {}
token_usage = stats.get("token_usage", {})

token_input_other = int(token_usage.get("input_other", 0) or 0)
token_input_cached = int(token_usage.get("input_cached", 0) or 0)
token_output = int(token_usage.get("output", 0) or 0)

start_time = float(stats.get("start_time", 0.0) or 0.0)
end_time = float(stats.get("end_time", 0.0) or 0.0)
time_to_first_token = float(stats.get("time_to_first_token", 0.0) or 0.0)

async with self.get_db() as session:
session: AsyncSession
async with session.begin():
record = ProviderStat(
agent_type=agent_type,
status=status,
umo=umo,
conversation_id=conversation_id,
provider_id=provider_id,
provider_model=provider_model,
token_input_other=token_input_other,
token_input_cached=token_input_cached,
token_output=token_output,
start_time=start_time,
end_time=end_time,
time_to_first_token=time_to_first_token,
)
session.add(record)
await session.flush()
await session.refresh(record)
return record
Comment on lines +212 to +216
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Refreshing the record after flush/begin is probably unnecessary overhead.

Here you add(record), flush(), then refresh(record) before returning. For a simple insert where only the autoincrement PK is DB-generated, flush() already populates record.id. The extra refresh() adds an unnecessary DB round-trip. Unless you depend on DB-side defaults/triggers updating other fields, you can remove await session.refresh(record) to avoid that overhead.

Suggested change
)
session.add(record)
await session.flush()
await session.refresh(record)
return record
)
session.add(record)
await session.flush()
return record


# ====
# Conversation Management
# ====
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncGenerator
from dataclasses import replace

from astrbot.core import logger
from astrbot.core import db_helper, logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.astr_main_agent import (
Expand Down Expand Up @@ -350,6 +350,15 @@ async def process(
resp=final_resp.completion_text if final_resp else None,
)

asyncio.create_task(
_record_internal_agent_stats(
event,
req,
agent_runner,
final_resp,
)
)

# 检查事件是否被停止,如果被停止则不保存历史记录
if not event.is_stopped() or agent_runner.was_aborted():
await self._save_to_history(
Expand Down Expand Up @@ -462,3 +471,46 @@ async def _save_to_history(
# these hosts are base64 encoded
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}
decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED]


async def _record_internal_agent_stats(
event: AstrMessageEvent,
req: ProviderRequest | None,
agent_runner: AgentRunner | None,
final_resp: LLMResponse | None,
) -> None:
"""Persist internal agent stats without affecting the user response flow."""
if agent_runner is None:
return

provider = agent_runner.provider
stats = agent_runner.stats
if provider is None or stats is None:
return

try:
provider_config = getattr(provider, "provider_config", {}) or {}
conversation_id = (
req.conversation.cid
if req is not None and req.conversation is not None
else None
)

if agent_runner.was_aborted():
status = "aborted"
elif final_resp is not None and final_resp.role == "err":
status = "error"
else:
status = "completed"

await db_helper.insert_provider_stat(
umo=event.unified_msg_origin,
conversation_id=conversation_id,
provider_id=provider_config.get("id", "") or provider.meta().id,
provider_model=provider.get_model(),
status=status,
stats=stats.to_dict(),
agent_type="internal",
)
except Exception as e:
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
Loading
Loading