diff --git a/.gitignore b/.gitignore index 5eb9616c8c..afd1659b8f 100644 --- a/.gitignore +++ b/.gitignore @@ -63,4 +63,5 @@ GenieData/ .kilocode/ .worktrees/ +.astrbot_sdk_testing/ dashboard/bun.lock diff --git a/AGENTS.md b/AGENTS.md index 9f3617ce9c..d13284dca5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,9 +26,9 @@ Runs on `http://localhost:3000` by default. 3. After finishing, use `ruff format .` and `ruff check .` to format and check the code. 4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`. 5. Use English for all new comments. -6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory. +6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.astrbot_path` helpers to get the AstrBot data and temp directory. ## PR instructions 1. Title format: use conventional commit messages -2. Use English to write PR title and descriptions. +2. Use English to write PR title and descriptions. \ No newline at end of file diff --git a/astrbot-sdk/LICENSE b/astrbot-sdk/LICENSE new file mode 100644 index 0000000000..51d7fd4c87 --- /dev/null +++ b/astrbot-sdk/LICENSE @@ -0,0 +1,11 @@ +AstrBot SDK repository notice +============================= + +This repository does not currently publish a standalone open-source license text. + +This file exists so the source repository and its `vendor/` subtree snapshot carry +the same notice instead of silently omitting licensing information. + +Unless the maintainers publish different licensing terms, do not assume this +repository grants redistribution or modification rights beyond applicable law and +explicit permission from the maintainers. diff --git a/astrbot-sdk/README.md b/astrbot-sdk/README.md new file mode 100644 index 0000000000..9cd71c50f0 --- /dev/null +++ b/astrbot-sdk/README.md @@ -0,0 +1,14 @@ +# AstrBot SDK Vendor Snapshot + +This directory is the minimized subtree payload consumed by the AstrBot main +repository. + +- `src/astrbot_sdk/` keeps the runtime SDK package plus the minimal testing + helpers that AstrBot and SDK-generated templates still treat as part of the + vendored contract +- agent skill templates and embedded markdown reference files are excluded +- root project-note templates for `astr init` stay vendored because the CLI + still generates `AGENTS.md` / `CLAUDE.md` by default +- `pyproject.toml` keeps the src-layout package discovery but drops dev/test-only metadata +- `VENDORED.md` describes the vendoring contract +- tests, docs, CI files, and other source-repo-only content stay outside this directory diff --git a/astrbot-sdk/VENDORED.md b/astrbot-sdk/VENDORED.md new file mode 100644 index 0000000000..a332777bcb --- /dev/null +++ b/astrbot-sdk/VENDORED.md @@ -0,0 +1,20 @@ +# Vendored Snapshot Notes + +This directory is a minimized snapshot for the AstrBot main repository to import +via `git subtree`. + +- The source of truth is this `astrbot-sdk` repository. +- `vendor/src/astrbot_sdk/` is synchronized from `src/astrbot_sdk/`. +- Vendored snapshots keep the runtime SDK plus the minimal testing helpers + (`testing.py`, `_testing_support.py`, `_internal/testing_support.py`) because + AstrBot and SDK-generated test templates still depend on them. +- Vendored snapshots exclude agent skill templates and markdown reference + assets that are not needed by the subtree consumer, but retain the default + `AGENTS.md` / `CLAUDE.md` project-note templates used by `astr init`. +- `vendor/pyproject.toml` keeps src-layout package discovery, but strips + test/dev-only sections so the subtree stays runtime-focused. +- Do not edit vendored files directly inside the AstrBot main repository. +- Tests and documentation remain only in the SDK source repository and are not + copied into the vendored snapshot. +- If the vendored copy needs changes, update the SDK source repository first and + regenerate the `vendor/` snapshot. diff --git a/astrbot-sdk/pyproject.toml b/astrbot-sdk/pyproject.toml new file mode 100644 index 0000000000..db6eff3658 --- /dev/null +++ b/astrbot-sdk/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "astrbot-sdk" +version = "0.1.0" +description = "AstrBot SDK with s5r runtime, worker protocol, and plugin tooling" +readme = "README.md" +requires-python = ">=3.12" +classifiers = [ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", +] +dependencies = [ + "aiohttp>=3.13.2", + "anthropic>=0.72.1", + "certifi>=2025.10.5", + "click>=8.3.0", + "docstring-parser>=0.17.0", + "google-genai>=1.50.0", + "loguru>=0.7.3", + "msgpack>=1.1.1", + "openai>=2.7.2", + "pydantic>=2.12.3", + "pyyaml>=6.0.3", + "uv>=0.9.17", +] + +[project.scripts] +astr = "astrbot_sdk.cli:cli" + +[tool.hatch.build.targets.wheel] +packages = ["src/astrbot_sdk"] +exclude = ["/src/astrbot_sdk/AGENTS.md"] + +[tool.hatch.build.targets.sdist] +include = [ + "/src", + "/README.md", + "/LICENSE", +] + +# ============================================================ +# Optional Dependencies +# ============================================================ diff --git a/astrbot-sdk/src/astrbot_sdk/__init__.py b/astrbot-sdk/src/astrbot_sdk/__init__.py new file mode 100644 index 0000000000..da30b663e3 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/__init__.py @@ -0,0 +1,222 @@ +"""AstrBot SDK 的顶层公共 API。 + +这里仅重新导出 astrbot-sdk 推荐直接导入的稳定入口。 + +新插件应直接使用此模块的导出: + from astrbot_sdk import Star, Context, MessageEvent + from astrbot_sdk.decorators import on_command, on_message + +迁移期适配入口位于独立模块;此处只暴露 astrbot-sdk 原生主入口。 +""" + +from .clients.managers import ( + ConversationCreateParams, + ConversationManagerClient, + ConversationRecord, + ConversationUpdateParams, + KnowledgeBaseCreateParams, + KnowledgeBaseDocumentRecord, + KnowledgeBaseDocumentUploadParams, + KnowledgeBaseManagerClient, + KnowledgeBaseRecord, + KnowledgeBaseRetrieveResult, + KnowledgeBaseRetrieveResultItem, + KnowledgeBaseUpdateParams, + MessageHistoryManagerClient, + MessageHistoryPage, + MessageHistoryRecord, + MessageHistorySender, + PersonaCreateParams, + PersonaManagerClient, + PersonaRecord, + PersonaUpdateParams, +) +from .clients.mcp import MCPManagerClient, MCPServerRecord, MCPServerScope, MCPSession +from .clients.metadata import PluginMetadata, StarMetadata +from .clients.permission import ( + PermissionCheckResult, + PermissionClient, + PermissionManagerClient, +) +from .clients.platform import PlatformError, PlatformStats, PlatformStatus +from .clients.provider import ( + ManagedProviderRecord, + ProviderChangeEvent, + ProviderManagerClient, +) +from .clients.session import SessionPluginManager, SessionServiceManager +from .commands import CommandGroup, command_group, print_cmd_tree +from .context import Context +from .conversation import ( + ConversationClosed, + ConversationReplaced, + ConversationSession, + ConversationState, +) +from .decorators import ( + acknowledge_global_mcp_risk, + admin_only, + background_task, + conversation_command, + cooldown, + group_only, + http_api, + mcp_server, + message_types, + on_command, + on_event, + on_message, + on_provider_change, + on_schedule, + platforms, + priority, + private_only, + provide_capability, + rate_limit, + register_skill, + require_admin, + require_permission, + validate_config, +) +from .errors import AstrBotError +from .events import MessageEvent +from .filters import ( + CustomFilter, + MessageTypeFilter, + PlatformFilter, + all_of, + any_of, + custom_filter, +) +from .message.components import ( + At, + AtAll, + BaseMessageComponent, + File, + Forward, + Image, + MediaHelper, + Plain, + Poke, + Record, + Reply, + UnknownComponent, + Video, +) +from .message.result import ( + EventResultType, + MessageBuilder, + MessageChain, + MessageEventResult, +) +from .message.session import MessageSession +from .plugin_kv import PluginKVStoreMixin +from .schedule import ScheduleContext +from .session_waiter import SessionController, session_waiter +from .star import Star +from .star_tools import StarTools +from .types import GreedyStr + +__all__ = [ + "AstrBotError", + "At", + "AtAll", + "BaseMessageComponent", + "CommandGroup", + "ConversationClosed", + "ConversationCreateParams", + "ConversationManagerClient", + "ConversationReplaced", + "ConversationRecord", + "ConversationSession", + "ConversationState", + "ConversationUpdateParams", + "Context", + "CustomFilter", + "EventResultType", + "File", + "Forward", + "GreedyStr", + "Image", + "KnowledgeBaseCreateParams", + "KnowledgeBaseDocumentRecord", + "KnowledgeBaseDocumentUploadParams", + "KnowledgeBaseManagerClient", + "KnowledgeBaseRecord", + "KnowledgeBaseRetrieveResult", + "KnowledgeBaseRetrieveResultItem", + "KnowledgeBaseUpdateParams", + "ManagedProviderRecord", + "MCPManagerClient", + "MCPSession", + "MCPServerRecord", + "MCPServerScope", + "MediaHelper", + "MessageHistoryManagerClient", + "MessageHistoryPage", + "MessageHistoryRecord", + "MessageHistorySender", + "MessageEvent", + "MessageEventResult", + "MessageChain", + "MessageBuilder", + "MessageSession", + "MessageTypeFilter", + "Plain", + "PluginKVStoreMixin", + "PluginMetadata", + "PermissionCheckResult", + "PermissionClient", + "PermissionManagerClient", + "PlatformFilter", + "PlatformError", + "PlatformStats", + "PlatformStatus", + "Poke", + "PersonaCreateParams", + "PersonaManagerClient", + "PersonaRecord", + "PersonaUpdateParams", + "ProviderChangeEvent", + "ProviderManagerClient", + "Record", + "Reply", + "ScheduleContext", + "SessionPluginManager", + "SessionServiceManager", + "SessionController", + "Star", + "StarMetadata", + "StarTools", + "UnknownComponent", + "Video", + "acknowledge_global_mcp_risk", + "admin_only", + "all_of", + "any_of", + "background_task", + "cooldown", + "conversation_command", + "command_group", + "custom_filter", + "group_only", + "http_api", + "mcp_server", + "message_types", + "on_command", + "on_event", + "on_message", + "on_provider_change", + "on_schedule", + "platforms", + "print_cmd_tree", + "priority", + "provide_capability", + "private_only", + "rate_limit", + "require_admin", + "require_permission", + "register_skill", + "session_waiter", + "validate_config", +] diff --git a/astrbot-sdk/src/astrbot_sdk/__main__.py b/astrbot-sdk/src/astrbot_sdk/__main__.py new file mode 100644 index 0000000000..624fd22f4c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/__main__.py @@ -0,0 +1,11 @@ +"""`python -m astrbot_sdk` 的 CLI 入口。""" + +from .cli import cli + + +def main() -> None: + cli() + + +if __name__ == "__main__": + main() diff --git a/astrbot-sdk/src/astrbot_sdk/_command_model.py b/astrbot-sdk/src/astrbot_sdk/_command_model.py new file mode 100644 index 0000000000..fd8f1ad851 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_command_model.py @@ -0,0 +1,17 @@ +from ._internal.command_model import ( + COMMAND_MODEL_DOCS_URL, + CommandModelParseResult, + ResolvedCommandModelParam, + format_command_model_help, + parse_command_model_remainder, + resolve_command_model_param, +) + +__all__ = [ + "COMMAND_MODEL_DOCS_URL", + "CommandModelParseResult", + "ResolvedCommandModelParam", + "format_command_model_help", + "parse_command_model_remainder", + "resolve_command_model_param", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py b/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py new file mode 100644 index 0000000000..6ccc0d22e9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py @@ -0,0 +1,7 @@ +"""Internal implementation modules for astrbot_sdk. + +This package groups private helpers that are not part of the public SDK API. +Imports outside the SDK should avoid depending on these modules directly. +""" + +__all__: list[str] = [] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py b/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py new file mode 100644 index 0000000000..664947f7af --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + +from ..errors import AstrBotError +from ..runtime._command_matching import split_command_remainder +from .injected_params import is_framework_injected_parameter +from .typing_utils import unwrap_optional + +# TODO:文档内容喵 +COMMAND_MODEL_DOCS_URL = "https://docs.astrbot.org/sdk/parameter-injection" + + +@dataclass(slots=True) +class ResolvedCommandModelParam: + name: str + model_cls: type[BaseModel] + + +@dataclass(slots=True) +class CommandModelParseResult: + model: BaseModel | None = None + help_text: str | None = None + + +def resolve_command_model_param(handler: Any) -> ResolvedCommandModelParam | None: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return None + try: + type_hints = inspect.get_annotations(handler, eval_str=True) + except Exception: + type_hints = {} + + candidates: list[ResolvedCommandModelParam] = [] + other_names: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + annotation = type_hints.get(parameter.name) + if _is_injected_parameter(parameter.name, annotation): + continue + normalized, _is_optional = unwrap_optional(annotation) + if isinstance(normalized, type) and issubclass(normalized, BaseModel): + candidates.append( + ResolvedCommandModelParam( + name=parameter.name, + model_cls=normalized, + ) + ) + continue + other_names.append(parameter.name) + + if not candidates: + return None + if len(candidates) > 1 or other_names: + names = [item.name for item in candidates] + raise ValueError( + "Command BaseModel injection requires exactly one non-injected BaseModel " + f"parameter, got models={names!r} others={other_names!r}" + ) + _validate_supported_model(candidates[0].model_cls) + return candidates[0] + + +def parse_command_model_remainder( + *, + remainder: str, + model_param: ResolvedCommandModelParam, + command_name: str, +) -> CommandModelParseResult: + tokens = split_command_remainder(remainder) + if any(token in {"-h", "--help"} for token in tokens): + return CommandModelParseResult( + help_text=format_command_model_help(command_name, model_param.model_cls) + ) + + fields = model_param.model_cls.model_fields + explicit_values: dict[str, Any] = {} + positional_values: dict[str, Any] = {} + positional_field_names = [ + name + for name, field in fields.items() + if _supported_scalar_type(field.annotation)[0] is not bool + ] + positional_index = 0 + index = 0 + while index < len(tokens): + token = tokens[index] + if not token.startswith("--"): + assigned = False + while positional_index < len(positional_field_names): + field_name = positional_field_names[positional_index] + positional_index += 1 + if field_name in explicit_values or field_name in positional_values: + continue + positional_values[field_name] = token + assigned = True + break + if not assigned: + raise _command_parse_error("Too many positional arguments") + index += 1 + continue + + raw_name = token[2:] + if not raw_name: + raise _command_parse_error("Invalid option '--'") + explicit_value: str | None = None + if "=" in raw_name: + raw_name, explicit_value = raw_name.split("=", 1) + negated = raw_name.startswith("no-") + # 与 argparse/click 惯例一致:--foo-bar 自动映射为字段名 foo_bar + cli_name = raw_name[3:] if negated else raw_name + field_name = cli_name.replace("-", "_") + field = fields.get(field_name) + if field is None: + raise _command_parse_error(f"Unknown option: --{raw_name}") + option_name = _format_option_name(field_name) + negated_option_name = f"--no-{option_name[2:]}" + if field_name in explicit_values: + raise _command_parse_error(f"Duplicate option: {option_name}") + field_type, _is_optional = _supported_scalar_type(field.annotation) + if field_type is bool: + if explicit_value is not None: + raise _command_parse_error( + f"Boolean option '{option_name}' only supports {option_name} or {negated_option_name}" + ) + explicit_values[field_name] = not negated + index += 1 + continue + if negated: + raise _command_parse_error( + f"Non-boolean option '{option_name}' does not support {negated_option_name}" + ) + if explicit_value is None: + index += 1 + if index >= len(tokens): + raise _command_parse_error(f"Missing value for option: {option_name}") + explicit_value = tokens[index] + explicit_values[field_name] = explicit_value + index += 1 + + values = {**positional_values, **explicit_values} + + try: + model = model_param.model_cls.model_validate(values) + except Exception as exc: + raise AstrBotError.invalid_input( + "命令参数解析失败", + hint=str(exc), + docs_url=COMMAND_MODEL_DOCS_URL, + details={ + "command": command_name, + "parameter": model_param.name, + "values": values, + }, + ) from exc + return CommandModelParseResult(model=model) + + +def format_command_model_help(command_name: str, model_cls: type[BaseModel]) -> str: + _validate_supported_model(model_cls) + lines = [f"用法: /{command_name} [options]"] + if model_cls.model_fields: + lines.append("参数:") + for name, field in model_cls.model_fields.items(): + field_type, is_optional = _supported_scalar_type(field.annotation) + type_name = getattr(field_type, "__name__", str(field_type)) + required = field.is_required() + default_text = "" + if not required: + default_text = f",默认 {field.default!r}" + elif is_optional: + default_text = ",默认 None" + description = str(field.description or "").strip() + detail = f"{name}: {type_name}" + if description: + detail += f" - {description}" + detail += ",必填" if required else ",可选" + detail += default_text + if field_type is bool: + detail += f",使用 --{name} / --no-{name}" + lines.append(detail) + return "\n".join(lines) + + +def _validate_supported_model(model_cls: type[BaseModel]) -> None: + for name, field in model_cls.model_fields.items(): + try: + _supported_scalar_type(field.annotation) + except TypeError as exc: + raise ValueError( + f"Unsupported command model field '{name}': {exc}" + ) from exc + + +def _supported_scalar_type(annotation: Any) -> tuple[type[Any], bool]: + normalized, is_optional = unwrap_optional(annotation) + if normalized in {str, int, float, bool}: + return normalized, is_optional + raise TypeError("only str/int/float/bool and Optional variants are supported") + + +def _format_option_name(field_name: str) -> str: + # Surface the canonical CLI spelling so parse errors match the user's option syntax. + return f"--{field_name.replace('_', '-')}" + + +def _command_parse_error(message: str) -> AstrBotError: + return AstrBotError.invalid_input( + message, + docs_url=COMMAND_MODEL_DOCS_URL, + ) + + +def _is_injected_parameter(name: str, annotation: Any) -> bool: + return is_framework_injected_parameter(name, annotation) + + +__all__ = [ + "COMMAND_MODEL_DOCS_URL", + "CommandModelParseResult", + "ResolvedCommandModelParam", + "format_command_model_help", + "parse_command_model_remainder", + "resolve_command_model_param", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py b/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py new file mode 100644 index 0000000000..6ddb942c29 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py @@ -0,0 +1,599 @@ +from __future__ import annotations + +import asyncio +import inspect +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any + +from pydantic import ValidationError + +from ..context import Context as RuntimeContext +from ..decorators import ( + BackgroundTaskMeta, + HttpApiMeta, + MCPServerMeta, + ValidateConfigMeta, + get_background_task_meta, + get_http_api_meta, + get_mcp_server_meta, + get_provider_change_meta, + get_skill_meta, + get_validate_config_meta, +) +from ..star import Star +from .sdk_logger import logger +from .star_runtime import bind_star_runtime + +_RUNTIME_STATE_ATTR = "__astrbot_decorator_runtime_state__" +_VALIDATED_CONFIGS_ATTR = "__astrbot_validated_configs__" + + +@dataclass(slots=True) +class DecoratorRuntimeState: + http_apis: list[tuple[str, list[str]]] = field(default_factory=list) + provider_hooks: list[asyncio.Task[None]] = field(default_factory=list) + background_tasks: list[asyncio.Task[Any]] = field(default_factory=list) + registered_skills: list[str] = field(default_factory=list) + local_mcp_servers: list[str] = field(default_factory=list) + global_mcp_servers: list[str] = field(default_factory=list) + + +def _runtime_state(instance: Any) -> DecoratorRuntimeState: + state = getattr(instance, _RUNTIME_STATE_ATTR, None) + if isinstance(state, DecoratorRuntimeState): + return state + state = DecoratorRuntimeState() + setattr(instance, _RUNTIME_STATE_ATTR, state) + return state + + +def _iter_bound_methods(instance: Any): + seen_names: set[str] = set() + for name in dir(instance.__class__): + if name.startswith("__") or name in seen_names: + continue + seen_names.add(name) + try: + raw_attr = inspect.getattr_static(instance, name) + except AttributeError: + continue + if isinstance(raw_attr, property): + continue + bound = getattr(instance, name, None) + if not callable(bound): + continue + raw = getattr(bound, "__func__", bound) + yield name, bound, raw + + +def _validated_config_store(instance: Any) -> dict[str, Any]: + values = getattr(instance, _VALIDATED_CONFIGS_ATTR, None) + if isinstance(values, dict): + return values + values = {} + setattr(instance, _VALIDATED_CONFIGS_ATTR, values) + return values + + +def _positional_arg_count(func: Any) -> int: + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + return 0 + return sum( + 1 + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ) + + +def _call_with_optional_context(bound: Any, context: RuntimeContext) -> Any: + return bound(context) if _positional_arg_count(bound) >= 1 else bound() + + +async def _await_if_needed(value: Any) -> Any: + if inspect.isawaitable(value): + return await value + return value + + +def _decorator_target_name(instance: Any, method_name: str | None = None) -> str: + class_name = instance.__class__.__name__ + if method_name is None: + return class_name + return f"{class_name}.{method_name}" + + +def _decorator_error( + *, + instance: Any, + decorator_name: str, + exc: Exception, + method_name: str | None = None, + details: str | None = None, +) -> RuntimeError: + message = f"{_decorator_target_name(instance, method_name)} {decorator_name} failed" + if details: + message += f" ({details})" + message += f": {exc}" + return RuntimeError(message) + + +def _http_api_details(meta: HttpApiMeta) -> str: + details = [f"route={meta.route!r}", f"methods={list(meta.methods)!r}"] + if meta.capability_name: + details.append(f"capability_name={meta.capability_name!r}") + return ", ".join(details) + + +def _provider_change_details(meta: Any) -> str: + return f"provider_types={list(meta.provider_types)!r}" + + +def _background_task_details(meta: BackgroundTaskMeta, method_name: str) -> str: + description = meta.description or f"background_task:{method_name}" + return ( + f"description={description!r}, auto_start={meta.auto_start!r}, " + f"on_error={meta.on_error!r}" + ) + + +def _mcp_server_details(meta: MCPServerMeta) -> str: + return ( + f"name={meta.name!r}, scope={meta.scope!r}, timeout={meta.timeout!r}, " + f"wait_until_ready={meta.wait_until_ready!r}" + ) + + +def _skill_details(name: str, path: str) -> str: + return f"name={name!r}, path={path!r}" + + +def _normalize_provider_type(value: Any) -> str: + enum_value = getattr(value, "value", None) + if isinstance(enum_value, str): + return enum_value.strip().lower() + return str(value).strip().lower() + + +def _is_valid_schema_expected_type(value: Any) -> bool: + if isinstance(value, type): + return True + return ( + isinstance(value, tuple) + and len(value) > 0 + and all(isinstance(item, type) for item in value) + ) + + +async def _run_model_validation( + *, + instance: Any, + method_name: str, + meta: ValidateConfigMeta, + config: dict[str, Any], +) -> None: + if meta.model is not None: + try: + validated = meta.model.model_validate(config) + except ValidationError as exc: + raise ValueError(str(exc)) from exc + _validated_config_store(instance)[method_name] = validated + return + + assert meta.schema is not None + validated = _validate_schema_config(meta.schema, config) + _validated_config_store(instance)[method_name] = validated + + +def _validate_schema_config( + schema: dict[str, Any], + config: dict[str, Any], +) -> dict[str, Any]: + validated: dict[str, Any] = {} + errors: list[str] = [] + + for field_name, field_schema in schema.items(): + if not isinstance(field_schema, dict): + errors.append(f"{field_name}: schema entry must be an object") + continue + present = field_name in config + value = config.get(field_name, field_schema.get("default")) + required = bool(field_schema.get("required", False)) + if value is None: + if required and "default" not in field_schema: + errors.append(f"{field_name}: is required") + validated[field_name] = value + continue + expected_type = field_schema.get("type") + if expected_type is not None and not _is_valid_schema_expected_type( + expected_type + ): + errors.append( + f"{field_name}: invalid schema 'type' entry {expected_type!r}; " + "expected a type or tuple of types" + ) + continue + if expected_type is not None and not isinstance(value, expected_type): + errors.append( + f"{field_name}: expected {getattr(expected_type, '__name__', expected_type)}, " + f"got {type(value).__name__}" + ) + continue + if isinstance(value, (int, float)) and not isinstance(value, bool): + minimum = field_schema.get("min") + maximum = field_schema.get("max") + range_value = field_schema.get("range") + if minimum is not None and value < minimum: + errors.append(f"{field_name}: must be >= {minimum}") + if maximum is not None and value > maximum: + errors.append(f"{field_name}: must be <= {maximum}") + if ( + isinstance(range_value, tuple) + and len(range_value) == 2 + and not (range_value[0] <= value <= range_value[1]) + ): + errors.append( + f"{field_name}: must be within [{range_value[0]}, {range_value[1]}]" + ) + if required and not present and "default" not in field_schema: + errors.append(f"{field_name}: is required") + validated[field_name] = value + + if errors: + raise ValueError("validate_config schema failed: " + "; ".join(errors)) + return validated + + +async def _run_validate_config(instance: Any, context: RuntimeContext) -> None: + config_payload = await context.metadata.get_plugin_config() + config = dict(config_payload or {}) + for method_name, _bound, raw in _iter_bound_methods(instance): + meta = get_validate_config_meta(raw) + if meta is None: + continue + try: + await _run_model_validation( + instance=instance, + method_name=method_name, + meta=meta, + config=config, + ) + except Exception as exc: + raise _decorator_error( + instance=instance, + method_name=method_name, + decorator_name="@validate_config", + exc=exc, + ) from exc + + +async def _register_http_apis(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for method_name, bound, raw in _iter_bound_methods(instance): + meta = get_http_api_meta(raw) + if meta is None: + continue + try: + await _register_http_api(bound=bound, meta=meta, context=context) + except Exception as exc: + raise _decorator_error( + instance=instance, + method_name=method_name, + decorator_name="@http_api", + details=_http_api_details(meta), + exc=exc, + ) from exc + state.http_apis.append((meta.route, list(meta.methods))) + + +async def _register_http_api( + *, + bound: Any, + meta: HttpApiMeta, + context: RuntimeContext, +) -> None: + if meta.capability_name: + await context.http.register_api( + route=meta.route, + handler_capability=meta.capability_name, + methods=list(meta.methods), + description=meta.description, + ) + return + await context.http.register_api( + route=meta.route, + handler=bound, + methods=list(meta.methods), + description=meta.description, + ) + + +async def _register_provider_change_hooks( + instance: Any, + context: RuntimeContext, +) -> None: + state = _runtime_state(instance) + for method_name, bound, raw in _iter_bound_methods(instance): + meta = get_provider_change_meta(raw) + if meta is None: + continue + target_name = _decorator_target_name(instance, method_name) + + async def callback( + provider_id: str, + provider_type: Any, + umo: str | None, + *, + _bound=bound, + _meta=meta, + ) -> None: + if _meta.provider_types: + current_type = _normalize_provider_type(provider_type) + if current_type not in _meta.provider_types: + return + owner = instance if isinstance(instance, Star) else None + try: + with bind_star_runtime(owner, context): + result = _bound(provider_id, provider_type, umo) + await _await_if_needed(result) + except Exception as exc: + raise RuntimeError( + f"{target_name} @on_provider_change callback failed " + f"(provider_id={provider_id!r}, provider_type={provider_type!r}, " + f"umo={umo!r}): {exc}" + ) from exc + + try: + task = await context.provider_manager.register_provider_change_hook( + callback + ) + except Exception as exc: + raise _decorator_error( + instance=instance, + method_name=method_name, + decorator_name="@on_provider_change", + details=_provider_change_details(meta), + exc=exc, + ) from exc + # TODO: provider.manager.watch_changes is currently restricted to + # reserved/system plugins. If this decorator should be public-facing, + # the capability boundary needs to be widened or a dedicated event feed + # should be introduced. + state.provider_hooks.append(task) + + +async def _start_background_tasks(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for method_name, bound, raw in _iter_bound_methods(instance): + meta = get_background_task_meta(raw) + if meta is None or not meta.auto_start: + continue + try: + task = await context.register_task( + _background_runner( + instance=instance, + bound=bound, + context=context, + meta=meta, + method_name=method_name, + ), + meta.description + or f"background_task:{instance.__class__.__name__}.{method_name}", + ) + except Exception as exc: + raise _decorator_error( + instance=instance, + method_name=method_name, + decorator_name="@background_task", + details=_background_task_details(meta, method_name), + exc=exc, + ) from exc + state.background_tasks.append(task) + + +async def _background_runner( + *, + instance: Any, + bound: Any, + context: RuntimeContext, + meta: BackgroundTaskMeta, + method_name: str, +) -> None: + while True: + try: + owner = instance if isinstance(instance, Star) else None + with bind_star_runtime(owner, context): + result = _call_with_optional_context(bound, context) + await _await_if_needed(result) + return + except asyncio.CancelledError: + raise + except Exception as exc: + if meta.on_error != "restart": + raise _decorator_error( + instance=instance, + method_name=method_name, + decorator_name="@background_task", + details=_background_task_details(meta, method_name), + exc=exc, + ) from exc + context.logger.exception( + "SDK decorator background_task restarting after failure: plugin_id={} task={} details={}", + context.plugin_id, + f"{instance.__class__.__name__}.{method_name}", + _background_task_details(meta, method_name), + ) + + +def _iter_class_and_method_meta_entries( + instance: Any, + getter, +) -> list[tuple[str, Any]]: + values = [ + (_decorator_target_name(instance), meta) for meta in getter(instance.__class__) + ] + for method_name, _bound, raw in _iter_bound_methods(instance): + values.extend( + (_decorator_target_name(instance, method_name), meta) + for meta in getter(raw) + ) + return values + + +async def _register_skills(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for target_name, meta in _iter_class_and_method_meta_entries( + instance, get_skill_meta + ): + try: + await context.register_skill( + name=meta.name, + path=meta.path, + description=meta.description, + ) + except Exception as exc: + raise RuntimeError( + f"{target_name} @register_skill failed " + f"({_skill_details(meta.name, meta.path)}): {exc}" + ) from exc + state.registered_skills.append(meta.name) + + +async def _register_mcp_servers(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for target_name, meta in _iter_class_and_method_meta_entries( + instance, get_mcp_server_meta + ): + try: + await _register_mcp_server(meta=meta, context=context) + except Exception as exc: + raise RuntimeError( + f"{target_name} @mcp_server failed ({_mcp_server_details(meta)}): {exc}" + ) from exc + if meta.scope == "global": + state.global_mcp_servers.append(meta.name) + else: + state.local_mcp_servers.append(meta.name) + + +async def _register_mcp_server( + *, + meta: MCPServerMeta, + context: RuntimeContext, +) -> None: + if meta.scope == "global": + if meta.config is None: + raise ValueError( + f"mcp_server(name={meta.name!r}, scope='global') requires config" + ) + await context.mcp.register_global_server( + meta.name, + dict(meta.config), + timeout=meta.timeout, + ) + return + + if meta.config not in (None, {}): + raise ValueError( + f"mcp_server(name={meta.name!r}, scope='local') does not support config registration" + ) + # TODO: local MCP only supports enable/disable of predeclared servers today. + # If the decorator is expected to register brand-new local servers, the MCP + # client/runtime needs a first-class local register/unregister API. + await context.mcp.enable_server(meta.name) + if meta.wait_until_ready: + await context.mcp.wait_until_ready(meta.name, timeout=meta.timeout) + + +async def _teardown_decorator_resources(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + + for task in reversed(state.provider_hooks): + with suppress(asyncio.CancelledError): + await context.provider_manager.unregister_provider_change_hook(task) + state.provider_hooks.clear() + + for task in reversed(state.background_tasks): + if not task.done(): + task.cancel() + for task in reversed(state.background_tasks): + with suppress(asyncio.CancelledError, Exception): + await task + state.background_tasks.clear() + + for route, methods in reversed(state.http_apis): + try: + await context.http.unregister_api(route, methods) + except Exception: + logger.exception( + "decorator http_api cleanup failed: plugin_id={} route={}", + context.plugin_id, + route, + ) + state.http_apis.clear() + + for name in reversed(state.registered_skills): + with suppress(Exception): + await context.unregister_skill(name) + state.registered_skills.clear() + + for name in reversed(state.local_mcp_servers): + with suppress(Exception): + await context.mcp.disable_server(name) + state.local_mcp_servers.clear() + + for name in reversed(state.global_mcp_servers): + with suppress(Exception): + await context.mcp.unregister_global_server(name) + state.global_mcp_servers.clear() + + +async def _invoke_hook( + *, + instance: Any, + hook: Any | None, + context: RuntimeContext, +) -> None: + if hook is None: + return + owner = instance if isinstance(instance, Star) else None + with bind_star_runtime(owner, context): + result = _call_with_optional_context(hook, context) + await _await_if_needed(result) + + +async def run_lifecycle_with_decorators( + *, + instance: Any, + hook: Any | None, + method_name: str, + context: RuntimeContext, +) -> None: + # Wrap decorator-managed startup failures with decorator-specific context so + # plugin authors do not only see a generic worker initialize timeout. + # Keep the lifecycle wrapper centralized so decorator-managed resources still + # work when plugins override on_start/on_stop without calling super(). + if method_name == "on_start": + await _run_validate_config(instance, context) + await _invoke_hook(instance=instance, hook=hook, context=context) + await _register_http_apis(instance, context) + await _register_provider_change_hooks(instance, context) + await _register_skills(instance, context) + await _register_mcp_servers(instance, context) + await _start_background_tasks(instance, context) + return + + try: + await _invoke_hook(instance=instance, hook=hook, context=context) + finally: + if method_name == "on_stop": + await _teardown_decorator_resources(instance, context) + + +__all__ = ["run_lifecycle_with_decorators"] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py b/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py new file mode 100644 index 0000000000..ced6229f93 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import functools +import inspect +from typing import Any + +try: + from typing import get_type_hints +except ImportError: # pragma: no cover + get_type_hints = None + +from .typing_utils import unwrap_optional + +_INJECTED_PARAMETER_NAMES = { + "event", + "ctx", + "context", + "sched", + "schedule", + "conversation", + "conv", +} + + +def is_framework_injected_parameter(name: str, annotation: Any) -> bool: + if name in _INJECTED_PARAMETER_NAMES: + return True + normalized, _is_optional = unwrap_optional(annotation) + if normalized is None: + return False + try: + injected_types = _framework_injected_types() + except Exception: + return False + if normalized in injected_types: + return True + if isinstance(normalized, type): + return issubclass(normalized, injected_types) + return False + + +def legacy_arg_parameter_names(handler: Any) -> list[str]: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + if get_type_hints is None: + type_hints = {} + else: + type_hints = get_type_hints(handler) + except Exception: + type_hints = {} + + names: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + if is_framework_injected_parameter( + parameter.name, type_hints.get(parameter.name) + ): + continue + names.append(parameter.name) + return names + + +@functools.lru_cache(maxsize=1) +def _framework_injected_types() -> tuple[type[Any], ...]: + from ..clients.llm import LLMResponse + from ..context import Context + from ..conversation import ConversationSession + from ..events import MessageEvent + from ..llm.entities import ProviderRequest + from ..message.result import MessageEventResult + from ..schedule import ScheduleContext + + return ( + Context, + MessageEvent, + ScheduleContext, + ConversationSession, + ProviderRequest, + LLMResponse, + MessageEventResult, + ) + + +__all__ = ["is_framework_injected_parameter", "legacy_arg_parameter_names"] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py b/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py new file mode 100644 index 0000000000..2fe2ec1d5e --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py @@ -0,0 +1,86 @@ +"""插件调用者身份上下文管理。 + +本模块使用 contextvars 实现跨异步任务传播插件身份, +用于在 capability 调用时自动识别调用者插件。 + +典型场景: + - http.register_api: 记录哪个插件注册了 API + - metadata.get_plugin_config: 只允许查询当前插件自己的配置 + - 能力路由层权限校验 + +使用方式: + with caller_plugin_scope("my_plugin"): + # 在此作用域内,current_caller_plugin_id() 返回 "my_plugin" + await ctx.http.register_api(...) + +注意: + contextvars 会自动传播到子任务(asyncio.create_task), + 无需手动传递。 +""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar, Token + +# 存储当前调用者插件 ID 的上下文变量 +_CALLER_PLUGIN_ID: ContextVar[str | None] = ContextVar( + "astrbot_sdk_caller_plugin_id", + default=None, +) + + +def current_caller_plugin_id() -> str | None: + """获取当前上下文中的调用者插件 ID。 + + Returns: + 当前插件 ID,如果不在插件调用上下文中则返回 None + """ + return _CALLER_PLUGIN_ID.get() + + +def bind_caller_plugin_id(plugin_id: str | None) -> Token[str | None]: + """绑定调用者插件 ID 到当前上下文。 + + Args: + plugin_id: 插件 ID,空字符串会被视为 None + + Returns: + 用于后续 reset 的 Token + + Note: + 通常使用 caller_plugin_scope 上下文管理器而非直接调用此函数 + """ + normalized = plugin_id.strip() if isinstance(plugin_id, str) else "" + return _CALLER_PLUGIN_ID.set(normalized or None) + + +def reset_caller_plugin_id(token: Token[str | None]) -> None: + """重置调用者插件 ID 到之前的状态。 + + Args: + token: bind_caller_plugin_id 返回的 Token + """ + _CALLER_PLUGIN_ID.reset(token) + + +@contextmanager +def caller_plugin_scope(plugin_id: str | None) -> Iterator[None]: + """创建一个绑定插件身份的上下文作用域。 + + Args: + plugin_id: 要绑定的插件 ID + + Yields: + None + + 示例: + with caller_plugin_scope("my_plugin"): + await some_capability_call() + """ + token = bind_caller_plugin_id(plugin_id) + try: + yield + finally: + reset_caller_plugin_id(token) diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py b/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py new file mode 100644 index 0000000000..d13720b500 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import json +import math +import re +from datetime import datetime, timedelta, timezone +from typing import Any + + +def is_ttl_memory_entry(value: Any) -> bool: + """Return whether a stored memory payload uses the TTL wrapper shape.""" + + return isinstance(value, dict) and "value" in value and "ttl_seconds" in value + + +def memory_value_for_search(stored: Any) -> dict[str, Any] | None: + """Unwrap the search payload from a stored memory record when possible.""" + + if not isinstance(stored, dict): + return None + if is_ttl_memory_entry(stored): + value = stored.get("value") + return value if isinstance(value, dict) else None + return stored + + +def extract_memory_text(stored: Any) -> str: + """Pick the canonical text that keyword/vector search should index.""" + + value = memory_value_for_search(stored) + if not isinstance(value, dict): + return "" + for field_name in ("embedding_text", "content", "summary", "title", "text"): + item = value.get(field_name) + if isinstance(item, str) and item.strip(): + return item.strip() + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str) + + +def memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None: + """Translate a TTL in seconds into an absolute UTC expiration timestamp.""" + + try: + ttl = int(ttl_seconds) + except (TypeError, ValueError): + return None + if ttl < 1: + return None + return datetime.now(timezone.utc) + timedelta(seconds=ttl) + + +def memory_expiration_from_stored_payload(stored: Any) -> datetime | None: + """Recover an absolute expiration timestamp from a stored TTL payload.""" + + if not is_ttl_memory_entry(stored) or not isinstance(stored, dict): + return None + raw_expires_at = stored.get("expires_at") + if isinstance(raw_expires_at, (int, float)): + return datetime.fromtimestamp(float(raw_expires_at), tz=timezone.utc) + if not isinstance(raw_expires_at, str): + return None + + normalized = raw_expires_at.strip() + if not normalized: + return None + if normalized.endswith("Z"): + normalized = f"{normalized[:-1]}+00:00" + try: + expires_at = datetime.fromisoformat(normalized) + except ValueError: + return None + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return expires_at.astimezone(timezone.utc) + + +def normalize_memory_namespace(value: Any) -> str: + """Normalize a namespace path into a stable slash-delimited string.""" + + if value is None: + return "" + if isinstance(value, (list, tuple)): + return join_memory_namespace(*value) + text = str(value).strip().replace("\\", "/") + if not text: + return "" + parts = [segment.strip() for segment in text.split("/") if segment.strip()] + return "/".join(parts) + + +def join_memory_namespace(*parts: Any) -> str: + """Join namespace segments while preserving the root namespace as empty.""" + + normalized_parts: list[str] = [] + for part in parts: + normalized = normalize_memory_namespace(part) + if not normalized: + continue + normalized_parts.extend( + segment for segment in normalized.split("/") if segment.strip() + ) + return "/".join(normalized_parts) + + +def memory_namespace_matches( + candidate: str, + namespace: str | None, + *, + include_descendants: bool, +) -> bool: + """Check whether a stored namespace belongs to the requested scope.""" + + if namespace is None: + return True + normalized_candidate = normalize_memory_namespace(candidate) + normalized_namespace = normalize_memory_namespace(namespace) + if not normalized_namespace: + return include_descendants or normalized_candidate == "" + if normalized_candidate == normalized_namespace: + return True + return include_descendants and normalized_candidate.startswith( + f"{normalized_namespace}/" + ) + + +def display_memory_namespace(value: Any) -> str | None: + """Return a user-facing namespace value.""" + + normalized = normalize_memory_namespace(value) + return normalized or None + + +def _memory_query_terms(value: str) -> list[str]: + normalized = re.sub(r"\s+", " ", str(value).strip().casefold()) + if not normalized: + return [] + terms = [item for item in re.findall(r"\w+", normalized, flags=re.UNICODE) if item] + if terms: + return terms + compact = normalized.replace(" ", "") + return [compact] if compact else [] + + +def memory_keyword_score(query: str, key: str, text: str) -> float: + """Score a keyword hit the same way across runtime and core bridge.""" + + normalized_query = str(query).casefold() + if not normalized_query: + return 1.0 + normalized_key = str(key).casefold() + normalized_text = str(text).casefold() + best = 0.0 + if normalized_query in normalized_key: + best = 1.0 + if normalized_query in normalized_text: + best = max(best, 0.92) + + terms = _memory_query_terms(normalized_query) + if not terms: + return best + + key_hits = sum(1 for term in terms if term in normalized_key) + text_hits = sum(1 for term in terms if term in normalized_text) + if key_hits: + best = max(best, 0.5 + 0.5 * (key_hits / len(terms))) + if text_hits: + best = max(best, 0.35 + 0.55 * (text_hits / len(terms))) + return min(best, 1.0) + + +def cosine_similarity(left: list[float], right: list[float]) -> float: + """Compute cosine similarity defensively for embedding vectors.""" + + if not left or not right or len(left) != len(right): + return 0.0 + left_norm = math.sqrt(sum(value * value for value in left)) + right_norm = math.sqrt(sum(value * value for value in right)) + if left_norm <= 0 or right_norm <= 0: + return 0.0 + return sum(a * b for a, b in zip(left, right, strict=False)) / ( + left_norm * right_norm + ) + + +def normalize_embedding(vector: list[float]) -> list[float]: + """Normalize an embedding for cosine/inner-product search.""" + + if not vector: + return [] + norm = math.sqrt(sum(value * value for value in vector)) + if norm <= 0: + return [0.0 for _ in vector] + return [float(value) / norm for value in vector] + + +def memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]: + """Normalize cached sidecar data into a stable memory index record.""" + + if isinstance(entry, dict): + return { + "text": str(entry.get("text", text)), + "embedding": ( + [float(item) for item in entry.get("embedding", [])] + if isinstance(entry.get("embedding"), list) + else None + ), + "provider_id": ( + str(entry.get("provider_id")).strip() + if entry.get("provider_id") is not None + else None + ), + } + return {"text": text, "embedding": None, "provider_id": None} diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py new file mode 100644 index 0000000000..471875e2fb --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +from pathlib import Path + +PLUGIN_ID_PATTERN = re.compile(r"^[A-Za-z0-9_](?:[A-Za-z0-9._-]{0,126}[A-Za-z0-9_])?$") +_WINDOWS_RESERVED_PLUGIN_IDS = { + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", +} + + +def validate_plugin_id(plugin_id: str) -> str: + normalized = str(plugin_id).strip() + if not normalized: + raise ValueError("plugin_id must not be empty") + if not PLUGIN_ID_PATTERN.fullmatch(normalized): + raise ValueError( + "plugin_id must use only letters, digits, dots, underscores, or hyphens" + ) + upper_normalized = normalized.upper() + base_name = upper_normalized.split(".", 1)[0] + if ( + upper_normalized in _WINDOWS_RESERVED_PLUGIN_IDS + or base_name in _WINDOWS_RESERVED_PLUGIN_IDS + ): + raise ValueError("plugin_id must not use a reserved Windows device name") + return normalized + + +def plugin_capability_prefix(plugin_id: str) -> str: + return f"{validate_plugin_id(plugin_id)}." + + +def capability_belongs_to_plugin(capability_name: str, plugin_id: str) -> bool: + return str(capability_name).strip().startswith(plugin_capability_prefix(plugin_id)) + + +def plugin_http_route_root(plugin_id: str) -> str: + return f"/{validate_plugin_id(plugin_id)}" + + +def http_route_belongs_to_plugin(route: str, plugin_id: str) -> bool: + normalized_route = str(route).strip() + route_root = plugin_http_route_root(plugin_id) + return normalized_route == route_root or normalized_route.startswith( + f"{route_root}/" + ) + + +def resolve_plugin_data_dir(root: Path, plugin_id: str) -> Path: + normalized = validate_plugin_id(plugin_id) + resolved_root = root.resolve() + candidate = (resolved_root / normalized).resolve() + try: + candidate.relative_to(resolved_root) + except ValueError as exc: + raise ValueError("plugin_id escapes the plugin data root") from exc + return candidate diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py new file mode 100644 index 0000000000..b89fb8dc18 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import asyncio +import inspect +import os +import time +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +try: + from astrbot.core.config.default import VERSION as _ASTRBOT_VERSION +except Exception: # noqa: BLE001 + _ASTRBOT_VERSION = "" + +__all__ = ["PluginLogEntry", "PluginLogger"] + + +@dataclass(slots=True) +class PluginLogEntry: + level: str + time: float + message: str + plugin_id: str + context: dict[str, Any] = field(default_factory=dict) + + +class _PluginLogBroker: + def __init__(self, plugin_id: str) -> None: + self.plugin_id = plugin_id + self._subscribers: set[asyncio.Queue[PluginLogEntry]] = set() + + def publish(self, entry: PluginLogEntry) -> None: + for queue in list(self._subscribers): + try: + queue.put_nowait(entry) + except asyncio.QueueFull: + continue + + async def watch(self) -> AsyncIterator[PluginLogEntry]: + queue: asyncio.Queue[PluginLogEntry] = asyncio.Queue() + self._subscribers.add(queue) + try: + while True: + yield await queue.get() + finally: + self._subscribers.discard(queue) + + +_BROKERS: dict[str, _PluginLogBroker] = {} + +_SHORT_LEVEL_NAMES = { + "DEBUG": "DBUG", + "INFO": "INFO", + "WARNING": "WARN", + "ERROR": "ERRO", + "CRITICAL": "CRIT", +} + +_ANSI_RESET = "\u001b[0m" +_ANSI_GREEN = "\u001b[32m" +_ANSI_LEVEL_COLORS = { + "DEBUG": "\u001b[1;34m", + "INFO": "\u001b[1;36m", + "WARNING": "\u001b[1;33m", + "ERROR": "\u001b[31m", + "CRITICAL": "\u001b[1;31m", +} + + +def _get_short_level_name(level_name: str) -> str: + return _SHORT_LEVEL_NAMES.get(level_name.upper(), level_name[:4].upper()) + + +def _build_source_file(pathname: str | None) -> str: + if not pathname: + return "unknown" + dirname = os.path.dirname(pathname) + return ( + os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "") + ) + + +def _plugin_tag_from_path(pathname: str | None) -> str: + if not pathname: + return "[Plug]" + norm_path = os.path.normpath(pathname) + if any( + marker in norm_path + for marker in ( + os.path.normpath("data/plugins"), + os.path.normpath("data/sdk_plugins"), + os.path.normpath("astrbot/builtin_stars"), + ) + ): + return "[Plug]" + return "[Core]" + + +def _level_color(level: str) -> str: + return _ANSI_LEVEL_COLORS.get(level.upper(), _ANSI_RESET) + + +def _get_broker(plugin_id: str) -> _PluginLogBroker: + broker = _BROKERS.get(plugin_id) + if broker is None: + broker = _PluginLogBroker(plugin_id) + _BROKERS[plugin_id] = broker + return broker + + +class PluginLogger: + def __init__( + self, + *, + plugin_id: str, + logger: Any, + bound_context: dict[str, Any] | None = None, + ) -> None: + self._plugin_id = plugin_id + self._logger = logger + self._broker = _get_broker(plugin_id) + self._bound_context = dict(bound_context or {}) + + @property + def plugin_id(self) -> str: + return self._plugin_id + + def bind(self, **kwargs: Any) -> PluginLogger: + bind = getattr(self._logger, "bind", None) + next_logger = self._logger + if callable(bind): + try: + next_logger = bind(**kwargs) + except Exception: + next_logger = self._logger + return PluginLogger( + plugin_id=self._plugin_id, + logger=next_logger, + bound_context={**self._bound_context, **kwargs}, + ) + + def opt(self, *args: Any, **kwargs: Any) -> PluginLogger: + opt = getattr(self._logger, "opt", None) + next_logger = self._logger + if callable(opt): + try: + next_logger = opt(*args, **kwargs) + except Exception: + next_logger = self._logger + return PluginLogger( + plugin_id=self._plugin_id, + logger=next_logger, + bound_context=self._bound_context, + ) + + async def watch(self) -> AsyncIterator[PluginLogEntry]: + async for entry in self._broker.watch(): + yield entry + + def log(self, level: str, message: Any, *args: Any, **kwargs: Any) -> None: + normalized_level = str(level).upper() + self._emit_console(normalized_level, message, *args, **kwargs) + self._publish(normalized_level, message, *args, **kwargs) + + def debug(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("DEBUG", message, *args, **kwargs) + self._publish("DEBUG", message, *args, **kwargs) + + def info(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("INFO", message, *args, **kwargs) + self._publish("INFO", message, *args, **kwargs) + + def warning(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("WARNING", message, *args, **kwargs) + self._publish("WARNING", message, *args, **kwargs) + + def error(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("ERROR", message, *args, **kwargs) + self._publish("ERROR", message, *args, **kwargs) + + def exception(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("ERROR", message, *args, exception=True, **kwargs) + self._publish("ERROR", message, *args, **kwargs) + + def _emit_console( + self, + level: str, + message: Any, + *args: Any, + exception: bool = False, + **kwargs: Any, + ) -> None: + if self._emit_console_with_opt( + level, + message, + *args, + exception=exception, + **kwargs, + ): + return + self._emit_console_fallback( + level, + message, + *args, + exception=exception, + **kwargs, + ) + + def _emit_console_with_opt( + self, + level: str, + message: Any, + *args: Any, + exception: bool = False, + **kwargs: Any, + ) -> bool: + opt = getattr(self._logger, "opt", None) + if not callable(opt): + return False + formatted_message = self._format_message(message, *args, **kwargs) + pathname, source_line = self._caller_info() + plugin_tag = _plugin_tag_from_path(pathname) + source_file = _build_source_file(pathname) + version_tag = ( + f" [v{_ASTRBOT_VERSION}]" + if _ASTRBOT_VERSION and level in {"WARNING", "ERROR", "CRITICAL"} + else "" + ) + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + level_text = _get_short_level_name(level) + level_color = _level_color(level) + line = ( + f"{_ANSI_GREEN}[{timestamp}]{_ANSI_RESET} {plugin_tag} " + f"{level_color}[{level_text}]{_ANSI_RESET}{version_tag} " + f"[{source_file}:{source_line}]: {level_color}{formatted_message}{_ANSI_RESET}" + ) + try: + emitter = opt(raw=True, exception=True) if exception else opt(raw=True) + log = getattr(emitter, "log", None) + if not callable(log): + return False + log(level, line + "\n") + return True + except Exception: + return False + + def _emit_console_fallback( + self, + level: str, + message: Any, + *args: Any, + exception: bool = False, + **kwargs: Any, + ) -> None: + method_names = [] + if exception: + method_names.append("exception") + method_names.append(str(level).lower()) + if exception: + method_names.append("error") + for method_name in method_names: + method = getattr(self._logger, method_name, None) + if not callable(method): + continue + try: + method(message, *args, **kwargs) + except Exception: + continue + return + log = getattr(self._logger, "log", None) + if callable(log): + try: + log(level, self._format_message(message, *args, **kwargs)) + except Exception: + return + + def _caller_info(self) -> tuple[str | None, int]: + frame = inspect.currentframe() + if frame is None: + return None, 0 + frame = frame.f_back + while frame is not None and frame.f_globals.get("__name__") == __name__: + frame = frame.f_back + if frame is None: + return None, 0 + return str(frame.f_code.co_filename), int(frame.f_lineno) + + def _publish(self, level: str, message: Any, *args: Any, **kwargs: Any) -> None: + entry = PluginLogEntry( + level=level, + time=time.time(), + message=self._format_message(message, *args, **kwargs), + plugin_id=self._plugin_id, + context=dict(self._bound_context), + ) + self._broker.publish(entry) + + @staticmethod + def _format_message(message: Any, *args: Any, **kwargs: Any) -> str: + if not isinstance(message, str): + return str(message) + text = message + if not args and not kwargs: + return text + try: + return text.format(*args, **kwargs) + except Exception: + return text + + def __getattr__(self, name: str) -> Any: + return getattr(self._logger, name) diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/sdk_logger.py b/astrbot-sdk/src/astrbot_sdk/_internal/sdk_logger.py new file mode 100644 index 0000000000..687926ffea --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/sdk_logger.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import os + +from loguru import logger as _raw_loguru_logger + +try: + from astrbot.core.config.default import VERSION as _ASTRBOT_VERSION +except Exception: # noqa: BLE001 + _ASTRBOT_VERSION = "" + +_SHORT_LEVEL_NAMES = { + "DEBUG": "DBUG", + "INFO": "INFO", + "WARNING": "WARN", + "ERROR": "ERRO", + "CRITICAL": "CRIT", +} + + +def _get_short_level_name(level_name: str) -> str: + return _SHORT_LEVEL_NAMES.get(level_name.upper(), level_name[:4].upper()) + + +def _build_source_file(pathname: str | None) -> str: + if not pathname: + return "unknown" + dirname = os.path.dirname(pathname) + return ( + os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "") + ) + + +def _patch_record(record: dict) -> None: + extra = record["extra"] + extra.setdefault("plugin_tag", "[Core]") + extra.setdefault("short_levelname", _get_short_level_name(record["level"].name)) + level_no = record["level"].no + version_tag = ( + f" [v{_ASTRBOT_VERSION}]" if _ASTRBOT_VERSION and level_no >= 30 else "" + ) + extra.setdefault("astrbot_version_tag", version_tag) + extra.setdefault("source_file", _build_source_file(record["file"].path)) + extra.setdefault("source_line", record["line"]) + extra.setdefault("is_trace", False) + + +logger = _raw_loguru_logger.patch(_patch_record) + +__all__ = ["logger"] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py b/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py new file mode 100644 index 0000000000..37211735e6 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..context import Context + from ..star import Star + + +_CURRENT_STAR_CONTEXT: ContextVar[Context | None] = ContextVar( + "astrbot_sdk_current_star_context", + default=None, +) +_CURRENT_STAR_INSTANCE: ContextVar[Star | None] = ContextVar( + "astrbot_sdk_current_star_instance", + default=None, +) + + +def current_star_context() -> Context | None: + return _CURRENT_STAR_CONTEXT.get() + + +def current_runtime_context() -> Context | None: + return _CURRENT_STAR_CONTEXT.get() + + +def current_star_instance() -> Star | None: + return _CURRENT_STAR_INSTANCE.get() + + +@contextmanager +def bind_star_runtime(star: Star | None, ctx: Context | None) -> Iterator[None]: + context_token = _CURRENT_STAR_CONTEXT.set(ctx) + star_token = _CURRENT_STAR_INSTANCE.set(star) + instance_token = star._bind_runtime_context(ctx) if star is not None else None + try: + yield + finally: + if star is not None and instance_token is not None: + star._reset_runtime_context(instance_token) + _CURRENT_STAR_INSTANCE.reset(star_token) + _CURRENT_STAR_CONTEXT.reset(context_token) diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py b/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py new file mode 100644 index 0000000000..05a550f824 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py @@ -0,0 +1,606 @@ +"""Shared support primitives for local SDK testing.""" + +from __future__ import annotations + +import asyncio +import typing +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, TextIO + +from ..context import CancelToken +from ..context import Context as RuntimeContext +from ..events import MessageEvent +from ..protocol.messages import EventMessage, PeerInfo +from ..runtime._streaming import StreamExecution +from ..runtime.capability_router import CapabilityRouter + + +def _clone_payload_mapping(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + return {str(key): item for key, item in value.items()} + + +@dataclass(slots=True) +class RecordedSend: + kind: str + message_id: str + session_id: str + text: str | None = None + image_url: str | None = None + chain: list[dict[str, Any]] | None = None + target: dict[str, Any] | None = None + raw: dict[str, Any] = field(default_factory=dict) + + @property + def session(self) -> str: + return self.session_id + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> RecordedSend: + if "text" in payload: + kind = "text" + elif "image_url" in payload: + kind = "image" + elif "chain" in payload: + kind = "chain" + else: + kind = "unknown" + return cls( + kind=kind, + message_id=str(payload.get("message_id", "")), + session_id=str(payload.get("session", "")), + text=payload.get("text") if isinstance(payload.get("text"), str) else None, + image_url=( + payload.get("image_url") + if isinstance(payload.get("image_url"), str) + else None + ), + chain=( + [dict(item) for item in payload.get("chain", [])] + if isinstance(payload.get("chain"), list) + else None + ), + target=_clone_payload_mapping(payload.get("target")), + raw=dict(payload), + ) + + +class StdoutPlatformSink: + def __init__(self, stream: TextIO | None = None) -> None: + self._stream = stream + self.records: list[RecordedSend] = [] + + def record(self, item: RecordedSend) -> None: + self.records.append(item) + if self._stream is None: + return + self._stream.write(self._format(item) + "\n") + self._stream.flush() + + def clear(self) -> None: + self.records.clear() + + def _format(self, item: RecordedSend) -> str: + if item.kind == "text": + return f"[text][{item.session_id}] {item.text or ''}" + if item.kind == "image": + return f"[image][{item.session_id}] {item.image_url or ''}" + if item.kind == "chain": + count = len(item.chain or []) + return f"[chain][{item.session_id}] {count} components" + return f"[send][{item.session_id}] {item.raw}" + + +class InMemoryDB: + def __init__(self, store: dict[str, Any]) -> None: + self._store = store + + def get(self, key: str, default: Any = None) -> Any: + return self._store.get(key, default) + + def set(self, key: str, value: Any) -> None: + self._store[key] = value + + def delete(self, key: str) -> None: + self._store.pop(key, None) + + def list(self, prefix: str | None = None) -> list[str]: + keys = sorted(self._store.keys()) + if prefix is None: + return keys + return [key for key in keys if key.startswith(prefix)] + + def get_many(self, keys: list[str]) -> list[dict[str, Any]]: + return [{"key": key, "value": self._store.get(key)} for key in keys] + + def set_many(self, items: list[dict[str, Any]]) -> None: + for item in items: + self.set(str(item.get("key", "")), item.get("value")) + + +class InMemoryMemory: + def __init__( + self, + store: dict[str, dict[str, Any]], + *, + expires_at: dict[str, datetime | None] | None = None, + ) -> None: + self._store = store + self._expires_at = expires_at if expires_at is not None else {} + + @staticmethod + def _is_ttl_entry(value: Any) -> bool: + """判断测试 memory 值是否使用 TTL 包装结构。 + + Args: + value: 待检查的存储值。 + + Returns: + bool: 如果包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。 + """ + return isinstance(value, dict) and "value" in value and "ttl_seconds" in value + + @classmethod + def _search_text(cls, value: Any) -> str: + """提取测试用 memory.search 的匹配文本。 + + Args: + value: 当前存储的 memory 值。 + + Returns: + str: 用于本地测试搜索的文本内容。 + """ + if cls._is_ttl_entry(value): + value = value.get("value") + if not isinstance(value, dict): + return "" + for field_name in ("embedding_text", "content", "summary", "title", "text"): + item = value.get(field_name) + if isinstance(item, str) and item.strip(): + return item.strip() + return str(value) + + def _is_expired(self, key: str) -> bool: + """判断测试 memory 键是否已经过期。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果当前时间已超过过期时间则返回 ``True``。 + """ + expires_at = self._expires_at.get(key) + return expires_at is not None and expires_at <= datetime.now(timezone.utc) + + def _purge_if_expired(self, key: str) -> bool: + """在测试 helper 中清理已过期的 memory 条目。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果条目已过期并被清理则返回 ``True``。 + """ + if not self._is_expired(key): + return False + self._store.pop(key, None) + self._expires_at.pop(key, None) + return True + + def get(self, key: str, default: Any = None) -> Any: + if self._purge_if_expired(key): + return default + return self._store.get(key, default) + + def save(self, key: str, value: dict[str, Any]) -> None: + self._store[key] = dict(value) + + def delete(self, key: str) -> None: + self._store.pop(key, None) + self._expires_at.pop(key, None) + + def search(self, query: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for key, value in list(self._store.items()): + if self._purge_if_expired(key): + continue + if query in key or query in self._search_text(value): + results.append({"key": key, "value": value}) + return results + + +class MockLLMClient: + def __init__(self, client: Any, router: MockCapabilityRouter) -> None: + self._client = client + self._router = router + + def mock_response(self, text: str) -> None: + self._router.enqueue_llm_response(text) + + def mock_stream_response(self, text: str) -> None: + self._router.enqueue_llm_stream_response(text) + + def clear_mock_responses(self) -> None: + self._router.clear_llm_responses() + + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + +class MockPlatformClient: + def __init__(self, client: Any, sink: StdoutPlatformSink) -> None: + self._client = client + self._sink = sink + + @property + def records(self) -> list[RecordedSend]: + return list(self._sink.records) + + def assert_sent( + self, + expected_text: str | None = None, + *, + kind: str = "text", + count: int | None = None, + ) -> None: + matched = [item for item in self._sink.records if item.kind == kind] + if expected_text is not None: + matched = [item for item in matched if item.text == expected_text] + if count is not None: + if len(matched) != count: + raise AssertionError( + f"expected {count} sent records, got {len(matched)}: {matched}" + ) + return + if not matched: + raise AssertionError( + f"expected sent record kind={kind!r} text={expected_text!r}, got {self._sink.records}" + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + +class MockCapabilityRouter(CapabilityRouter): + def __init__(self, *, platform_sink: StdoutPlatformSink | None = None) -> None: + self.platform_sink = platform_sink or StdoutPlatformSink() + self._llm_responses: list[str] = [] + self._llm_stream_responses: list[str] = [] + super().__init__() + self.db = InMemoryDB(self.db_store) + self.memory = InMemoryMemory( + self.memory_store, + expires_at=self._memory_expires_at, + ) + + def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]: + return super().list_dynamic_command_routes(plugin_id) + + def remove_dynamic_command_routes_for_plugin(self, plugin_id: str) -> None: + super().remove_dynamic_command_routes_for_plugin(plugin_id) + + def emit_provider_change( + self, + provider_id: str, + provider_type: str, + umo: str | None = None, + ) -> None: + super().emit_provider_change(provider_id, provider_type, umo) + + def record_platform_error( + self, + platform_id: str, + message: str, + *, + traceback: str | None = None, + ) -> None: + super().record_platform_error(platform_id, message, traceback=traceback) + + def set_platform_stats(self, platform_id: str, stats: dict[str, Any]) -> None: + super().set_platform_stats(platform_id, stats) + + def enqueue_llm_response(self, text: str) -> None: + self._llm_responses.append(text) + + def enqueue_llm_stream_response(self, text: str) -> None: + self._llm_stream_responses.append(text) + + def clear_llm_responses(self) -> None: + self._llm_responses.clear() + self._llm_stream_responses.clear() + + async def execute( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool, + cancel_token, + request_id: str, + ) -> dict[str, Any] | StreamExecution: + if capability == "llm.chat": + return {"text": self._take_llm_response(str(payload.get("prompt", "")))} + if capability == "llm.chat_raw": + text = self._take_llm_response(str(payload.get("prompt", ""))) + return { + "text": text, + "usage": { + "input_tokens": len(str(payload.get("prompt", ""))), + "output_tokens": len(text), + }, + "finish_reason": "stop", + "tool_calls": [], + "role": "assistant", + "reasoning_content": None, + "reasoning_signature": None, + } + if capability == "llm.stream_chat": + text = self._take_llm_stream_response(str(payload.get("prompt", ""))) + + async def iterator() -> typing.AsyncIterator[dict[str, Any]]: + for char in text: + cancel_token.raise_if_cancelled() + await asyncio.sleep(0) + yield {"text": char} + + return StreamExecution( + iterator=iterator(), + finalize=lambda chunks: { + "text": "".join(item.get("text", "") for item in chunks) + }, + ) + before = len(self.sent_messages) + result = await super().execute( + capability, + payload, + stream=stream, + cancel_token=cancel_token, + request_id=request_id, + ) + self._flush_platform_records(before) + return result + + def _flush_platform_records(self, start_index: int) -> None: + for payload in self.sent_messages[start_index:]: + self.platform_sink.record(RecordedSend.from_payload(payload)) + + def _take_llm_response(self, prompt: str) -> str: + if self._llm_responses: + return self._llm_responses.pop(0) + return f"Echo: {prompt}" + + def _take_llm_stream_response(self, prompt: str) -> str: + if self._llm_stream_responses: + return self._llm_stream_responses.pop(0) + if self._llm_responses: + return self._llm_responses.pop(0) + return f"Echo: {prompt}" + + +class MockPeer: + def __init__(self, router: MockCapabilityRouter) -> None: + self._router = router + self._counter = 0 + self.remote_peer = PeerInfo( + name="astrbot-local-core", + role="core", + version="local", + ) + self.remote_capabilities = list(router.all_descriptors()) + self.remote_capability_map = { + item.name: item for item in self.remote_capabilities + } + self.remote_handlers: list[Any] = [] + self.remote_provided_capabilities: list[Any] = [] + self.remote_metadata = {"mode": "local"} + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: + if stream: + raise ValueError("stream=True 请使用 invoke_stream()") + return typing.cast( + dict[str, Any], + await self._router.execute( + capability, + payload, + stream=False, + cancel_token=CancelToken(), + request_id=request_id or self._next_id(), + ), + ) + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + include_completed: bool = False, + ): + request_id = request_id or self._next_id() + execution = typing.cast( + StreamExecution, + await self._router.execute( + capability, + payload, + stream=True, + cancel_token=CancelToken(), + request_id=request_id, + ), + ) + + async def iterator(): + yield EventMessage.model_validate({"id": request_id, "phase": "started"}) + chunks: list[dict[str, Any]] = [] + async for chunk in execution.iterator: + if execution.collect_chunks: + chunks.append(chunk) + yield EventMessage.model_validate( + {"id": request_id, "phase": "delta", "data": chunk} + ) + output = execution.finalize(chunks) + if include_completed: + yield EventMessage.model_validate( + {"id": request_id, "phase": "completed", "output": output} + ) + + return iterator() + + def _next_id(self) -> str: + self._counter += 1 + return f"local_{self._counter:04d}" + + +def _normalize_plugin_metadata( + plugin_id: str, + plugin_metadata: Mapping[str, Any] | None, +) -> dict[str, Any]: + if plugin_metadata is None: + plugin_metadata = {} + declared_name = plugin_metadata.get("name") + if declared_name is not None and str(declared_name) != plugin_id: + raise ValueError( + "MockContext.plugin_metadata['name'] 必须与 plugin_id 一致," + f"当前收到 {declared_name!r} != {plugin_id!r}" + ) + description = plugin_metadata.get("description") + if description is None: + description = plugin_metadata.get("desc", "") + return { + "name": plugin_id, + "display_name": str(plugin_metadata.get("display_name") or plugin_id), + "description": str(description or ""), + "author": str(plugin_metadata.get("author") or ""), + "version": str(plugin_metadata.get("version") or "0.0.0"), + "enabled": bool(plugin_metadata.get("enabled", True)), + "reserved": bool(plugin_metadata.get("reserved", False)), + "acknowledge_global_mcp_risk": bool( + plugin_metadata.get("acknowledge_global_mcp_risk", False) + ), + "local_mcp_servers": ( + { + str(server_name): dict(server_payload) + for server_name, server_payload in plugin_metadata.get( + "local_mcp_servers", + {}, + ).items() + if str(server_name).strip() and isinstance(server_payload, dict) + } + if isinstance(plugin_metadata.get("local_mcp_servers"), dict) + else {} + ), + "support_platforms": [ + str(item) + for item in plugin_metadata.get("support_platforms", []) + if isinstance(item, str) + ] + if isinstance(plugin_metadata.get("support_platforms"), list) + else [], + "astrbot_version": ( + str(plugin_metadata.get("astrbot_version")) + if plugin_metadata.get("astrbot_version") is not None + else None + ), + } + + +class MockContext(RuntimeContext): + def __init__( + self, + *, + plugin_id: str = "test-plugin", + logger: Any | None = None, + cancel_token: CancelToken | None = None, + platform_sink: StdoutPlatformSink | None = None, + plugin_metadata: Mapping[str, Any] | None = None, + ) -> None: + self.platform_sink = platform_sink or StdoutPlatformSink() + self.router = MockCapabilityRouter(platform_sink=self.platform_sink) + self.mock_peer = MockPeer(self.router) + super().__init__( + peer=self.mock_peer, + plugin_id=plugin_id, + cancel_token=cancel_token, + logger=logger, + ) + self.router.upsert_plugin( + metadata=_normalize_plugin_metadata(plugin_id, plugin_metadata), + config={}, + ) + self.llm = MockLLMClient(self.llm, self.router) + self.platform = MockPlatformClient(self.platform, self.platform_sink) + + @property + def sent_messages(self) -> list[RecordedSend]: + return list(self.platform_sink.records) + + @property + def event_actions(self) -> list[dict[str, Any]]: + return list(self.router.event_actions) + + +class MockMessageEvent(MessageEvent): + def __init__( + self, + *, + text: str = "", + user_id: str | None = "test-user", + group_id: str | None = None, + platform: str | None = "test", + session_id: str | None = "test-session", + raw: dict[str, Any] | None = None, + context: MockContext | None = None, + ) -> None: + self.replies: list[str] = [] + super().__init__( + text=text, + user_id=user_id, + group_id=group_id, + platform=platform, + session_id=session_id, + raw=raw, + context=context, + ) + if context is not None: + self.bind_runtime_reply(context) + elif self._reply_handler is None: + self.bind_reply_handler(self._capture_reply) + + @property + def is_private(self) -> bool: + return self.group_id is None + + def bind_runtime_reply(self, context: MockContext) -> None: + self._context = context + + async def reply(text: str) -> None: + self.replies.append(text) + await context.platform.send(self.session_ref or self.session_id, text) + + self.bind_reply_handler(reply) + + async def _capture_reply(self, text: str) -> None: + self.replies.append(text) + + +__all__ = [ + "InMemoryDB", + "InMemoryMemory", + "MockCapabilityRouter", + "MockContext", + "MockLLMClient", + "MockMessageEvent", + "MockPeer", + "MockPlatformClient", + "RecordedSend", + "StdoutPlatformSink", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py b/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py new file mode 100644 index 0000000000..7cac7421ba --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import typing +from types import UnionType +from typing import Any + + +def unwrap_optional(annotation: Any) -> tuple[Any, bool]: + origin = typing.get_origin(annotation) + if origin in {typing.Union, UnionType}: + args = [item for item in typing.get_args(annotation) if item is not type(None)] + if len(args) == 1: + return args[0], True + return annotation, False + + +__all__ = ["unwrap_optional"] diff --git a/astrbot-sdk/src/astrbot_sdk/_memory_backend.py b/astrbot-sdk/src/astrbot_sdk/_memory_backend.py new file mode 100644 index 0000000000..50f94cbced --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_memory_backend.py @@ -0,0 +1,1515 @@ +from __future__ import annotations + +import asyncio +import json +import re +import sqlite3 +import threading +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, cast + +from ._internal.memory_utils import ( + cosine_similarity, + display_memory_namespace, + extract_memory_text, + join_memory_namespace, + memory_keyword_score, + memory_namespace_matches, + memory_value_for_search, + normalize_embedding, + normalize_memory_namespace, +) + + +def _utcnow() -> datetime: + # Centralize time access so expiry tests can advance time without mutating SQLite internals. + return datetime.now(timezone.utc) + + +def _sql_placeholders(count: int) -> str: + if count <= 0: + raise ValueError("count must be positive") + return ", ".join("?" for _ in range(count)) + + +def _normalize_scope_namespace(namespace: str | None) -> str | None: + if namespace is None: + return None + return normalize_memory_namespace(namespace) + + +def _escape_like_value(value: str) -> str: + return str(value).replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + +EmbedMany = Callable[[list[str]], Awaitable[list[list[float]]] | list[list[float]]] +EmbedOne = Callable[[str], Awaitable[list[float]] | list[float]] + + +@dataclass(slots=True) +class MemorySearchResult: + key: str + namespace: str + value: dict[str, Any] | None + score: float + match_type: str + + def to_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "key": self.key, + "value": self.value, + "score": self.score, + "match_type": self.match_type, + } + namespace = display_memory_namespace(self.namespace) + if namespace is not None: + payload["namespace"] = namespace + return payload + + +@dataclass(slots=True) +class _StoredRecord: + namespace: str + key: str + stored: dict[str, Any] + search_text: str + updated_at: str + + +@dataclass(slots=True) +class _VectorCandidate: + namespace: str + key: str + stored: dict[str, Any] + search_text: str + score: float + + +class PluginMemoryBackend: + """Persistent plugin-scoped memory backend with namespace-aware search.""" + + def __init__(self, data_dir: Path) -> None: + self._base_dir = Path(data_dir) / "memory" + self._db_path = self._base_dir / "memory.sqlite3" + self._vector_dir = self._base_dir / "vectors" + self._lock = threading.RLock() + self._initialized = False + self._fts_enabled = False + self._vector_indexes: dict[str, Any | None] = {} + self._vector_fallbacks: dict[str, list[tuple[int, list[float]]]] = {} + + async def save( + self, + key: str, + value: dict[str, Any], + *, + namespace: str | None = None, + ) -> None: + await asyncio.to_thread( + self._save_sync, + str(key), + dict(value), + normalize_memory_namespace(namespace), + None, + ) + + async def save_with_ttl( + self, + key: str, + value: dict[str, Any], + ttl_seconds: int, + *, + namespace: str | None = None, + ) -> None: + expires_at = _utcnow().timestamp() + max(int(ttl_seconds), 0) + await asyncio.to_thread( + self._save_sync, + str(key), + dict(value), + normalize_memory_namespace(namespace), + { + "ttl_seconds": int(ttl_seconds), + "expires_at": datetime.fromtimestamp( + expires_at, + tz=timezone.utc, + ).isoformat(), + }, + ) + + async def get( + self, + key: str, + *, + namespace: str | None = None, + ) -> dict[str, Any] | None: + return await asyncio.to_thread( + self._get_sync, + str(key), + normalize_memory_namespace(namespace), + ) + + async def list_keys( + self, + *, + namespace: str | None = None, + ) -> list[str]: + return await asyncio.to_thread( + self._list_keys_sync, + normalize_memory_namespace(namespace), + ) + + async def exists( + self, + key: str, + *, + namespace: str | None = None, + ) -> bool: + return await asyncio.to_thread( + self._exists_sync, + str(key), + normalize_memory_namespace(namespace), + ) + + async def get_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> list[dict[str, Any]]: + normalized_namespace = normalize_memory_namespace(namespace) + return await asyncio.to_thread( + self._get_many_sync, + [str(item) for item in keys], + normalized_namespace, + ) + + async def delete( + self, + key: str, + *, + namespace: str | None = None, + ) -> bool: + return await asyncio.to_thread( + self._delete_sync, + str(key), + normalize_memory_namespace(namespace), + ) + + async def clear_namespace( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + normalized_namespace = _normalize_scope_namespace(namespace) + return await asyncio.to_thread( + self._clear_namespace_sync, + normalized_namespace, + bool(include_descendants), + ) + + async def delete_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> int: + normalized_namespace = normalize_memory_namespace(namespace) + return await asyncio.to_thread( + self._delete_many_sync, + [str(item) for item in keys], + normalized_namespace, + ) + + async def count( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + normalized_namespace = _normalize_scope_namespace(namespace) + return await asyncio.to_thread( + self._count_sync, + normalized_namespace, + bool(include_descendants), + ) + + async def stats( + self, + *, + namespace: str | None = None, + include_descendants: bool = True, + ) -> dict[str, Any]: + normalized_namespace = _normalize_scope_namespace(namespace) + return await asyncio.to_thread( + self._stats_sync, + normalized_namespace, + bool(include_descendants), + ) + + async def search( + self, + query: str, + *, + namespace: str | None = None, + include_descendants: bool = True, + mode: str, + limit: int | None, + min_score: float | None, + provider_id: str | None = None, + embed_one: EmbedOne | None = None, + embed_many: EmbedMany | None = None, + ) -> list[dict[str, Any]]: + normalized_namespace = _normalize_scope_namespace(namespace) + normalized_mode = str(mode).strip().lower() or "keyword" + query_text = str(query) + + await asyncio.to_thread(self._purge_expired_sync) + + keyword_candidates = await asyncio.to_thread( + self._keyword_candidates_sync, + query_text, + normalized_namespace, + bool(include_descendants), + limit, + ) + + vector_candidates: list[_VectorCandidate] = [] + if normalized_mode in {"vector", "hybrid"} and provider_id: + await self._ensure_embeddings( + provider_id=provider_id, + namespace=normalized_namespace, + include_descendants=bool(include_descendants), + embed_one=embed_one, + embed_many=embed_many, + ) + if embed_one is not None: + raw_query_embedding = await _maybe_await(embed_one(query_text)) + query_embedding = normalize_embedding( + [float(item) for item in raw_query_embedding] + ) + vector_candidates = await asyncio.to_thread( + self._vector_candidates_sync, + provider_id, + query_embedding, + normalized_namespace, + bool(include_descendants), + limit, + ) + + merged: dict[tuple[str, str], dict[str, Any]] = {} + for record in keyword_candidates: + identity = (record.namespace, record.key) + merged[identity] = { + "namespace": record.namespace, + "key": record.key, + "stored": record.stored, + "keyword_score": memory_keyword_score( + query_text, + record.key, + record.search_text, + ), + "vector_score": 0.0, + } + for record in vector_candidates: + identity = (record.namespace, record.key) + current = merged.setdefault( + identity, + { + "namespace": record.namespace, + "key": record.key, + "stored": record.stored, + "keyword_score": memory_keyword_score( + query_text, + record.key, + record.search_text, + ), + "vector_score": 0.0, + }, + ) + current["vector_score"] = max( + float(current["vector_score"]), + float(record.score), + ) + + results: list[MemorySearchResult] = [] + for item in merged.values(): + keyword_score = max(0.0, float(item["keyword_score"])) + vector_score = max(0.0, float(item["vector_score"])) + score = self._combined_score( + mode=normalized_mode, + keyword_score=keyword_score, + vector_score=vector_score, + ) + if score <= 0: + continue + if min_score is not None and score < float(min_score): + continue + + if normalized_mode == "keyword" or ( + keyword_score > 0 and vector_score <= 0 + ): + match_type = "keyword" + elif normalized_mode == "vector" or keyword_score <= 0: + match_type = "vector" + else: + match_type = "hybrid" + + results.append( + MemorySearchResult( + key=str(item["key"]), + namespace=str(item["namespace"]), + value=memory_value_for_search(item["stored"]), + score=score, + match_type=match_type, + ) + ) + + results.sort(key=lambda item: (-item.score, item.namespace, item.key)) + if limit is not None and limit >= 0: + results = results[:limit] + return [item.to_payload() for item in results] + + async def _ensure_embeddings( + self, + *, + provider_id: str, + namespace: str | None, + include_descendants: bool, + embed_one: EmbedOne | None, + embed_many: EmbedMany | None, + ) -> None: + missing = await asyncio.to_thread( + self._missing_embeddings_sync, + provider_id, + namespace, + include_descendants, + ) + if missing: + texts = [record.search_text for record in missing] + embeddings: list[list[float]] + if embed_many is not None: + raw_embeddings = await _maybe_await(embed_many(texts)) + embeddings = [ + normalize_embedding([float(value) for value in item]) + for item in raw_embeddings + ] + elif embed_one is not None: + embeddings = [] + for text in texts: + raw_vector = await _maybe_await(embed_one(text)) + embeddings.append( + normalize_embedding([float(value) for value in raw_vector]) + ) + else: + embeddings = [] + await asyncio.to_thread( + self._upsert_embeddings_sync, + provider_id, + missing, + embeddings, + ) + await asyncio.to_thread(self._ensure_vector_index_sync, provider_id) + + def _save_sync( + self, + key: str, + value: dict[str, Any], + namespace: str, + ttl_metadata: dict[str, Any] | None, + ) -> None: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + stored = dict(value) + expires_at: str | None = None + if ttl_metadata is not None: + expires_at = str(ttl_metadata.get("expires_at", "")).strip() or None + stored = { + "value": dict(value), + "ttl_seconds": int(ttl_metadata.get("ttl_seconds", 0)), + } + if expires_at is not None: + stored["expires_at"] = expires_at + search_text = extract_memory_text(stored) + stored_json = json.dumps( + stored, + ensure_ascii=False, + sort_keys=True, + default=str, + ) + updated_at = _utcnow().isoformat() + conn.execute( + """ + INSERT INTO memory_records(namespace, key, stored_json, search_text, expires_at, updated_at) + VALUES(?, ?, ?, ?, ?, ?) + ON CONFLICT(namespace, key) DO UPDATE SET + stored_json = excluded.stored_json, + search_text = excluded.search_text, + expires_at = excluded.expires_at, + updated_at = excluded.updated_at + """, + (namespace, key, stored_json, search_text, expires_at, updated_at), + ) + self._sync_fts_row_locked( + conn, + namespace=namespace, + key=key, + search_text=search_text, + ) + provider_rows = conn.execute( + """ + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE namespace = ? AND key = ? + """, + (namespace, key), + ).fetchall() + conn.execute( + "DELETE FROM memory_embeddings WHERE namespace = ? AND key = ?", + (namespace, key), + ) + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + conn.commit() + finally: + conn.close() + + def _get_sync(self, key: str, namespace: str) -> dict[str, Any] | None: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + row = conn.execute( + """ + SELECT stored_json + FROM memory_records + WHERE namespace = ? AND key = ? + """, + (namespace, key), + ).fetchone() + if row is None: + return None + stored = self._load_stored_json(row[0]) + return memory_value_for_search(stored) + finally: + conn.close() + + def _list_keys_sync(self, namespace: str) -> list[str]: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + rows = conn.execute( + """ + SELECT key + FROM memory_records + WHERE namespace = ? + ORDER BY key COLLATE NOCASE ASC, key ASC + """, + (namespace,), + ).fetchall() + return [str(row[0]) for row in rows] + finally: + conn.close() + + def _exists_sync(self, key: str, namespace: str) -> bool: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + row = conn.execute( + """ + SELECT 1 + FROM memory_records + WHERE namespace = ? AND key = ? + LIMIT 1 + """, + (namespace, key), + ).fetchone() + return row is not None + finally: + conn.close() + + def _get_many_sync(self, keys: list[str], namespace: str) -> list[dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + if not keys: + return [] + lookup_keys = list(dict.fromkeys(keys)) + placeholders = _sql_placeholders(len(lookup_keys)) + rows = conn.execute( + f""" + SELECT key, stored_json + FROM memory_records + WHERE namespace = ? AND key IN ({placeholders}) + """, + (namespace, *lookup_keys), + ).fetchall() + stored_by_key = { + str(row[0]): self._load_stored_json(row[1]) for row in rows + } + return [ + { + "key": key, + "value": memory_value_for_search(stored_by_key.get(key)), + } + for key in keys + ] + finally: + conn.close() + + def _delete_sync(self, key: str, namespace: str) -> bool: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + deleted = self._delete_record_locked(conn, namespace=namespace, key=key) + conn.commit() + return deleted + finally: + conn.close() + + def _clear_namespace_sync( + self, + namespace: str | None, + include_descendants: bool, + ) -> int: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + deleted = self._delete_scope_locked( + conn, + namespace=namespace, + include_descendants=include_descendants, + ) + conn.commit() + return deleted + finally: + conn.close() + + def _delete_many_sync(self, keys: list[str], namespace: str) -> int: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + unique_keys = list(dict.fromkeys(keys)) + if not unique_keys: + conn.commit() + return 0 + placeholders = _sql_placeholders(len(unique_keys)) + provider_rows = conn.execute( + f""" + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE namespace = ? AND key IN ({placeholders}) + """, + (namespace, *unique_keys), + ).fetchall() + conn.execute( + f"DELETE FROM memory_embeddings WHERE namespace = ? AND key IN ({placeholders})", + (namespace, *unique_keys), + ) + deleted = conn.execute( + f"DELETE FROM memory_records WHERE namespace = ? AND key IN ({placeholders})", + (namespace, *unique_keys), + ).rowcount + if self._fts_enabled: + conn.execute( + f"DELETE FROM memory_records_fts WHERE namespace = ? AND key IN ({placeholders})", + (namespace, *unique_keys), + ) + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + conn.commit() + return deleted + finally: + conn.close() + + def _count_sync( + self, + namespace: str | None, + include_descendants: bool, + ) -> int: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + return int( + conn.execute( + f"SELECT COUNT(*) FROM memory_records WHERE {where_sql}", + params, + ).fetchone()[0] + ) + finally: + conn.close() + + def _stats_sync( + self, + namespace: str | None, + include_descendants: bool, + ) -> dict[str, Any]: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + total_items = int( + conn.execute( + f"SELECT COUNT(*) FROM memory_records WHERE {where_sql}", + params, + ).fetchone()[0] + ) + ttl_entries = int( + conn.execute( + f""" + SELECT COUNT(*) + FROM memory_records + WHERE {where_sql} AND expires_at IS NOT NULL + """, + params, + ).fetchone()[0] + ) + total_bytes = int( + conn.execute( + f""" + SELECT COALESCE(SUM(LENGTH(key) + LENGTH(stored_json)), 0) + FROM memory_records + WHERE {where_sql} + """, + params, + ).fetchone()[0] + ) + namespace_count = int( + conn.execute( + f""" + SELECT COUNT(DISTINCT namespace) + FROM memory_records + WHERE {where_sql} + """, + params, + ).fetchone()[0] + ) + embedding_where_sql, embedding_params = self._namespace_where( + namespace, + include_descendants=include_descendants, + alias="e", + ) + embedded_items = int( + conn.execute( + f""" + SELECT COUNT(*) + FROM ( + SELECT DISTINCT e.namespace, e.key + FROM memory_embeddings e + WHERE {embedding_where_sql} + ) + """, + embedding_params, + ).fetchone()[0] + ) + indexed_items = total_items + dirty_items = max(indexed_items - embedded_items, 0) + provider_rows = conn.execute( + """ + SELECT provider_id, dirty + FROM memory_vector_state + ORDER BY provider_id + """ + ).fetchall() + return { + "total_items": total_items, + "total_bytes": total_bytes, + "ttl_entries": ttl_entries, + "namespace": ( + None + if namespace is None + else normalize_memory_namespace(namespace) + ), + "namespace_count": namespace_count, + "indexed_items": indexed_items, + "embedded_items": embedded_items, + "dirty_items": dirty_items, + "fts_enabled": self._fts_enabled, + "vector_backend": self._vector_backend_label(), + "vector_indexes": [ + { + "provider_id": str(provider_id), + "dirty": bool(dirty), + } + for provider_id, dirty in provider_rows + ], + } + finally: + conn.close() + + def _keyword_candidates_sync( + self, + query: str, + namespace: str | None, + include_descendants: bool, + limit: int | None, + ) -> list[_StoredRecord]: + with self._lock: + conn = self._connect() + try: + fetch_limit = max((int(limit) if limit is not None else 10) * 8, 50) + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + seen: set[tuple[str, str]] = set() + records: list[_StoredRecord] = [] + fts_query = self._fts_query(query) + if self._fts_enabled and fts_query is not None: + fts_where_sql, fts_params = self._namespace_where( + namespace, + include_descendants=include_descendants, + alias="r", + ) + rows = conn.execute( + f""" + SELECT r.namespace, r.key, r.stored_json, r.search_text, r.updated_at + FROM memory_records_fts f + JOIN memory_records r + ON r.namespace = f.namespace AND r.key = f.key + WHERE {fts_where_sql} AND memory_records_fts MATCH ? + ORDER BY bm25(memory_records_fts), r.updated_at DESC + LIMIT ? + """, + (*fts_params, fts_query, fetch_limit), + ).fetchall() + for row in rows: + record = self._stored_record_from_row(row) + identity = (record.namespace, record.key) + if identity not in seen: + seen.add(identity) + records.append(record) + + like_query = f"%{str(query).strip()}%" + if not records or len(records) < fetch_limit: + rows = conn.execute( + f""" + SELECT namespace, key, stored_json, search_text, updated_at + FROM memory_records + WHERE {where_sql} + AND (? = '%%' OR key LIKE ? COLLATE NOCASE OR search_text LIKE ? COLLATE NOCASE) + ORDER BY updated_at DESC + LIMIT ? + """, + (*params, like_query, like_query, like_query, fetch_limit), + ).fetchall() + for row in rows: + record = self._stored_record_from_row(row) + identity = (record.namespace, record.key) + if identity not in seen: + seen.add(identity) + records.append(record) + return records + finally: + conn.close() + + def _missing_embeddings_sync( + self, + provider_id: str, + namespace: str | None, + include_descendants: bool, + ) -> list[_StoredRecord]: + with self._lock: + conn = self._connect() + try: + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + alias="r", + ) + rows = conn.execute( + f""" + SELECT r.namespace, r.key, r.stored_json, r.search_text, r.updated_at + FROM memory_records r + LEFT JOIN memory_embeddings e + ON e.namespace = r.namespace + AND e.key = r.key + AND e.provider_id = ? + WHERE {where_sql} AND e.id IS NULL + ORDER BY r.updated_at DESC + """, + (provider_id, *params), + ).fetchall() + return [self._stored_record_from_row(row) for row in rows] + finally: + conn.close() + + def _upsert_embeddings_sync( + self, + provider_id: str, + records: list[_StoredRecord], + embeddings: list[list[float]], + ) -> None: + if not records: + return + with self._lock: + conn = self._connect() + try: + for index, record in enumerate(records): + vector = embeddings[index] if index < len(embeddings) else [] + conn.execute( + """ + INSERT INTO memory_embeddings(namespace, key, provider_id, embedding_json, updated_at) + VALUES(?, ?, ?, ?, ?) + ON CONFLICT(namespace, key, provider_id) DO UPDATE SET + embedding_json = excluded.embedding_json, + updated_at = excluded.updated_at + """, + ( + record.namespace, + record.key, + provider_id, + json.dumps( + vector, ensure_ascii=False, separators=(",", ":") + ), + _utcnow().isoformat(), + ), + ) + self._mark_vector_dirty_locked(conn, provider_id) + conn.commit() + finally: + conn.close() + + def _vector_candidates_sync( + self, + provider_id: str, + query_embedding: list[float], + namespace: str | None, + include_descendants: bool, + limit: int | None, + ) -> list[_VectorCandidate]: + if not query_embedding: + return [] + with self._lock: + conn = self._connect() + try: + index = self._vector_indexes.get(provider_id) + fetch_limit = max((int(limit) if limit is not None else 10) * 10, 50) + if index is not None and self._faiss_available(): + return self._faiss_vector_candidates_locked( + conn=conn, + provider_id=provider_id, + query_embedding=query_embedding, + namespace=namespace, + include_descendants=include_descendants, + fetch_limit=fetch_limit, + ) + return self._fallback_vector_candidates_locked( + conn=conn, + provider_id=provider_id, + query_embedding=query_embedding, + namespace=namespace, + include_descendants=include_descendants, + fetch_limit=fetch_limit, + ) + finally: + conn.close() + + def _ensure_vector_index_sync(self, provider_id: str) -> None: + with self._lock: + conn = self._connect() + try: + self._init_storage_locked(conn) + row = conn.execute( + """ + SELECT dirty + FROM memory_vector_state + WHERE provider_id = ? + """, + (provider_id,), + ).fetchone() + dirty = True if row is None else bool(row[0]) + if not dirty and provider_id in self._vector_indexes: + return + + index_path = ( + self._vector_dir / f"{self._safe_filename(provider_id)}.faiss" + ) + if not dirty and index_path.exists() and self._faiss_available(): + try: + faiss = self._import_faiss() + self._vector_indexes[provider_id] = faiss.read_index( + str(index_path) + ) + self._vector_fallbacks.pop(provider_id, None) + return + except Exception: + pass + + rows = conn.execute( + """ + SELECT id, embedding_json + FROM memory_embeddings + WHERE provider_id = ? + ORDER BY id + """, + (provider_id,), + ).fetchall() + ids: list[int] = [] + vectors: list[list[float]] = [] + for raw_id, raw_vector in rows: + vector = self._load_embedding_json(raw_vector) + if not vector: + continue + ids.append(int(raw_id)) + vectors.append(vector) + + if self._faiss_available() and vectors: + faiss = self._import_faiss() + np = self._import_numpy() + dimension = len(vectors[0]) + base_index = faiss.IndexFlatIP(dimension) + index = faiss.IndexIDMap2(base_index) + index.add_with_ids( + np.array(vectors, dtype="float32"), + np.array(ids, dtype="int64"), + ) + self._vector_indexes[provider_id] = index + self._vector_fallbacks.pop(provider_id, None) + self._vector_dir.mkdir(parents=True, exist_ok=True) + faiss.write_index(index, str(index_path)) + else: + self._vector_indexes[provider_id] = None + self._vector_fallbacks[provider_id] = list( + zip(ids, vectors, strict=False) + ) + conn.execute( + """ + INSERT INTO memory_vector_state(provider_id, dirty, updated_at) + VALUES(?, 0, ?) + ON CONFLICT(provider_id) DO UPDATE SET + dirty = 0, + updated_at = excluded.updated_at + """, + (provider_id, _utcnow().isoformat()), + ) + conn.commit() + finally: + conn.close() + + def _faiss_vector_candidates_locked( + self, + *, + conn: sqlite3.Connection, + provider_id: str, + query_embedding: list[float], + namespace: str | None, + include_descendants: bool, + fetch_limit: int, + ) -> list[_VectorCandidate]: + index = self._vector_indexes.get(provider_id) + if index is None: + return [] + np = self._import_numpy() + total_count = int(getattr(index, "ntotal", 0) or 0) + if total_count <= 0: + return [] + + collected: list[_VectorCandidate] = [] + seen: set[tuple[str, str]] = set() + current_limit = min(fetch_limit, total_count) + while current_limit > 0: + scores, ids = index.search( + np.array([query_embedding], dtype="float32"), + current_limit, + ) + raw_ids = [int(item) for item in ids[0] if int(item) >= 0] + score_map = { + int(item_id): max(0.0, float(score)) + for item_id, score in zip(raw_ids, scores[0], strict=False) + } + if not score_map: + break + placeholders = ",".join("?" for _ in score_map) + rows = conn.execute( + f""" + SELECT e.id, r.namespace, r.key, r.stored_json, r.search_text + FROM memory_embeddings e + JOIN memory_records r + ON r.namespace = e.namespace AND r.key = e.key + WHERE e.provider_id = ? + AND e.id IN ({placeholders}) + """, + (provider_id, *score_map.keys()), + ).fetchall() + row_map = {int(row[0]): row for row in rows} + for item_id in raw_ids: + row = row_map.get(item_id) + if row is None: + continue + record_namespace = normalize_memory_namespace(row[1]) + if not memory_namespace_matches( + record_namespace, + namespace, + include_descendants=include_descendants, + ): + continue + identity = (record_namespace, str(row[2])) + if identity in seen: + continue + seen.add(identity) + collected.append( + _VectorCandidate( + namespace=record_namespace, + key=str(row[2]), + stored=self._load_stored_json(row[3]), + search_text=str(row[4]), + score=max(0.0, score_map.get(item_id, 0.0)), + ) + ) + if len(collected) >= fetch_limit or current_limit >= total_count: + break + next_limit = min(total_count, current_limit * 2) + if next_limit == current_limit: + break + current_limit = next_limit + return collected + + def _fallback_vector_candidates_locked( + self, + *, + conn: sqlite3.Connection, + provider_id: str, + query_embedding: list[float], + namespace: str | None, + include_descendants: bool, + fetch_limit: int, + ) -> list[_VectorCandidate]: + rows = conn.execute( + """ + SELECT e.namespace, e.key, e.embedding_json, r.stored_json, r.search_text + FROM memory_embeddings e + JOIN memory_records r + ON r.namespace = e.namespace AND r.key = e.key + WHERE e.provider_id = ? + """, + (provider_id,), + ).fetchall() + candidates: list[_VectorCandidate] = [] + for raw_namespace, raw_key, raw_embedding, raw_stored, raw_search_text in rows: + record_namespace = normalize_memory_namespace(raw_namespace) + if not memory_namespace_matches( + record_namespace, + namespace, + include_descendants=include_descendants, + ): + continue + embedding = self._load_embedding_json(raw_embedding) + score = max(0.0, cosine_similarity(query_embedding, embedding)) + if score <= 0: + continue + candidates.append( + _VectorCandidate( + namespace=record_namespace, + key=str(raw_key), + stored=self._load_stored_json(raw_stored), + search_text=str(raw_search_text), + score=score, + ) + ) + candidates.sort(key=lambda item: (-item.score, item.namespace, item.key)) + return candidates[:fetch_limit] + + def _purge_expired_sync(self) -> None: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + conn.commit() + finally: + conn.close() + + def _purge_expired_locked(self, conn: sqlite3.Connection) -> None: + self._init_storage_locked(conn) + now_iso = _utcnow().isoformat() + rows = conn.execute( + """ + SELECT namespace, key + FROM memory_records + WHERE expires_at IS NOT NULL AND expires_at <= ? + """, + (now_iso,), + ).fetchall() + for namespace, key in rows: + self._delete_record_locked( + conn, + namespace=normalize_memory_namespace(namespace), + key=str(key), + ) + + def _delete_record_locked( + self, + conn: sqlite3.Connection, + *, + namespace: str, + key: str, + ) -> bool: + provider_rows = conn.execute( + """ + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE namespace = ? AND key = ? + """, + (namespace, key), + ).fetchall() + conn.execute( + "DELETE FROM memory_embeddings WHERE namespace = ? AND key = ?", + (namespace, key), + ) + deleted = ( + conn.execute( + "DELETE FROM memory_records WHERE namespace = ? AND key = ?", + (namespace, key), + ).rowcount + > 0 + ) + if self._fts_enabled: + conn.execute( + "DELETE FROM memory_records_fts WHERE namespace = ? AND key = ?", + (namespace, key), + ) + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + return deleted + + def _delete_scope_locked( + self, + conn: sqlite3.Connection, + *, + namespace: str | None, + include_descendants: bool, + ) -> int: + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + affected_rows = conn.execute( + f""" + SELECT namespace, key + FROM memory_records + WHERE {where_sql} + """, + params, + ).fetchall() + if not affected_rows: + return 0 + + pair_placeholders = ", ".join("(?, ?)" for _ in affected_rows) + pair_params = tuple( + value + for raw_namespace, raw_key in affected_rows + for value in (normalize_memory_namespace(raw_namespace), str(raw_key)) + ) + + provider_rows = conn.execute( + f""" + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ).fetchall() + conn.execute( + f""" + DELETE FROM memory_embeddings + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ) + if self._fts_enabled: + conn.execute( + f""" + DELETE FROM memory_records_fts + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ) + deleted = conn.execute( + f""" + DELETE FROM memory_records + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ).rowcount + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + return deleted + + def _connect(self) -> sqlite3.Connection: + self._base_dir.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self._db_path) + conn.row_factory = sqlite3.Row + self._init_storage_locked(conn) + return conn + + def _init_storage_locked(self, conn: sqlite3.Connection) -> None: + if self._initialized: + return + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_records ( + namespace TEXT NOT NULL, + key TEXT NOT NULL, + stored_json TEXT NOT NULL, + search_text TEXT NOT NULL, + expires_at TEXT, + updated_at TEXT NOT NULL, + PRIMARY KEY(namespace, key) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memory_records_namespace + ON memory_records(namespace) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memory_records_expires_at + ON memory_records(expires_at) + """ + ) + try: + conn.execute( + """ + CREATE VIRTUAL TABLE IF NOT EXISTS memory_records_fts + USING fts5(namespace UNINDEXED, key, search_text, tokenize='unicode61') + """ + ) + self._fts_enabled = True + except sqlite3.OperationalError: + self._fts_enabled = False + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + provider_id TEXT NOT NULL, + embedding_json TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(namespace, key, provider_id) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memory_embeddings_provider + ON memory_embeddings(provider_id, namespace) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_vector_state ( + provider_id TEXT PRIMARY KEY, + dirty INTEGER NOT NULL DEFAULT 1, + updated_at TEXT NOT NULL + ) + """ + ) + conn.commit() + self._initialized = True + + def _sync_fts_row_locked( + self, + conn: sqlite3.Connection, + *, + namespace: str, + key: str, + search_text: str, + ) -> None: + if not self._fts_enabled: + return + conn.execute( + "DELETE FROM memory_records_fts WHERE namespace = ? AND key = ?", + (namespace, key), + ) + conn.execute( + """ + INSERT INTO memory_records_fts(namespace, key, search_text) + VALUES(?, ?, ?) + """, + (namespace, key, search_text), + ) + + def _mark_vector_dirty_locked( + self, + conn: sqlite3.Connection, + provider_id: str, + ) -> None: + conn.execute( + """ + INSERT INTO memory_vector_state(provider_id, dirty, updated_at) + VALUES(?, 1, ?) + ON CONFLICT(provider_id) DO UPDATE SET + dirty = 1, + updated_at = excluded.updated_at + """, + (provider_id, _utcnow().isoformat()), + ) + self._vector_indexes.pop(provider_id, None) + self._vector_fallbacks.pop(provider_id, None) + + @staticmethod + def _combined_score( + *, + mode: str, + keyword_score: float, + vector_score: float, + ) -> float: + if mode == "keyword": + return keyword_score + if mode == "vector": + return vector_score + if keyword_score > 0 and vector_score > 0: + return min(1.0, 0.65 * vector_score + 0.35 * keyword_score + 0.05) + if vector_score > 0: + return min(1.0, vector_score) + return min(1.0, keyword_score) + + @staticmethod + def _load_stored_json(raw_value: Any) -> dict[str, Any]: + if isinstance(raw_value, dict): + return dict(raw_value) + if isinstance(raw_value, str): + decoded = json.loads(raw_value) + return dict(decoded) if isinstance(decoded, dict) else {} + return {} + + @staticmethod + def _load_embedding_json(raw_value: Any) -> list[float]: + if isinstance(raw_value, list): + return [float(item) for item in raw_value] + if isinstance(raw_value, str): + decoded = json.loads(raw_value) + if isinstance(decoded, list): + return [float(item) for item in decoded] + return [] + + @staticmethod + def _stored_record_from_row(row: Any) -> _StoredRecord: + return _StoredRecord( + namespace=normalize_memory_namespace(row[0]), + key=str(row[1]), + stored=PluginMemoryBackend._load_stored_json(row[2]), + search_text=str(row[3]), + updated_at=str(row[4]), + ) + + @staticmethod + def _namespace_where( + namespace: str | None, + *, + include_descendants: bool, + alias: str | None = None, + ) -> tuple[str, tuple[Any, ...]]: + column = f"{alias}.namespace" if alias else "namespace" + if namespace is None: + return "1 = 1", () + normalized_namespace = normalize_memory_namespace(namespace) + if not normalized_namespace: + if include_descendants: + return "1 = 1", () + return f"{column} = ''", () + if include_descendants: + escaped_namespace = _escape_like_value(normalized_namespace) + return ( + f"({column} = ? OR {column} LIKE ? ESCAPE '\\')", + (normalized_namespace, f"{escaped_namespace}/%"), + ) + return f"{column} = ?", (normalized_namespace,) + + @staticmethod + def _fts_query(query: str) -> str | None: + stripped = str(query).strip() + if not stripped: + return None + terms = [ + item for item in re.findall(r"\w+", stripped, flags=re.UNICODE) if item + ] + if not terms: + return None + escaped_terms = [term.replace('"', '""') for term in terms[:8]] + return " OR ".join(f'"{term}"' for term in escaped_terms) + + @staticmethod + def _safe_filename(value: str) -> str: + return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(value)).strip("._") or "default" + + @staticmethod + def _import_faiss() -> Any: + # FAISS often ships without stable type stubs, so keep the lazy import + # boundary explicitly dynamic to avoid false-positive Pylance errors. + import faiss + + return cast(Any, faiss) + + @staticmethod + def _import_numpy(): + import numpy + + return numpy + + @classmethod + def _faiss_available(cls) -> bool: + try: + faiss = cls._import_faiss() + cls._import_numpy() + except Exception: + return False + required_attrs = ( + "IndexFlatIP", + "IndexIDMap2", + "read_index", + "write_index", + ) + return all(hasattr(faiss, attr) for attr in required_attrs) + + def _vector_backend_label(self) -> str: + return "faiss" if self._faiss_available() else "exact" + + +async def _maybe_await(value: Any) -> Any: + if asyncio.iscoroutine(value) or isinstance(value, asyncio.Future): + return await value + return value + + +def extend_memory_namespace( + base_namespace: str | None, + extra_namespace: str | None, +) -> str: + """Join a base namespace with a relative namespace override.""" + + return join_memory_namespace(base_namespace, extra_namespace) diff --git a/astrbot-sdk/src/astrbot_sdk/_message_types.py b/astrbot-sdk/src/astrbot_sdk/_message_types.py new file mode 100644 index 0000000000..1d2df56040 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_message_types.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +_GROUP_MESSAGE_TYPES = {"group", "groupmessage", "group_message"} +_PRIVATE_MESSAGE_TYPES = { + "private", + "privatemessage", + "private_message", + "friend", + "friendmessage", + "friend_message", +} +_OTHER_MESSAGE_TYPES = {"other", "othermessage", "other_message"} + + +def normalize_message_type( + value: Any, + *, + group_id: str | None = None, + user_id: str | None = None, + empty_default: str = "", +) -> str: + """Collapse SDK-visible message types to canonical values.""" + + normalized = str(getattr(value, "value", value) or "").strip().lower() + if normalized in _GROUP_MESSAGE_TYPES: + return "group" + if normalized in _PRIVATE_MESSAGE_TYPES: + return "private" + if normalized in _OTHER_MESSAGE_TYPES: + return "other" + if group_id: + return "group" + if user_id: + return "private" + if not normalized: + return empty_default + return "other" diff --git a/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py b/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py new file mode 100644 index 0000000000..5d2a3d9b17 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py @@ -0,0 +1,3 @@ +from ._internal.plugin_logger import PluginLogEntry, PluginLogger + +__all__ = ["PluginLogEntry", "PluginLogger"] diff --git a/astrbot-sdk/src/astrbot_sdk/_star_runtime.py b/astrbot-sdk/src/astrbot_sdk/_star_runtime.py new file mode 100644 index 0000000000..d6d9fe215d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_star_runtime.py @@ -0,0 +1,13 @@ +from ._internal.star_runtime import ( + bind_star_runtime, + current_runtime_context, + current_star_context, + current_star_instance, +) + +__all__ = [ + "bind_star_runtime", + "current_runtime_context", + "current_star_context", + "current_star_instance", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_testing_support.py b/astrbot-sdk/src/astrbot_sdk/_testing_support.py new file mode 100644 index 0000000000..1e945e8e06 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_testing_support.py @@ -0,0 +1,25 @@ +from ._internal.testing_support import ( + InMemoryDB, + InMemoryMemory, + MockCapabilityRouter, + MockContext, + MockLLMClient, + MockMessageEvent, + MockPeer, + MockPlatformClient, + RecordedSend, + StdoutPlatformSink, +) + +__all__ = [ + "InMemoryDB", + "InMemoryMemory", + "MockCapabilityRouter", + "MockContext", + "MockLLMClient", + "MockMessageEvent", + "MockPeer", + "MockPlatformClient", + "RecordedSend", + "StdoutPlatformSink", +] diff --git a/astrbot-sdk/src/astrbot_sdk/cli.py b/astrbot-sdk/src/astrbot_sdk/cli.py new file mode 100644 index 0000000000..7977bbcc71 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/cli.py @@ -0,0 +1,1512 @@ +"""AstrBot SDK 的命令行入口。 + +本模块提供 astrbot-sdk 命令行工具的所有子命令,包括: +- init: 创建新插件骨架,生成 plugin.yaml、main.py、README.md 等模板文件 +- validate: 校验插件清单、导入路径和 handler 发现是否正常 +- build: 将插件打包为 .zip 发布包 +- dev: 本地开发模式,支持 --local/--watch/--interactive 等调试选项 +- run: 启动插件主管进程(supervisor),通过 stdio 与 AstrBot 核心通信 +- worker: 内部命令,由 supervisor 调用以启动单个插件工作进程 + +错误处理: +所有 CLI 异常都会被分类并返回标准化的退出码和错误提示, +便于 CI/CD 集成和用户快速定位问题。 +""" + +from __future__ import annotations + +import asyncio +import importlib.resources as resources +import os +import re +import sys +import typing +import zipfile +from collections.abc import Coroutine +from dataclasses import dataclass, field +from importlib.resources.abc import Traversable +from pathlib import Path +from textwrap import dedent +from typing import Any + +import click + +from ._internal.sdk_logger import logger +from .errors import AstrBotError +from .runtime.bootstrap import run_plugin_worker, run_supervisor, run_websocket_server +from .runtime.loader import load_plugin, load_plugin_spec, validate_plugin_spec + +EXIT_OK = 0 +EXIT_UNEXPECTED = 1 +EXIT_USAGE = 2 +EXIT_PLUGIN_LOAD = 3 +EXIT_RUNTIME = 4 +EXIT_PLUGIN_EXECUTION = 5 +BUILD_EXCLUDED_DIRS = { + ".agents", + ".claude", + ".git", + ".idea", + ".mypy_cache", + ".opencode", + ".pytest_cache", + ".ruff_cache", + ".venv", + "__pycache__", + "dist", +} +BUILD_EXCLUDED_FILES = { + "AGENTS.md", + "CLAUDE.md", + ".astrbot-worker-state.json", +} +WATCH_POLL_INTERVAL_SECONDS = 0.5 +SUPPORTED_INIT_AGENTS = ("claude", "codex", "opencode") +_TEMPLATE_VARIABLE_PATTERN = re.compile(r"{{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*}}") +INIT_AGENT_SKILL_ROOTS = { + "claude": Path(".claude") / "skills", + "codex": Path(".agents") / "skills", + "opencode": Path(".opencode") / "skills", +} +INIT_AGENT_DISPLAY_NAMES = { + "claude": "Claude Code", + "codex": "Codex", + "opencode": "OpenCode", +} +INIT_SKILL_TEMPLATE_NAME = "astrbot-plugin-dev" +INIT_PROJECT_NOTE_TEMPLATE_DIR = ("templates", "project_notes") +INIT_PROJECT_NOTE_TEMPLATE_NAMES = ("AGENTS.md", "CLAUDE.md") + + +class _CliPluginValidationError(RuntimeError): + """CLI 侧的插件结构或打包校验失败。""" + + +class _CliPluginLoadError(RuntimeError): + """CLI 侧的本地开发插件加载失败。""" + + +class _CliPluginExecutionError(RuntimeError): + """CLI 侧的本地开发插件执行失败。""" + + +@dataclass(slots=True) +class _PluginTreeWatcher: + plugin_dir: Path + snapshot: dict[str, tuple[int, int]] = field(init=False, default_factory=dict) + + def __post_init__(self) -> None: + self.snapshot = _snapshot_watch_files(self.plugin_dir) + + def poll_changes(self) -> list[str]: + current = _snapshot_watch_files(self.plugin_dir) + changed = sorted( + path + for path in set(self.snapshot) | set(current) + if self.snapshot.get(path) != current.get(path) + ) + self.snapshot = current + return changed + + +def setup_logger(verbose: bool = False) -> None: + """初始化 CLI 使用的日志配置。""" + logger.remove() + logger.add( + sys.stderr, + format="{time:HH:mm:ss} | {level: <8} | {message}", + level="DEBUG" if verbose else "INFO", + colorize=True, + ) + + +def _resolve_protocol_stdout( + protocol_stdout: str | None, +) -> tuple[typing.TextIO, typing.TextIO | None]: + configured = str(protocol_stdout).strip() if protocol_stdout is not None else "" + if not configured: + stdout = sys.stdout + if callable(getattr(stdout, "isatty", None)) and stdout.isatty(): + opened_stdout = open(os.devnull, "w", encoding="utf-8") + return opened_stdout, opened_stdout + return stdout, None + if configured.lower() == "console": + return sys.stdout, None + output_path = os.devnull if configured.lower() == "silent" else configured + opened_stdout = open(output_path, "w", encoding="utf-8") + return opened_stdout, opened_stdout + + +def _run_async_entrypoint( + entrypoint: Coroutine[Any, Any, object], + *, + log_message: str, + log_level: str = "info", + context: dict[str, Any] | None = None, +) -> None: + log_method = getattr(logger, log_level) + log_method(log_message) + try: + asyncio.run(entrypoint) + except (click.Abort, KeyboardInterrupt): + click.echo("\n创建插件已优雅地中断。", err=True) + raise SystemExit(130) + except Exception as exc: + exit_code, error_code, hint = _classify_cli_exception(exc) + docs_url = exc.docs_url if isinstance(exc, AstrBotError) else "" + details = exc.details if isinstance(exc, AstrBotError) else None + _render_cli_error( + error_code=error_code, + message=str(exc), + hint=hint, + docs_url=docs_url, + details=details, + context=context, + ) + if exit_code == EXIT_UNEXPECTED: + logger.exception("CLI 异常退出") + raise SystemExit(exit_code) from exc + + +def _run_sync_entrypoint( + entrypoint: typing.Callable[[], object], + *, + log_message: str, + log_level: str = "info", + context: dict[str, Any] | None = None, +) -> None: + log_method = getattr(logger, log_level) + log_method(log_message) + try: + entrypoint() + except (click.Abort, KeyboardInterrupt): + click.echo("\n创建插件已优雅地中断。", err=True) + raise SystemExit(130) + except Exception as exc: + exit_code, error_code, hint = _classify_cli_exception(exc) + docs_url = exc.docs_url if isinstance(exc, AstrBotError) else "" + details = exc.details if isinstance(exc, AstrBotError) else None + _render_cli_error( + error_code=error_code, + message=str(exc), + hint=hint, + docs_url=docs_url, + details=details, + context=context, + ) + if exit_code == EXIT_UNEXPECTED: + logger.exception("CLI 异常退出") + raise SystemExit(exit_code) from exc + + +def _classify_cli_exception(exc: Exception) -> tuple[int, str, str]: + if isinstance(exc, AstrBotError): + return ( + EXIT_RUNTIME, + exc.code, + exc.hint or "请检查本地 mock core 与插件调用参数", + ) + if isinstance( + exc, + ( + _CliPluginValidationError, + _CliPluginLoadError, + FileNotFoundError, + ImportError, + ModuleNotFoundError, + ), + ): + return ( + EXIT_PLUGIN_LOAD, + "plugin_load_error", + "请检查插件目录、plugin.yaml、requirements.txt(如有)和导入路径", + ) + if isinstance(exc, LookupError): + return ( + EXIT_RUNTIME, + "dispatch_error", + "请检查 handler 或 capability 是否已正确注册", + ) + if isinstance(exc, _CliPluginExecutionError): + return ( + EXIT_PLUGIN_EXECUTION, + "plugin_execution_error", + "请检查插件生命周期、handler 或 capability 的实现", + ) + return ( + EXIT_UNEXPECTED, + "unexpected_error", + "请查看详细日志,必要时使用 --verbose 重试", + ) + + +def _render_cli_error( + *, + error_code: str, + message: str, + hint: str = "", + docs_url: str = "", + details: dict[str, Any] | None = None, + context: dict[str, Any] | None = None, +) -> None: + click.echo(f"Error[{error_code}]: {message}", err=True) + if hint: + click.echo(f"Suggestion: {hint}", err=True) + if docs_url: + click.echo(f"Docs: {docs_url}", err=True) + if details: + click.echo(f"Details: {details}", err=True) + if not context: + return + for key, value in context.items(): + click.echo(f"{key}: {value}", err=True) + + +def _render_nonfatal_dev_error( + exc: Exception, + *, + context: dict[str, Any] | None = None, +) -> None: + exit_code, error_code, hint = _classify_cli_exception(exc) + _render_cli_error( + error_code=error_code, + message=str(exc), + hint=hint, + context=context, + ) + if exit_code == EXIT_UNEXPECTED: + logger.exception("watch 模式收到未分类异常") + + +def _iter_watch_files(plugin_dir: Path) -> typing.Iterator[Path]: + root = plugin_dir.resolve() + for path in sorted(root.rglob("*")): + if path.is_dir(): + continue + relative = path.relative_to(root) + if any(part in BUILD_EXCLUDED_DIRS for part in relative.parts[:-1]): + continue + if relative.name in BUILD_EXCLUDED_FILES: + continue + if path.suffix in {".pyc", ".pyo"}: + continue + yield path + + +def _snapshot_watch_files(plugin_dir: Path) -> dict[str, tuple[int, int]]: + root = plugin_dir.resolve() + snapshot: dict[str, tuple[int, int]] = {} + for path in _iter_watch_files(root): + try: + stat = path.stat() + except FileNotFoundError: + continue + snapshot[path.relative_to(root).as_posix()] = ( + stat.st_mtime_ns, + stat.st_size, + ) + return snapshot + + +def _format_watch_changes(changes: list[str], *, limit: int = 5) -> str: + if not changes: + return "未知文件" + preview = changes[:limit] + text = ", ".join(preview) + if len(changes) > limit: + text += f" 等 {len(changes)} 个文件" + return text + + +class _ReloadableLocalDevRunner: + def __init__( + self, + *, + plugin_dir: Path, + state: dict[str, Any], + plugin_load_error: type[Exception], + plugin_execution_error: type[Exception], + plugin_harness, + stdout_platform_sink, + ) -> None: + self.plugin_dir = plugin_dir + self.state = state + self._plugin_load_error = plugin_load_error + self._plugin_execution_error = plugin_execution_error + self._plugin_harness = plugin_harness + self._stdout_platform_sink = stdout_platform_sink + self._harness = None + self._lock = asyncio.Lock() + + async def close(self) -> None: + async with self._lock: + await self._stop_harness() + + async def reload(self) -> bool: + async with self._lock: + await self._stop_harness() + harness = self._plugin_harness.from_plugin_dir( + self.plugin_dir, + session_id=str(self.state["session_id"]), + user_id=str(self.state["user_id"]), + platform=str(self.state["platform"]), + group_id=typing.cast(str | None, self.state["group_id"]), + event_type=str(self.state["event_type"]), + platform_sink=self._stdout_platform_sink(stream=sys.stdout), + ) + try: + await harness.start() + except self._plugin_load_error as exc: + _render_nonfatal_dev_error( + _CliPluginLoadError(str(exc)), + context={"plugin_dir": self.plugin_dir}, + ) + return False + except self._plugin_execution_error as exc: + _render_nonfatal_dev_error( + _CliPluginExecutionError(str(exc)), + context={"plugin_dir": self.plugin_dir}, + ) + return False + self._harness = harness + return True + + async def dispatch_text(self, text: str) -> bool: + async with self._lock: + if self._harness is None: + click.echo("当前插件未成功加载,等待下一次文件变更后重试。") + return False + try: + await self._harness.dispatch_text( + text, + session_id=str(self.state["session_id"]), + user_id=str(self.state["user_id"]), + platform=str(self.state["platform"]), + group_id=typing.cast(str | None, self.state["group_id"]), + event_type=str(self.state["event_type"]), + ) + except (self._plugin_load_error, self._plugin_execution_error) as exc: + _render_nonfatal_dev_error( + _CliPluginExecutionError(str(exc)), + context={"plugin_dir": self.plugin_dir}, + ) + return False + except Exception as exc: + _render_nonfatal_dev_error( + exc, + context={"plugin_dir": self.plugin_dir}, + ) + return False + return True + + async def _stop_harness(self) -> None: + if self._harness is None: + return + try: + await self._harness.stop() + finally: + self._harness = None + + +async def _run_local_dev_watch( + *, + runner: _ReloadableLocalDevRunner, + event_text: str | None, + interactive: bool, + watch_poll_interval: float, + max_watch_reloads: int | None = None, +) -> None: + watcher = _PluginTreeWatcher(runner.plugin_dir) + reload_count = 0 + + async def reload_and_maybe_rerun(*, announce: str | None) -> None: + if announce: + click.echo(announce) + if not await runner.reload(): + return + if event_text is not None: + await runner.dispatch_text(event_text) + + async def watch_loop(stop_event: asyncio.Event) -> None: + nonlocal reload_count + while not stop_event.is_set(): + await asyncio.sleep(watch_poll_interval) + changes = watcher.poll_changes() + if not changes: + continue + await reload_and_maybe_rerun( + announce=( + f"检测到文件变更,重新加载插件:{_format_watch_changes(changes)}" + ) + ) + reload_count += 1 + if max_watch_reloads is not None and reload_count >= max_watch_reloads: + stop_event.set() + return + + stop_event = asyncio.Event() + watch_task: asyncio.Task[None] | None = None + try: + await reload_and_maybe_rerun( + announce=( + "watch 模式已启动,监听插件目录变更。" + if event_text is not None + else "watch 模式已启动,监听插件目录变更并按需热重载。" + ) + ) + if max_watch_reloads == 0: + return + watch_task = asyncio.create_task(watch_loop(stop_event)) + if interactive: + click.echo( + "本地交互模式已启动。可用命令:/session /user /platform /group /private /event /exit" + ) + while not stop_event.is_set(): + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + break + text = line.strip() + if not text: + continue + if _handle_dev_meta_command(text, runner.state): + if text in {"/exit", "/quit"}: + break + continue + await runner.dispatch_text(text) + stop_event.set() + return + await stop_event.wait() + finally: + stop_event.set() + if watch_task is not None: + watch_task.cancel() + try: + await watch_task + except asyncio.CancelledError: + pass + await runner.close() + + +async def _run_local_dev( + *, + plugin_dir: Path, + event_text: str | None, + interactive: bool, + watch: bool, + session_id: str, + user_id: str, + platform: str, + group_id: str | None, + event_type: str, + watch_poll_interval: float = WATCH_POLL_INTERVAL_SECONDS, + max_watch_reloads: int | None = None, +) -> None: + from .testing import ( + PluginHarness, + StdoutPlatformSink, + _PluginExecutionError, + _PluginLoadError, + ) + + state = { + "session_id": session_id, + "user_id": user_id, + "platform": platform, + "group_id": group_id, + "event_type": event_type, + } + if watch: + runner = _ReloadableLocalDevRunner( + plugin_dir=plugin_dir, + state=state, + plugin_load_error=_PluginLoadError, + plugin_execution_error=_PluginExecutionError, + plugin_harness=PluginHarness, + stdout_platform_sink=StdoutPlatformSink, + ) + await _run_local_dev_watch( + runner=runner, + event_text=event_text, + interactive=interactive, + watch_poll_interval=watch_poll_interval, + max_watch_reloads=max_watch_reloads, + ) + return + + sink = StdoutPlatformSink(stream=sys.stdout) + harness = PluginHarness.from_plugin_dir( + plugin_dir, + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + platform_sink=sink, + ) + try: + async with harness: + if interactive: + click.echo( + "本地交互模式已启动。可用命令:/session /user /platform /group /private /event /exit" + ) + while True: + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + break + text = line.strip() + if not text: + continue + if _handle_dev_meta_command(text, state): + if text in {"/exit", "/quit"}: + break + continue + await harness.dispatch_text( + text, + session_id=str(state["session_id"]), + user_id=str(state["user_id"]), + platform=str(state["platform"]), + group_id=typing.cast(str | None, state["group_id"]), + event_type=str(state["event_type"]), + ) + return + assert event_text is not None + await harness.dispatch_text( + event_text, + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + ) + except _PluginLoadError as exc: + raise _CliPluginLoadError(str(exc)) from exc + except _PluginExecutionError as exc: + raise _CliPluginExecutionError(str(exc)) from exc + + +def _handle_dev_meta_command(command: str, state: dict[str, Any]) -> bool: + if command in {"/exit", "/quit"}: + return True + if command.startswith("/session "): + state["session_id"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 session_id -> {state['session_id']}") + return True + if command.startswith("/user "): + state["user_id"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 user_id -> {state['user_id']}") + return True + if command.startswith("/platform "): + state["platform"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 platform -> {state['platform']}") + return True + if command.startswith("/group "): + state["group_id"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 group_id -> {state['group_id']}") + return True + if command == "/private": + state["group_id"] = None + click.echo("已切换为私聊上下文") + return True + if command.startswith("/event "): + state["event_type"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 event_type -> {state['event_type']}") + return True + return False + + +def _slugify_plugin_name(value: str) -> str: + slug = re.sub(r"[^a-zA-Z0-9]+", "_", value).strip("_").lower() + return slug or "my_plugin" + + +def _normalize_plugin_name(value: str) -> str: + normalized = _slugify_plugin_name(value) + if normalized.startswith("astrbot_plugin_"): + return normalized + normalized = normalized.removeprefix("astrbot_plugin") + normalized = normalized.strip("_") + suffix = normalized or "my_plugin" + return f"astrbot_plugin_{suffix}" + + +def _class_name_for_plugin(value: str) -> str: + parts = [part for part in re.split(r"[^a-zA-Z0-9]+", value) if part] + if not parts: + return "MyPlugin" + return "".join(part[:1].upper() + part[1:] for part in parts) + + +def _sanitize_build_part(value: str) -> str: + sanitized = re.sub(r"[^a-zA-Z0-9._-]+", "_", value).strip("._-") + return sanitized or "artifact" + + +def _parse_init_agents( + _ctx: click.Context, + _param: click.Parameter, + value: str | None, +) -> tuple[str, ...]: + if value is None: + return () + + normalized_agents: list[str] = [] + seen: set[str] = set() + invalid_agents: list[str] = [] + for raw_agent in value.split(","): + candidate = raw_agent.strip().lower() + if not candidate: + invalid_agents.append("") + continue + if candidate not in SUPPORTED_INIT_AGENTS: + invalid_agents.append(raw_agent.strip()) + continue + if candidate in seen: + continue + seen.add(candidate) + normalized_agents.append(candidate) + + if invalid_agents: + supported = ", ".join(SUPPORTED_INIT_AGENTS) + invalid = ", ".join(invalid_agents) + raise click.BadParameter(f"仅支持以下 agent: {supported};非法值: {invalid}") + return tuple(normalized_agents) + + +def _render_init_plugin_yaml( + *, + plugin_name: str, + display_name: str, + desc: str, + author: str, + repo: str, + version: str, +) -> str: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + class_name = _class_name_for_plugin(plugin_name) + return dedent( + f"""\ + name: {plugin_name} + display_name: {display_name} + desc: {desc} + author: {author} + repo: {repo} + version: {version} + runtime: + python: "{python_version}" + components: + - class: main:{class_name} + """ + ) + + +def _render_init_main_py(*, plugin_name: str) -> str: + class_name = _class_name_for_plugin(plugin_name) + return dedent( + f"""\ + from astrbot_sdk import Context, MessageEvent, Star, on_command + + + class {class_name}(Star): + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("Hello, World!") + """ + ) + + +def _render_init_readme(*, plugin_name: str) -> str: + return dedent( + f"""\ + # {plugin_name} + + 一个最小可运行的 AstrBot SDK 插件。 + + ## 目录结构 + + ``` + . + ├── plugin.yaml + ├── requirements.txt + ├── main.py + └── tests + └── test_plugin.py + ``` + + ## 本地开发 + + ```bash + astrbot-sdk validate + astrbot-sdk dev --local --event-text hello + astrbot-sdk dev --local --watch --event-text hello + ``` + + ## 运行测试 + + ```bash + python -m pytest tests/test_plugin.py -v + ``` + """ + ) + + +def _render_init_gitignore() -> str: + return dedent( + """\ + # Python + __pycache__/ + *.py[cod] + *.pyo + *.egg-info/ + dist/ + build/ + *.egg + + # 虚拟环境 + .venv/ + venv/ + env/ + + # IDE + .idea/ + .vscode/ + *.swp + *.swo + *~ + + # OS + .DS_Store + Thumbs.db + desktop.ini + + # 测试 / 检查缓存 + .pytest_cache/ + .ruff_cache/ + .mypy_cache/ + .coverage + htmlcov/ + + # 开发/构建工具 + /.claude/ + /.agents/ + /.opencode/ + + # 图床配置(含 API 密钥等敏感信息) + /image_host/config.json + + # 插件测试产物 + /.astrbot_sdk_testing/ + """ + ) + + +def _render_init_test_py(*, plugin_name: str) -> str: + class_name = _class_name_for_plugin(plugin_name) + return dedent( + f"""\ + from pathlib import Path + + import pytest + + from astrbot_sdk.testing import MockContext, MockMessageEvent, PluginHarness + from main import {class_name} + + + @pytest.mark.asyncio + async def test_hello_handler(): + plugin = {class_name}() + ctx = MockContext( + plugin_id="{plugin_name}", + plugin_metadata={{"display_name": "{class_name}"}}, + ) + event = MockMessageEvent(text="/hello", context=ctx) + + await plugin.hello(event, ctx) + + assert event.replies == ["Hello, World!"] + ctx.platform.assert_sent("Hello, World!") + + + @pytest.mark.asyncio + async def test_hello_dispatch(): + plugin_dir = Path(__file__).resolve().parents[1] + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + records = await harness.dispatch_text("hello") + + assert any(record.text == "Hello, World!" for record in records) + """ + ) + + +def _plugin_root_hint_for_agent(agent: str) -> str: + skill_dir = INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME + return "/".join(".." for _ in skill_dir.parts) or "." + + +def _build_agent_template_context( + *, + plugin_name: str, + display_name: str, + agent: str, +) -> dict[str, str]: + return { + "plugin_name": plugin_name, + "display_name": display_name, + "class_name": _class_name_for_plugin(plugin_name), + "skill_name": f"{plugin_name}_project", + "plugin_root": _plugin_root_hint_for_agent(agent), + "agent_name": agent, + "agent_display_name": INIT_AGENT_DISPLAY_NAMES[agent], + "skill_dir_name": INIT_SKILL_TEMPLATE_NAME, + } + + +def _render_template_text(template_text: str, context: dict[str, str]) -> str: + def replace(match: re.Match[str]) -> str: + key = match.group(1) + if key not in context: + raise _CliPluginValidationError(f"agent 模板变量未定义:{key}") + return context[key] + + return _TEMPLATE_VARIABLE_PATTERN.sub(replace, template_text) + + +def _copy_rendered_template_tree( + source_dir: Traversable, + target_dir: Path, + *, + context: dict[str, str], +) -> None: + target_dir.mkdir(parents=True, exist_ok=True) + for entry in sorted(source_dir.iterdir(), key=lambda item: item.name): + destination = target_dir / entry.name + if entry.is_dir(): + _copy_rendered_template_tree(entry, destination, context=context) + continue + destination.write_text( + _render_template_text(entry.read_text(encoding="utf-8"), context), + encoding="utf-8", + ) + + +def _render_init_agent_templates( + *, + target_dir: Path, + plugin_name: str, + display_name: str, + agents: tuple[str, ...], +) -> None: + if not agents: + return + + template_root = resources.files("astrbot_sdk").joinpath( + "templates", + "skills", + INIT_SKILL_TEMPLATE_NAME, + ) + if not template_root.is_dir(): + raise _CliPluginValidationError( + f"未找到项目级 skill 模板:{INIT_SKILL_TEMPLATE_NAME}" + ) + + for agent in agents: + context = _build_agent_template_context( + plugin_name=plugin_name, + display_name=display_name, + agent=agent, + ) + _copy_rendered_template_tree( + template_root, + target_dir / INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME, + context=context, + ) + + +def _render_init_project_notes(*, target_dir: Path) -> None: + template_root = resources.files("astrbot_sdk").joinpath( + *INIT_PROJECT_NOTE_TEMPLATE_DIR + ) + if not template_root.is_dir(): + raise _CliPluginValidationError("未找到项目级说明模板:AGENTS.md / CLAUDE.md") + + for template_name in INIT_PROJECT_NOTE_TEMPLATE_NAMES: + template_path = template_root.joinpath(template_name) + if not template_path.is_file(): + raise _CliPluginValidationError( + f"未找到项目级说明模板文件:{template_name}" + ) + # Keep these notes as packaged resources so `astr init` behaves the same + # from a repo checkout, an sdist, and an installed wheel. + (target_dir / template_name).write_text( + template_path.read_text(encoding="utf-8"), + encoding="utf-8", + ) + + +def _ensure_plugin_dir_exists(plugin_dir: Path) -> Path: + resolved = plugin_dir.resolve() + if not resolved.exists() or not resolved.is_dir(): + raise _CliPluginValidationError(f"插件目录不存在:{plugin_dir}") + return resolved + + +def _resolve_dev_plugin_dir(plugin_dir: Path | None) -> Path: + if plugin_dir is not None: + return plugin_dir + current_dir = Path.cwd() + if (current_dir / "plugin.yaml").exists(): + return Path(".") + raise click.BadParameter( + "未提供 --plugin-dir,且当前目录未找到 plugin.yaml", + param_hint="--plugin-dir", + ) + + +def _load_validated_plugin(plugin_dir: Path) -> tuple[Any, Any]: + resolved_dir = _ensure_plugin_dir_exists(plugin_dir) + plugin = load_plugin_spec(resolved_dir) + try: + validate_plugin_spec(plugin) + except ValueError as exc: + raise _CliPluginValidationError(str(exc)) from exc + + loaded = load_plugin(plugin) + if not loaded.instances: + raise _CliPluginValidationError( + "未找到可加载的组件,请检查 plugin.yaml 中的 components" + ) + return plugin, loaded + + +def _build_kind(plugin: Any) -> str: + return ( + "legacy-main" + if bool(plugin.manifest_data.get("__legacy_main__")) + else "plugin-yaml" + ) + + +def _path_is_within(path: Path, root: Path) -> bool: + try: + path.resolve().relative_to(root.resolve()) + except ValueError: + return False + return True + + +def _iter_build_files(plugin_dir: Path, output_dir: Path) -> list[Path]: + files: list[Path] = [] + for path in sorted(plugin_dir.rglob("*")): + if path.is_dir(): + continue + if _path_is_within(path, output_dir): + continue + relative = path.relative_to(plugin_dir) + if any(part in BUILD_EXCLUDED_DIRS for part in relative.parts[:-1]): + continue + if relative.name in BUILD_EXCLUDED_FILES: + continue + if path.suffix in {".pyc", ".pyo"}: + continue + files.append(path) + return files + + +def _prompt_nonempty_text(prompt: str) -> str: + while True: + value = click.prompt(prompt, type=str, default="", show_default=False).strip() + if value: + return value + click.echo("该字段不能为空,请重新输入。") + + +def _default_init_repo_name(plugin_name: str) -> str: + return _normalize_plugin_name(plugin_name) + + +def _collect_init_metadata(name: str | None) -> tuple[str, str, str, str, str]: + plugin_name = name if name is not None else _prompt_nonempty_text("插件名字") + author = _prompt_nonempty_text("作者") + repo = _default_init_repo_name(plugin_name) + desc = click.prompt("描述", type=str, default="", show_default=False).strip() + version = click.prompt("版本", type=str, default="1.0.0", show_default=True).strip() + return plugin_name, author, repo, desc, version or "1.0.0" + + +def _init_plugin(name: str | None, agents: tuple[str, ...] = ()) -> None: + raw_name, author, repo, desc, version = _collect_init_metadata(name) + plugin_name = _normalize_plugin_name(raw_name) + target_dir = Path(plugin_name) + if target_dir.exists(): + raise _CliPluginValidationError(f"目标目录已存在:{target_dir}") + + display_name = raw_name.strip() or plugin_name + target_dir.mkdir(parents=True, exist_ok=False) + (target_dir / "tests").mkdir() + (target_dir / "plugin.yaml").write_text( + _render_init_plugin_yaml( + plugin_name=plugin_name, + display_name=display_name, + desc=desc, + author=author, + repo=repo, + version=version, + ), + encoding="utf-8", + ) + (target_dir / "requirements.txt").write_text("", encoding="utf-8") + (target_dir / "main.py").write_text( + _render_init_main_py(plugin_name=plugin_name), + encoding="utf-8", + ) + (target_dir / "README.md").write_text( + _render_init_readme(plugin_name=plugin_name), + encoding="utf-8", + ) + (target_dir / ".gitignore").write_text( + _render_init_gitignore(), + encoding="utf-8", + ) + (target_dir / "tests" / "test_plugin.py").write_text( + _render_init_test_py(plugin_name=plugin_name), + encoding="utf-8", + ) + _render_init_project_notes(target_dir=target_dir) + _render_init_agent_templates( + target_dir=target_dir, + plugin_name=plugin_name, + display_name=display_name, + agents=agents, + ) + + import subprocess + + try: + process = subprocess.run( + ["git", "init", str(target_dir)], + capture_output=True, + text=True, + ) + if process.returncode != 0: + stderr = process.stderr.strip() + raise RuntimeError( + f"Git 初始化失败(退出码 {process.returncode})" + + (f": {stderr}" if stderr else "") + ) + click.echo(f"Git 仓库已初始化: {target_dir}") + except FileNotFoundError: + click.echo("警告: 未找到 git 命令,请先安装 git 后手动执行 git init") + except RuntimeError as e: + click.echo(f"警告: {e}") + + click.echo(f"已创建插件:{target_dir}") + if agents: + generated_paths = ", ".join( + str(INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME) + for agent in agents + ) + click.echo(f"已生成项目级 skill:{generated_paths}") + click.echo("后续命令:") + click.echo(f" astrbot-sdk validate --plugin-dir {target_dir}") + click.echo( + f" astrbot-sdk dev --local --plugin-dir {target_dir} --event-text hello" + ) + + +def _validate_plugin(plugin_dir: Path) -> None: + plugin, loaded = _load_validated_plugin(plugin_dir) + click.echo(f"校验通过:{plugin.name}") + click.echo(f"kind: {_build_kind(plugin)}") + click.echo(f"plugin_dir: {plugin.plugin_dir}") + click.echo(f"handlers: {len(loaded.handlers)}") + click.echo(f"capabilities: {len(loaded.capabilities)}") + click.echo(f"instances: {len(loaded.instances)}") + + +def _build_plugin(plugin_dir: Path, output_dir: Path | None) -> None: + plugin, _ = _load_validated_plugin(plugin_dir) + build_dir = (output_dir or (plugin.plugin_dir / "dist")).resolve() + build_dir.mkdir(parents=True, exist_ok=True) + + version = _sanitize_build_part(str(plugin.manifest_data.get("version") or "0.0.0")) + archive_name = f"{_sanitize_build_part(plugin.name)}-{version}.zip" + archive_path = build_dir / archive_name + + with zipfile.ZipFile( + archive_path, + mode="w", + compression=zipfile.ZIP_DEFLATED, + ) as archive: + for path in _iter_build_files(plugin.plugin_dir, build_dir): + archive.write(path, arcname=path.relative_to(plugin.plugin_dir)) + + click.echo(f"构建完成:{archive_path}") + click.echo(f"artifact: {archive_path}") + + +def _run_websocket_worker_entrypoint( + *, + worker_id: str | None, + plugin_dirs: tuple[Path, ...], + host: str, + port: int, + path: str, + tls_ca_file: Path, + tls_cert_file: Path, + tls_key_file: Path, +) -> None: + resolved_plugin_dirs = list(plugin_dirs) if plugin_dirs else [Path.cwd()] + _run_async_entrypoint( + run_websocket_server( + worker_id=worker_id, + plugin_dirs=resolved_plugin_dirs, + host=host, + port=port, + path=path, + tls_ca_file=tls_ca_file, + tls_cert_file=tls_cert_file, + tls_key_file=tls_key_file, + ), + log_message=f"启动 WebSocket Worker,端口:{port}", + context={ + "worker_id": worker_id, + "plugin_dirs": resolved_plugin_dirs, + "port": port, + "path": path, + }, + ) + + +@click.group() +@click.option("-v", "--verbose", is_flag=True, help="Enable verbose output") +@click.pass_context +def cli(ctx, verbose: bool) -> None: + """AstrBot SDK CLI。""" + ctx.ensure_object(dict) + ctx.obj["verbose"] = verbose + setup_logger(verbose) + + +@cli.command() +@click.option( + "--plugins-dir", + default="plugins", + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Directory containing plugin folders", +) +@click.option( + "--workers-manifest", + default=None, + type=click.Path(file_okay=True, dir_okay=False, path_type=Path), + help="Supervisor manifest describing remote websocket workers", +) +@click.option( + "--protocol-stdout", + default=None, + type=str, + help="Redirect runtime protocol stdout to console, silent, or a file path", +) +def run( + plugins_dir: Path, + workers_manifest: Path | None, + protocol_stdout: str | None, +) -> None: + """Start the plugin supervisor over stdio.""" + transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout) + try: + _run_async_entrypoint( + run_supervisor( + plugins_dir=plugins_dir, + stdout=transport_stdout, + workers_manifest=workers_manifest, + ), + log_message=f"启动插件主管进程,插件目录:{plugins_dir}", + context={ + "plugins_dir": plugins_dir, + "workers_manifest": workers_manifest, + }, + ) + finally: + if opened_stdout is not None: + opened_stdout.close() + + +@cli.command() +@click.argument("name", type=str, required=False) +@click.option( + "--agents", + callback=_parse_init_agents, + metavar="claude,codex,opencode", + help="Generate per-agent project templates, comma-separated: claude,codex,opencode", +) +def init(name: str | None, agents: tuple[str, ...]) -> None: + """Create a new plugin skeleton in the target directory.""" + _run_sync_entrypoint( + lambda: _init_plugin(name, agents), + log_message=f"创建插件:{name or ''}", + context={"target": name or ""}, + ) + + +@cli.command() +@click.option( + "--plugin-dir", + default=".", + show_default=True, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to validate", +) +def validate(plugin_dir: Path) -> None: + """Validate plugin manifest, imports and handler discovery.""" + _run_sync_entrypoint( + lambda: _validate_plugin(plugin_dir), + log_message=f"校验插件目录:{plugin_dir}", + context={"plugin_dir": plugin_dir}, + ) + + +@cli.command() +@click.option( + "--plugin-dir", + default=".", + show_default=True, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to package", +) +@click.option( + "--output-dir", + default=None, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Directory for the build artifact, defaults to /dist", +) +def build(plugin_dir: Path, output_dir: Path | None) -> None: + """Validate and package a plugin into a zip artifact.""" + _run_sync_entrypoint( + lambda: _build_plugin(plugin_dir, output_dir), + log_message=f"构建插件包:{plugin_dir}", + context={"plugin_dir": plugin_dir, "output_dir": output_dir}, + ) + + +@cli.command() +@click.option( + "--plugin-dir", + required=False, + default=None, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to run locally, defaults to current directory when plugin.yaml exists", +) +@click.option("--local", "local_mode", is_flag=True, help="Run against local mock core") +@click.option( + "--standalone", + "standalone_mode", + is_flag=True, + help="Deprecated alias of --local", +) +@click.option("--event-text", type=str, help="Single message text to dispatch") +@click.option("--interactive", is_flag=True, help="Read follow-up messages from stdin") +@click.option( + "--watch", + is_flag=True, + help="Reload the local harness when plugin files change", +) +@click.option("--session-id", default="local-session", show_default=True) +@click.option("--user-id", default="local-user", show_default=True) +@click.option("--platform", "platform_name", default="test", show_default=True) +@click.option("--group-id", default=None) +@click.option("--event-type", default="message", show_default=True) +def dev( + plugin_dir: Path | None, + local_mode: bool, + standalone_mode: bool, + event_text: str | None, + interactive: bool, + watch: bool, + session_id: str, + user_id: str, + platform_name: str, + group_id: str | None, + event_type: str, +) -> None: + """Run a plugin against the local mock core for development.""" + if not (local_mode or standalone_mode): + raise click.BadParameter("当前 dev 只支持 --local/--standalone 模式") + if interactive and event_text: + raise click.BadParameter("--interactive 与 --event-text 不能同时使用") + if not interactive and not event_text: + raise click.BadParameter("请提供 --event-text,或改用 --interactive") + resolved_plugin_dir = _resolve_dev_plugin_dir(plugin_dir) + _run_async_entrypoint( + _run_local_dev( + plugin_dir=resolved_plugin_dir, + event_text=event_text, + interactive=interactive, + watch=watch, + session_id=session_id, + user_id=user_id, + platform=platform_name, + group_id=group_id, + event_type=event_type, + ), + log_message=f"启动本地开发模式:{resolved_plugin_dir}", + context={ + "plugin_dir": resolved_plugin_dir, + "session_id": session_id, + "platform": platform_name, + "event_type": event_type, + }, + ) + + +@cli.command(hidden=True) +@click.option( + "--plugin-dir", + required=False, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), +) +@click.option( + "--group-metadata", + required=False, + type=click.Path(file_okay=True, dir_okay=False, path_type=Path), +) +@click.option( + "--protocol-stdout", + default=None, + type=str, + help="Redirect runtime protocol stdout to console, silent, or a file path", +) +def worker( + plugin_dir: Path | None, + group_metadata: Path | None, + protocol_stdout: str | None, +) -> None: + """Internal command used by the supervisor to start a worker.""" + if plugin_dir is None and group_metadata is None: + raise click.UsageError("Either --plugin-dir or --group-metadata is required") + if plugin_dir is not None and group_metadata is not None: + raise click.UsageError( + "--plugin-dir and --group-metadata are mutually exclusive" + ) + + target = str(group_metadata or plugin_dir) + transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout) + if group_metadata is not None: + entrypoint = run_plugin_worker( + group_metadata=group_metadata, + stdout=transport_stdout, + ) + else: + entrypoint = run_plugin_worker( + plugin_dir=plugin_dir, + stdout=transport_stdout, + ) + try: + _run_async_entrypoint( + entrypoint, + log_message=f"启动插件工作进程:{target}", + log_level="debug", + context={"plugin_dir": plugin_dir}, + ) + finally: + if opened_stdout is not None: + opened_stdout.close() + + +@cli.command("serve-worker") +@click.option("--worker-id", default=None, type=str, help="Stable websocket worker id") +@click.option( + "--plugin-dir", + "plugin_dirs", + multiple=True, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to serve; repeat to host multiple plugins in one worker", +) +@click.option("--host", default="127.0.0.1", show_default=True) +@click.option("--port", default=8765, type=int, show_default=True) +@click.option("--path", default="/", show_default=True) +@click.option( + "--tls-ca-file", + required=True, + type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path), +) +@click.option( + "--tls-cert-file", + required=True, + type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path), +) +@click.option( + "--tls-key-file", + required=True, + type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path), +) +def serve_worker( + worker_id: str | None, + plugin_dirs: tuple[Path, ...], + host: str, + port: int, + path: str, + tls_ca_file: Path, + tls_cert_file: Path, + tls_key_file: Path, +) -> None: + """Serve one or more plugins as a standalone websocket worker.""" + _run_websocket_worker_entrypoint( + worker_id=worker_id, + plugin_dirs=plugin_dirs, + host=host, + port=port, + path=path, + tls_ca_file=tls_ca_file, + tls_cert_file=tls_cert_file, + tls_key_file=tls_key_file, + ) + + +@cli.command(hidden=True) +@click.option("--worker-id", default=None, type=str) +@click.option( + "--plugin-dir", + "plugin_dirs", + multiple=True, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), +) +@click.option("--host", default="127.0.0.1", show_default=True) +@click.option("--port", default=8765, type=int, show_default=True) +@click.option("--path", default="/", show_default=True) +@click.option( + "--tls-ca-file", + required=True, + type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path), +) +@click.option( + "--tls-cert-file", + required=True, + type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path), +) +@click.option( + "--tls-key-file", + required=True, + type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path), +) +def websocket( + worker_id: str | None, + plugin_dirs: tuple[Path, ...], + host: str, + port: int, + path: str, + tls_ca_file: Path, + tls_cert_file: Path, + tls_key_file: Path, +) -> None: + """Deprecated websocket runtime wrapper for standalone worker scenarios.""" + logger.warning("'astr websocket' is deprecated; use 'astr serve-worker' instead") + _run_websocket_worker_entrypoint( + worker_id=worker_id, + plugin_dirs=plugin_dirs, + host=host, + port=port, + path=path, + tls_ca_file=tls_ca_file, + tls_cert_file=tls_cert_file, + tls_key_file=tls_key_file, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/__init__.py b/astrbot-sdk/src/astrbot_sdk/clients/__init__.py new file mode 100644 index 0000000000..d70c7fc3ee --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/__init__.py @@ -0,0 +1,107 @@ +"""原生 astrbot-sdk 能力客户端 + +这些客户端为 Context 提供了用于调用远程能力的狭窄且具类型化 (typed) 的接口。 +它们负责处理能力名称、载荷格式化(payload shaping)以及结果解码,且不会暴露协议或传输层的具体细节。 + +为了保持 Context 接口的精简与稳定,迁移适配层 (Migration shims) 以及高层级编排逻辑 (higher-level orchestration) 均不包含在这些原生能力客户端之内。 + +当前公开客户端: + - LLMClient: 文本/结构化/流式 LLM 调用 + - MemoryClient: 记忆搜索、保存、读取、删除 + - DBClient: 键值存储 get/set/delete/list + - FileServiceClient: 文件令牌注册与解析 + - PlatformClient: 平台消息发送与成员查询 + - ProviderClient: Provider 元信息与专用 provider proxy + - PersonaManagerClient: 人格管理 + - ConversationManagerClient: 对话管理 + - KnowledgeBaseManagerClient: 知识库管理 + - HTTPClient: Web API 注册 + - MetadataClient: 插件元数据查询 + - SkillClient: 运行时注册插件 skill +""" + +from .db import DBClient +from .files import FileRegistration, FileServiceClient +from .http import HTTPClient +from .llm import ChatMessage, LLMClient, LLMResponse +from .managers import ( + ConversationCreateParams, + ConversationManagerClient, + ConversationRecord, + ConversationUpdateParams, + KnowledgeBaseCreateParams, + KnowledgeBaseManagerClient, + KnowledgeBaseRecord, + MessageHistoryManagerClient, + MessageHistoryPage, + MessageHistoryRecord, + MessageHistorySender, + PersonaCreateParams, + PersonaManagerClient, + PersonaRecord, + PersonaUpdateParams, +) +from .mcp import MCPManagerClient, MCPServerRecord, MCPServerScope, MCPSession +from .memory import MemoryClient +from .metadata import MetadataClient, PluginMetadata, StarMetadata +from .permission import PermissionCheckResult, PermissionClient, PermissionManagerClient +from .platform import PlatformClient, PlatformError, PlatformStats, PlatformStatus +from .provider import ( + ManagedProviderRecord, + ProviderChangeEvent, + ProviderClient, + ProviderManagerClient, +) +from .registry import HandlerMetadata, RegistryClient +from .session import SessionPluginManager, SessionServiceManager +from .skills import SkillClient, SkillRegistration + +__all__ = [ + "ChatMessage", + "ConversationCreateParams", + "ConversationManagerClient", + "ConversationRecord", + "ConversationUpdateParams", + "DBClient", + "FileRegistration", + "FileServiceClient", + "HTTPClient", + "KnowledgeBaseCreateParams", + "KnowledgeBaseManagerClient", + "KnowledgeBaseRecord", + "MessageHistoryManagerClient", + "MessageHistoryPage", + "MessageHistoryRecord", + "MessageHistorySender", + "LLMClient", + "LLMResponse", + "MCPManagerClient", + "MCPSession", + "MCPServerRecord", + "MCPServerScope", + "MemoryClient", + "ManagedProviderRecord", + "MetadataClient", + "PermissionCheckResult", + "PermissionClient", + "PermissionManagerClient", + "PlatformClient", + "PlatformError", + "PlatformStats", + "PlatformStatus", + "PersonaCreateParams", + "PersonaManagerClient", + "PersonaRecord", + "PersonaUpdateParams", + "ProviderChangeEvent", + "ProviderClient", + "ProviderManagerClient", + "PluginMetadata", + "StarMetadata", + "HandlerMetadata", + "RegistryClient", + "SessionPluginManager", + "SessionServiceManager", + "SkillClient", + "SkillRegistration", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/_errors.py b/astrbot-sdk/src/astrbot_sdk/clients/_errors.py new file mode 100644 index 0000000000..e926321b25 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/_errors.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from ..errors import AstrBotError + + +def client_call_label( + client_name: str, + method_name: str, + details: str | None = None, +) -> str: + label = f"{client_name}.{method_name}" + if details: + return f"{label} ({details})" + return label + + +def wrap_client_exception( + *, + client_name: str, + method_name: str, + exc: Exception, + details: str | None = None, +) -> Exception: + message = f"{client_call_label(client_name, method_name, details)} failed: {exc}" + if isinstance(exc, AstrBotError): + return AstrBotError( + code=exc.code, + message=message, + hint=exc.hint, + retryable=exc.retryable, + docs_url=exc.docs_url, + details=exc.details, + ) + try: + rebuilt = exc.__class__(message) + except Exception: + return RuntimeError(message) + if isinstance(rebuilt, Exception): + return rebuilt + return RuntimeError(message) + + +__all__ = ["client_call_label", "wrap_client_exception"] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py b/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py new file mode 100644 index 0000000000..4a6e9db7d9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py @@ -0,0 +1,188 @@ +"""能力代理模块。 + +提供 CapabilityProxy 类,作为客户端与 Peer 之间的中间层,负责: +- 检查远程能力是否可用 +- 验证流式调用支持 +- 统一封装 invoke 和 invoke_stream 调用 + +设计说明: + CapabilityProxy 是新版架构的核心组件。每个专用客户端 (LLMClient, DBClient 等) + 都通过 CapabilityProxy 与远程通信,并在发起调用时绑定当前插件身份, + 让运行时把调用者信息放进协议层而不是业务 payload。 + +使用示例: + proxy = CapabilityProxy(peer) + + # 普通调用 + result = await proxy.call("llm.chat", {"prompt": "hello"}) + + # 流式调用 + async for delta in proxy.stream("llm.stream_chat", {"prompt": "hello"}): + print(delta["text"]) +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from typing import Any, Protocol + +from .._internal.invocation_context import caller_plugin_scope +from ..errors import AstrBotError + + +class _CapabilityDescriptorLike(Protocol): + supports_stream: bool | None + + +class _CapabilityPeerLike(Protocol): + remote_capability_map: Mapping[str, _CapabilityDescriptorLike] + remote_peer: Any | None + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: ... + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + ) -> AsyncIterator[Any]: ... + + +class CapabilityProxy: + """能力代理类,封装 Peer 的能力调用接口。 + + 负责在调用前验证能力可用性和流式支持,提供统一的 call/stream 接口。 + + Attributes: + _peer: 底层 Peer 实例,负责实际的 RPC 通信 + """ + + def __init__( + self, + peer: _CapabilityPeerLike, + caller_plugin_id: str | None = None, + request_scope_id: str | None = None, + ) -> None: + """初始化能力代理。 + + Args: + peer: Peer 实例,提供 remote_capability_map 和 invoke/invoke_stream 方法 + """ + self._peer = peer + self._caller_plugin_id = caller_plugin_id + self._request_scope_id = request_scope_id + + def _get_descriptor(self, name: str) -> _CapabilityDescriptorLike | None: + """获取能力描述符。 + + Args: + name: 能力名称,如 "llm.chat" + + Returns: + 能力描述符,若不存在则返回 None + """ + capability_map = getattr(self._peer, "remote_capability_map", {}) + if not isinstance(capability_map, Mapping): + return None + return capability_map.get(name) + + def _remote_initialized(self) -> bool: + peer_attrs = getattr(self._peer, "__dict__", None) + if not isinstance(peer_attrs, dict): + return False + + # Avoid getattr() here: MagicMock synthesizes truthy child attributes and + # makes an uninitialized peer look ready. + remote_peer = peer_attrs.get("remote_peer") + capability_map = peer_attrs.get("remote_capability_map") + return bool(remote_peer) or ( + isinstance(capability_map, Mapping) and bool(capability_map) + ) + + def _ensure_available(self, name: str, *, stream: bool) -> None: + """确保能力可用且支持指定的调用模式。 + + Args: + name: 能力名称 + stream: 是否需要流式支持 + + Raises: + AstrBotError: 能力不存在或流式不支持 + """ + descriptor = self._get_descriptor(name) + if descriptor is None: + if self._remote_initialized(): + raise AstrBotError.capability_not_found(name) + return + if stream and not descriptor.supports_stream: + raise AstrBotError.invalid_input(f"{name} 不支持 stream=true") + + def _prepare_payload(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + if ( + not isinstance(self._request_scope_id, str) + or not self._request_scope_id + or not name.startswith("system.event.") + ): + return payload + scoped_payload = dict(payload) + scoped_payload.setdefault("_request_scope_id", self._request_scope_id) + return scoped_payload + + async def call(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + """执行普通能力调用(非流式)。 + + Args: + name: 能力名称,如 "llm.chat", "db.get" + payload: 调用参数字典 + + Returns: + 调用结果字典 + + Raises: + AstrBotError: 能力不存在或调用失败 + + 示例: + result = await proxy.call("llm.chat", {"prompt": "hello"}) + print(result["text"]) + """ + self._ensure_available(name, stream=False) + prepared_payload = self._prepare_payload(name, payload) + with caller_plugin_scope(self._caller_plugin_id): + return await self._peer.invoke(name, prepared_payload, stream=False) + + async def stream( + self, + name: str, + payload: dict[str, Any], + ) -> AsyncIterator[dict[str, Any]]: + """执行流式能力调用。 + + Args: + name: 能力名称,如 "llm.stream_chat" + payload: 调用参数字典 + + Yields: + 每个增量数据块(phase="delta" 时的 data 字段) + + Raises: + AstrBotError: 能力不存在或不支持流式 + + 示例: + async for delta in proxy.stream("llm.stream_chat", {"prompt": "hello"}): + print(delta["text"], end="") + """ + self._ensure_available(name, stream=True) + prepared_payload = self._prepare_payload(name, payload) + with caller_plugin_scope(self._caller_plugin_id): + event_stream = await self._peer.invoke_stream(name, prepared_payload) + async for event in event_stream: + if event.phase == "delta": + yield event.data diff --git a/astrbot-sdk/src/astrbot_sdk/clients/db.py b/astrbot-sdk/src/astrbot_sdk/clients/db.py new file mode 100644 index 0000000000..bf2783490d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/db.py @@ -0,0 +1,161 @@ +"""数据库客户端模块。 + +提供键值存储能力,用于持久化插件数据。 + +功能说明: + - 数据永久存储,除非用户显式删除 + - 值类型支持任意 JSON 数据 + - 支持前缀查询键列表 + - 支持批量读写 + - 支持订阅变更事件 +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any + +from ._proxy import CapabilityProxy + + +class DBClient: + """键值数据库客户端。 + + 提供插件数据的持久化存储能力,数据永久保存直到显式删除。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化数据库客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + async def get(self, key: str) -> Any | None: + """获取指定键的值。 + + Args: + key: 数据键名 + + Returns: + 存储的值,若键不存在则返回 None + + 示例: + data = await ctx.db.get("user_settings") + if data: + print(data["theme"]) + """ + output = await self._proxy.call("db.get", {"key": key}) + return output.get("value") + + async def set(self, key: str, value: Any) -> None: + """设置键值对。 + + Args: + key: 数据键名 + value: 要存储的 JSON 值 + + 示例: + await ctx.db.set("user_settings", {"theme": "dark", "lang": "zh"}) + await ctx.db.set("greeted", True) + """ + await self._proxy.call("db.set", {"key": key, "value": value}) + + async def delete(self, key: str) -> None: + """删除指定键的数据。 + + Args: + key: 要删除的数据键名 + + 示例: + await ctx.db.delete("user_settings") + """ + await self._proxy.call("db.delete", {"key": key}) + + async def list(self, prefix: str | None = None) -> list[str]: + """列出匹配前缀的所有键。 + + Args: + prefix: 键前缀过滤,None 表示列出所有键 + + Returns: + 匹配的键名列表 + + 示例: + # 列出所有用户设置相关的键 + keys = await ctx.db.list("user_") + # ["user_settings", "user_profile", "user_history"] + """ + output = await self._proxy.call("db.list", {"prefix": prefix}) + keys = output.get("keys") + if not isinstance(keys, (list, tuple)): + return [] + return [str(item) for item in keys] + + async def get_many(self, keys: Sequence[str]) -> dict[str, Any | None]: + """批量获取多个键的值。 + + Args: + keys: 要读取的键列表 + + Returns: + 一个 dict,key 为键名,value 为对应值(不存在则为 None) + + 示例: + values = await ctx.db.get_many(["user:1", "user:2"]) + if values["user:1"] is None: + print("user:1 missing") + """ + output = await self._proxy.call("db.get_many", {"keys": list(keys)}) + items = output.get("items") + if not isinstance(items, (list, tuple)): + return {} + result: dict[str, Any | None] = {} + for item in items: + if not isinstance(item, dict): + continue + key = item.get("key") + if not isinstance(key, str): + continue + result[key] = item.get("value") + return result + + async def set_many( + self, items: Mapping[str, Any] | Sequence[tuple[str, Any]] + ) -> None: + """批量写入多个键值对。 + + Args: + items: 键值对集合(dict 或二元组序列) + + 示例: + await ctx.db.set_many({"user:1": {"name": "a"}, "user:2": {"name": "b"}}) + """ + if isinstance(items, Mapping): + pairs = list(items.items()) + else: + pairs = list(items) + + payload_items: list[dict[str, Any]] = [ + {"key": str(key), "value": value} for key, value in pairs + ] + await self._proxy.call("db.set_many", {"items": payload_items}) + + def watch(self, prefix: str | None = None) -> AsyncIterator[dict[str, Any]]: + """订阅 KV 变更事件(流式)。 + + Args: + prefix: 键前缀过滤;None 表示订阅所有键 + + Yields: + 变更事件 dict:{"op": "set"|"delete", "key": str, "value": Any|None} + + 示例: + async for event in ctx.db.watch("user:"): + print(event["op"], event["key"]) + """ + return self._proxy.stream("db.watch", {"prefix": prefix}) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/files.py b/astrbot-sdk/src/astrbot_sdk/clients/files.py new file mode 100644 index 0000000000..94d716151a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/files.py @@ -0,0 +1,79 @@ +"""文件服务客户端。 + +提供文件令牌注册和令牌反查能力,封装 `system.file.*` capabilities。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from ._proxy import CapabilityProxy + + +@dataclass(slots=True) +class FileRegistration: + """文件注册结果。""" + + token: str + url: str + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> FileRegistration: + return cls( + token=str(payload.get("token", "")), + url=str(payload.get("url", "")), + ) + + +class FileServiceClient: + """文件服务能力客户端。""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def _register( + self, + path: str, + *, + timeout: float | None, + ) -> FileRegistration: + output = await self._proxy.call( + "system.file.register", + {"path": str(path), "timeout": timeout}, + ) + return FileRegistration.from_payload(output) + + async def register_file( + self, + path: str, + timeout: float | None = None, + ) -> str: + """注册本地文件并返回文件令牌。""" + + return (await self._register(path, timeout=timeout)).token + + async def register_file_url( + self, + path: str, + timeout: float | None = None, + ) -> str: + """注册本地文件并返回公开访问 URL。""" + + return (await self._register(path, timeout=timeout)).url + + async def handle_file(self, token: str) -> str: + """将文件令牌解析回本地文件路径。""" + + output = await self._proxy.call( + "system.file.handle", + {"token": str(token)}, + ) + return str(output.get("path", "")) + + async def _register_file_url( + self, + path: str, + timeout: float | None = None, + ) -> str: + return await self.register_file_url(path, timeout=timeout) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/http.py b/astrbot-sdk/src/astrbot_sdk/clients/http.py new file mode 100644 index 0000000000..84c7417af6 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/http.py @@ -0,0 +1,187 @@ +"""HTTP 客户端模块。 + +提供 HTTP API 注册能力。 + +功能说明: + - 注册自定义 Web API 端点 + - 支持异步请求处理 + - 与宿主 Web 服务器集成 + +设计说明: + 由于跨进程架构,handler 函数无法直接序列化传递。 + 插件需要先声明处理 HTTP 请求的 capability,然后注册路由到 capability 的映射。 + 当前插件身份由运行时在协议层透传,客户端 payload 不暴露 `plugin_id`。 + + 调用流程: + HTTP 请求 → 宿主 Web 服务器 → 查找 route 映射 → invoke capability → Worker 执行 handler → 返回响应 + +示例: + # 插件声明处理 HTTP 请求的 capability + @provide_capability( + name="my_plugin.http_handler", + description="处理 /my_plugin/api 的 HTTP 请求", + input_schema={...}, + output_schema={...} + ) + async def handle_http_request(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + + # 注册路由 → capability 映射 + await ctx.http.register_api( + route="/my_plugin/api", + methods=["GET", "POST"], + handler_capability="my_plugin.http_handler", + description="我的 API" + ) +""" + +from __future__ import annotations + +from typing import Any + +from ..decorators import get_capability_meta +from ..errors import AstrBotError +from ._errors import wrap_client_exception +from ._proxy import CapabilityProxy + + +def _resolve_handler_capability( + handler_capability: str | None, + handler: Any | None, +) -> str: + if handler_capability and handler is not None: + raise AstrBotError.invalid_input( + "register_api 不能同时提供 handler_capability 和 handler", + hint="请二选一:传 capability 名称字符串,或传 @provide_capability 标记的方法", + ) + if handler_capability: + return handler_capability + if handler is None: + raise AstrBotError.invalid_input( + "register_api 需要提供 handler_capability 或 handler", + hint="示例:handler_capability='demo.http_handler' 或 handler=self.http_handler_capability", + ) + target = getattr(handler, "__func__", handler) + meta = get_capability_meta(target) + if meta is None: + raise AstrBotError.invalid_input( + "register_api(handler=...) 需要传入使用 @provide_capability 声明的方法", + hint="请先用 @provide_capability(name='demo.http_handler', ...) 标记该方法", + ) + return meta.descriptor.name + + +class HTTPClient: + """HTTP 能力客户端。 + + 提供 Web API 注册能力,允许插件暴露自定义 HTTP 端点。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化 HTTP 客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + async def register_api( + self, + route: str, + handler_capability: str | None = None, + *, + handler: Any | None = None, + methods: list[str] | None = None, + description: str = "", + ) -> None: + """注册 Web API 端点。 + + Args: + route: API 路由路径(必须使用 "/{plugin_id}" 或 "/{plugin_id}/...") + handler_capability: 处理此路由的 capability 名称 + handler: 使用 @provide_capability 标记的方法引用 + methods: HTTP 方法列表,默认 ["GET"] + description: API 描述 + + 示例: + await ctx.http.register_api( + route="/my_plugin/api", + handler_capability="my_plugin.http_handler", + methods=["GET", "POST"], + description="我的 API" + ) + """ + if methods is None: + methods = ["GET"] + resolved_handler = _resolve_handler_capability(handler_capability, handler) + try: + await self._proxy.call( + "http.register_api", + { + "route": route, + "methods": methods, + "handler_capability": resolved_handler, + "description": description, + }, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="HTTPClient", + method_name="register_api", + details=f"route={route!r}, methods={list(methods)!r}", + exc=exc, + ) from exc + + async def unregister_api( + self, route: str, methods: list[str] | None = None + ) -> None: + """注销 Web API 端点。 + + Args: + route: API 路由路径 + methods: HTTP 方法列表,None 表示所有方法 + + 示例: + await ctx.http.unregister_api("/my_plugin/api") + """ + if methods is None: + methods = [] + try: + await self._proxy.call( + "http.unregister_api", + {"route": route, "methods": methods}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="HTTPClient", + method_name="unregister_api", + details=f"route={route!r}, methods={list(methods)!r}", + exc=exc, + ) from exc + + async def list_apis(self) -> list[dict[str, Any]]: + """列出当前插件注册的所有 API。 + + Returns: + API 列表,每项包含 route, methods, description + + 示例: + apis = await ctx.http.list_apis() + for api in apis: + print(f"{api['route']}: {api['methods']}") + """ + try: + output = await self._proxy.call( + "http.list_apis", + {}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="HTTPClient", + method_name="list_apis", + exc=exc, + ) from exc + return output.get("apis", []) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/llm.py b/astrbot-sdk/src/astrbot_sdk/clients/llm.py new file mode 100644 index 0000000000..62ff86d32c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/llm.py @@ -0,0 +1,293 @@ +"""大语言模型客户端模块。 + +提供 astrbot-sdk 原生的 LLM 能力调用接口。 + +设计边界: + - `chat()` 是便捷文本接口,返回最终文本 + - `chat_raw()` 返回完整结构化响应 + - `stream_chat()` 返回文本增量 + - Agent 循环、动态工具注册等更高层 orchestration 不放在客户端内, + 由上层运行时或独立迁移入口承接 +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Mapping, Sequence +from typing import Any + +from pydantic import BaseModel, Field + +from ._proxy import CapabilityProxy + + +class ChatMessage(BaseModel): + """聊天消息模型。 + + 用于构建对话历史,传递给 LLM。 + + Attributes: + role: 消息角色,如 "user", "assistant", "system" + content: 消息内容 + + 示例: + history = [ + ChatMessage(role="user", content="你好"), + ChatMessage(role="assistant", content="你好!有什么可以帮助你的?"), + ChatMessage(role="user", content="今天天气怎么样?"), + ] + """ + + role: str + content: str + + +ChatHistoryItem = ChatMessage | Mapping[str, Any] + + +def _serialize_history( + history: Sequence[ChatHistoryItem] | None, +) -> list[dict[str, Any]]: + if history is None: + return [] + + serialized: list[dict[str, Any]] = [] + for item in history: + if isinstance(item, ChatMessage): + serialized.append(item.model_dump()) + continue + if isinstance(item, Mapping): + serialized.append(dict(item)) + continue + raise TypeError("history 项必须是 ChatMessage 或 mapping") + return serialized + + +def _normalize_chat_context_payload( + *, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, +) -> dict[str, list[dict[str, Any]]]: + if contexts is not None: + return {"contexts": _serialize_history(contexts)} + if history is not None: + return {"contexts": _serialize_history(history)} + return {} + + +def _build_chat_payload( + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + extra: dict[str, Any] | None = None, +) -> dict[str, Any]: + payload: dict[str, Any] = {"prompt": prompt} + if system is not None: + payload["system"] = system + payload.update(_normalize_chat_context_payload(history=history, contexts=contexts)) + if provider_id is not None: + payload["provider_id"] = provider_id + if tool_calls_result is not None: + payload["tool_calls_result"] = [dict(item) for item in tool_calls_result] + if model is not None: + payload["model"] = model + if temperature is not None: + payload["temperature"] = temperature + if extra: + payload.update(extra) + return payload + + +class LLMResponse(BaseModel): + """LLM 响应模型。 + + 包含完整的 LLM 响应信息,用于 chat_raw() 方法返回。 + + Attributes: + text: 生成的文本内容 + usage: Token 使用统计,如 {"prompt_tokens": 10, "completion_tokens": 20} + finish_reason: 结束原因,如 "stop", "length", "tool_calls" + tool_calls: 工具调用列表(如果 LLM 决定调用工具) + """ + + text: str + usage: dict[str, Any] | None = None + finish_reason: str | None = None + tool_calls: list[dict[str, Any]] = Field(default_factory=list) + role: str | None = None + reasoning_content: str | None = None + reasoning_signature: str | None = None + + +class LLMClient: + """大语言模型客户端。 + + 提供与 LLM 交互的能力,支持普通聊天和流式聊天。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化 LLM 客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + async def chat( + self, + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, + ) -> str: + """发送聊天请求并返回文本响应。 + + 这是简化的聊天接口,仅返回生成的文本内容。 + 如需完整响应信息(包括 usage、tool_calls),请使用 chat_raw()。 + + Args: + prompt: 用户输入的提示文本 + system: 系统提示词,用于指导 LLM 行为 + history: 对话历史,用于保持上下文连续性 + model: 指定使用的模型名称(可选,由核心自动选择) + temperature: 生成温度,控制随机性(0-1) + **kwargs: 额外透传参数,如 `image_urls`、`tools` + + Returns: + LLM 生成的文本内容 + + 示例: + # 简单对话 + reply = await ctx.llm.chat("你好,介绍一下自己") + + # 带历史的对话 + history = [ + ChatMessage(role="user", content="我叫小明"), + ChatMessage(role="assistant", content="你好小明!"), + ] + reply = await ctx.llm.chat("你记得我的名字吗?", history=history) + """ + output = await self._proxy.call( + "llm.chat", + _build_chat_payload( + prompt, + system=system, + history=history, + contexts=contexts, + provider_id=provider_id, + tool_calls_result=tool_calls_result, + model=model, + temperature=temperature, + extra=kwargs, + ), + ) + return str(output.get("text", "")) + + async def chat_raw( + self, + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, + ) -> LLMResponse: + """发送聊天请求并返回完整响应。 + + 与 chat() 不同,此方法返回完整的 LLMResponse 对象, + 包含 usage、finish_reason、tool_calls 等信息。 + + Args: + prompt: 用户输入的提示文本 + **kwargs: 额外参数,如 system, history, model, temperature 等 + + Returns: + LLMResponse 对象,包含完整响应信息 + + 示例: + response = await ctx.llm.chat_raw("写一首诗", temperature=0.8) + print(f"生成文本: {response.text}") + print(f"Token 使用: {response.usage}") + """ + payload = _build_chat_payload( + prompt, + system=system, + history=history, + contexts=contexts, + provider_id=provider_id, + tool_calls_result=tool_calls_result, + model=model, + temperature=temperature, + extra=kwargs, + ) + output = await self._proxy.call( + "llm.chat_raw", + payload, + ) + return LLMResponse.model_validate(output) + + async def stream_chat( + self, + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + """流式聊天,逐块返回响应文本。 + + 适用于需要实时显示生成内容的场景,如聊天界面。 + + Args: + prompt: 用户输入的提示文本 + system: 系统提示词 + history: 对话历史 + model: 指定模型 + temperature: 采样温度 + **kwargs: 额外透传参数,如 `image_urls`、`tools` + + Yields: + 每个生成的文本块 + + 示例: + async for chunk in ctx.llm.stream_chat("讲一个故事"): + print(chunk, end="", flush=True) + """ + async for data in self._proxy.stream( + "llm.stream_chat", + _build_chat_payload( + prompt, + system=system, + history=history, + contexts=contexts, + provider_id=provider_id, + tool_calls_result=tool_calls_result, + model=model, + temperature=temperature, + extra=kwargs, + ), + ): + yield str(data.get("text", "")) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/managers.py b/astrbot-sdk/src/astrbot_sdk/clients/managers.py new file mode 100644 index 0000000000..1809689e9b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/managers.py @@ -0,0 +1,886 @@ +"""Typed SDK manager clients for persona, conversation, and knowledge base.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ..errors import AstrBotError, ErrorCodes +from ..message.components import ( + BaseMessageComponent, + component_to_payload_sync, + payload_to_component, +) +from ..message.session import MessageSession +from ._proxy import CapabilityProxy + + +class _ManagerModel(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + def to_update_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_unset=True) + + +def _normalize_session(session: str | MessageSession) -> str: + if isinstance(session, MessageSession): + return str(session) + return str(session) + + +def _require_message_history_session( + session: MessageSession, +) -> dict[str, str]: + if not isinstance(session, MessageSession): + raise TypeError( + "message_history requires astrbot_sdk.message.session.MessageSession" + ) + return { + "platform_id": str(session.platform_id), + "message_type": str(session.message_type), + "session_id": str(session.session_id), + } + + +def _normalize_message_history_parts( + parts: list[BaseMessageComponent], +) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for part in parts: + if not isinstance(part, BaseMessageComponent): + raise TypeError( + "message_history.append requires BaseMessageComponent items in parts" + ) + normalized.append(component_to_payload_sync(part)) + return normalized + + +def _normalize_message_history_boundary(value: datetime) -> str: + if not isinstance(value, datetime): + raise TypeError("message_history boundary requires datetime") + normalized = value + if normalized.tzinfo is None: + normalized = normalized.replace(tzinfo=timezone.utc) + else: + normalized = normalized.astimezone(timezone.utc) + return normalized.isoformat() + + +class PersonaRecord(_ManagerModel): + persona_id: str + system_prompt: str + begin_dialogs: list[str] = Field(default_factory=list) + tools: list[str] | None = None + skills: list[str] | None = None + custom_error_message: str | None = None + folder_id: str | None = None + sort_order: int = 0 + created_at: str | None = None + updated_at: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> PersonaRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class PersonaCreateParams(_ManagerModel): + persona_id: str + system_prompt: str + begin_dialogs: list[str] = Field(default_factory=list) + tools: list[str] | None = None + skills: list[str] | None = None + custom_error_message: str | None = None + folder_id: str | None = None + sort_order: int = 0 + + +class PersonaUpdateParams(_ManagerModel): + system_prompt: str | None = None + begin_dialogs: list[str] | None = None + tools: list[str] | None = None + skills: list[str] | None = None + custom_error_message: str | None = None + + +class ConversationRecord(_ManagerModel): + conversation_id: str + session: str + platform_id: str + history: list[dict[str, Any]] = Field(default_factory=list) + title: str | None = None + persona_id: str | None = None + created_at: str | None = None + updated_at: str | None = None + token_usage: int | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> ConversationRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ConversationCreateParams(_ManagerModel): + platform_id: str | None = None + history: list[dict[str, Any]] | None = None + title: str | None = None + persona_id: str | None = None + + +class ConversationUpdateParams(_ManagerModel): + history: list[dict[str, Any]] | None = None + title: str | None = None + persona_id: str | None = None + token_usage: int | None = None + + +class MessageHistorySender(_ManagerModel): + sender_id: str | None = None + sender_name: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MessageHistorySender | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class MessageHistoryRecord(_ManagerModel): + id: int + session: MessageSession + sender: MessageHistorySender = Field(default_factory=MessageHistorySender) + parts: list[BaseMessageComponent] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime | None = None + updated_at: datetime | None = None + idempotency_key: str | None = None + + @model_validator(mode="before") + @classmethod + def _normalize_payload(cls, value: Any) -> Any: + if not isinstance(value, dict): + return value + normalized = dict(value) + + session_payload = normalized.get("session") + if isinstance(session_payload, dict): + normalized["session"] = MessageSession( + platform_id=str(session_payload.get("platform_id", "")), + message_type=str(session_payload.get("message_type", "")), + session_id=str(session_payload.get("session_id", "")), + ) + + sender_payload = normalized.get("sender") + if isinstance(sender_payload, dict): + normalized["sender"] = MessageHistorySender.model_validate(sender_payload) + elif sender_payload is None: + normalized["sender"] = MessageHistorySender() + + parts_payload = normalized.get("parts") + if isinstance(parts_payload, list): + normalized["parts"] = [ + payload_to_component(item) + for item in parts_payload + if isinstance(item, dict) + ] + + metadata_payload = normalized.get("metadata") + if not isinstance(metadata_payload, dict): + normalized["metadata"] = {} + + return normalized + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MessageHistoryRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class MessageHistoryPage(_ManagerModel): + records: list[MessageHistoryRecord] = Field(default_factory=list) + next_cursor: str | None = None + total: int | None = None + + @model_validator(mode="before") + @classmethod + def _normalize_payload(cls, value: Any) -> Any: + if not isinstance(value, dict): + return value + normalized = dict(value) + records_payload = normalized.get("records") + if isinstance(records_payload, list): + normalized["records"] = [ + record + for record in ( + MessageHistoryRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in records_payload + ) + if record is not None + ] + return normalized + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MessageHistoryPage | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseRecord(_ManagerModel): + kb_id: str + kb_name: str + description: str | None = None + emoji: str | None = None + embedding_provider_id: str + rerank_provider_id: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + doc_count: int = 0 + chunk_count: int = 0 + created_at: str | None = None + updated_at: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> KnowledgeBaseRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseCreateParams(_ManagerModel): + kb_name: str + embedding_provider_id: str + description: str | None = None + emoji: str | None = None + rerank_provider_id: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + + +class KnowledgeBaseUpdateParams(_ManagerModel): + kb_name: str | None = None + embedding_provider_id: str | None = None + description: str | None = None + emoji: str | None = None + rerank_provider_id: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + + +class KnowledgeBaseDocumentRecord(_ManagerModel): + doc_id: str + kb_id: str + doc_name: str + file_type: str + file_size: int + file_path: str = "" + chunk_count: int = 0 + media_count: int = 0 + created_at: str | None = None + updated_at: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> KnowledgeBaseDocumentRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseRetrieveResultItem(_ManagerModel): + chunk_id: str + doc_id: str + kb_id: str + kb_name: str + doc_name: str + chunk_index: int + content: str + score: float + char_count: int + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> KnowledgeBaseRetrieveResultItem | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseRetrieveResult(_ManagerModel): + context_text: str + results: list[KnowledgeBaseRetrieveResultItem] = Field(default_factory=list) + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> KnowledgeBaseRetrieveResult | None: + if not isinstance(payload, dict): + return None + items = payload.get("results") + normalized_items = ( + [ + item.model_dump() + for item in ( + KnowledgeBaseRetrieveResultItem.from_payload(candidate) + if isinstance(candidate, dict) + else None + for candidate in items + ) + if item is not None + ] + if isinstance(items, list) + else [] + ) + return cls.model_validate( + { + "context_text": str(payload.get("context_text", "")), + "results": normalized_items, + } + ) + + +class KnowledgeBaseDocumentUploadParams(_ManagerModel): + file_token: str | None = None + url: str | None = None + text: str | None = None + file_name: str | None = None + file_type: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + batch_size: int | None = None + tasks_limit: int | None = None + max_retries: int | None = None + enable_cleaning: bool | None = None + cleaning_provider_id: str | None = None + + @model_validator(mode="after") + def _validate_source(self) -> KnowledgeBaseDocumentUploadParams: + if any( + isinstance(value, str) and value.strip() + for value in (self.file_token, self.url, self.text) + ): + return self + raise ValueError( + "knowledge base document upload requires file_token, url, or text" + ) + + +class PersonaManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def get_persona(self, persona_id: str) -> PersonaRecord: + try: + output = await self._proxy.call( + "persona.get", + {"persona_id": str(persona_id)}, + ) + except AstrBotError as exc: + if exc.code == ErrorCodes.INVALID_INPUT: + raise ValueError(f"persona not found: {persona_id}") from exc + raise + persona = PersonaRecord.from_payload(output.get("persona")) + if persona is None: + raise ValueError(f"persona not found: {persona_id}") + return persona + + async def get_all_personas(self) -> list[PersonaRecord]: + output = await self._proxy.call("persona.list", {}) + items = output.get("personas") + if not isinstance(items, list): + return [] + return [ + persona + for persona in ( + PersonaRecord.from_payload(item) if isinstance(item, dict) else None + for item in items + ) + if persona is not None + ] + + async def create_persona(self, params: PersonaCreateParams) -> PersonaRecord: + output = await self._proxy.call( + "persona.create", + {"persona": params.to_payload()}, + ) + persona = PersonaRecord.from_payload(output.get("persona")) + if persona is None: + raise ValueError("persona.create returned no persona") + return persona + + async def update_persona( + self, + persona_id: str, + params: PersonaUpdateParams, + ) -> PersonaRecord | None: + output = await self._proxy.call( + "persona.update", + {"persona_id": str(persona_id), "persona": params.to_update_payload()}, + ) + return PersonaRecord.from_payload(output.get("persona")) + + async def delete_persona(self, persona_id: str) -> None: + await self._proxy.call("persona.delete", {"persona_id": str(persona_id)}) + + +class ConversationManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def new_conversation( + self, + session: str | MessageSession, + params: ConversationCreateParams | None = None, + ) -> str: + output = await self._proxy.call( + "conversation.new", + { + "session": _normalize_session(session), + "conversation": (params.to_payload() if params is not None else {}), + }, + ) + return str(output.get("conversation_id", "")) + + async def switch_conversation( + self, + session: str | MessageSession, + conversation_id: str, + ) -> None: + await self._proxy.call( + "conversation.switch", + { + "session": _normalize_session(session), + "conversation_id": str(conversation_id), + }, + ) + + async def delete_conversation( + self, + session: str | MessageSession, + conversation_id: str | None = None, + ) -> None: + """Delete one conversation for the session. + + When ``conversation_id`` is ``None``, this deletes the current selected + conversation for the session only. It does not delete all conversations + under the session. + """ + + await self._proxy.call( + "conversation.delete", + { + "session": _normalize_session(session), + "conversation_id": conversation_id, + }, + ) + + async def get_conversation( + self, + session: str | MessageSession, + conversation_id: str, + *, + create_if_not_exists: bool = False, + ) -> ConversationRecord | None: + output = await self._proxy.call( + "conversation.get", + { + "session": _normalize_session(session), + "conversation_id": str(conversation_id), + "create_if_not_exists": bool(create_if_not_exists), + }, + ) + return ConversationRecord.from_payload(output.get("conversation")) + + async def get_current_conversation( + self, + session: str | MessageSession, + *, + create_if_not_exists: bool = False, + ) -> ConversationRecord | None: + output = await self._proxy.call( + "conversation.get_current", + { + "session": _normalize_session(session), + "create_if_not_exists": bool(create_if_not_exists), + }, + ) + return ConversationRecord.from_payload(output.get("conversation")) + + async def get_conversations( + self, + session: str | MessageSession | None = None, + *, + platform_id: str | None = None, + ) -> list[ConversationRecord]: + output = await self._proxy.call( + "conversation.list", + { + "session": ( + _normalize_session(session) if session is not None else None + ), + "platform_id": platform_id, + }, + ) + items = output.get("conversations") + if not isinstance(items, list): + return [] + return [ + conversation + for conversation in ( + ConversationRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if conversation is not None + ] + + async def update_conversation( + self, + session: str | MessageSession, + conversation_id: str | None = None, + params: ConversationUpdateParams | None = None, + ) -> None: + await self._proxy.call( + "conversation.update", + { + "session": _normalize_session(session), + "conversation_id": conversation_id, + "conversation": ( + params.to_update_payload() if params is not None else {} + ), + }, + ) + + async def unset_persona( + self, + session: str | MessageSession, + conversation_id: str | None = None, + ) -> None: + await self._proxy.call( + "conversation.unset_persona", + { + "session": _normalize_session(session), + "conversation_id": conversation_id, + }, + ) + + +class MessageHistoryManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def list( + self, + session: MessageSession, + *, + cursor: str | None = None, + limit: int = 50, + ) -> MessageHistoryPage: + output = await self._proxy.call( + "message_history.list", + { + "session": _require_message_history_session(session), + "cursor": str(cursor) if cursor is not None else None, + "limit": int(limit), + }, + ) + page = MessageHistoryPage.from_payload(output.get("page")) + if page is None: + raise ValueError("message_history.list returned no page") + return page + + async def get( + self, + session: MessageSession, + record_id: int, + ) -> MessageHistoryRecord | None: + output = await self._proxy.call( + "message_history.get_by_id", + { + "session": _require_message_history_session(session), + "record_id": int(record_id), + }, + ) + return MessageHistoryRecord.from_payload(output.get("record")) + + async def get_by_id( + self, + session: MessageSession, + record_id: int, + ) -> MessageHistoryRecord | None: + return await self.get(session, record_id) + + async def append( + self, + session: MessageSession, + *, + parts: list[BaseMessageComponent], + sender: MessageHistorySender | dict[str, Any], + metadata: dict[str, Any] | None = None, + idempotency_key: str | None = None, + ) -> MessageHistoryRecord: + if isinstance(sender, MessageHistorySender): + sender_payload = sender.to_payload() + elif isinstance(sender, dict): + sender_payload = MessageHistorySender.model_validate(sender).to_payload() + else: + raise TypeError( + "message_history.append requires MessageHistorySender for sender" + ) + output = await self._proxy.call( + "message_history.append", + { + "session": _require_message_history_session(session), + "sender": sender_payload, + "parts": _normalize_message_history_parts(parts), + "metadata": dict(metadata or {}), + "idempotency_key": ( + str(idempotency_key) if idempotency_key is not None else None + ), + }, + ) + record = MessageHistoryRecord.from_payload(output.get("record")) + if record is None: + raise ValueError("message_history.append returned no record") + return record + + async def delete_before( + self, + session: MessageSession, + *, + before: datetime, + ) -> int: + output = await self._proxy.call( + "message_history.delete_before", + { + "session": _require_message_history_session(session), + "before": _normalize_message_history_boundary(before), + }, + ) + return int(output.get("deleted_count", 0) or 0) + + async def delete_after( + self, + session: MessageSession, + *, + after: datetime, + ) -> int: + output = await self._proxy.call( + "message_history.delete_after", + { + "session": _require_message_history_session(session), + "after": _normalize_message_history_boundary(after), + }, + ) + return int(output.get("deleted_count", 0) or 0) + + async def delete_all(self, session: MessageSession) -> int: + output = await self._proxy.call( + "message_history.delete_all", + {"session": _require_message_history_session(session)}, + ) + return int(output.get("deleted_count", 0) or 0) + + +class KnowledgeBaseManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def list_kbs(self) -> list[KnowledgeBaseRecord]: + output = await self._proxy.call("kb.list", {}) + items = output.get("kbs") + if not isinstance(items, list): + return [] + return [ + kb + for kb in ( + KnowledgeBaseRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if kb is not None + ] + + async def get_kb(self, kb_id: str) -> KnowledgeBaseRecord | None: + output = await self._proxy.call("kb.get", {"kb_id": str(kb_id)}) + return KnowledgeBaseRecord.from_payload(output.get("kb")) + + async def create_kb( + self, + params: KnowledgeBaseCreateParams, + ) -> KnowledgeBaseRecord: + output = await self._proxy.call("kb.create", {"kb": params.to_payload()}) + kb = KnowledgeBaseRecord.from_payload(output.get("kb")) + if kb is None: + raise ValueError("kb.create returned no knowledge base") + return kb + + async def update_kb( + self, + kb_id: str, + params: KnowledgeBaseUpdateParams, + ) -> KnowledgeBaseRecord | None: + output = await self._proxy.call( + "kb.update", + {"kb_id": str(kb_id), "kb": params.to_update_payload()}, + ) + return KnowledgeBaseRecord.from_payload(output.get("kb")) + + async def delete_kb(self, kb_id: str) -> bool: + output = await self._proxy.call("kb.delete", {"kb_id": str(kb_id)}) + return bool(output.get("deleted", False)) + + async def retrieve( + self, + query: str, + *, + kb_ids: list[str] | None = None, + kb_names: list[str] | None = None, + top_k_fusion: int | None = None, + top_m_final: int | None = None, + ) -> KnowledgeBaseRetrieveResult | None: + request_payload: dict[str, Any] = { + "query": str(query), + "kb_ids": [str(item) for item in (kb_ids or [])], + "kb_names": [str(item) for item in (kb_names or [])], + } + if top_k_fusion is not None: + request_payload["top_k_fusion"] = int(top_k_fusion) + if top_m_final is not None: + request_payload["top_m_final"] = int(top_m_final) + output = await self._proxy.call( + "kb.retrieve", + request_payload, + ) + return KnowledgeBaseRetrieveResult.from_payload(output.get("result")) + + async def upload_document( + self, + kb_id: str, + params: KnowledgeBaseDocumentUploadParams, + ) -> KnowledgeBaseDocumentRecord: + output = await self._proxy.call( + "kb.document.upload", + {"kb_id": str(kb_id), "document": params.to_payload()}, + ) + document = KnowledgeBaseDocumentRecord.from_payload(output.get("document")) + if document is None: + raise ValueError("kb.document.upload returned no document") + return document + + async def list_documents( + self, + kb_id: str, + *, + offset: int = 0, + limit: int = 100, + ) -> list[KnowledgeBaseDocumentRecord]: + output = await self._proxy.call( + "kb.document.list", + {"kb_id": str(kb_id), "offset": int(offset), "limit": int(limit)}, + ) + items = output.get("documents") + if not isinstance(items, list): + return [] + return [ + document + for document in ( + KnowledgeBaseDocumentRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if document is not None + ] + + async def get_document( + self, + kb_id: str, + doc_id: str, + ) -> KnowledgeBaseDocumentRecord | None: + output = await self._proxy.call( + "kb.document.get", + {"kb_id": str(kb_id), "doc_id": str(doc_id)}, + ) + return KnowledgeBaseDocumentRecord.from_payload(output.get("document")) + + async def delete_document( + self, + kb_id: str, + doc_id: str, + ) -> bool: + output = await self._proxy.call( + "kb.document.delete", + {"kb_id": str(kb_id), "doc_id": str(doc_id)}, + ) + return bool(output.get("deleted", False)) + + async def refresh_document( + self, + kb_id: str, + doc_id: str, + ) -> KnowledgeBaseDocumentRecord | None: + output = await self._proxy.call( + "kb.document.refresh", + {"kb_id": str(kb_id), "doc_id": str(doc_id)}, + ) + return KnowledgeBaseDocumentRecord.from_payload(output.get("document")) + + +__all__ = [ + "ConversationCreateParams", + "ConversationManagerClient", + "ConversationRecord", + "ConversationUpdateParams", + "KnowledgeBaseCreateParams", + "KnowledgeBaseDocumentRecord", + "KnowledgeBaseDocumentUploadParams", + "KnowledgeBaseManagerClient", + "KnowledgeBaseRecord", + "KnowledgeBaseRetrieveResult", + "KnowledgeBaseRetrieveResultItem", + "KnowledgeBaseUpdateParams", + "MessageHistoryManagerClient", + "MessageHistoryPage", + "MessageHistoryRecord", + "MessageHistorySender", + "PersonaCreateParams", + "PersonaManagerClient", + "PersonaRecord", + "PersonaUpdateParams", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/mcp.py b/astrbot-sdk/src/astrbot_sdk/clients/mcp.py new file mode 100644 index 0000000000..90a5f3391d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/mcp.py @@ -0,0 +1,415 @@ +"""MCP 管理客户端。 + +提供本地 MCP 服务、全局 MCP 服务和临时 MCP session 的 SDK 封装。 +""" + +from __future__ import annotations + +from contextlib import AbstractAsyncContextManager +from dataclasses import dataclass, field +from enum import Enum +from types import TracebackType +from typing import Any + +from ..errors import AstrBotError +from ._errors import wrap_client_exception +from ._proxy import CapabilityProxy + + +class MCPServerScope(str, Enum): + local = "local" + global_ = "global" + + +@dataclass(slots=True) +class MCPServerRecord: + """MCP 服务快照。""" + + name: str + scope: MCPServerScope + active: bool + running: bool + config: dict[str, Any] = field(default_factory=dict) + tools: list[str] = field(default_factory=list) + errlogs: list[str] = field(default_factory=list) + last_error: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MCPServerRecord | None: + if not isinstance(payload, dict): + return None + scope_value = str(payload.get("scope") or MCPServerScope.local.value).strip() + try: + scope = MCPServerScope(scope_value) + except ValueError: + scope = MCPServerScope.local + return cls( + name=str(payload.get("name", "")), + scope=scope, + active=bool(payload.get("active", False)), + running=bool(payload.get("running", False)), + config=( + dict(payload.get("config")) + if isinstance(payload.get("config"), dict) + else {} + ), + tools=[ + str(item) + for item in payload.get("tools", []) + if isinstance(item, str) and item + ] + if isinstance(payload.get("tools"), list) + else [], + errlogs=[ + str(item) + for item in payload.get("errlogs", []) + if isinstance(item, str) + ] + if isinstance(payload.get("errlogs"), list) + else [], + last_error=( + str(payload.get("last_error")) + if payload.get("last_error") is not None + else None + ), + ) + + +def _server_records_from_payload(items: Any) -> list[MCPServerRecord]: + if not isinstance(items, list): + return [] + return [ + record + for record in ( + MCPServerRecord.from_payload(item) if isinstance(item, dict) else None + for item in items + ) + if record is not None + ] + + +def _require_server_record( + payload: dict[str, Any], + *, + action: str, +) -> MCPServerRecord: + record = MCPServerRecord.from_payload(payload.get("server")) + if record is None: + raise ValueError(f"{action} returned no server") + return record + + +class MCPSession(AbstractAsyncContextManager["MCPSession"]): + """临时 MCP session 的异步上下文封装。""" + + def __init__( + self, + proxy: CapabilityProxy, + *, + name: str, + config: dict[str, Any], + timeout: float, + ) -> None: + self._proxy = proxy + self._name = str(name) + self._config = dict(config) + self._timeout = float(timeout) + self._session_id: str | None = None + self._tools: list[str] = [] + + async def __aenter__(self) -> MCPSession: + try: + output = await self._proxy.call( + "mcp.session.open", + { + "name": self._name, + "config": dict(self._config), + "timeout": self._timeout, + }, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPSession", + method_name="open", + details=f"name={self._name!r}, timeout={self._timeout!r}", + exc=exc, + ) from exc + session_id = str(output.get("session_id", "")).strip() + if not session_id: + raise ValueError("mcp.session.open returned no session_id") + self._session_id = session_id + tools = output.get("tools") + self._tools = ( + [str(item) for item in tools if isinstance(item, str)] + if isinstance(tools, list) + else [] + ) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + session_id = self._session_id + self._session_id = None + self._tools = [] + if not session_id: + return + try: + await self._proxy.call("mcp.session.close", {"session_id": session_id}) + except AstrBotError: + raise + except Exception: + # Session cleanup should not mask the original error raised inside the + # managed block. + if exc_type is None: + raise + + async def call_tool( + self, + tool_name: str, + args: dict[str, Any] | None = None, + ) -> dict[str, Any]: + session_id = self._require_session_id() + try: + output = await self._proxy.call( + "mcp.session.call_tool", + { + "session_id": session_id, + "tool_name": str(tool_name), + "args": dict(args or {}), + }, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPSession", + method_name="call_tool", + details=f"session_id={session_id!r}, tool_name={str(tool_name)!r}", + exc=exc, + ) from exc + result = output.get("result") + if not isinstance(result, dict): + raise ValueError("mcp.session.call_tool returned no result object") + return dict(result) + + async def list_tools(self) -> list[str]: + session_id = self._require_session_id() + try: + output = await self._proxy.call( + "mcp.session.list_tools", + {"session_id": session_id}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPSession", + method_name="list_tools", + details=f"session_id={session_id!r}", + exc=exc, + ) from exc + tools = output.get("tools") + self._tools = ( + [str(item) for item in tools if isinstance(item, str)] + if isinstance(tools, list) + else [] + ) + return list(self._tools) + + def _require_session_id(self) -> str: + if self._session_id is None: + raise RuntimeError("MCP session is not active; use 'async with'") + return self._session_id + + +class MCPManagerClient: + """MCP 服务管理客户端。""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def get_server(self, name: str) -> MCPServerRecord | None: + try: + output = await self._proxy.call("mcp.local.get", {"name": str(name)}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="get_server", + details=f"name={str(name)!r}", + exc=exc, + ) from exc + return MCPServerRecord.from_payload(output.get("server")) + + async def list_servers(self) -> list[MCPServerRecord]: + try: + output = await self._proxy.call("mcp.local.list", {}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="list_servers", + exc=exc, + ) from exc + return _server_records_from_payload(output.get("servers")) + + async def enable_server(self, name: str) -> MCPServerRecord: + try: + output = await self._proxy.call("mcp.local.enable", {"name": str(name)}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="enable_server", + details=f"name={str(name)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.local.enable") + + async def disable_server(self, name: str) -> MCPServerRecord: + try: + output = await self._proxy.call("mcp.local.disable", {"name": str(name)}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="disable_server", + details=f"name={str(name)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.local.disable") + + async def wait_until_ready( + self, + name: str, + *, + timeout: float = 30.0, + ) -> MCPServerRecord: + try: + output = await self._proxy.call( + "mcp.local.wait_until_ready", + {"name": str(name), "timeout": float(timeout)}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="wait_until_ready", + details=f"name={str(name)!r}, timeout={float(timeout)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.local.wait_until_ready") + + def session( + self, + name: str, + config: dict[str, Any], + *, + timeout: float = 30.0, + ) -> MCPSession: + return MCPSession( + self._proxy, + name=str(name), + config=dict(config), + timeout=float(timeout), + ) + + async def register_global_server( + self, + name: str, + config: dict[str, Any], + *, + timeout: float = 30.0, + ) -> MCPServerRecord: + try: + output = await self._proxy.call( + "mcp.global.register", + { + "name": str(name), + "config": dict(config), + "timeout": float(timeout), + }, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="register_global_server", + details=f"name={str(name)!r}, timeout={float(timeout)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.global.register") + + async def get_global_server(self, name: str) -> MCPServerRecord | None: + try: + output = await self._proxy.call("mcp.global.get", {"name": str(name)}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="get_global_server", + details=f"name={str(name)!r}", + exc=exc, + ) from exc + return MCPServerRecord.from_payload(output.get("server")) + + async def list_global_servers(self) -> list[MCPServerRecord]: + try: + output = await self._proxy.call("mcp.global.list", {}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="list_global_servers", + exc=exc, + ) from exc + return _server_records_from_payload(output.get("servers")) + + async def enable_global_server( + self, + name: str, + *, + timeout: float = 30.0, + ) -> MCPServerRecord: + try: + output = await self._proxy.call( + "mcp.global.enable", + {"name": str(name), "timeout": float(timeout)}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="enable_global_server", + details=f"name={str(name)!r}, timeout={float(timeout)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.global.enable") + + async def disable_global_server(self, name: str) -> MCPServerRecord: + try: + output = await self._proxy.call("mcp.global.disable", {"name": str(name)}) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="disable_global_server", + details=f"name={str(name)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.global.disable") + + async def unregister_global_server(self, name: str) -> MCPServerRecord: + try: + output = await self._proxy.call( + "mcp.global.unregister", {"name": str(name)} + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MCPManagerClient", + method_name="unregister_global_server", + details=f"name={str(name)!r}", + exc=exc, + ) from exc + return _require_server_record(output, action="mcp.global.unregister") + + +__all__ = [ + "MCPManagerClient", + "MCPSession", + "MCPServerRecord", + "MCPServerScope", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/memory.py b/astrbot-sdk/src/astrbot_sdk/clients/memory.py new file mode 100644 index 0000000000..55d302ca4f --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/memory.py @@ -0,0 +1,426 @@ +"""记忆客户端模块。 + +提供 AI 记忆存储能力,用于存储和检索对话记忆、用户偏好等上下文数据。 + +设计说明: + MemoryClient 与 DBClient 的区别: + - DBClient: 简单的键值存储,精确匹配 + - MemoryClient: 支持基于当前 bridge 行为的记忆检索,适合 AI 上下文管理 + + 记忆系统可用于: + - 存储用户偏好和设置 + - 记录对话摘要 + - 缓存 AI 推理结果 +""" + +from __future__ import annotations + +from typing import Any, Literal + +from .._internal.memory_utils import join_memory_namespace +from ._proxy import CapabilityProxy + + +def _normalize_search_item(item: Any) -> dict[str, Any] | None: + if not isinstance(item, dict): + return None + normalized = dict(item) + value = normalized.get("value") + if isinstance(value, dict): + for key, payload_value in value.items(): + normalized.setdefault(str(key), payload_value) + return normalized + + +class MemoryClient: + """记忆客户端。 + + 提供 AI 记忆的存储和检索能力。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__( + self, + proxy: CapabilityProxy, + *, + namespace: str | None = None, + ) -> None: + """初始化记忆客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + self._namespace = join_memory_namespace(namespace) + + def namespace(self, *parts: Any) -> MemoryClient: + """创建一个工作在子命名空间中的派生客户端。""" + + return MemoryClient( + self._proxy, + namespace=join_memory_namespace(self._namespace, *parts), + ) + + def _resolve_exact_namespace(self, namespace: str | None) -> str: + if namespace is None: + return self._namespace + return join_memory_namespace(self._namespace, namespace) + + def _resolve_scope_namespace(self, namespace: str | None) -> tuple[bool, str]: + if namespace is None: + if self._namespace: + return True, self._namespace + return False, "" + return True, join_memory_namespace(self._namespace, namespace) + + async def search( + self, + query: str, + *, + mode: Literal["auto", "keyword", "vector", "hybrid"] = "auto", + limit: int | None = None, + min_score: float | None = None, + provider_id: str | None = None, + namespace: str | None = None, + include_descendants: bool = True, + ) -> list[dict[str, Any]]: + """搜索记忆项。 + + 默认会在有 embedding provider 时执行 hybrid 检索, + 否则退化为关键词检索。返回结果包含 `score` 与 `match_type` 字段。 + + Args: + query: 搜索查询文本 + mode: 搜索模式,支持 auto/keyword/vector/hybrid + limit: 最大返回条数 + min_score: 最低分数阈值 + provider_id: 指定 embedding provider,默认使用当前激活的 provider + + Returns: + 匹配的记忆项列表,按相关度排序 + + 示例: + results = await ctx.memory.search( + "用户喜欢什么颜色", + mode="hybrid", + limit=5, + ) + for item in results: + print(item["key"], item["score"], item["match_type"]) + """ + payload: dict[str, Any] = {"query": query, "mode": mode} + if limit is not None: + payload["limit"] = limit + if min_score is not None: + payload["min_score"] = min_score + if provider_id is not None: + payload["provider_id"] = provider_id + has_namespace, resolved_namespace = self._resolve_scope_namespace(namespace) + if has_namespace: + payload["namespace"] = resolved_namespace + payload["include_descendants"] = bool(include_descendants) + output = await self._proxy.call("memory.search", payload) + items = output.get("items") + if not isinstance(items, (list, tuple)): + return [] + normalized_items: list[dict[str, Any]] = [] + for item in items: + normalized = _normalize_search_item(item) + if normalized is not None: + normalized_items.append(normalized) + return normalized_items + + async def save( + self, + key: str, + value: dict[str, Any] | None = None, + namespace: str | None = None, + **extra: Any, + ) -> None: + """保存记忆项。 + + 将数据存储到记忆系统,可通过 search() 检索或 get() 精确获取。 + + Args: + key: 记忆项的唯一标识键 + value: 要存储的数据字典 + **extra: 额外的键值对,会合并到 value 中 + Raises: + TypeError: 如果 value 不是 dict 类型 + 示例: + 保存用户偏好 + await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"}) + + 使用关键字参数 + await ctx.memory.save("note", None, content="重要笔记", tags=["work"]) + + 使用 embedding_text 显式指定检索文本 + await ctx.memory.save( + "profile", + {"name": "alice", "embedding_text": "Alice 喜欢蓝色和海边"}, + ) + """ + if value is not None and not isinstance(value, dict): + raise TypeError("memory.save 的 value 必须是 dict") + payload = dict(value or {}) + if extra: + payload.update(extra) + request: dict[str, Any] = {"key": key, "value": payload} + request["namespace"] = self._resolve_exact_namespace(namespace) + await self._proxy.call("memory.save", request) + + async def get( + self, + key: str, + *, + namespace: str | None = None, + ) -> dict[str, Any] | None: + """精确获取单个记忆项。 + + 通过唯一键精确获取记忆内容,不经过搜索匹配。 + + Args: + key: 记忆项的唯一键 + + Returns: + 记忆项内容字典,若不存在则返回 None + + 示例: + pref = await ctx.memory.get("user_pref") + if pref: + print(f"用户偏好主题: {pref.get('theme')}") + """ + payload: dict[str, Any] = {"key": key} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.get", payload) + value = output.get("value") + return value if isinstance(value, dict) else None + + async def list_keys( + self, + *, + namespace: str | None = None, + ) -> list[str]: + """列出指定精确命名空间下的全部键。""" + + payload: dict[str, Any] = { + "namespace": self._resolve_exact_namespace(namespace) + } + output = await self._proxy.call("memory.list_keys", payload) + keys = output.get("keys") + if not isinstance(keys, (list, tuple)): + return [] + return [str(item) for item in keys] + + async def exists( + self, + key: str, + *, + namespace: str | None = None, + ) -> bool: + """检查指定精确命名空间中是否存在某个键。""" + + payload: dict[str, Any] = {"key": key} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.exists", payload) + return bool(output.get("exists", False)) + + async def delete( + self, + key: str, + *, + namespace: str | None = None, + ) -> None: + """删除记忆项。 + + Args: + key: 要删除的记忆项键名 + + 示例: + await ctx.memory.delete("old_note") + """ + payload: dict[str, Any] = {"key": key} + payload["namespace"] = self._resolve_exact_namespace(namespace) + await self._proxy.call("memory.delete", payload) + + async def clear_namespace( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + """清空命名空间中的记忆项,可选递归清空子命名空间。""" + + payload: dict[str, Any] = { + "namespace": self._resolve_exact_namespace(namespace), + "include_descendants": bool(include_descendants), + } + output = await self._proxy.call("memory.clear_namespace", payload) + return int(output.get("deleted_count", 0)) + + async def save_with_ttl( + self, + key: str, + value: dict[str, Any], + ttl_seconds: int, + *, + namespace: str | None = None, + ) -> None: + """保存带过期时间的记忆项。 + + 与 save() 不同,此方法允许设置记忆项的存活时间(TTL), + 过期后记忆项将自动删除。 + + Args: + key: 记忆项的唯一标识键 + value: 要存储的数据字典 + ttl_seconds: 存活时间(秒),必须大于 0 + + Raises: + TypeError: 如果 value 不是 dict 类型 + ValueError: 如果 ttl_seconds 小于 1 + + 示例: + # 保存临时会话状态,1小时后过期 + await ctx.memory.save_with_ttl( + "session_temp", + {"state": "waiting"}, + ttl_seconds=3600, + ) + """ + if not isinstance(value, dict): + raise TypeError("memory.save_with_ttl 的 value 必须是 dict") + if ttl_seconds < 1: + raise ValueError("ttl_seconds 必须大于 0") + payload: dict[str, Any] = { + "key": key, + "value": value, + "ttl_seconds": ttl_seconds, + } + payload["namespace"] = self._resolve_exact_namespace(namespace) + await self._proxy.call("memory.save_with_ttl", payload) + + async def get_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> list[dict[str, Any]]: + """批量获取多个记忆项。 + + 一次性获取多个键对应的记忆内容,比多次调用 get() 更高效。 + + Args: + keys: 记忆项键名列表 + + Returns: + 记忆项列表,每项包含 key 和 value 字段, + 不存在的键返回 value 为 None + + 示例: + items = await ctx.memory.get_many(["pref1", "pref2", "pref3"]) + for item in items: + if item["value"]: + print(f"{item['key']}: {item['value']}") + """ + payload: dict[str, Any] = {"keys": keys} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.get_many", payload) + items = output.get("items") + if not isinstance(items, (list, tuple)): + return [] + return [dict(item) for item in items if isinstance(item, dict)] + + async def delete_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> int: + """批量删除多个记忆项。 + + 一次性删除多个键对应的记忆项,返回实际删除的数量。 + + Args: + keys: 要删除的记忆项键名列表 + + Returns: + 实际删除的记忆项数量 + + 示例: + deleted = await ctx.memory.delete_many(["old1", "old2", "old3"]) + print(f"删除了 {deleted} 条记忆") + """ + payload: dict[str, Any] = {"keys": keys} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.delete_many", payload) + return int(output.get("deleted_count", 0)) + + async def count( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + """统计命名空间中的记忆项数量,可选包含子命名空间。""" + + payload: dict[str, Any] = { + "namespace": self._resolve_exact_namespace(namespace), + "include_descendants": bool(include_descendants), + } + output = await self._proxy.call("memory.count", payload) + return int(output.get("count", 0)) + + async def stats( + self, + *, + namespace: str | None = None, + include_descendants: bool = True, + ) -> dict[str, Any]: + """获取记忆系统统计信息。 + + 返回记忆系统的当前状态,包括条目数、索引状态和脏索引数量。 + + Returns: + 统计信息字典,包含: + - total_items: 总记忆条目数 + - total_bytes: 总占用字节数(可选) + - ttl_entries: 带过期时间的条目数(可选) + - indexed_items: 已建立检索索引的条目数(可选) + - embedded_items: 已生成向量的条目数(可选) + - dirty_items: 等待重建索引的条目数(可选) + + 示例: + stats = await ctx.memory.stats() + print(f"记忆库共有 {stats['total_items']} 条记录") + if "embedded_items" in stats: + print(f"其中 {stats['embedded_items']} 条已经向量化") + """ + payload: dict[str, Any] = { + "include_descendants": bool(include_descendants), + } + has_namespace, resolved_namespace = self._resolve_scope_namespace(namespace) + if has_namespace: + payload["namespace"] = resolved_namespace + output = await self._proxy.call("memory.stats", payload) + stats = { + "total_items": output.get("total_items", 0), + "total_bytes": output.get("total_bytes"), + } + for key in ( + "namespace", + "namespace_count", + "fts_enabled", + "vector_backend", + "vector_indexes", + "plugin_id", + "ttl_entries", + "indexed_items", + "embedded_items", + "dirty_items", + ): + if key in output: + stats[key] = output.get(key) + return stats diff --git a/astrbot-sdk/src/astrbot_sdk/clients/metadata.py b/astrbot-sdk/src/astrbot_sdk/clients/metadata.py new file mode 100644 index 0000000000..9d68314b22 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/metadata.py @@ -0,0 +1,145 @@ +"""元数据客户端模块。 + +提供插件元数据查询能力。 + +功能说明: + - 查询已加载插件信息 + - 获取插件列表 + - 访问当前插件配置 + +安全边界: + 插件身份由运行时透传到协议层;客户端只暴露业务参数,不接受外部指定调用者。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from ._errors import wrap_client_exception +from ._proxy import CapabilityProxy + + +@dataclass +class StarMetadata: + """插件元数据。""" + + name: str + display_name: str + description: str + repo: str + author: str + version: str + enabled: bool = True + support_platforms: list[str] = field(default_factory=list) + astrbot_version: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> StarMetadata: + raw_support_platforms = data.get("support_platforms") + support_platforms = ( + [str(item) for item in raw_support_platforms if isinstance(item, str)] + if isinstance(raw_support_platforms, list) + else [] + ) + return cls( + name=str(data.get("name", "")), + display_name=str(data.get("display_name", data.get("name", ""))), + description=str(data.get("desc", data.get("description", ""))), + repo=str(data.get("repo", "")), + author=str(data.get("author", "")), + version=str(data.get("version", "0.0.0")), + enabled=bool(data.get("enabled", True)), + support_platforms=support_platforms, + astrbot_version=( + str(data.get("astrbot_version")) + if data.get("astrbot_version") is not None + else None + ), + ) + + +PluginMetadata = StarMetadata + + +class MetadataClient: + """元数据能力客户端。""" + + def __init__(self, proxy: CapabilityProxy, plugin_id: str) -> None: + self._proxy = proxy + self._plugin_id = plugin_id + + async def get_plugin(self, name: str) -> StarMetadata | None: + try: + output = await self._proxy.call( + "metadata.get_plugin", + {"name": name}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MetadataClient", + method_name="get_plugin", + details=f"name={name!r}", + exc=exc, + ) from exc + data = output.get("plugin") + if data is None: + return None + return StarMetadata.from_dict(data) + + async def list_plugins(self) -> list[StarMetadata]: + try: + output = await self._proxy.call("metadata.list_plugins", {}) + except Exception as exc: + raise wrap_client_exception( + client_name="MetadataClient", + method_name="list_plugins", + exc=exc, + ) from exc + items = output.get("plugins", []) + return [ + StarMetadata.from_dict(item) for item in items if isinstance(item, dict) + ] + + async def get_current_plugin(self) -> StarMetadata | None: + return await self.get_plugin(self._plugin_id) + + async def get_plugin_config(self, name: str | None = None) -> dict[str, Any] | None: + target = name or self._plugin_id + if target != self._plugin_id: + raise PermissionError( + "get_plugin_config 只允许访问当前插件自己的配置," + f"请求的插件 '{target}' 不是当前插件 '{self._plugin_id}'" + ) + try: + output = await self._proxy.call( + "metadata.get_plugin_config", + {"name": target}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MetadataClient", + method_name="get_plugin_config", + details=f"name={target!r}", + exc=exc, + ) from exc + config = output.get("config") + return dict(config) if isinstance(config, dict) else None + + async def save_plugin_config(self, config: dict[str, Any]) -> dict[str, Any]: + if not isinstance(config, dict): + raise TypeError("save_plugin_config requires a dict payload") + try: + output = await self._proxy.call( + "metadata.save_plugin_config", + {"config": dict(config)}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="MetadataClient", + method_name="save_plugin_config", + details=f"keys={sorted(str(key) for key in config)!r}", + exc=exc, + ) from exc + saved = output.get("config") + return dict(saved) if isinstance(saved, dict) else {} diff --git a/astrbot-sdk/src/astrbot_sdk/clients/permission.py b/astrbot-sdk/src/astrbot_sdk/clients/permission.py new file mode 100644 index 0000000000..546c8ea589 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/permission.py @@ -0,0 +1,100 @@ +"""权限能力客户端。""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict + +from ._proxy import CapabilityProxy + + +class PermissionCheckResult(BaseModel): + """权限检查结果。""" + + model_config = ConfigDict(extra="forbid") + + is_admin: bool + role: Literal["member", "admin"] + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> PermissionCheckResult | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class PermissionClient: + """权限查询客户端。""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def check( + self, + user_id: str, + session_id: str | None = None, + ) -> PermissionCheckResult: + payload: dict[str, Any] = {"user_id": str(user_id)} + if session_id is not None: + payload["session_id"] = str(session_id) + output = await self._proxy.call("permission.check", payload) + result = PermissionCheckResult.from_payload(output) + if result is None: + return PermissionCheckResult(is_admin=False, role="member") + return result + + async def get_admins(self) -> list[str]: + output = await self._proxy.call("permission.get_admins", {}) + admins = output.get("admins") + if not isinstance(admins, list): + return [] + return [str(item) for item in admins] + + +class PermissionManagerClient: + """权限管理客户端。""" + + def __init__( + self, + proxy: CapabilityProxy, + *, + source_event_payload: dict[str, Any] | None = None, + ) -> None: + self._proxy = proxy + self._source_event_payload = ( + dict(source_event_payload) if isinstance(source_event_payload, dict) else {} + ) + + def _caller_is_admin(self) -> bool: + return bool(self._source_event_payload.get("is_admin", False)) + + async def add_admin(self, user_id: str) -> bool: + output = await self._proxy.call( + "permission.manager.add_admin", + { + "user_id": str(user_id), + "_caller_is_admin": self._caller_is_admin(), + }, + ) + return bool(output.get("changed", False)) + + async def remove_admin(self, user_id: str) -> bool: + output = await self._proxy.call( + "permission.manager.remove_admin", + { + "user_id": str(user_id), + "_caller_is_admin": self._caller_is_admin(), + }, + ) + return bool(output.get("changed", False)) + + +__all__ = [ + "PermissionCheckResult", + "PermissionClient", + "PermissionManagerClient", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/platform.py b/astrbot-sdk/src/astrbot_sdk/clients/platform.py new file mode 100644 index 0000000000..7a4bcccacf --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/platform.py @@ -0,0 +1,339 @@ +"""平台客户端模块。 + +提供 astrbot-sdk 原生的平台能力调用。 + +设计边界: + - `PlatformClient` 只负责直接的平台 capability + - 迁移期消息桥接由独立迁移入口承接,不放进原生客户端 + - 富消息链通过 `platform.send_chain` 发送,链构建能力位于专门的消息模块 +""" + +from __future__ import annotations + +from collections.abc import Sequence +from enum import Enum +from typing import Any, cast + +from pydantic import BaseModel, ConfigDict, Field + +from ..message.components import BaseMessageComponent, Plain +from ..message.result import MessageChain +from ..message.session import MessageSession +from ..protocol.descriptors import SessionRef +from ._errors import wrap_client_exception +from ._proxy import CapabilityProxy + + +class _PlatformModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class PlatformStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + ERROR = "error" + STOPPED = "stopped" + + @classmethod + def from_value(cls, value: Any) -> PlatformStatus: + if isinstance(value, cls): + return value + try: + return cls(str(value).strip().lower()) + except ValueError: + return cls.PENDING + + +class PlatformError(_PlatformModel): + message: str + timestamp: str + traceback: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> PlatformError | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class PlatformStats(_PlatformModel): + id: str + type: str + display_name: str + status: PlatformStatus + started_at: str | None = None + error_count: int + last_error: PlatformError | None = None + unified_webhook: bool + meta: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> PlatformStats | None: + if not isinstance(payload, dict): + return None + normalized = dict(payload) + normalized["status"] = PlatformStatus.from_value(payload.get("status")) + normalized["last_error"] = PlatformError.from_payload(payload.get("last_error")) + meta = payload.get("meta") + normalized["meta"] = dict(meta) if isinstance(meta, dict) else {} + return cls.model_validate(normalized) + + +class PlatformClient: + """平台消息客户端。 + + 提供向聊天平台发送消息和获取信息的能力。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化平台客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + def _build_target_payload( + self, + session: str | SessionRef | MessageSession, + ) -> tuple[str, dict[str, Any]]: + if isinstance(session, SessionRef): + return session.session, {"target": session.to_payload()} + if isinstance(session, MessageSession): + return str(session), {} + return str(session), {} + + async def _coerce_chain_payload( + self, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + ) -> list[dict[str, Any]]: + if isinstance(content, str): + return await MessageChain( + [Plain(content, convert=False)] + ).to_payload_async() + if isinstance(content, MessageChain): + return await content.to_payload_async() + if ( + isinstance(content, Sequence) + and not isinstance(content, (str, bytes)) + and all(isinstance(item, BaseMessageComponent) for item in content) + ): + components = cast(Sequence[BaseMessageComponent], content) + return await MessageChain(list(components)).to_payload_async() + if ( + isinstance(content, Sequence) + and not isinstance(content, (str, bytes)) + and all(isinstance(item, dict) for item in content) + ): + payload_items = cast(Sequence[dict[str, Any]], content) + return [dict(item) for item in payload_items] + raise TypeError( + "content must be str, MessageChain, sequence of message components, " + "or sequence of platform.send_chain payload dicts" + ) + + async def send( + self, + session: str | SessionRef | MessageSession, + text: str, + ) -> dict[str, Any]: + """发送文本消息。 + + 向指定的会话(用户或群组)发送文本消息。 + + Args: + session: 统一消息来源标识 (UMO),格式如 "platform:instance:user_id" + text: 要发送的文本内容 + + Returns: + 发送结果,可能包含消息 ID 等信息 + + 示例: + # 发送消息到当前会话 + await ctx.platform.send(event.session_id, "收到您的消息!") + """ + session_id, extra = self._build_target_payload(session) + try: + return await self._proxy.call( + "platform.send", + {"session": session_id, "text": text, **extra}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="PlatformClient", + method_name="send", + details=f"session={session_id!r}", + exc=exc, + ) from exc + + async def send_image( + self, + session: str | SessionRef | MessageSession, + image_url: str, + ) -> dict[str, Any]: + """发送图片消息。 + + 向指定的会话发送图片,支持 URL 或本地路径。 + + Args: + session: 统一消息来源标识 (UMO) + image_url: 图片 URL 或本地文件路径 + + Returns: + 发送结果 + + 示例: + await ctx.platform.send_image( + event.session_id, + "https://example.com/image.png" + ) + """ + session_id, extra = self._build_target_payload(session) + try: + return await self._proxy.call( + "platform.send_image", + {"session": session_id, "image_url": image_url, **extra}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="PlatformClient", + method_name="send_image", + details=f"session={session_id!r}", + exc=exc, + ) from exc + + async def send_chain( + self, + session: str | SessionRef | MessageSession, + chain: MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]], + ) -> dict[str, Any]: + """发送富消息链。 + + Args: + session: 统一消息来源标识 (UMO) + chain: 序列化后的消息组件数组 + + Returns: + 发送结果 + """ + session_id, extra = self._build_target_payload(session) + chain_payload = await self._coerce_chain_payload(chain) + try: + return await self._proxy.call( + "platform.send_chain", + {"session": session_id, "chain": chain_payload, **extra}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="PlatformClient", + method_name="send_chain", + details=f"session={session_id!r}, items={len(chain_payload)!r}", + exc=exc, + ) from exc + + async def send_by_session( + self, + session: str | MessageSession, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + ) -> dict[str, Any]: + """主动向指定会话发送消息链。 + + `Sequence[dict]` 的结构与 `platform.send_chain` 完全一致: + 每一项都应是 `{"type": "...", "data": {...}}`。 + """ + chain_payload = await self._coerce_chain_payload(content) + session_id = str(session) + try: + return await self._proxy.call( + "platform.send_by_session", + {"session": session_id, "chain": chain_payload}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="PlatformClient", + method_name="send_by_session", + details=f"session={session_id!r}, items={len(chain_payload)!r}", + exc=exc, + ) from exc + + async def send_by_id( + self, + platform_id: str, + session_id: str, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + *, + message_type: str = "private", + ) -> dict[str, Any]: + """主动向指定平台会话发送消息。""" + session = MessageSession( + platform_id=str(platform_id), + message_type=str(message_type), + session_id=str(session_id), + ) + return await self.send_by_session(session, content) + + async def get_members( + self, + session: str | SessionRef | MessageSession, + ) -> list[dict[str, Any]]: + """获取群组成员列表。 + + 获取指定群组的成员信息列表。注意仅对群组会话有效。 + + Args: + session: 群组会话的统一消息来源标识 (UMO) + + Returns: + 成员信息列表,每个成员是一个字典,可能包含: + - user_id: 用户 ID + - nickname: 昵称 + - role: 角色 (owner, admin, member) + + 示例: + members = await ctx.platform.get_members(event.session_id) + for member in members: + print(f"{member['nickname']} ({member['user_id']})") + """ + session_id, extra = self._build_target_payload(session) + try: + output = await self._proxy.call( + "platform.get_members", + {"session": session_id, **extra}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="PlatformClient", + method_name="get_members", + details=f"session={session_id!r}", + exc=exc, + ) from exc + members = output.get("members") + if not isinstance(members, (list, tuple)): + return [] + return list(members) + + +__all__ = [ + "PlatformClient", + "PlatformError", + "PlatformStats", + "PlatformStatus", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/provider.py b/astrbot-sdk/src/astrbot_sdk/clients/provider.py new file mode 100644 index 0000000000..7142efee0a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/provider.py @@ -0,0 +1,353 @@ +"""Provider discovery and provider-management clients.""" + +from __future__ import annotations + +import asyncio +import contextlib +import inspect +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from ..llm.entities import ProviderMeta, ProviderType +from ..llm.providers import ( + ProviderProxy, + STTProvider, + TTSProvider, + provider_proxy_from_meta, +) +from ._proxy import CapabilityProxy + + +class _ProviderModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class ManagedProviderRecord(_ProviderModel): + id: str + model: str | None = None + type: str + provider_type: ProviderType + loaded: bool + enabled: bool + provider_source_id: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> ManagedProviderRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ProviderChangeEvent(_ProviderModel): + provider_id: str + provider_type: ProviderType + umo: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> ProviderChangeEvent | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ProviderClient: + """Provider 查询客户端。""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + @staticmethod + def _provider_meta_list(items: Any) -> list[ProviderMeta]: + if not isinstance(items, list): + return [] + providers: list[ProviderMeta] = [] + for item in items: + if not isinstance(item, dict): + continue + provider = ProviderMeta.from_payload(item) + if provider is not None: + providers.append(provider) + return providers + + async def list_all(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_tts(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_tts", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_stt(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_stt", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_embedding(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_embedding", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_rerank(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_rerank", {}) + return self._provider_meta_list(output.get("providers")) + + async def _get_tts_support_stream(self, provider_id: str) -> bool: + output = await self._proxy.call( + "provider.tts.support_stream", + {"provider_id": str(provider_id)}, + ) + return bool(output.get("supported", False)) + + async def _build_proxy(self, meta: ProviderMeta | None) -> ProviderProxy | None: + if meta is None: + return None + tts_supports_stream = None + if meta.provider_type == ProviderType.TEXT_TO_SPEECH: + tts_supports_stream = await self._get_tts_support_stream(meta.id) + return provider_proxy_from_meta( + self._proxy, + meta, + tts_supports_stream=tts_supports_stream, + ) + + async def get(self, provider_id: str) -> ProviderProxy | None: + output = await self._proxy.call( + "provider.get_by_id", + {"provider_id": str(provider_id)}, + ) + return await self._build_proxy( + ProviderMeta.from_payload(output.get("provider")) + ) + + async def get_using_chat(self, umo: str | None = None) -> ProviderMeta | None: + output = await self._proxy.call("provider.get_using", {"umo": umo}) + return ProviderMeta.from_payload(output.get("provider")) + + async def get_using_tts(self, umo: str | None = None) -> TTSProvider | None: + output = await self._proxy.call("provider.get_using_tts", {"umo": umo}) + provider = await self._build_proxy( + ProviderMeta.from_payload(output.get("provider")) + ) + return provider if isinstance(provider, TTSProvider) else None + + async def get_using_stt(self, umo: str | None = None) -> STTProvider | None: + output = await self._proxy.call("provider.get_using_stt", {"umo": umo}) + provider = await self._build_proxy( + ProviderMeta.from_payload(output.get("provider")) + ) + return provider if isinstance(provider, STTProvider) else None + + +class ProviderManagerClient: + """Provider 管理客户端。""" + + def __init__( + self, + proxy: CapabilityProxy, + *, + plugin_id: str | None = None, + logger: Any | None = None, + ) -> None: + self._proxy = proxy + self._plugin_id = plugin_id + self._logger = logger + self._change_hook_tasks: set[asyncio.Task[None]] = set() + + @staticmethod + def _provider_type_value(provider_type: ProviderType | str) -> str: + if isinstance(provider_type, ProviderType): + return provider_type.value + return str(provider_type).strip() + + @staticmethod + def _record_from_output(output: dict[str, Any]) -> ManagedProviderRecord | None: + return ManagedProviderRecord.from_payload(output.get("provider")) + + async def set_provider( + self, + provider_id: str, + provider_type: ProviderType | str, + umo: str | None = None, + ) -> None: + await self._proxy.call( + "provider.manager.set", + { + "provider_id": str(provider_id), + "provider_type": self._provider_type_value(provider_type), + "umo": umo, + }, + ) + + async def get_provider_by_id( + self, + provider_id: str, + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.get_by_id", + {"provider_id": str(provider_id)}, + ) + return self._record_from_output(output) + + async def get_merged_provider_config( + self, + provider_id: str, + ) -> dict[str, Any] | None: + output = await self._proxy.call( + "provider.manager.get_merged_provider_config", + {"provider_id": str(provider_id).strip()}, + ) + config = output.get("config") + return dict(config) if isinstance(config, dict) else None + + async def load_provider( + self, + provider_config: dict[str, Any], + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.load", + {"provider_config": dict(provider_config)}, + ) + return self._record_from_output(output) + + async def terminate_provider(self, provider_id: str) -> None: + await self._proxy.call( + "provider.manager.terminate", + {"provider_id": str(provider_id)}, + ) + + async def create_provider( + self, + provider_config: dict[str, Any], + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.create", + {"provider_config": dict(provider_config)}, + ) + return self._record_from_output(output) + + async def update_provider( + self, + origin_provider_id: str, + new_config: dict[str, Any], + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.update", + { + "origin_provider_id": str(origin_provider_id), + "new_config": dict(new_config), + }, + ) + return self._record_from_output(output) + + async def delete_provider( + self, + provider_id: str | None = None, + provider_source_id: str | None = None, + ) -> None: + await self._proxy.call( + "provider.manager.delete", + { + "provider_id": provider_id, + "provider_source_id": provider_source_id, + }, + ) + + async def get_insts(self) -> list[ManagedProviderRecord]: + output = await self._proxy.call("provider.manager.get_insts", {}) + items = output.get("providers") + if not isinstance(items, list): + return [] + return [ + record + for record in ( + ManagedProviderRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if record is not None + ] + + async def watch_changes(self) -> AsyncIterator[ProviderChangeEvent]: + async for chunk in self._proxy.stream("provider.manager.watch_changes", {}): + event = ProviderChangeEvent.from_payload(chunk) + if event is not None: + yield event + + async def register_provider_change_hook( + self, + callback: Callable[ + [str, ProviderType, str | None], + Awaitable[None] | None, + ], + ) -> asyncio.Task[None]: + async def runner() -> None: + async for event in self.watch_changes(): + result = callback( + event.provider_id, + event.provider_type, + event.umo, + ) + if inspect.isawaitable(result): + await result + + task = asyncio.create_task(runner()) + self._change_hook_tasks.add(task) + task.add_done_callback(self._log_change_hook_result) + return task + + async def unregister_provider_change_hook( + self, + task: asyncio.Task[None], + ) -> None: + if task not in self._change_hook_tasks: + return + self._change_hook_tasks.discard(task) + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + def _log_change_hook_result(self, task: asyncio.Task[None]) -> None: + self._change_hook_tasks.discard(task) + if task.cancelled(): + debug_logger = getattr(self._logger, "debug", None) + if callable(debug_logger): + debug_logger( + "Provider change hook cancelled: plugin_id={}", + self._plugin_id, + ) + return + try: + task.result() + except asyncio.CancelledError: + debug_logger = getattr(self._logger, "debug", None) + if callable(debug_logger): + debug_logger( + "Provider change hook cancelled: plugin_id={}", + self._plugin_id, + ) + except Exception: + exception_logger = getattr(self._logger, "exception", None) + if callable(exception_logger): + exception_logger( + "Provider change hook failed: plugin_id={}", + self._plugin_id, + ) + + +__all__ = [ + "ManagedProviderRecord", + "ProviderChangeEvent", + "ProviderClient", + "ProviderManagerClient", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/registry.py b/astrbot-sdk/src/astrbot_sdk/clients/registry.py new file mode 100644 index 0000000000..7cb9288b13 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/registry.py @@ -0,0 +1,167 @@ +"""只读 handler 注册表客户端。""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from ._errors import wrap_client_exception +from ._proxy import CapabilityProxy + + +def _coerce_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +@dataclass(slots=True) +class HandlerMetadata: + plugin_name: str + handler_full_name: str + trigger_type: str + description: str | None = None + event_types: list[str] = field(default_factory=list) + enabled: bool = True + group_path: list[str] = field(default_factory=list) + priority: int = 0 + kind: str = "handler" + require_admin: bool = False + required_role: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> HandlerMetadata: + return cls( + plugin_name=str(data.get("plugin_name", "")), + handler_full_name=str(data.get("handler_full_name", "")), + trigger_type=str(data.get("trigger_type", "")), + description=( + None + if data.get("description") is None + else str(data.get("description", "")).strip() or None + ), + event_types=[ + str(item) + for item in data.get("event_types", []) + if isinstance(item, str) + ], + enabled=bool(data.get("enabled", True)), + group_path=[ + str(item) + for item in data.get("group_path", []) + if isinstance(item, str) + ], + priority=_coerce_int(data.get("priority", 0), 0), + kind=str(data.get("kind", "handler") or "handler"), + require_admin=bool(data.get("require_admin", False)), + required_role=( + None + if data.get("required_role") is None + else str(data.get("required_role", "")).strip() or None + ), + ) + + +class RegistryClient: + """只读 handler 注册表客户端。""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def get_handlers_by_event_type( + self, + event_type: str, + ) -> list[HandlerMetadata]: + try: + output = await self._proxy.call( + "registry.get_handlers_by_event_type", + {"event_type": event_type}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="RegistryClient", + method_name="get_handlers_by_event_type", + details=f"event_type={event_type!r}", + exc=exc, + ) from exc + return [ + HandlerMetadata.from_dict(item) + for item in output.get("handlers", []) + if isinstance(item, dict) + ] + + async def get_handler_by_full_name( + self, + full_name: str, + ) -> HandlerMetadata | None: + try: + output = await self._proxy.call( + "registry.get_handler_by_full_name", + {"full_name": full_name}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="RegistryClient", + method_name="get_handler_by_full_name", + details=f"full_name={full_name!r}", + exc=exc, + ) from exc + handler = output.get("handler") + if not isinstance(handler, dict): + return None + return HandlerMetadata.from_dict(handler) + + async def set_handler_whitelist( + self, + plugin_names: list[str] | set[str] | None, + ) -> list[str] | None: + names = None + if plugin_names is not None: + names = sorted({str(item) for item in plugin_names if str(item).strip()}) + try: + output = await self._proxy.call( + "system.event.handler_whitelist.set", + {"plugin_names": names}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="RegistryClient", + method_name="set_handler_whitelist", + details=f"plugin_names={names!r}", + exc=exc, + ) from exc + result = output.get("plugin_names") + if not isinstance(result, list): + return None + return [str(item) for item in result] + + async def get_handler_whitelist(self) -> list[str] | None: + try: + output = await self._proxy.call("system.event.handler_whitelist.get", {}) + except Exception as exc: + raise wrap_client_exception( + client_name="RegistryClient", + method_name="get_handler_whitelist", + exc=exc, + ) from exc + result = output.get("plugin_names") + if not isinstance(result, list): + return None + return [str(item) for item in result] + + async def clear_handler_whitelist(self) -> None: + try: + await self._proxy.call( + "system.event.handler_whitelist.set", + {"plugin_names": None}, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="RegistryClient", + method_name="clear_handler_whitelist", + exc=exc, + ) from exc + + +__all__ = ["HandlerMetadata", "RegistryClient"] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/session.py b/astrbot-sdk/src/astrbot_sdk/clients/session.py new file mode 100644 index 0000000000..c2901708cd --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/session.py @@ -0,0 +1,135 @@ +"""Session-scoped SDK managers.""" + +from __future__ import annotations + +from typing import Any + +from ..events import MessageEvent +from ..message.session import MessageSession +from ._proxy import CapabilityProxy +from .registry import HandlerMetadata + + +def _normalize_session(session: str | MessageSession | MessageEvent) -> str: + if isinstance(session, MessageEvent): + return str(session.unified_msg_origin) + if isinstance(session, MessageSession): + return str(session) + return str(session) + + +def _handler_to_payload(handler: HandlerMetadata) -> dict[str, Any]: + return { + "plugin_name": handler.plugin_name, + "handler_full_name": handler.handler_full_name, + "trigger_type": handler.trigger_type, + "description": handler.description, + "event_types": list(handler.event_types), + "enabled": handler.enabled, + "group_path": list(handler.group_path), + "priority": handler.priority, + "kind": handler.kind, + "require_admin": handler.require_admin, + } + + +class SessionPluginManager: + """Session-scoped plugin status manager.""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def is_plugin_enabled_for_session( + self, + session: str | MessageSession | MessageEvent, + plugin_name: str, + ) -> bool: + output = await self._proxy.call( + "session.plugin.is_enabled", + { + "session": _normalize_session(session), + "plugin_name": str(plugin_name), + }, + ) + return bool(output.get("enabled", False)) + + async def filter_handlers_by_session( + self, + session: str | MessageSession | MessageEvent, + handlers: list[HandlerMetadata], + ) -> list[HandlerMetadata]: + output = await self._proxy.call( + "session.plugin.filter_handlers", + { + "session": _normalize_session(session), + "handlers": [_handler_to_payload(handler) for handler in handlers], + }, + ) + items = output.get("handlers") + if not isinstance(items, list): + return [] + return [ + HandlerMetadata.from_dict(item) for item in items if isinstance(item, dict) + ] + + +class SessionServiceManager: + """Session-scoped LLM/TTS service status manager.""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def is_llm_enabled_for_session( + self, + session: str | MessageSession | MessageEvent, + ) -> bool: + output = await self._proxy.call( + "session.service.is_llm_enabled", + {"session": _normalize_session(session)}, + ) + return bool(output.get("enabled", False)) + + async def set_llm_status_for_session( + self, + session: str | MessageSession | MessageEvent, + enabled: bool, + ) -> None: + await self._proxy.call( + "session.service.set_llm_status", + {"session": _normalize_session(session), "enabled": bool(enabled)}, + ) + + async def should_process_llm_request( + self, + event_or_session: str | MessageSession | MessageEvent, + ) -> bool: + return await self.is_llm_enabled_for_session(event_or_session) + + async def is_tts_enabled_for_session( + self, + session: str | MessageSession | MessageEvent, + ) -> bool: + output = await self._proxy.call( + "session.service.is_tts_enabled", + {"session": _normalize_session(session)}, + ) + return bool(output.get("enabled", False)) + + async def set_tts_status_for_session( + self, + session: str | MessageSession | MessageEvent, + enabled: bool, + ) -> None: + await self._proxy.call( + "session.service.set_tts_status", + {"session": _normalize_session(session), "enabled": bool(enabled)}, + ) + + async def should_process_tts_request( + self, + event_or_session: str | MessageSession | MessageEvent, + ) -> bool: + return await self.is_tts_enabled_for_session(event_or_session) + + +__all__ = ["SessionPluginManager", "SessionServiceManager"] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/skills.py b/astrbot-sdk/src/astrbot_sdk/clients/skills.py new file mode 100644 index 0000000000..54115a2bfb --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/skills.py @@ -0,0 +1,90 @@ +"""技能注册客户端。""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from ._errors import wrap_client_exception +from ._proxy import CapabilityProxy + + +@dataclass(slots=True) +class SkillRegistration: + """已注册技能的元数据。""" + + name: str + description: str + path: str + skill_dir: str + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SkillRegistration: + return cls( + name=str(data.get("name", "")), + description=str(data.get("description", "") or ""), + path=str(data.get("path", "")), + skill_dir=str(data.get("skill_dir", "")), + ) + + +class SkillClient: + """技能管理能力客户端。""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def register( + self, + *, + name: str, + path: str, + description: str = "", + ) -> SkillRegistration: + try: + output = await self._proxy.call( + "skill.register", + { + "name": name, + "path": path, + "description": description, + }, + ) + except Exception as exc: + raise wrap_client_exception( + client_name="SkillClient", + method_name="register", + details=f"name={name!r}, path={path!r}", + exc=exc, + ) from exc + return SkillRegistration.from_dict(output) + + async def unregister(self, name: str) -> bool: + try: + output = await self._proxy.call("skill.unregister", {"name": name}) + except Exception as exc: + raise wrap_client_exception( + client_name="SkillClient", + method_name="unregister", + details=f"name={name!r}", + exc=exc, + ) from exc + return bool(output.get("removed", False)) + + async def list(self) -> list[SkillRegistration]: + try: + output = await self._proxy.call("skill.list", {}) + except Exception as exc: + raise wrap_client_exception( + client_name="SkillClient", + method_name="list", + exc=exc, + ) from exc + return [ + SkillRegistration.from_dict(item) + for item in output.get("skills", []) + if isinstance(item, dict) + ] + + +__all__ = ["SkillClient", "SkillRegistration"] diff --git a/astrbot-sdk/src/astrbot_sdk/commands.py b/astrbot-sdk/src/astrbot_sdk/commands.py new file mode 100644 index 0000000000..1d4f278e1c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/commands.py @@ -0,0 +1,161 @@ +"""SDK-native command group helpers. + +本模块提供命令分组工具,用于组织具有层级关系的命令。 + +CommandGroup 允许以嵌套方式定义命令树,例如: + admin + ├── user + │ ├── add + │ └── remove + └── config + ├── get + └── set + +特性: +- 支持命令别名,自动展开父级路径的所有别名组合 +- 自动生成命令树的可视化输出 (print_cmd_tree) +- 与 @on_command 装饰器无缝集成 +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from itertools import product +from typing import Any + +from .decorators import on_command, set_command_route_meta +from .protocol.descriptors import CommandRouteSpec + + +@dataclass(slots=True) +class _CommandNode: + name: str + aliases: list[str] = field(default_factory=list) + description: str | None = None + subgroups: list[CommandGroup] = field(default_factory=list) + commands: list[tuple[str, str | None]] = field(default_factory=list) + + +class CommandGroup: + def __init__( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + parent: CommandGroup | None = None, + ) -> None: + self.name = name + self.aliases = list(aliases or []) + self.description = description + self.parent = parent + self._tree = _CommandNode( + name=name, aliases=self.aliases, description=description + ) + + def group( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + ) -> CommandGroup: + child = CommandGroup( + name, + aliases=aliases, + description=description, + parent=self, + ) + self._tree.subgroups.append(child) + return child + + def command( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + full_command = " ".join([*self.path, name]) + full_aliases = self._expand_aliases(name=name, aliases=aliases or []) + display_command = full_command + route = CommandRouteSpec( + group_path=self.path, + display_command=display_command, + group_help=self.description, + ) + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + decorated = on_command( + full_command, + aliases=full_aliases, + description=description, + )(func) + self._tree.commands.append((name, description)) + set_command_route_meta(decorated, route) + return decorated + + return decorator + + @property + def path(self) -> list[str]: + if self.parent is None: + return [self.name] + return [*self.parent.path, self.name] + + def print_cmd_tree(self) -> str: + lines: list[str] = [] + self._append_tree_lines(lines, indent=0) + return "\n".join(lines) + + def _append_tree_lines(self, lines: list[str], *, indent: int) -> None: + prefix = " " * indent + label = self.name + if self.aliases: + label += f" ({', '.join(self.aliases)})" + lines.append(f"{prefix}{label}") + for command_name, description in self._tree.commands: + command_label = f"{prefix} - {command_name}" + if description: + command_label += f": {description}" + lines.append(command_label) + for subgroup in self._tree.subgroups: + subgroup._append_tree_lines(lines, indent=indent + 1) + + def _expand_aliases(self, *, name: str, aliases: list[str]) -> list[str]: + group_segments: list[list[str]] = [] + cursor: CommandGroup | None = self + ancestry: list[CommandGroup] = [] + while cursor is not None: + ancestry.append(cursor) + cursor = cursor.parent + for group in reversed(ancestry): + group_segments.append([group.name, *group.aliases]) + leaf_segments = [name, *aliases] + expanded: set[str] = set() + for parts in product(*group_segments, leaf_segments): + route = " ".join(parts) + if route != " ".join([*self.path, name]): + expanded.add(route) + return sorted(expanded) + + +def command_group( + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> CommandGroup: + return CommandGroup( + name, + aliases=aliases, + description=description, + ) + + +def print_cmd_tree(group: CommandGroup) -> str: + return group.print_cmd_tree() + + +__all__ = ["CommandGroup", "command_group", "print_cmd_tree"] diff --git a/astrbot-sdk/src/astrbot_sdk/context.py b/astrbot-sdk/src/astrbot_sdk/context.py new file mode 100644 index 0000000000..5cff122933 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/context.py @@ -0,0 +1,880 @@ +"""astrbot-sdk 原生运行时上下文。 + +`Context` 是插件与 AstrBot Core 交互的主要入口, +负责组合所有 capability 客户端并提供统一的访问接口。 + +每个 handler 调用都会创建一个新的 Context 实例, +绑定到当前的 Peer、插件 ID 和取消令牌。 + +Attributes: + llm: LLM 能力客户端,用于 AI 对话 + memory: 记忆能力客户端,用于语义存储 + db: 数据库客户端,用于 KV 持久化 + files: 文件服务客户端,用于文件令牌注册与解析 + platform: 平台客户端,用于发送消息 + permission: 权限客户端,用于查询用户角色 + providers: Provider 客户端,用于查询和调用专用 Provider + provider_manager: Provider 管理客户端,用于 reserved/system 级操作 + permission_manager: 权限管理客户端,用于 reserved/system 级管理员维护 + personas: 人格管理客户端 + conversations: 对话管理客户端 + kbs: 知识库管理客户端 + message_history: 消息历史管理客户端 + http: HTTP 客户端,用于注册 API 端点 + metadata: 元数据客户端,用于查询插件信息 + mcp: MCP 管理客户端,用于本地/全局 MCP 服务管理 + skills: Skill 客户端,用于向 AstrBot 注册插件技能 + plugin_id: 当前插件的唯一标识 + logger: 绑定了插件 ID 的日志器 + cancel_token: 取消令牌,用于处理请求取消 +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from ._internal.plugin_logger import PluginLogger +from ._internal.sdk_logger import logger as base_logger +from ._internal.star_runtime import current_star_instance +from ._message_types import normalize_message_type +from .clients import ( + DBClient, + HTTPClient, + LLMClient, + MCPManagerClient, + MemoryClient, + MetadataClient, + PermissionClient, + PermissionManagerClient, + PlatformClient, + PlatformError, + PlatformStats, + PlatformStatus, + RegistryClient, + SkillClient, +) +from .clients._proxy import CapabilityProxy +from .clients.files import FileServiceClient +from .clients.llm import LLMResponse +from .clients.managers import ( + ConversationManagerClient, + KnowledgeBaseManagerClient, + MessageHistoryManagerClient, + PersonaManagerClient, +) +from .clients.provider import ProviderClient, ProviderManagerClient +from .clients.session import SessionPluginManager, SessionServiceManager +from .clients.skills import SkillRegistration +from .errors import AstrBotError +from .llm.entities import LLMToolSpec, ProviderMeta, ProviderRequest +from .llm.tools import LLMToolManager +from .message.components import BaseMessageComponent +from .message.result import MessageChain +from .message.session import MessageSession +from .session_waiter import ( + _mark_session_waiter_background_task, + _unmark_session_waiter_background_task, +) + +PlatformCompatContent = ( + str | MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]] +) + + +def _context_call_label(method_name: str, details: str | None = None) -> str: + label = f"Context.{method_name}" + if details: + return f"{label} ({details})" + return label + + +def _wrap_context_exception( + *, + method_name: str, + exc: Exception, + details: str | None = None, +) -> Exception: + message = f"{_context_call_label(method_name, details)} failed: {exc}" + if isinstance(exc, AstrBotError): + return AstrBotError( + code=exc.code, + message=message, + hint=exc.hint, + retryable=exc.retryable, + docs_url=exc.docs_url, + details=exc.details, + ) + return RuntimeError(message) + + +@dataclass(slots=True) +class PlatformCompatFacade: + """兼容层平台入口,仅暴露安全元信息和主动发送能力。""" + + _ctx: Context + id: str + name: str + type: str + status: PlatformStatus = PlatformStatus.PENDING + errors: list[PlatformError] = field(default_factory=list) + last_error: PlatformError | None = None + unified_webhook: bool = False + _state_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False) + + async def send_by_session( + self, + session: str | MessageSession, + content: PlatformCompatContent, + ) -> dict[str, Any]: + return await self._ctx.platform.send_by_session(session, content) + + async def send_by_id( + self, + session_id: str, + content: PlatformCompatContent, + *, + message_type: str = "private", + ) -> dict[str, Any]: + return await self._ctx.platform.send_by_id( + self.id, + session_id, + content, + message_type=message_type, + ) + + async def send( + self, + session: str | MessageSession, + content: PlatformCompatContent, + *, + message_type: str = "private", + ) -> dict[str, Any]: + if isinstance(session, MessageSession): + return await self.send_by_session(session, content) + session_text = str(session).strip() + if ":" in session_text: + return await self.send_by_session(session_text, content) + return await self.send_by_id( + session_text, + content, + message_type=message_type, + ) + + async def refresh(self) -> None: + async with self._state_lock: + await self._refresh_locked() + + async def clear_errors(self) -> None: + async with self._state_lock: + try: + await self._ctx._proxy.call( + "platform.manager.clear_errors", + {"platform_id": self.id}, + ) + await self._refresh_locked() + except Exception as exc: + raise _wrap_context_exception( + method_name="platform.clear_errors", + details=f"platform_id={self.id!r}", + exc=exc, + ) from exc + + async def get_stats(self) -> PlatformStats | None: + try: + output = await self._ctx._proxy.call( + "platform.manager.get_stats", + {"platform_id": self.id}, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="platform.get_stats", + details=f"platform_id={self.id!r}", + exc=exc, + ) from exc + return PlatformStats.from_payload(output.get("stats")) + + def _apply_snapshot(self, payload: Any) -> None: + if not isinstance(payload, dict): + return + self.name = str(payload.get("name", self.name)) + self.type = str(payload.get("type", self.type)) + self.status = PlatformStatus.from_value(payload.get("status")) + errors_payload = payload.get("errors") + if isinstance(errors_payload, list): + self.errors = [ + error + for error in ( + PlatformError.from_payload(item) if isinstance(item, dict) else None + for item in errors_payload + ) + if error is not None + ] + self.last_error = PlatformError.from_payload(payload.get("last_error")) + self.unified_webhook = bool(payload.get("unified_webhook", False)) + + async def _refresh_locked(self) -> None: + try: + output = await self._ctx._proxy.call( + "platform.manager.get_by_id", + {"platform_id": self.id}, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="platform.refresh", + details=f"platform_id={self.id!r}", + exc=exc, + ) from exc + self._apply_snapshot(output.get("platform")) + + +@dataclass(slots=True) +class CancelToken: + """请求取消令牌。 + + 用于协调长时间运行操作的取消。当用户取消请求或 + 上游超时时,令牌会被触发,允许 handler 及时清理资源。 + + Example: + async def long_operation(ctx: Context): + for item in large_list: + ctx.cancel_token.raise_if_cancelled() + await process(item) + """ + + _cancelled: asyncio.Event + + def __init__(self) -> None: + self._cancelled = asyncio.Event() + + def cancel(self) -> None: + """触发取消信号。""" + self._cancelled.set() + + @property + def cancelled(self) -> bool: + """检查是否已被取消。""" + return self._cancelled.is_set() + + async def wait(self) -> None: + """等待取消信号。""" + await self._cancelled.wait() + + def raise_if_cancelled(self) -> None: + """如果已取消则抛出 CancelledError。 + + Raises: + asyncio.CancelledError: 如果令牌已被取消 + """ + if self.cancelled: + raise asyncio.CancelledError + + +class Context: + """插件运行时上下文。 + + 组合所有 capability 客户端,提供统一的访问接口。 + 每个 handler 调用都会创建新的 Context 实例。 + + Attributes: + peer: 协议对等端,用于底层通信 + llm: LLM 客户端 + memory: 记忆客户端 + db: 数据库客户端 + files: 文件服务客户端 + platform: 平台客户端 + permission: 权限客户端 + providers: Provider 客户端 + provider_manager: Provider 管理客户端 + permission_manager: 权限管理客户端 + personas: 人格管理客户端 + conversations: 对话管理客户端 + kbs: 知识库管理客户端 + message_history: 消息历史管理客户端 + http: HTTP 客户端 + metadata: 元数据客户端 + registry: 能力注册客户端 + skills: 技能客户端 + session_plugins: 会话插件管理器 + session_services: 会话服务管理器 + mcp: MCP 管理客户端 + plugin_id: 当前插件 ID + logger: 日志器 + cancel_token: 取消令牌 + """ + + def __init__( + self, + *, + peer, + plugin_id: str, + request_id: str | None = None, + cancel_token: CancelToken | None = None, + logger: Any | None = None, + source_event_payload: dict[str, Any] | None = None, + ) -> None: + """初始化上下文。 + + Args: + peer: 协议对等端实例 + plugin_id: 当前插件 ID + cancel_token: 取消令牌,None 时创建新令牌 + logger: 日志器,None 时使用默认 logger 并绑定 plugin_id + """ + proxy = CapabilityProxy( + peer, + caller_plugin_id=plugin_id, + request_scope_id=request_id, + ) + if isinstance(logger, PluginLogger): + bound_logger = logger + else: + bound_logger = logger or base_logger.bind(plugin_id=plugin_id) + self._proxy = proxy + self.peer = peer + self.llm = LLMClient(proxy) + self.memory = MemoryClient(proxy) + self.db = DBClient(proxy) + self.files = FileServiceClient(proxy) + self.platform = PlatformClient(proxy) + self.permission = PermissionClient(proxy) + self.providers = ProviderClient(proxy) + self.provider_manager = ProviderManagerClient( + proxy, + plugin_id=plugin_id, + logger=bound_logger, + ) + self.permission_manager = PermissionManagerClient( + proxy, + source_event_payload=source_event_payload, + ) + self.personas = PersonaManagerClient(proxy) + self.conversations = ConversationManagerClient(proxy) + self.kbs = KnowledgeBaseManagerClient(proxy) + self.message_history = MessageHistoryManagerClient(proxy) + self.http = HTTPClient(proxy) + self.metadata = MetadataClient(proxy, plugin_id) + self.mcp = MCPManagerClient(proxy) + self.registry = RegistryClient(proxy) + self.skills = SkillClient(proxy) + self.session_plugins = SessionPluginManager(proxy) + self.session_services = SessionServiceManager(proxy) + self.persona_manager = self.personas + self.conversation_manager = self.conversations + self.kb_manager = self.kbs + self.message_history_manager = self.message_history + self.mcp_manager = self.mcp + self._llm_tool_manager = LLMToolManager(proxy) + self.plugin_id = plugin_id + self.logger: PluginLogger = ( + bound_logger + if isinstance(bound_logger, PluginLogger) + else PluginLogger(plugin_id=plugin_id, logger=bound_logger) + ) + self.cancel_token = cancel_token or CancelToken() + self.request_id = request_id + self._source_event_payload = ( + dict(source_event_payload) if isinstance(source_event_payload, dict) else {} + ) + + async def get_data_dir(self) -> Path: + """Return the plugin-scoped data directory path.""" + try: + output = await self._proxy.call("system.get_data_dir", {}) + except Exception as exc: + raise _wrap_context_exception( + method_name="get_data_dir", + exc=exc, + ) from exc + return Path(str(output.get("path", ""))) + + async def _register_file_url( + self, + path: str, + timeout: float | None = None, + ) -> str: + try: + return await self.files.register_file_url(path, timeout=timeout) + except Exception as exc: + raise _wrap_context_exception( + method_name="register_file_url", + details=f"path={str(path)!r}, timeout={timeout!r}", + exc=exc, + ) from exc + + async def text_to_image( + self, + text: str, + *, + return_url: bool = True, + ) -> str: + """Render plain text into an image using the host renderer.""" + try: + output = await self._proxy.call( + "system.text_to_image", + {"text": text, "return_url": return_url}, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="text_to_image", + details=f"return_url={return_url!r}", + exc=exc, + ) from exc + return str(output.get("result", "")) + + async def html_render( + self, + tmpl: str, + data: dict[str, Any], + *, + return_url: bool = True, + options: dict[str, Any] | None = None, + ) -> str: + """Render an HTML template using the host renderer.""" + try: + output = await self._proxy.call( + "system.html_render", + { + "tmpl": tmpl, + "data": dict(data), + "return_url": return_url, + "options": options, + }, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="html_render", + details=f"tmpl={tmpl!r}, return_url={return_url!r}", + exc=exc, + ) from exc + return str(output.get("result", "")) + + async def get_using_provider(self, umo: str | None = None) -> ProviderMeta | None: + return await self.providers.get_using_chat(umo) + + async def get_current_chat_provider_id(self, umo: str | None = None) -> str | None: + try: + output = await self._proxy.call( + "provider.get_current_chat_provider_id", + {"umo": umo}, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="get_current_chat_provider_id", + details=f"umo={umo!r}", + exc=exc, + ) from exc + value = output.get("provider_id") + return str(value) if value else None + + async def get_all_providers(self) -> list[ProviderMeta]: + return await self.providers.list_all() + + async def get_all_tts_providers(self) -> list[ProviderMeta]: + return await self.providers.list_tts() + + async def get_all_stt_providers(self) -> list[ProviderMeta]: + return await self.providers.list_stt() + + async def get_all_embedding_providers(self) -> list[ProviderMeta]: + return await self.providers.list_embedding() + + async def get_all_rerank_providers(self) -> list[ProviderMeta]: + return await self.providers.list_rerank() + + async def get_using_tts_provider( + self, umo: str | None = None + ) -> ProviderMeta | None: + provider = await self.providers.get_using_tts(umo) + return provider.meta() if provider is not None else None + + async def get_using_stt_provider( + self, umo: str | None = None + ) -> ProviderMeta | None: + provider = await self.providers.get_using_stt(umo) + return provider.meta() if provider is not None else None + + async def send_message( + self, + session: str | MessageSession, + content: PlatformCompatContent, + ) -> dict[str, Any]: + return await self.platform.send_by_session(session, content) + + async def send_message_by_id( + self, + type: str, + id: str, + content: PlatformCompatContent, + *, + platform: str, + ) -> dict[str, Any]: + platform_payload = await self._resolve_platform_target(platform) + return await self.platform.send_by_id( + str(platform_payload.get("id", "")), + str(id), + content, + message_type=self._normalize_compat_message_type(type), + ) + + @staticmethod + def _normalize_compat_message_type(value: str) -> str: + normalized = normalize_message_type(value) + if not normalized: + raise AstrBotError.invalid_input("send_message_by_id requires type") + return normalized + + async def _resolve_platform_target(self, platform: str) -> dict[str, Any]: + target = str(platform).strip() + if not target: + raise AstrBotError.invalid_input( + "send_message_by_id requires explicit platform" + ) + instances = await self._list_platform_instances() + id_matches = [ + item for item in instances if str(item.get("id", "")).strip() == target + ] + if len(id_matches) == 1: + return id_matches[0] + normalized_target = target.lower() + alias_matches = [ + item + for item in instances + if str(item.get("type", "")).strip().lower() == normalized_target + or str(item.get("name", "")).strip().lower() == normalized_target + ] + if len(alias_matches) == 1: + return alias_matches[0] + if len(alias_matches) > 1: + raise AstrBotError.invalid_input( + f"send_message_by_id platform '{target}' is ambiguous" + ) + raise AstrBotError.invalid_input( + f"send_message_by_id cannot resolve platform '{target}'" + ) + + def get_llm_tool_manager(self) -> LLMToolManager: + return self._llm_tool_manager + + async def activate_llm_tool(self, name: str) -> bool: + return await self._llm_tool_manager.activate(name) + + async def deactivate_llm_tool(self, name: str) -> bool: + return await self._llm_tool_manager.deactivate(name) + + async def add_llm_tools(self, *tools: LLMToolSpec) -> list[str]: + return await self._llm_tool_manager.add(*tools) + + async def register_llm_tool( + self, + name: str, + parameters_schema: dict[str, Any], + desc: str, + func_obj: Callable[..., Any] | Callable[..., Awaitable[Any]], + *, + active: bool = True, + ) -> list[str]: + if not callable(func_obj): + raise TypeError("register_llm_tool requires a callable func_obj") + tool_name = str(name).strip() + if not tool_name: + raise AstrBotError.invalid_input("register_llm_tool requires name") + if not isinstance(parameters_schema, dict): + raise TypeError("register_llm_tool requires parameters_schema dict") + + handler_ref = f"__dynamic_llm_tool__:{tool_name}" + tool_spec = LLMToolSpec.create( + name=tool_name, + description=str(desc), + parameters_schema=dict(parameters_schema), + handler_ref=handler_ref, + active=bool(active), + ) + owner = getattr(func_obj, "__self__", None) or current_star_instance() + dispatcher = getattr(self.peer, "_sdk_capability_dispatcher", None) + if dispatcher is not None and hasattr(dispatcher, "add_dynamic_llm_tool"): + dispatcher.add_dynamic_llm_tool( + plugin_id=self.plugin_id, + spec=tool_spec, + callable_obj=func_obj, + owner=owner, + ) + try: + return await self._llm_tool_manager.add(tool_spec) + except Exception as exc: + if dispatcher is not None and hasattr(dispatcher, "remove_llm_tool"): + dispatcher.remove_llm_tool(self.plugin_id, tool_name) + raise _wrap_context_exception( + method_name="register_llm_tool", + details=f"name={tool_name!r}, active={bool(active)!r}", + exc=exc, + ) from exc + + async def unregister_llm_tool(self, name: str) -> bool: + removed = await self._llm_tool_manager.remove(str(name)) + dispatcher = getattr(self.peer, "_sdk_capability_dispatcher", None) + if dispatcher is not None and hasattr(dispatcher, "remove_llm_tool"): + dispatcher.remove_llm_tool(self.plugin_id, str(name)) + return removed + + async def register_skill( + self, + *, + name: str, + path: str | Path, + description: str = "", + ) -> SkillRegistration: + try: + return await self.skills.register( + name=name, + path=str(path), + description=description, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="register_skill", + details=f"name={name!r}, path={str(path)!r}", + exc=exc, + ) from exc + + async def unregister_skill(self, name: str) -> bool: + try: + return await self.skills.unregister(name) + except Exception as exc: + raise _wrap_context_exception( + method_name="unregister_skill", + details=f"name={name!r}", + exc=exc, + ) from exc + + async def tool_loop_agent( + self, + request: ProviderRequest | None = None, + **kwargs: Any, + ) -> LLMResponse: + provider_request = request or ProviderRequest() + if kwargs: + merged = provider_request.model_dump() + merged.update(kwargs) + provider_request = ProviderRequest.model_validate(merged) + payload = provider_request.to_payload() + target_payload = self._source_event_payload.get("target") + if isinstance(target_payload, dict): + # Preserve the original message target so core can recover the + # dispatch token for message-bound tool loop execution. + payload["target"] = dict(target_payload) + try: + output = await self._proxy.call("agent.tool_loop.run", payload) + except Exception as exc: + raise _wrap_context_exception( + method_name="tool_loop_agent", + details=( + f"session_id={provider_request.session_id!r}, " + f"contexts={len(provider_request.contexts)!r}" + ), + exc=exc, + ) from exc + return LLMResponse.model_validate(output) + + def _source_event_type(self) -> str: + event_type = self._source_event_payload.get("event_type") + if isinstance(event_type, str) and event_type.strip(): + return event_type.strip() + fallback_type = self._source_event_payload.get("type") + if isinstance(fallback_type, str) and fallback_type.strip(): + return fallback_type.strip() + raw_payload = self._source_event_payload.get("raw") + if isinstance(raw_payload, dict): + raw_event_type = raw_payload.get("event_type") + if isinstance(raw_event_type, str) and raw_event_type.strip(): + return raw_event_type.strip() + return "" + + async def register_commands( + self, + command_name: str, + handler_full_name: str, + *, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ignore_prefix: bool = False, + ) -> None: + source_event_type = self._source_event_type() + if source_event_type not in {"astrbot_loaded", "platform_loaded"}: + raise AstrBotError.invalid_input( + "register_commands is only available in astrbot_loaded/platform_loaded events" + ) + if ignore_prefix: + raise AstrBotError.invalid_input( + "register_commands(ignore_prefix=True) is unsupported in SDK runtime" + ) + if isinstance(priority, bool) or not isinstance(priority, int): + raise AstrBotError.invalid_input( + "register_commands priority must be an integer" + ) + normalized_command_name = str(command_name) + normalized_handler_name = str(handler_full_name) + try: + await self._proxy.call( + "registry.command.register", + { + "command_name": normalized_command_name, + "handler_full_name": normalized_handler_name, + "source_event_type": source_event_type, + "desc": str(desc), + "priority": priority, + "use_regex": bool(use_regex), + "ignore_prefix": False, + }, + ) + except Exception as exc: + raise _wrap_context_exception( + method_name="register_commands", + details=( + f"command_name={normalized_command_name!r}, " + f"handler_full_name={normalized_handler_name!r}" + ), + exc=exc, + ) from exc + + async def register_task( + self, + task: Awaitable[Any], + desc: str, + ) -> asyncio.Task[Any]: + """Register a background task owned by the current SDK context. + + This is the recommended way to launch follow-up work that should outlive + the current handler dispatch, including `session_waiter(...)` flows. + Directly awaiting a waiter inside the current handler keeps the original + dispatch open until the next message arrives. + + Example: + await event.reply("请输入用户名:") + await ctx.register_task( + self.collect_username(event), + "waiter:collect_username", + ) + """ + task_desc = str(desc) + + async def _wrap_future(future: asyncio.Future[Any]) -> Any: + return await future + + if isinstance(task, asyncio.Task): + background_task = task + elif asyncio.isfuture(task): + background_task = asyncio.create_task(_wrap_future(task)) + elif asyncio.iscoroutine(task): + background_task = asyncio.create_task(task) + else: + raise TypeError( + "Context.register_task requires an awaitable task object; " + f"got {type(task).__name__} for desc={task_desc!r}" + ) + + _mark_session_waiter_background_task(background_task) + + def _on_done(done_task: asyncio.Task[Any]) -> None: + _unmark_session_waiter_background_task(done_task) + if done_task.cancelled(): + debug_logger = getattr(self.logger, "debug", None) + if callable(debug_logger): + debug_logger( + "SDK background task cancelled: plugin_id={} desc={}", + self.plugin_id, + task_desc, + ) + return + try: + done_task.result() + except Exception as exc: + exception_logger = getattr(self.logger, "exception", None) + if callable(exception_logger): + exception_logger( + "SDK background task failed: plugin_id={} desc={} error={}", + self.plugin_id, + task_desc, + str(exc), + ) + + background_task.add_done_callback(_on_done) + return background_task + + async def _list_platform_instances(self) -> list[dict[str, Any]]: + try: + output = await self._proxy.call("platform.list_instances", {}) + except Exception as exc: + raise _wrap_context_exception( + method_name="list_platforms", + exc=exc, + ) from exc + items = output.get("platforms") + if not isinstance(items, list): + return [] + normalized: list[dict[str, Any]] = [] + for item in items: + if not isinstance(item, dict): + continue + platform_id = str(item.get("id", "")).strip() + platform_type = str(item.get("type", "")).strip() + if not platform_id or not platform_type: + continue + normalized.append( + { + "id": platform_id, + "name": str(item.get("name", platform_id)), + "type": platform_type, + "status": PlatformStatus.from_value(item.get("status")), + } + ) + return normalized + + def _build_platform_facade( + self, + platform_payload: dict[str, Any], + ) -> PlatformCompatFacade: + return PlatformCompatFacade( + _ctx=self, + id=str(platform_payload.get("id", "")), + name=str(platform_payload.get("name", "")), + type=str(platform_payload.get("type", "")), + status=PlatformStatus.from_value(platform_payload.get("status")), + ) + + async def list_platforms(self) -> list[PlatformCompatFacade]: + """获取所有平台实例的兼容层列表。 + + Returns: + 所有可见平台实例的兼容层对象列表 + + Example: + for platform in await ctx.list_platforms(): + print(platform.id, platform.status) + """ + return [ + self._build_platform_facade(item) + for item in await self._list_platform_instances() + ] + + async def get_platform(self, platform_type: str) -> PlatformCompatFacade | None: + target_type = str(platform_type).strip().lower() + if not target_type: + return None + for item in await self._list_platform_instances(): + if str(item.get("type", "")).strip().lower() == target_type: + return self._build_platform_facade(item) + return None + + async def get_platform_inst(self, platform_id: str) -> PlatformCompatFacade | None: + target_id = str(platform_id).strip() + if not target_id: + return None + for item in await self._list_platform_instances(): + if str(item.get("id", "")).strip() == target_id: + return self._build_platform_facade(item) + return None diff --git a/astrbot-sdk/src/astrbot_sdk/conversation.py b/astrbot-sdk/src/astrbot_sdk/conversation.py new file mode 100644 index 0000000000..78e3cd9095 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/conversation.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from .context import Context +from .events import MessageEvent +from .message.components import BaseMessageComponent +from .message.result import MessageChain +from .session_waiter import SessionWaiterManager + +DEFAULT_BUSY_MESSAGE = "当前会话已有进行中的交互,请先完成后再试。" + + +class ConversationState(str, Enum): + ACTIVE = "active" + REJECTED_BUSY = "rejected_busy" + REPLACED = "replaced" + TIMEOUT = "timeout" + COMPLETED = "completed" + CANCELLED = "cancelled" + + +class ConversationReplaced(RuntimeError): + pass + + +class ConversationClosed(RuntimeError): + pass + + +@dataclass(slots=True) +class ConversationSession: + ctx: Context + event: MessageEvent + waiter_manager: SessionWaiterManager + timeout: int + state: ConversationState = ConversationState.ACTIVE + _owner_task: asyncio.Task[Any] | None = None + + def __post_init__(self) -> None: + if self.state is None: + self.state = ConversationState.ACTIVE + return + if not isinstance(self.state, ConversationState): + self.state = ConversationState(str(self.state)) + + def bind_owner_task(self, task: asyncio.Task[Any]) -> None: + self._owner_task = task + + @property + def session_key(self) -> str: + return self.event.unified_msg_origin + + @property + def active(self) -> bool: + return self.state == ConversationState.ACTIVE + + async def ask(self, prompt: str, timeout: int | None = None) -> MessageEvent: + self._ensure_usable("ask") + if prompt: + await self.reply(prompt) + try: + return await self.waiter_manager.wait_for_event( + event=self.event, + timeout=timeout or self.timeout, + record_history_chains=False, + ) + except asyncio.TimeoutError: + self.close(ConversationState.TIMEOUT) + raise + except asyncio.CancelledError as exc: + if self.state == ConversationState.REPLACED: + raise ConversationReplaced( + "conversation replaced by a newer session" + ) from exc + self.close(ConversationState.CANCELLED) + raise + + async def reply(self, text: str) -> None: + self._ensure_usable("reply") + await self.event.reply(text) + + async def reply_chain( + self, + chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]], + ) -> None: + self._ensure_usable("reply_chain") + await self.event.reply_chain(chain) + + async def send_message( + self, + content: str | MessageChain | list[BaseMessageComponent] | list[dict[str, Any]], + ) -> dict[str, Any]: + self._ensure_usable("send_message") + return await self.ctx.platform.send_by_session(self.event.session_id, content) + + def end(self) -> None: + self.close(ConversationState.COMPLETED) + + def mark_replaced(self) -> None: + self.close(ConversationState.REPLACED) + + def close(self, state: ConversationState) -> None: + if self.state != ConversationState.ACTIVE and state == self.state: + return + if ( + self.state != ConversationState.ACTIVE + and state != ConversationState.REPLACED + ): + return + self.state = state + + def _ensure_usable(self, action: str) -> None: + if ( + self._owner_task is not None + and asyncio.current_task() is not self._owner_task + ): + raise ConversationClosed( + f"ConversationSession cannot be used outside its owner task during {action}" + ) + if not self.active: + raise ConversationClosed( + f"ConversationSession is already closed ({self.state.value}) during {action}" + ) + + +__all__ = [ + "ConversationClosed", + "ConversationReplaced", + "ConversationSession", + "ConversationState", + "DEFAULT_BUSY_MESSAGE", +] diff --git a/astrbot-sdk/src/astrbot_sdk/decorators.py b/astrbot-sdk/src/astrbot_sdk/decorators.py new file mode 100644 index 0000000000..98afba0713 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/decorators.py @@ -0,0 +1,1393 @@ +"""astrbot-sdk 原生装饰器。 + +提供声明式的方法来注册 handler 和 capability。 +装饰器会在方法上附加元数据,由 Star.__init_subclass__ 自动收集。 + +触发器装饰器: + - @on_command: 命令触发器 + - @on_message: 消息触发器(关键词/正则) + - @on_event: 事件触发器 + - @on_schedule: 定时任务触发器 + - @conversation_command: 带会话生命周期的命令触发器 + +权限与过滤装饰器: + - @require_admin / @admin_only: 管理员权限标记 + - @require_permission: 通用角色权限标记 + - @platforms: 限定平台 + - @group_only / @private_only: 群聊/私聊限定 + - @message_types: 消息类型过滤 + +限流装饰器: + - @rate_limit: 滑动窗口限流 + - @cooldown: 冷却时间 + +优先级装饰器: + - @priority: 设置执行优先级 + +能力导出装饰器: + - @provide_capability: 声明对外暴露的能力 + - @register_llm_tool: 注册 LLM 工具 + - @register_agent: 注册 Agent + +Example: + class MyPlugin(Star): + @on_command("hello", aliases=["hi"]) + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") + + @on_message(keywords=["help"]) + async def help(self, event: MessageEvent, ctx: Context): + await event.reply("Help info...") + + @provide_capability("my_plugin.calculate", description="计算") + async def calculate(self, payload: dict, ctx: Context): + return {"result": payload["x"] * 2} +""" + +from __future__ import annotations + +import inspect +import typing +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal, TypeVar, cast + +from pydantic import BaseModel + +from ._internal.typing_utils import unwrap_optional +from .llm.agents import AgentSpec, BaseAgentRunner +from .llm.entities import LLMToolSpec +from .protocol.descriptors import ( + RESERVED_CAPABILITY_PREFIXES, + CapabilityDescriptor, + CommandRouteSpec, + CommandTrigger, + EventTrigger, + FilterSpec, + MessageTrigger, + MessageTypeFilterSpec, + Permissions, + PlatformFilterSpec, + ScheduleTrigger, +) + +HandlerCallable = Callable[..., Any] +_HandlerT = TypeVar("_HandlerT", bound=Callable[..., Any]) +HANDLER_META_ATTR = "__astrbot_handler_meta__" +CAPABILITY_META_ATTR = "__astrbot_capability_meta__" +LLM_TOOL_META_ATTR = "__astrbot_llm_tool_meta__" +AGENT_META_ATTR = "__astrbot_agent_meta__" +HTTP_API_META_ATTR = "__astrbot_http_api_meta__" +VALIDATE_CONFIG_META_ATTR = "__astrbot_validate_config_meta__" +PROVIDER_CHANGE_META_ATTR = "__astrbot_provider_change_meta__" +BACKGROUND_TASK_META_ATTR = "__astrbot_background_task_meta__" +MCP_SERVER_META_ATTR = "__astrbot_mcp_server_meta__" +SKILL_META_ATTR = "__astrbot_skill_meta__" + +LimiterScope = Literal["session", "user", "group", "global"] +LimiterBehavior = Literal["hint", "silent", "error"] +ConversationMode = Literal["replace", "reject"] + + +@dataclass(slots=True) +class LimiterMeta: + kind: Literal["rate_limit", "cooldown"] + limit: int + window: float + scope: LimiterScope = "session" + behavior: LimiterBehavior = "hint" + message: str | None = None + + +@dataclass(slots=True) +class ConversationMeta: + timeout: int = 60 + mode: ConversationMode = "replace" + busy_message: str | None = None + grace_period: float = 1.0 + + +@dataclass(slots=True) +class HandlerMeta: + """Handler 元数据。 + + 存储在方法上的 __astrbot_handler_meta__ 属性中。 + + Attributes: + trigger: 触发器(命令/消息/事件/定时) + kind: handler 类型标识 + contract: 契约类型(可选) + priority: 执行优先级(数值越大越先执行) + permissions: 权限要求 + """ + + trigger: CommandTrigger | MessageTrigger | EventTrigger | ScheduleTrigger | None = ( + None + ) + kind: str = "handler" + contract: str | None = None + description: str | None = None + priority: int = 0 + permissions: Permissions = field(default_factory=Permissions) + filters: list[FilterSpec] = field(default_factory=list) + local_filters: list[Any] = field(default_factory=list) + command_route: CommandRouteSpec | None = None + limiter: LimiterMeta | None = None + conversation: ConversationMeta | None = None + decorator_sources: dict[str, str] = field(default_factory=dict) + + +@dataclass(slots=True) +class CapabilityMeta: + """Capability 元数据。 + + 存储在方法上的 __astrbot_capability_meta__ 属性中。 + + Attributes: + descriptor: 能力描述符 + """ + + descriptor: CapabilityDescriptor + + +@dataclass(slots=True) +class LLMToolMeta: + spec: LLMToolSpec + + +@dataclass(slots=True) +class AgentMeta: + spec: AgentSpec + + +@dataclass(slots=True) +class HttpApiMeta: + route: str + methods: list[str] = field(default_factory=lambda: ["GET"]) + description: str = "" + capability_name: str | None = None + + +@dataclass(slots=True) +class ValidateConfigMeta: + model: type[BaseModel] | None = None + schema: dict[str, Any] | None = None + + +def _is_valid_validate_config_expected_type(value: Any) -> bool: + if isinstance(value, type): + return True + return ( + isinstance(value, tuple) + and len(value) > 0 + and all(isinstance(item, type) for item in value) + ) + + +def _validate_validate_config_schema(schema: dict[str, Any]) -> None: + for field_name, field_schema in schema.items(): + if not isinstance(field_schema, dict): + raise TypeError( + f"validate_config schema field {field_name!r} must be a dict" + ) + expected_type = field_schema.get("type") + if expected_type is not None and not _is_valid_validate_config_expected_type( + expected_type + ): + raise TypeError( + "validate_config schema field " + f"{field_name!r} has invalid 'type' entry {expected_type!r}; " + "expected a type or tuple of types" + ) + + +@dataclass(slots=True) +class ProviderChangeMeta: + provider_types: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class BackgroundTaskMeta: + description: str = "" + auto_start: bool = True + on_error: Literal["log", "restart"] = "log" + + +@dataclass(slots=True) +class MCPServerMeta: + name: str + scope: Literal["local", "global"] = "global" + config: dict[str, Any] | None = None + timeout: float = 30.0 + wait_until_ready: bool = True + + +@dataclass(slots=True) +class SkillMeta: + name: str + path: str + description: str = "" + + +def _get_or_create_meta(func: HandlerCallable) -> HandlerMeta: + """获取或创建 handler 元数据。""" + meta = getattr(func, HANDLER_META_ATTR, None) + if meta is None: + meta = HandlerMeta() + setattr(func, HANDLER_META_ATTR, meta) + return meta + + +def get_handler_meta(func: HandlerCallable) -> HandlerMeta | None: + """获取方法的 handler 元数据。 + + Args: + func: 要检查的方法 + + Returns: + HandlerMeta 实例,如果没有则返回 None + """ + return getattr(func, HANDLER_META_ATTR, None) + + +def get_capability_meta(func: HandlerCallable) -> CapabilityMeta | None: + """获取方法的 capability 元数据。 + + Args: + func: 要检查的方法 + + Returns: + CapabilityMeta 实例,如果没有则返回 None + """ + return getattr(func, CAPABILITY_META_ATTR, None) + + +def get_llm_tool_meta(func: HandlerCallable) -> LLMToolMeta | None: + return getattr(func, LLM_TOOL_META_ATTR, None) + + +def get_agent_meta(obj: Any) -> AgentMeta | None: + return getattr(obj, AGENT_META_ATTR, None) + + +def get_http_api_meta(func: HandlerCallable) -> HttpApiMeta | None: + return getattr(func, HTTP_API_META_ATTR, None) + + +def get_validate_config_meta(func: HandlerCallable) -> ValidateConfigMeta | None: + return getattr(func, VALIDATE_CONFIG_META_ATTR, None) + + +def get_provider_change_meta(func: HandlerCallable) -> ProviderChangeMeta | None: + return getattr(func, PROVIDER_CHANGE_META_ATTR, None) + + +def get_background_task_meta(func: HandlerCallable) -> BackgroundTaskMeta | None: + return getattr(func, BACKGROUND_TASK_META_ATTR, None) + + +def get_mcp_server_meta(obj: Any) -> list[MCPServerMeta]: + values = getattr(obj, MCP_SERVER_META_ATTR, None) + if not isinstance(values, list): + return [] + return [item for item in values if isinstance(item, MCPServerMeta)] + + +def get_skill_meta(obj: Any) -> list[SkillMeta]: + values = getattr(obj, SKILL_META_ATTR, None) + if not isinstance(values, list): + return [] + return [item for item in values if isinstance(item, SkillMeta)] + + +def _append_list_meta(obj: Any, attr_name: str, value: Any) -> None: + values = getattr(obj, attr_name, None) + if not isinstance(values, list): + values = [] + setattr(obj, attr_name, values) + values.append(value) + + +def _replace_filter(meta: HandlerMeta, spec: FilterSpec) -> None: + kind = getattr(spec, "kind", None) + meta.filters = [ + item for item in meta.filters if getattr(item, "kind", None) != kind + ] + meta.filters.append(spec) + + +def _has_filter_kind(meta: HandlerMeta, kind: str) -> bool: + return any(getattr(item, "kind", None) == kind for item in meta.filters) + + +def _set_platform_filter( + meta: HandlerMeta, + values: list[str], + *, + source: str, +) -> None: + normalized = [ + value for value in dict.fromkeys(str(item).strip() for item in values) if value + ] + if not normalized: + return + existing = meta.decorator_sources.get("platforms") + if existing is not None and existing != source: + raise ValueError("platforms(...) 不能与 on_message(platforms=...) 混用") + if existing is None and _has_filter_kind(meta, "platform"): + raise ValueError("platforms(...) 不能与已有平台过滤器混用") + meta.decorator_sources["platforms"] = source + _replace_filter(meta, PlatformFilterSpec(platforms=normalized)) + + +def _set_message_type_filter( + meta: HandlerMeta, + values: list[str], + *, + source: str, +) -> None: + normalized = [ + value + for value in dict.fromkeys(str(item).strip().lower() for item in values) + if value + ] + if not normalized: + return + existing = meta.decorator_sources.get("message_types") + if existing is not None and existing != source: + raise ValueError( + "group_only()/private_only()/message_types(...) 不能与已有消息类型约束混用" + ) + if existing is None and _has_filter_kind(meta, "message_type"): + raise ValueError( + "group_only()/private_only()/message_types(...) 不能与已有消息类型过滤器混用" + ) + meta.decorator_sources["message_types"] = source + _replace_filter(meta, MessageTypeFilterSpec(message_types=normalized)) + + +def _validate_message_trigger_compatibility(meta: HandlerMeta) -> None: + if meta.limiter is None or meta.trigger is None: + return + trigger_type = getattr(meta.trigger, "type", None) + if trigger_type not in {"command", "message"}: + raise ValueError( + "rate_limit(...) 和 cooldown(...) 只适用于 on_command/on_message" + ) + + +def _set_required_role( + meta: HandlerMeta, + role: Literal["member", "admin"], +) -> None: + current = meta.permissions.required_role + if current is not None and current != role: + raise ValueError( + f"require_permission({role!r}) 与已有权限要求 {current!r} 冲突" + ) + meta.permissions.required_role = role + meta.permissions.require_admin = role == "admin" + + +def _normalize_description(description: str | None) -> str | None: + if description is None: + return None + text = str(description).strip() + return text or None + + +def _require_handler_callable( + target: Any, + *, + decorator_name: str, +) -> None: + if not callable(target): + raise TypeError(f"{decorator_name} can only decorate callables") + + +def _validate_limiter_args( + *, + kind: str, + limit: int, + window: float, + scope: LimiterScope, + behavior: LimiterBehavior, +) -> None: + if isinstance(limit, bool) or int(limit) <= 0: + raise ValueError(f"{kind} requires a positive limit") + if float(window) <= 0: + raise ValueError(f"{kind} requires a positive window") + if scope not in {"session", "user", "group", "global"}: + raise ValueError(f"unsupported limiter scope: {scope}") + if behavior not in {"hint", "silent", "error"}: + raise ValueError(f"unsupported limiter behavior: {behavior}") + + +def _set_limiter( + func: _HandlerT, + limiter: LimiterMeta, +) -> _HandlerT: + meta = _get_or_create_meta(func) + if meta.limiter is not None: + raise ValueError("rate_limit(...) 和 cooldown(...) 不能叠加在同一个 handler 上") + meta.limiter = limiter + _validate_message_trigger_compatibility(meta) + return func + + +def _model_to_schema( + model: type[BaseModel] | None, + *, + label: str, +) -> dict[str, Any] | None: + """将 pydantic 模型转换为 JSON Schema。 + + Args: + model: pydantic BaseModel 子类 + label: 错误消息中的字段名 + + Returns: + JSON Schema 字典,如果 model 为 None 则返回 None + + Raises: + TypeError: 如果 model 不是 BaseModel 子类 + """ + if model is None: + return None + if not isinstance(model, type) or not issubclass(model, BaseModel): + raise TypeError(f"{label} 必须是 pydantic BaseModel 子类") + return cast(dict[str, Any], model.model_json_schema()) + + +def on_command( + command: str | typing.Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, + group: str | typing.Sequence[str] | None = None, + group_help: str | None = None, +) -> Callable[[_HandlerT], _HandlerT]: + """注册命令处理方法。 + + 当用户发送指定命令时触发。命令格式为 `/{command}` 或直接 `{command}`, + 取决于平台配置。 + + Args: + command: 命令名称(不包含前缀符) + aliases: 命令别名列表 + description: 命令描述,用于帮助信息 + group: 指令组路径。传入 "admin" 表示一级组;传入 ["admin", "user"] 表示多级组 + 设置后实际命令为 ``"admin command"`` 或 ``"admin user command"`` + group_help: 指令组描述,用于帮助信息 + + Returns: + 装饰器函数 + + Example: + @on_command("echo", aliases=["repeat"], description="重复消息") + async def echo(self, event: MessageEvent, ctx: Context): + await event.reply(event.text) + + @on_command("ban", group="admin", description="封禁用户") + async def admin_ban(self, event: MessageEvent, ctx: Context): + await event.reply("已封禁") + """ + + if aliases is not None and not isinstance(aliases, list): + raise TypeError("on_command aliases must be a list of strings") + + commands = ( + [str(command).strip()] + if isinstance(command, str) + else [str(item).strip() for item in command] + ) + commands = [item for item in commands if item] + if not commands: + raise ValueError("on_command requires at least one non-empty command name") + + group_path: list[str] = [] + if group is not None: + group_path = ( + [str(group).strip()] + if isinstance(group, str) + else [str(item).strip() for item in group] + ) + group_path = [item for item in group_path if item] + + canonical = commands[0] + display_command = " ".join([*group_path, canonical]) if group_path else canonical + merged_aliases: list[str] = [ + item + for item in dict.fromkeys([*commands[1:], *(aliases or [])]) + if isinstance(item, str) and item and item != canonical + ] + expanded_aliases: list[str] = ( + [" ".join([*group_path, alias]) for alias in merged_aliases] + if group_path + else merged_aliases + ) + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="on_command(...)") + meta = _get_or_create_meta(func) + normalized_description = _normalize_description(description) + trigger_command = display_command if group_path else canonical + meta.trigger = CommandTrigger( + command=trigger_command, + aliases=expanded_aliases if group_path else merged_aliases, + description=normalized_description, + ) + meta.description = normalized_description + if group_path: + meta.command_route = CommandRouteSpec( + group_path=group_path, + display_command=display_command, + group_help=_normalize_description(group_help), + ) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def on_message( + *, + regex: str | None = None, + keywords: list[str] | None = None, + platforms: list[str] | None = None, + message_types: list[str] | None = None, + description: str | None = None, +) -> Callable[[_HandlerT], _HandlerT]: + """注册消息处理方法。 + + 当消息匹配指定条件时触发。支持正则表达式或关键词匹配。 + + Args: + regex: 正则表达式模式 + keywords: 关键词列表(任一匹配即可) + platforms: 限定平台列表(如 ["qq", "wechat"]) + + Returns: + 装饰器函数 + + Note: + regex 和 keywords 至少提供一个 + + Example: + @on_message(keywords=["help", "帮助"]) + async def help(self, event: MessageEvent, ctx: Context): + await event.reply("帮助信息") + + @on_message(regex=r"\\d+") # 匹配数字 + async def number_handler(self, event: MessageEvent, ctx: Context): + await event.reply("收到了数字") + """ + + if keywords is not None and not isinstance(keywords, list): + raise TypeError("on_message keywords must be a list of strings") + if platforms is not None and not isinstance(platforms, list): + raise TypeError("on_message platforms must be a list of strings") + if message_types is not None and not isinstance(message_types, list): + raise TypeError("on_message message_types must be a list of strings") + + normalized_regex = None if regex is None else str(regex).strip() + normalized_keywords = [ + str(item).strip() for item in (keywords or []) if str(item).strip() + ] + if not normalized_regex and not normalized_keywords: + raise ValueError("on_message(...) requires regex or at least one keyword") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="on_message(...)") + meta = _get_or_create_meta(func) + meta.trigger = MessageTrigger( + regex=normalized_regex, + keywords=normalized_keywords, + platforms=platforms or [], + message_types=message_types or [], + ) + meta.description = _normalize_description(description) + if platforms: + _set_platform_filter(meta, list(platforms), source="trigger.platforms") + if message_types: + _set_message_type_filter( + meta, + list(message_types), + source="trigger.message_types", + ) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def append_filter_meta( + func: _HandlerT, + *, + specs: list[FilterSpec] | None = None, + local_bindings: list[Any] | None = None, +) -> _HandlerT: + """追加过滤器元数据。""" + meta = _get_or_create_meta(func) + if specs: + meta.filters.extend(specs) + if local_bindings: + meta.local_filters.extend(local_bindings) + return func + + +def set_command_route_meta( + func: _HandlerT, + route: CommandRouteSpec, +) -> _HandlerT: + """设置命令路由元数据。""" + meta = _get_or_create_meta(func) + meta.command_route = route + return func + + +def on_event( + event_type: str, + *, + description: str | None = None, +) -> Callable[[_HandlerT], _HandlerT]: + """注册事件处理方法。 + + 当特定类型的事件发生时触发。用于处理非消息类型的事件, + 如群成员变动、好友请求等。 + + Args: + event_type: 事件类型标识 + + Returns: + 装饰器函数 + + Example: + @on_event("group_member_join") + async def on_join(self, event, ctx): + await ctx.platform.send(event.group_id, "欢迎新人!") + """ + + normalized_event_type = str(event_type).strip() + if not normalized_event_type: + raise ValueError("on_event(...) requires a non-empty event_type") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="on_event(...)") + meta = _get_or_create_meta(func) + meta.trigger = EventTrigger(event_type=normalized_event_type) + meta.description = _normalize_description(description) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def on_schedule( + *, + name: str | None = None, + cron: str | None = None, + interval_seconds: int | None = None, + timezone: str | None = None, + description: str | None = None, +) -> Callable[[_HandlerT], _HandlerT]: + """注册定时任务方法。 + + 按指定的时间计划定期执行。 + + Args: + name: 调度任务名称,默认回退为插件 ID 与 handler ID 组合 + cron: cron 表达式(如 "0 8 * * *" 表示每天 8:00) + interval_seconds: 执行间隔(秒) + timezone: IANA 时区名称(如 "Asia/Shanghai") + + Returns: + 装饰器函数 + + Note: + cron 和 interval_seconds 至少提供一个 + + Example: + @on_schedule(cron="0 8 * * *") # 每天 8:00 + async def morning_greeting(self, ctx): + await ctx.platform.send("group_123", "早上好!") + + @on_schedule(interval_seconds=3600) # 每小时 + async def hourly_check(self, ctx): + pass + """ + + normalized_name = None if name is None else str(name).strip() or None + normalized_cron = None if cron is None else str(cron).strip() or None + normalized_timezone = None if timezone is None else str(timezone).strip() or None + if normalized_cron is None and interval_seconds is None: + raise ValueError("on_schedule(...) requires cron or interval_seconds") + if interval_seconds is not None and ( + isinstance(interval_seconds, bool) or int(interval_seconds) <= 0 + ): + raise ValueError("on_schedule(...) interval_seconds must be a positive integer") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="on_schedule(...)") + meta = _get_or_create_meta(func) + meta.trigger = ScheduleTrigger( + name=normalized_name, + cron=normalized_cron, + interval_seconds=( + None if interval_seconds is None else int(interval_seconds) + ), + timezone=normalized_timezone, + ) + meta.description = _normalize_description(description) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def http_api( + route: str, + *, + methods: list[str] | None = None, + description: str = "", + capability_name: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + normalized_route = str(route).strip() + if not normalized_route: + raise ValueError("http_api(...) requires a non-empty route") + normalized_methods = methods or ["GET"] + normalized_methods = [ + str(item).strip().upper() for item in normalized_methods if str(item).strip() + ] + if not normalized_methods: + raise ValueError("http_api(...) requires at least one HTTP method") + + def decorator(func: HandlerCallable) -> HandlerCallable: + _require_handler_callable(func, decorator_name="http_api(...)") + setattr( + func, + HTTP_API_META_ATTR, + HttpApiMeta( + route=normalized_route, + methods=normalized_methods, + description=str(description), + capability_name=( + str(capability_name).strip() + if capability_name is not None + else None + ), + ), + ) + return func + + return decorator + + +def validate_config( + *, + model: type[BaseModel] | None = None, + schema: dict[str, Any] | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + if model is None and schema is None: + raise ValueError("validate_config(...) requires model or schema") + if model is not None and schema is not None: + raise ValueError("validate_config(...) cannot accept model and schema together") + if model is not None and ( + not isinstance(model, type) or not issubclass(model, BaseModel) + ): + raise TypeError("validate_config model must be a pydantic BaseModel subclass") + if schema is not None and not isinstance(schema, dict): + raise TypeError("validate_config schema must be a dict") + if isinstance(schema, dict): + _validate_validate_config_schema(schema) + + def decorator(func: HandlerCallable) -> HandlerCallable: + _require_handler_callable(func, decorator_name="validate_config(...)") + setattr( + func, + VALIDATE_CONFIG_META_ATTR, + ValidateConfigMeta( + model=model, + schema=dict(schema) if isinstance(schema, dict) else None, + ), + ) + return func + + return decorator + + +def on_provider_change( + *, + provider_types: list[str] | tuple[str, ...] | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + normalized = [ + str(item).strip().lower() + for item in (provider_types or []) + if str(item).strip() + ] + + def decorator(func: HandlerCallable) -> HandlerCallable: + _require_handler_callable(func, decorator_name="on_provider_change(...)") + setattr( + func, + PROVIDER_CHANGE_META_ATTR, + ProviderChangeMeta(provider_types=normalized), + ) + return func + + return decorator + + +def background_task( + *, + description: str = "", + auto_start: bool = True, + on_error: Literal["log", "restart"] = "log", +) -> Callable[[HandlerCallable], HandlerCallable]: + if on_error not in {"log", "restart"}: + raise ValueError("background_task on_error must be 'log' or 'restart'") + + def decorator(func: HandlerCallable) -> HandlerCallable: + _require_handler_callable(func, decorator_name="background_task(...)") + setattr( + func, + BACKGROUND_TASK_META_ATTR, + BackgroundTaskMeta( + description=str(description), + auto_start=bool(auto_start), + on_error=on_error, + ), + ) + return func + + return decorator + + +def mcp_server( + *, + name: str, + scope: Literal["local", "global"] = "global", + config: dict[str, Any] | None = None, + timeout: float = 30.0, + wait_until_ready: bool = True, +): + normalized_name = str(name).strip() + if not normalized_name: + raise ValueError("mcp_server(...) requires a non-empty name") + if scope not in {"local", "global"}: + raise ValueError("mcp_server scope must be 'local' or 'global'") + if config is not None and not isinstance(config, dict): + raise TypeError("mcp_server config must be a dict") + if float(timeout) <= 0: + raise ValueError("mcp_server timeout must be positive") + + meta = MCPServerMeta( + name=normalized_name, + scope=scope, + config=dict(config) if isinstance(config, dict) else None, + timeout=float(timeout), + wait_until_ready=bool(wait_until_ready), + ) + + def decorator(target): + _append_list_meta(target, MCP_SERVER_META_ATTR, meta) + return target + + return decorator + + +def register_skill( + *, + name: str, + path: str, + description: str = "", +): + normalized_name = str(name).strip() + normalized_path = str(path).strip() + if not normalized_name: + raise ValueError("register_skill(...) requires a non-empty name") + if not normalized_path: + raise ValueError("register_skill(...) requires a non-empty path") + + meta = SkillMeta( + name=normalized_name, + path=normalized_path, + description=str(description), + ) + + def decorator(target): + _append_list_meta(target, SKILL_META_ATTR, meta) + return target + + return decorator + + +def require_admin(func: _HandlerT) -> _HandlerT: + """标记 handler 需要管理员权限。 + + 当用户不是管理员时,handler 将不会被调用。 + + Args: + func: 要标记的方法 + + Returns: + 标记后的方法 + + Example: + @on_command("admin") + @require_admin + async def admin_only(self, event: MessageEvent, ctx: Context): + await event.reply("管理员命令执行成功") + """ + _require_handler_callable(func, decorator_name="require_admin") + meta = _get_or_create_meta(func) + _set_required_role(meta, "admin") + return func + + +def admin_only(func: _HandlerT) -> _HandlerT: + return require_admin(func) + + +def require_permission( + role: Literal["member", "admin"], +) -> Callable[[_HandlerT], _HandlerT]: + normalized_role = str(role).strip().lower() + if normalized_role not in {"member", "admin"}: + raise ValueError("require_permission(...) 只支持 'member' 或 'admin'") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="require_permission(...)") + meta = _get_or_create_meta(func) + _set_required_role( + meta, + cast(Literal["member", "admin"], normalized_role), + ) + return func + + return decorator + + +def platforms(*names: str) -> Callable[[_HandlerT], _HandlerT]: + normalized_names = [str(name).strip() for name in names if str(name).strip()] + if not normalized_names: + raise ValueError("platforms(...) requires at least one non-empty platform name") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="platforms(...)") + meta = _get_or_create_meta(func) + _set_platform_filter(meta, normalized_names, source="decorator.platforms") + return func + + return decorator + + +def message_types(*types: str) -> Callable[[_HandlerT], _HandlerT]: + normalized_types = [str(item).strip() for item in types if str(item).strip()] + if not normalized_types: + raise ValueError("message_types(...) requires at least one non-empty type") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="message_types(...)") + meta = _get_or_create_meta(func) + _set_message_type_filter( + meta, + normalized_types, + source="decorator.message_types", + ) + return func + + return decorator + + +def group_only() -> Callable[[_HandlerT], _HandlerT]: + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="group_only()") + meta = _get_or_create_meta(func) + _set_message_type_filter(meta, ["group"], source="decorator.group_only") + return func + + return decorator + + +def private_only() -> Callable[[_HandlerT], _HandlerT]: + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="private_only()") + meta = _get_or_create_meta(func) + _set_message_type_filter(meta, ["private"], source="decorator.private_only") + return func + + return decorator + + +def priority(value: int) -> Callable[[_HandlerT], _HandlerT]: + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError("priority(...) requires an integer") + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="priority(...)") + meta = _get_or_create_meta(func) + meta.priority = value + return func + + return decorator + + +def rate_limit( + limit: int, + window: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable[[_HandlerT], _HandlerT]: + _validate_limiter_args( + kind="rate_limit", + limit=limit, + window=window, + scope=scope, + behavior=behavior, + ) + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="rate_limit(...)") + return _set_limiter( + func, + LimiterMeta( + kind="rate_limit", + limit=int(limit), + window=float(window), + scope=scope, + behavior=behavior, + message=message, + ), + ) + + return decorator + + +def cooldown( + seconds: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable[[_HandlerT], _HandlerT]: + _validate_limiter_args( + kind="cooldown", + limit=1, + window=seconds, + scope=scope, + behavior=behavior, + ) + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="cooldown(...)") + return _set_limiter( + func, + LimiterMeta( + kind="cooldown", + limit=1, + window=float(seconds), + scope=scope, + behavior=behavior, + message=message, + ), + ) + + return decorator + + +def conversation_command( + command: str | typing.Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, + group: str | typing.Sequence[str] | None = None, + group_help: str | None = None, + timeout: int = 60, + mode: ConversationMode = "replace", + busy_message: str | None = None, + grace_period: float = 1.0, +) -> Callable[[_HandlerT], _HandlerT]: + """注册带会话生命周期的命令处理方法。 + + 在 ``on_command`` 基础上附加会话元数据,支持超时、并发策略和宽限期控制。 + + Args: + command: 命令名称或序列(首项为正式名,其余视为别名) + aliases: 额外别名列表 + description: 命令描述 + group: 指令组路径,例如 ``"admin"`` 或 ``["admin", "user"]`` + group_help: 指令组描述,用于帮助信息 + timeout: 会话超时时间(秒),必须为正整数 + mode: 会话冲突时的行为: + - ``"replace"``: 替换当前会话 + - ``"reject"``: 拒绝新请求 + busy_message: 拒绝新请求时的提示消息 + grace_period: 宽限期(秒),用于会话生命周期处理 + + Returns: + 装饰器函数 + + Raises: + ValueError: mode 不合法、timeout 非正整数或 grace_period 非正数 + + Example: + @conversation_command("chat", timeout=120, mode="reject", busy_message="请稍后再试") + async def chat(self, event: MessageEvent, ctx: Context): + await event.reply("开始对话...") + """ + if mode not in {"replace", "reject"}: + raise ValueError("conversation_command mode must be 'replace' or 'reject'") + # bool 是 int 子类,需单独排除 + if isinstance(timeout, bool) or int(timeout) <= 0: + raise ValueError("conversation_command timeout must be a positive integer") + if float(grace_period) <= 0: + raise ValueError("conversation_command grace_period must be positive") + + command_decorator = on_command( + command, + aliases=aliases, + description=description, + group=group, + group_help=group_help, + ) + + def decorator(func: _HandlerT) -> _HandlerT: + _require_handler_callable(func, decorator_name="conversation_command(...)") + decorated = command_decorator(func) + meta = _get_or_create_meta(decorated) + meta.conversation = ConversationMeta( + timeout=int(timeout), + mode=mode, + busy_message=busy_message, + grace_period=float(grace_period), + ) + return decorated + + return decorator + + +def provide_capability( + name: str, + *, + description: str, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + input_model: type[BaseModel] | None = None, + output_model: type[BaseModel] | None = None, + supports_stream: bool = False, + cancelable: bool = False, +) -> Callable[[HandlerCallable], HandlerCallable]: + """声明插件对外暴露的 capability。 + + 允许其他插件或 Core 通过 capability 名称调用此方法。 + 支持使用 JSON Schema 或 pydantic 模型定义输入输出。 + + Args: + name: capability 名称(不能使用保留命名空间,且运行时必须以当前 plugin_id 为前缀) + description: 能力描述 + input_schema: 输入 JSON Schema + output_schema: 输出 JSON Schema + input_model: 输入 pydantic 模型(与 input_schema 二选一) + output_model: 输出 pydantic 模型(与 output_schema 二选一) + supports_stream: 是否支持流式输出 + cancelable: 是否可取消 + + Returns: + 装饰器函数 + + Raises: + ValueError: 如果使用保留命名空间,或同时提供 schema 和 model + + Example: + @provide_capability( + "my_plugin.calculate", + description="执行计算", + input_model=CalculateInput, + output_model=CalculateOutput, + ) + async def calculate(self, payload: dict, ctx: Context): + return {"result": payload["x"] * 2} + """ + + normalized_name = str(name).strip() + if not normalized_name: + raise ValueError("provide_capability(...) requires a non-empty name") + normalized_description = _normalize_description(description) + if normalized_description is None: + raise ValueError("provide_capability(...) requires a non-empty description") + if input_schema is not None and not isinstance(input_schema, dict): + raise TypeError("input_schema must be a dict") + if output_schema is not None and not isinstance(output_schema, dict): + raise TypeError("output_schema must be a dict") + + def decorator(func: HandlerCallable) -> HandlerCallable: + _require_handler_callable(func, decorator_name="provide_capability(...)") + if normalized_name.startswith(RESERVED_CAPABILITY_PREFIXES): + raise ValueError( + f"保留 capability 命名空间不能用于插件导出:{normalized_name}" + ) + if input_schema is not None and input_model is not None: + raise ValueError("input_schema 和 input_model 不能同时提供") + if output_schema is not None and output_model is not None: + raise ValueError("output_schema 和 output_model 不能同时提供") + descriptor = CapabilityDescriptor( + name=normalized_name, + description=normalized_description, + input_schema=( + input_schema + if input_schema is not None + else _model_to_schema(input_model, label="input_model") + ), + output_schema=( + output_schema + if output_schema is not None + else _model_to_schema(output_model, label="output_model") + ), + supports_stream=supports_stream, + cancelable=cancelable, + ) + setattr(func, CAPABILITY_META_ATTR, CapabilityMeta(descriptor=descriptor)) + return func + + return decorator + + +def _annotation_to_schema(annotation: Any) -> dict[str, Any]: + normalized, _is_optional = unwrap_optional(annotation) + origin = typing.get_origin(normalized) + if normalized is str: + return {"type": "string"} + if normalized is int: + return {"type": "integer"} + if normalized is float: + return {"type": "number"} + if normalized is bool: + return {"type": "boolean"} + if normalized is dict or origin is dict: + return {"type": "object"} + if normalized is list or origin is list: + args = typing.get_args(normalized) + item_schema = _annotation_to_schema(args[0]) if args else {} + return {"type": "array", "items": item_schema} + return {"type": "string"} + + +def _callable_parameters_schema(func: HandlerCallable) -> dict[str, Any]: + signature = inspect.signature(func) + type_hints: dict[str, Any] = {} + try: + type_hints = typing.get_type_hints(func) + except Exception: + type_hints = {} + + properties: dict[str, Any] = {} + required: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + if parameter.name == "self": + continue + annotation = type_hints.get(parameter.name) + normalized, _is_optional = unwrap_optional(annotation) + if parameter.name in {"event", "ctx", "context"}: + continue + properties[parameter.name] = _annotation_to_schema(normalized) + if parameter.default is inspect.Parameter.empty and not _is_optional: + required.append(parameter.name) + schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + schema["required"] = required + return schema + + +def register_llm_tool( + name: str | None = None, + *, + description: str | None = None, + parameters_schema: dict[str, Any] | None = None, + active: bool = True, +) -> Callable[[HandlerCallable], HandlerCallable]: + if parameters_schema is not None and not isinstance(parameters_schema, dict): + raise TypeError("register_llm_tool parameters_schema must be a dict") + if not isinstance(active, bool): + raise TypeError("register_llm_tool active must be a bool") + + def decorator(func: HandlerCallable) -> HandlerCallable: + _require_handler_callable(func, decorator_name="register_llm_tool(...)") + tool_name = str(name or func.__name__).strip() + if not tool_name: + raise ValueError("LLM tool name must not be empty") + setattr( + func, + LLM_TOOL_META_ATTR, + LLMToolMeta( + spec=LLMToolSpec.create( + name=tool_name, + description=description + or (inspect.getdoc(func) or "").splitlines()[0] + if inspect.getdoc(func) + else "", + parameters_schema=parameters_schema + or _callable_parameters_schema(func), + handler_ref=tool_name, + active=active, + ) + ), + ) + return func + + return decorator + + +def register_agent( + name: str, + *, + description: str = "", + tool_names: list[str] | None = None, +) -> Callable[[type[BaseAgentRunner]], type[BaseAgentRunner]]: + if tool_names is not None and not isinstance(tool_names, list): + raise TypeError("register_agent tool_names must be a list of strings") + normalized_name = str(name).strip() + if not normalized_name: + raise ValueError("register_agent(...) requires a non-empty name") + normalized_tool_names = [ + str(tool_name).strip() + for tool_name in dict.fromkeys(tool_names or []) + if str(tool_name).strip() + ] + + def decorator(cls: type[BaseAgentRunner]) -> type[BaseAgentRunner]: + if not inspect.isclass(cls) or not issubclass(cls, BaseAgentRunner): + raise TypeError("@register_agent() 只接受 BaseAgentRunner 子类") + setattr( + cls, + AGENT_META_ATTR, + AgentMeta( + spec=AgentSpec( + name=normalized_name, + description=description, + tool_names=normalized_tool_names, + runner_class=f"{cls.__module__}.{cls.__qualname__}", + ) + ), + ) + return cls + + return decorator + + +def acknowledge_global_mcp_risk(cls: type[Any]) -> type[Any]: + """Mark an SDK plugin class as eligible to mutate global MCP state. + + This is intentionally a coarse, class-level marker. Runtime enforcement lives + in the Core MCP capability bridge. + """ + + setattr(cls, "__astrbot_acknowledge_global_mcp_risk__", True) + return cls diff --git a/astrbot-sdk/src/astrbot_sdk/errors.py b/astrbot-sdk/src/astrbot_sdk/errors.py new file mode 100644 index 0000000000..c33244f387 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/errors.py @@ -0,0 +1,311 @@ +"""跨运行时边界传递的统一错误模型。 + +AstrBotError 是 SDK 中所有可预期错误的标准格式, +支持跨进程传递(通过 to_payload/from_payload 序列化)。 + +错误处理流程: + 1. 运行时抛出 AstrBotError 子类或实例 + 2. 错误被捕获并序列化为 payload + 3. 跨进程传输后反序列化 + 4. 在 on_error 钩子中统一处理 + +Example: + # 抛出错误 + raise AstrBotError.invalid_input("参数不能为空") + + # 捕获并处理 + try: + await some_operation() + except AstrBotError as e: + if e.retryable: + # 可重试的错误 + await retry() + else: + # 不可重试的错误 + await event.reply(e.hint or e.message) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +class ErrorCodes: + """AstrBot SDK 的稳定错误码常量。 + + 这些错误码在协议层稳定,不应随意更改。 + 新增错误码应放在对应分类的末尾。 + + 分类: + - 不可重试错误(retryable=False):配置错误、权限错误等 + - 可重试错误(retryable=True):网络超时、临时故障等 + """ + + UNKNOWN_ERROR = "unknown_error" + + # 不可重试错误 - 配置或使用问题 + LLM_NOT_CONFIGURED = "llm_not_configured" + CAPABILITY_NOT_FOUND = "capability_not_found" + PERMISSION_DENIED = "permission_denied" + LLM_ERROR = "llm_error" + INVALID_INPUT = "invalid_input" + CANCELLED = "cancelled" + PROTOCOL_VERSION_MISMATCH = "protocol_version_mismatch" + PROTOCOL_ERROR = "protocol_error" + INTERNAL_ERROR = "internal_error" + RATE_LIMITED = "rate_limited" + COOLDOWN_ACTIVE = "cooldown_active" + + # 可重试错误 - 临时故障 + CAPABILITY_TIMEOUT = "capability_timeout" + NETWORK_ERROR = "network_error" + LLM_TEMPORARY_ERROR = "llm_temporary_error" + + +@dataclass(slots=True) +class AstrBotError(Exception): + """AstrBot SDK 的标准错误类型。 + + 所有可预期的错误都应使用此类或其工厂方法创建。 + 支持跨进程传递,包含用户友好的提示信息。 + + Attributes: + code: 错误码,来自 ErrorCodes 常量 + message: 错误消息,面向开发者 + hint: 用户提示,面向终端用户 + retryable: 是否可重试 + + Example: + # 使用工厂方法创建错误 + raise AstrBotError.invalid_input("参数格式错误", hint="请使用 JSON 格式") + + # 检查错误类型 + try: + await operation() + except AstrBotError as e: + if e.code == ErrorCodes.CAPABILITY_NOT_FOUND: + logger.error(f"能力不存在: {e.message}") + """ + + code: str + message: str + hint: str = "" + retryable: bool = False + docs_url: str = "" + details: dict[str, Any] | None = None + + def __str__(self) -> str: + return self.message + + @classmethod + def cancelled(cls, message: str = "调用被取消") -> AstrBotError: + """创建取消错误。 + + Args: + message: 错误消息 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.CANCELLED, + message=message, + hint="", + retryable=False, + ) + + @classmethod + def capability_not_found(cls, name: str) -> AstrBotError: + """创建能力未找到错误。 + + Args: + name: 未找到的能力名称 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.CAPABILITY_NOT_FOUND, + message=f"未找到能力:{name}", + hint="请确认 AstrBot Core 是否已注册该 capability", + retryable=False, + ) + + @classmethod + def invalid_input( + cls, + message: str, + *, + hint: str = "请检查调用参数", + docs_url: str = "", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + """创建输入无效错误。 + + Args: + message: 详细错误消息 + hint: 用户提示 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.INVALID_INPUT, + message=message, + hint=hint, + retryable=False, + docs_url=docs_url, + details=details, + ) + + @classmethod + def protocol_version_mismatch(cls, message: str) -> AstrBotError: + """创建协议版本不匹配错误。 + + Args: + message: 详细错误消息 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.PROTOCOL_VERSION_MISMATCH, + message=message, + hint="请升级 astrbot_sdk 至最新版本", + retryable=False, + ) + + @classmethod + def protocol_error(cls, message: str) -> AstrBotError: + """创建协议错误。 + + Args: + message: 详细错误消息 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.PROTOCOL_ERROR, + message=message, + hint="请检查通信双方的协议实现", + retryable=False, + ) + + @classmethod + def internal_error( + cls, + message: str, + *, + hint: str = "请联系插件作者", + docs_url: str = "", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + """创建内部错误。 + + Args: + message: 详细错误消息 + hint: 用户提示 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.INTERNAL_ERROR, + message=message, + hint=hint, + retryable=False, + docs_url=docs_url, + details=details, + ) + + @classmethod + def network_error( + cls, + message: str, + *, + hint: str = "网络请求失败,请稍后重试", + docs_url: str = "", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + return cls( + code=ErrorCodes.NETWORK_ERROR, + message=message, + hint=hint, + retryable=True, + docs_url=docs_url, + details=details, + ) + + @classmethod + def rate_limited( + cls, + *, + hint: str = "操作过于频繁,请稍后再试。", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + return cls( + code=ErrorCodes.RATE_LIMITED, + message="handler invocation is rate limited", + hint=hint, + retryable=False, + details=details, + ) + + @classmethod + def cooldown_active( + cls, + *, + hint: str, + details: dict[str, Any] | None = None, + ) -> AstrBotError: + return cls( + code=ErrorCodes.COOLDOWN_ACTIVE, + message="handler cooldown is active", + hint=hint, + retryable=False, + details=details, + ) + + def to_payload(self) -> dict[str, object]: + """序列化为可传输的字典格式。 + + 用于跨进程传递错误信息。 + + Returns: + 包含错误信息的字典 + """ + return { + "code": self.code, + "message": self.message, + "hint": self.hint, + "retryable": self.retryable, + "docs_url": self.docs_url, + "details": dict(self.details) if isinstance(self.details, dict) else None, + } + + @classmethod + def from_payload(cls, payload: dict[str, object]) -> AstrBotError: + """从字典反序列化错误实例。 + + Args: + payload: 包含错误信息的字典 + + Returns: + AstrBotError 实例 + """ + details_payload = payload.get("details") + details = ( + {str(key): value for key, value in details_payload.items()} + if isinstance(details_payload, dict) + else None + ) + return cls( + code=str(payload.get("code", ErrorCodes.UNKNOWN_ERROR)), + message=str(payload.get("message", "未知错误")), + hint=str(payload.get("hint", "")), + retryable=bool(payload.get("retryable", False)), + docs_url=str(payload.get("docs_url", "")), + details=details, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/events.py b/astrbot-sdk/src/astrbot_sdk/events.py new file mode 100644 index 0000000000..22f85255c7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/events.py @@ -0,0 +1,794 @@ +"""astrbot-sdk 原生事件对象。 + +顶层 ``MessageEvent`` 保持精简,只承载 astrbot-sdk 运行时真正需要的基础能力。 +迁移期扩展事件能力放在独立模块中,而不是继续塞回顶层事件类型。 + +MessageEvent 是 handler 接收的主要事件类型,封装了: + - 消息文本内容 + - 发送者信息(user_id, group_id) + - 平台标识 + - 回复能力(reply, reply_image, reply_chain) +""" + +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, TypeVar + +from ._message_types import normalize_message_type +from .message.components import ( + At, + BaseMessageComponent, + File, + Image, + Plain, + component_to_payload_sync, + payloads_to_components, +) +from .message.result import EventResultType, MessageChain, MessageEventResult +from .protocol.descriptors import SessionRef + +if TYPE_CHECKING: + from .context import Context + + +@dataclass(slots=True) +class PlainTextResult: + """纯文本结果。 + + 用于 handler 返回简单的文本结果。 + """ + + text: str + + +ReplyHandler = Callable[[str], Awaitable[None]] +_MessageComponentT = TypeVar("_MessageComponentT", bound=BaseMessageComponent) + +_JSON_DROP = object() + + +def _coerce_str(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return str(value) + + +def _coerce_optional_str(value: Any) -> str | None: + if value is None: + return None + text = value if isinstance(value, str) else str(value) + return text or None + + +def _json_safe_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + items = [] + for item in value: + normalized = _json_safe_value(item) + if normalized is not _JSON_DROP: + items.append(normalized) + return items + if isinstance(value, dict): + normalized_dict: dict[str, Any] = {} + for key, item in value.items(): + normalized = _json_safe_value(item) + if normalized is not _JSON_DROP: + normalized_dict[str(key)] = normalized + return normalized_dict + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + try: + return _json_safe_value(model_dump()) + except Exception: + return _JSON_DROP + try: + json.dumps(value) + except (TypeError, ValueError): + return _JSON_DROP + return value + + +def _json_safe_mapping(value: Any) -> dict[str, Any]: + if not isinstance(value, dict): + return {} + normalized: dict[str, Any] = {} + for key, item in value.items(): + safe_item = _json_safe_value(item) + if safe_item is not _JSON_DROP: + normalized[str(key)] = safe_item + return normalized + + +class MessageEvent: + """消息事件对象。 + + 封装收到的消息,提供便捷的回复方法。 + 每个 handler 调用都会创建新的 MessageEvent 实例。 + + Attributes: + text: 消息文本内容 + user_id: 发送者用户 ID,缺失时为空字符串 + group_id: 群组 ID(私聊时为 None) + platform: 平台标识(如 "qq", "wechat"),缺失时为空字符串 + session_id: 会话 ID(通常是 group_id 或 user_id,缺失时为空字符串) + raw: 原始消息数据 + + Example: + @on_command("echo") + async def echo(self, event: MessageEvent, ctx: Context): + await event.reply(f"你说: {event.text}") + """ + + text: str + user_id: str + group_id: str | None + platform: str + session_id: str + self_id: str + platform_id: str + message_type: str + sender_name: str + raw: dict[str, Any] + _is_admin: bool + _stopped: bool + _host_extras: dict[str, Any] + _host_extras_present: bool + _sdk_local_extras: dict[str, Any] + _sdk_local_extras_present: bool + _sdk_local_extras_dirty: bool + _messages: list[BaseMessageComponent] + _messages_present: bool + _message_outline: str + _sent_messages: list[BaseMessageComponent] + _sent_messages_present: bool + _sent_message_outline: str + _sent_message_outline_present: bool + _context: Context | None + _reply_handler: ReplyHandler | None + + def __init__( + self, + *, + text: str = "", + user_id: str | None = None, + group_id: str | None = None, + platform: str | None = None, + session_id: str | None = None, + self_id: str | None = None, + platform_id: str | None = None, + message_type: str | None = None, + sender_name: str | None = None, + is_admin: bool = False, + raw: dict[str, Any] | None = None, + context: Context | None = None, + reply_handler: ReplyHandler | None = None, + ) -> None: + """初始化消息事件。 + + Args: + text: 消息文本 + user_id: 用户 ID + group_id: 群组 ID + platform: 平台标识 + session_id: 会话 ID,None 时自动从 group_id/user_id 推断 + raw: 原始消息数据 + context: 运行时上下文 + reply_handler: 自定义回复处理器 + """ + normalized_user_id = _coerce_str(user_id) + normalized_group_id = _coerce_optional_str(group_id) + normalized_platform = _coerce_str(platform) + normalized_session_id = _coerce_str(session_id) + + self.text = text + self.user_id = normalized_user_id + self.group_id = normalized_group_id + self.platform = normalized_platform + self.session_id = ( + normalized_session_id or normalized_group_id or normalized_user_id or "" + ) + self.self_id = _coerce_str(self_id) + self.platform_id = _coerce_str(platform_id) or normalized_platform + self.message_type = normalize_message_type( + message_type, + group_id=normalized_group_id, + user_id=normalized_user_id, + ) + self.sender_name = _coerce_str(sender_name) + self._is_admin = bool(is_admin) + self.raw = raw or {} + self._stopped = False + host_extras = self.raw.get("host_extras") + raw_extras = self.raw.get("extras") + self._host_extras = _json_safe_mapping( + host_extras if isinstance(host_extras, dict) else raw_extras + ) + self._host_extras_present = "host_extras" in self.raw or "extras" in self.raw + sdk_local_extras = self.raw.get("sdk_local_extras") + self._sdk_local_extras = _json_safe_mapping(sdk_local_extras) + self._sdk_local_extras_present = "sdk_local_extras" in self.raw + self._sdk_local_extras_dirty = False + messages_payload = self.raw.get("messages") + self._messages = ( + payloads_to_components(messages_payload) + if isinstance(messages_payload, list) + else [] + ) + self._messages_present = "messages" in self.raw + self._message_outline = str(self.raw.get("message_outline", self.text)) + sent_messages_payload = self.raw.get("sent_messages") + self._sent_messages = ( + payloads_to_components(sent_messages_payload) + if isinstance(sent_messages_payload, list) + else [] + ) + self._sent_messages_present = "sent_messages" in self.raw + self._sent_message_outline = str(self.raw.get("sent_message_outline", "")) + self._sent_message_outline_present = "sent_message_outline" in self.raw + self._context = context + self._reply_handler = reply_handler + if self._reply_handler is None and context is not None: + self._reply_handler = lambda text: context.platform.send( + self.session_ref or self.session_id, + text, + ) + + def _require_runtime_context(self, action: str) -> Context: + """获取运行时上下文,不存在则抛出异常。""" + if self._context is None: + raise RuntimeError(f"MessageEvent 未绑定运行时上下文,无法 {action}") + return self._context + + def _reply_target(self) -> SessionRef | str: + """获取回复目标。""" + return self.session_ref or self.session_id + + @classmethod + def from_payload( + cls, + payload: dict[str, Any], + *, + context: Context | None = None, + reply_handler: ReplyHandler | None = None, + ) -> MessageEvent: + """从协议载荷创建事件实例。 + + Args: + payload: 协议层传递的消息数据 + context: 运行时上下文 + reply_handler: 自定义回复处理器 + + Returns: + 新的 MessageEvent 实例 + """ + target_payload = payload.get("target") + session_id = payload.get("session_id") + platform = payload.get("platform") + if isinstance(target_payload, dict): + target = SessionRef.model_validate(target_payload) + session_id = session_id or target.session + platform = platform or target.platform + return cls( + text=str(payload.get("text", "")), + user_id=payload.get("user_id"), + group_id=payload.get("group_id"), + platform=platform, + session_id=session_id, + self_id=payload.get("self_id"), + platform_id=payload.get("platform_id"), + message_type=payload.get("message_type"), + sender_name=payload.get("sender_name"), + is_admin=bool(payload.get("is_admin", False)), + raw=payload, + context=context, + reply_handler=reply_handler, + ) + + def to_payload(self) -> dict[str, Any]: + """转换为协议载荷格式。 + + Returns: + 可序列化的字典 + """ + payload = dict(self.raw) + payload.update( + { + "text": self.text, + "user_id": self.user_id, + "group_id": self.group_id, + "platform": self.platform, + "session_id": self.session_id, + "self_id": self.self_id, + "platform_id": self.platform_id, + "message_type": self.message_type, + "sender_name": self.sender_name, + "is_admin": self._is_admin, + } + ) + if self.session_ref is not None: + payload["target"] = self.session_ref.to_payload() + merged_extras = dict(self._host_extras) + merged_extras.update(self._sdk_local_extras_payload()) + if merged_extras: + payload["extras"] = merged_extras + elif self._host_extras_present: + payload["extras"] = {} + else: + payload.pop("extras", None) + if self._host_extras or self._host_extras_present: + payload["host_extras"] = dict(self._host_extras) + else: + payload.pop("host_extras", None) + sdk_local_extras = self._sdk_local_extras_payload() + if sdk_local_extras or self._should_serialize_sdk_local_extras(): + payload["sdk_local_extras"] = sdk_local_extras + else: + payload.pop("sdk_local_extras", None) + if self._messages or self._messages_present: + payload["messages"] = [ + component_to_payload_sync(component) for component in self._messages + ] + else: + payload.pop("messages", None) + payload["message_outline"] = self._message_outline + if self._sent_messages or self._sent_messages_present: + payload["sent_messages"] = [ + component_to_payload_sync(component) + for component in self._sent_messages + ] + else: + payload.pop("sent_messages", None) + if self._sent_message_outline or self._sent_message_outline_present: + payload["sent_message_outline"] = self._sent_message_outline + else: + payload.pop("sent_message_outline", None) + return payload + + @property + def session_ref(self) -> SessionRef | None: + """获取会话引用对象。 + + Returns: + SessionRef 实例,如果没有有效的 session_id 则返回 None + """ + if not self.session_id: + return None + return SessionRef( + conversation_id=self.session_id, + platform=self.platform, + raw=self.raw or None, + ) + + @property + def target(self) -> SessionRef | None: + """session_ref 的别名。""" + return self.session_ref + + @property + def unified_msg_origin(self) -> str: + """Unified message origin string.""" + return self.session_id + + def is_private_chat(self) -> bool: + """Whether the current event belongs to a private chat.""" + if self.message_type: + return self.message_type == "private" + return not bool(self.group_id) + + def is_group_chat(self) -> bool: + if self.message_type: + return self.message_type == "group" + return bool(self.group_id) + + def get_platform_id(self) -> str: + """Get the platform instance identifier.""" + return self.platform_id + + def get_message_type(self) -> str: + """Get the normalized message type.""" + return self.message_type + + def get_session_id(self) -> str: + """Get the current session identifier.""" + return self.session_id + + def is_admin(self) -> bool: + """Whether the sender has admin permission.""" + return self._is_admin + + def has_admin_permission(self) -> bool: + """Return whether the sender currently has administrator permission.""" + return self.is_admin() + + def get_messages(self) -> list[BaseMessageComponent]: + """Return SDK message components for the current event.""" + return list(self._messages) + + def get_sent_messages(self) -> list[BaseMessageComponent]: + """Return outbound SDK message components for after-send events.""" + return list(self._sent_messages) + + def has_component(self, type_: type[BaseMessageComponent]) -> bool: + return any(isinstance(component, type_) for component in self._messages) + + def get_components( + self, + type_: type[_MessageComponentT], + ) -> list[_MessageComponentT]: + return [ + component for component in self._messages if isinstance(component, type_) + ] + + def get_images(self) -> list[Image]: + return self.get_components(Image) + + def get_files(self) -> list[File]: + return self.get_components(File) + + def extract_plain_text(self) -> str: + return " ".join( + component.text + for component in self._messages + if isinstance(component, Plain) + ) + + def get_at_users(self) -> list[str]: + return [ + str(component.qq) + for component in self._messages + if isinstance(component, At) and str(component.qq).lower() != "all" + ] + + def get_message_outline(self) -> str: + """Return the normalized message outline.""" + return self._message_outline + + def get_sent_message_outline(self) -> str: + """Return the outbound message outline for after-send events.""" + return self._sent_message_outline + + async def get_group(self) -> dict[str, Any] | None: + """Get current-group metadata for the bound message request.""" + context = self._require_runtime_context("get_group") + output = await context._proxy.call( # noqa: SLF001 + "platform.get_group", + { + "session": self.session_id, + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + payload = output.get("group") + if not isinstance(payload, dict): + return None + return dict(payload) + + def set_extra(self, key: str, value: Any) -> None: + """Store SDK-local transient event data. + + Values written here are immediately available through ``get_extra()`` + inside the current handler invocation. If you expect the value to remain + available after the event crosses the SDK bridge into a later handler or + lifecycle event, store only JSON-serializable data. + + Recommended approach: + - Keep values to ``dict`` / ``list`` / ``str`` / ``int`` / ``float`` / + ``bool`` / ``None`` and nested combinations of those types. + - Convert framework objects into payloads before storing them. For + message components, use ``component_to_payload_sync()`` before + ``set_extra()`` and ``payload_to_component()`` after ``get_extra()``. + + Non-serializable values may still be readable in the current handler, + but they will be dropped when the SDK bridge serializes extras for a + later event. + """ + self._sdk_local_extras[key] = value + self._sdk_local_extras_dirty = True + + def get_extra(self, key: str | None = None, default: Any = None) -> Any: + """Read SDK-local transient event data. + + Extras returned here merge host-provided extras with values previously + written via ``set_extra()``. If a key was written with a + non-serializable value, it may disappear after the event is serialized + across the SDK bridge. In that case, persist a JSON-safe payload + instead of the original object. + """ + extras = dict(self._host_extras) + extras.update(self._sdk_local_extras) + if key is None: + return extras + return extras.get(key, default) + + def clear_extra(self) -> None: + """Clear SDK-local transient event data.""" + self._sdk_local_extras.clear() + self._sdk_local_extras_dirty = True + + def _sdk_local_extras_payload(self) -> dict[str, Any]: + return _json_safe_mapping(self._sdk_local_extras) + + def _should_serialize_sdk_local_extras(self) -> bool: + return ( + self._sdk_local_extras_present + or self._sdk_local_extras_dirty + or bool(self._sdk_local_extras) + ) + + async def request_llm(self) -> bool: + """Request the default LLM chain for the current message request.""" + context = self._require_runtime_context("request_llm") + output = await context._proxy.call( # noqa: SLF001 + "system.event.llm.request", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + return bool(output.get("should_call_llm", False)) + + async def should_call_llm(self) -> bool: + """Read the current default-LLM decision from the host bridge.""" + context = self._require_runtime_context("should_call_llm") + output = await context._proxy.call( # noqa: SLF001 + "system.event.llm.get_state", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + return bool(output.get("should_call_llm", False)) + + async def set_result(self, result: MessageEventResult) -> MessageEventResult: + """Store a request-scoped SDK result in the host bridge.""" + context = self._require_runtime_context("set_result") + await context._proxy.call( # noqa: SLF001 + "system.event.result.set", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + "result": result.to_payload(), + }, + ) + return result + + async def get_result(self) -> MessageEventResult | None: + """Read the current request-scoped SDK result from the host bridge.""" + context = self._require_runtime_context("get_result") + output = await context._proxy.call( # noqa: SLF001 + "system.event.result.get", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + payload = output.get("result") + if not isinstance(payload, dict): + return None + return MessageEventResult.from_payload(payload) + + async def clear_result(self) -> None: + """Clear the current request-scoped SDK result.""" + context = self._require_runtime_context("clear_result") + await context._proxy.call( # noqa: SLF001 + "system.event.result.clear", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + + def stop_event(self) -> None: + """Mark the SDK-local event as stopped.""" + self._stopped = True + + def continue_event(self) -> None: + """Clear the SDK-local stop flag.""" + self._stopped = False + + def is_stopped(self) -> bool: + """Return whether the SDK-local event is stopped.""" + return self._stopped + + async def reply(self, text: str) -> None: + """回复文本消息。 + + Args: + text: 要回复的文本内容 + + Raises: + RuntimeError: 如果未绑定 reply handler + """ + if self._reply_handler is None: + raise RuntimeError("MessageEvent 未绑定 reply handler,无法 reply") + await self._reply_handler(text) + + async def reply_image(self, image_url: str) -> None: + """回复图片消息。 + + Args: + image_url: 图片 URL + + Raises: + RuntimeError: 如果未绑定运行时上下文 + """ + context = self._require_runtime_context("reply_image") + await context.platform.send_image(self._reply_target(), image_url) + + async def reply_chain( + self, + chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]], + ) -> None: + """回复消息链(多类型消息组合)。 + + Args: + chain: 消息链组件列表 + + Raises: + RuntimeError: 如果未绑定运行时上下文 + """ + context = self._require_runtime_context("reply_chain") + await context.platform.send_chain(self._reply_target(), chain) + + async def react(self, emoji: str) -> bool: + """Send a platform reaction when supported.""" + context = self._require_runtime_context("react") + output = await context._proxy.call( # noqa: SLF001 + "system.event.react", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + "emoji": emoji, + }, + ) + return bool(output.get("supported", False)) + + async def send_typing(self) -> bool: + """Emit typing state when the host platform supports it.""" + context = self._require_runtime_context("send_typing") + output = await context._proxy.call( # noqa: SLF001 + "system.event.send_typing", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + return bool(output.get("supported", False)) + + async def send_streaming( + self, + generator, + use_fallback: bool = False, + ) -> bool: + """Replay normalized chunks through the host streaming pathway.""" + context = self._require_runtime_context("send_streaming") + output = await context._proxy.call( # noqa: SLF001 + "system.event.send_streaming", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + "use_fallback": use_fallback, + }, + ) + if not bool(output.get("supported", False)): + return False + + stream_id = str(output.get("stream_id", "")) + if not stream_id: + return False + + try: + async for item in generator: + if isinstance(item, str): + chain = MessageChain([Plain(item, convert=False)]) + else: + chain = self._coerce_chain_or_raise(item) + await context._proxy.call( # noqa: SLF001 + "system.event.send_streaming_chunk", + { + "stream_id": stream_id, + "chain": await chain.to_payload_async(), + }, + ) + finally: + output = await context._proxy.call( # noqa: SLF001 + "system.event.send_streaming_close", + {"stream_id": stream_id}, + ) + return bool(output.get("supported", False)) + + def bind_reply_handler(self, reply_handler: ReplyHandler) -> None: + """绑定自定义回复处理器。 + + Args: + reply_handler: 回复处理函数 + """ + self._reply_handler = reply_handler + + def plain_result(self, text: str) -> PlainTextResult: + """创建纯文本结果。 + + Args: + text: 结果文本 + + Returns: + PlainTextResult 实例 + """ + return PlainTextResult(text=text) + + def make_result(self) -> MessageEventResult: + """Create an empty SDK-local result wrapper.""" + return MessageEventResult(type=EventResultType.EMPTY) + + def image_result(self, url_or_path: str) -> MessageEventResult: + """Create a chain result that contains one image component.""" + if url_or_path.startswith(("http://", "https://")): + image = Image.fromURL(url_or_path) + elif url_or_path.startswith("base64://"): + image = Image.fromBase64(url_or_path.removeprefix("base64://")) + else: + image = Image.fromFileSystem(url_or_path) + return MessageEventResult( + type=EventResultType.CHAIN, + chain=MessageChain([image]), + ) + + def chain_result( + self, + chain: MessageChain | list[BaseMessageComponent], + ) -> MessageEventResult: + """Create a chain result from SDK components.""" + normalized = ( + chain if isinstance(chain, MessageChain) else MessageChain(list(chain)) + ) + return MessageEventResult(type=EventResultType.CHAIN, chain=normalized) + + @staticmethod + def _coerce_chain_or_raise(item: Any) -> MessageChain: + if isinstance(item, MessageEventResult): + return item.chain + if isinstance(item, MessageChain): + return item + if isinstance(item, BaseMessageComponent): + return MessageChain([item]) + if isinstance(item, list) and all( + isinstance(component, BaseMessageComponent) for component in item + ): + return MessageChain(list(item)) + raise TypeError( + "send_streaming only accepts str, MessageChain, MessageEventResult or SDK message components" + ) diff --git a/astrbot-sdk/src/astrbot_sdk/filters.py b/astrbot-sdk/src/astrbot_sdk/filters.py new file mode 100644 index 0000000000..a47e3ec090 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/filters.py @@ -0,0 +1,228 @@ +"""SDK-native filter declarations. + +本模块提供事件过滤器的声明式 API,用于在 handler 执行前进行条件判断。 + +内置过滤器类型: +- PlatformFilter: 按平台名称过滤(如 qq、wechat) +- MessageTypeFilter: 按消息类型过滤(如 group、private) +- CustomFilter: 用户自定义的同步布尔函数 + +组合操作: +- all_of(*filters): 所有过滤器都通过才执行(AND 逻辑) +- any_of(*filters): 任一过滤器通过即可执行(OR 逻辑) +- 支持 & 和 | 运算符进行链式组合 + +例子: +@custom_filter( + all_of( + PlatformFilter(["qq"]), + MessageTypeFilter(["group"]), + CustomFilter(lambda event: "hello" in event.text), + ) +) + +过滤器在本地(SDK worker 进程内)求值,避免不必要的跨进程调用。 +""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal, TypeAlias, TypeVar + +from .decorators import append_filter_meta +from .protocol.descriptors import ( + CompositeFilterSpec, + FilterSpec, + LocalFilterRefSpec, + MessageTypeFilterSpec, + PlatformFilterSpec, +) + +FilterOperator: TypeAlias = Literal["and", "or"] +_HandlerT = TypeVar("_HandlerT", bound=Callable[..., Any]) + + +@dataclass(slots=True) +class LocalFilterBinding: + filter_id: str + callable: Callable[..., bool] + args: dict[str, Any] = field(default_factory=dict) + + def evaluate(self, *, event=None, ctx=None) -> bool: + signature = inspect.signature(self.callable) + kwargs: dict[str, Any] = {} + if "event" in signature.parameters: + kwargs["event"] = event + if "ctx" in signature.parameters: + kwargs["ctx"] = ctx + result = self.callable(**kwargs) + if inspect.isawaitable(result): + raise TypeError("CustomFilter must return a synchronous bool") + if not isinstance(result, bool): + raise TypeError("CustomFilter must return bool") + return result + + +class FilterBinding: + def __and__(self, other: FilterBinding) -> CompositeFilter: + return CompositeFilter("and", [self, other]) + + def __or__(self, other: FilterBinding) -> CompositeFilter: + return CompositeFilter("or", [self, other]) + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + raise NotImplementedError + + +@dataclass(slots=True) +class PlatformFilter(FilterBinding): + platforms: list[str] + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + return PlatformFilterSpec(platforms=list(self.platforms)), [] + + +@dataclass(slots=True) +class MessageTypeFilter(FilterBinding): + message_types: list[str] + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + return MessageTypeFilterSpec(message_types=list(self.message_types)), [] + + +@dataclass(slots=True) +class CustomFilter(FilterBinding): + callable: Callable[..., bool] + filter_id: str | None = None + + def __post_init__(self) -> None: + if self.filter_id is None: + self.filter_id = f"{self.callable.__module__}.{getattr(self.callable, '__qualname__', self.callable.__name__)}" + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + assert self.filter_id is not None + return LocalFilterRefSpec(filter_id=self.filter_id), [ + LocalFilterBinding(filter_id=self.filter_id, callable=self.callable), + ] + + +@dataclass(slots=True) +class CompositeFilter(FilterBinding): + operator: FilterOperator + children: list[FilterBinding] + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + compiled_children: list[FilterSpec] = [] + local_bindings: list[LocalFilterBinding] = [] + for child in self.children: + spec, locals_for_child = child.compile() + compiled_children.append(spec) + local_bindings.extend(locals_for_child) + + if local_bindings: + filter_id = ( + "composite:" + + ":".join(binding.filter_id for binding in local_bindings) + + f":{self.operator}" + ) + + def _evaluate(*, event=None, ctx=None) -> bool: + results = [ + _evaluate_filter_spec_locally( + spec, local_bindings, event=event, ctx=ctx + ) + for spec in compiled_children + ] + if self.operator == "and": + return all(results) + return any(results) + + return ( + LocalFilterRefSpec(filter_id=filter_id), + [LocalFilterBinding(filter_id=filter_id, callable=_evaluate)], + ) + + return CompositeFilterSpec(kind=self.operator, children=compiled_children), [] + + +def _evaluate_filter_spec_locally( + spec: FilterSpec, + local_bindings: list[LocalFilterBinding], + *, + event=None, + ctx=None, +) -> bool: + if isinstance(spec, PlatformFilterSpec): + if event is None: + return True + platform = getattr(event, "platform", "") or "" + return platform in spec.platforms + if isinstance(spec, MessageTypeFilterSpec): + if event is None: + return True + message_type = getattr(event, "message_type", "") or "" + return message_type in spec.message_types + if isinstance(spec, LocalFilterRefSpec): + binding = next( + (item for item in local_bindings if item.filter_id == spec.filter_id), + None, + ) + if binding is None: + # LocalFilterRefSpec 只在当前 worker 持有同名 local binding 时可真正执行。 + # 缺失 binding 往往意味着描述符来自远端/测试快照,此时保持 fail-open, + # 避免因为无法调用进程内函数而把原本可执行的 handler 错误过滤掉。 + return True + return binding.evaluate(event=event, ctx=ctx) + if isinstance(spec, CompositeFilterSpec): + results = [ + _evaluate_filter_spec_locally( + child, + local_bindings, + event=event, + ctx=ctx, + ) + for child in spec.children + ] + if spec.kind == "and": + return all(results) + return any(results) + return True + + +def custom_filter( + binding: FilterBinding, +) -> Callable[[_HandlerT], _HandlerT]: + """Attach a filter declaration to a handler.""" + + def decorator(func: _HandlerT) -> _HandlerT: + spec, local_bindings = binding.compile() + append_filter_meta( + func, + specs=[spec], + local_bindings=local_bindings, + ) + return func + + return decorator + + +def all_of(*bindings: FilterBinding) -> CompositeFilter: + return CompositeFilter("and", list(bindings)) + + +def any_of(*bindings: FilterBinding) -> CompositeFilter: + return CompositeFilter("or", list(bindings)) + + +__all__ = [ + "CustomFilter", + "FilterBinding", + "LocalFilterBinding", + "MessageTypeFilter", + "PlatformFilter", + "all_of", + "any_of", + "custom_filter", +] diff --git a/astrbot-sdk/src/astrbot_sdk/llm/__init__.py b/astrbot-sdk/src/astrbot_sdk/llm/__init__.py new file mode 100644 index 0000000000..02e15b9d2f --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/__init__.py @@ -0,0 +1,105 @@ +"""Canonical SDK LLM/tool/provider entrypoints for P0.5.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .agents import AgentSpec, BaseAgentRunner + from .entities import ( + LLMToolSpec, + ProviderMeta, + ProviderRequest, + ProviderType, + RerankResult, + ToolCallsResult, + ) + from .providers import ( + EmbeddingProvider, + ProviderProxy, + RerankProvider, + STTProvider, + TTSAudioChunk, + TTSProvider, + ) + from .tools import LLMToolManager + +__all__ = [ + "AgentSpec", + "BaseAgentRunner", + "EmbeddingProvider", + "LLMToolManager", + "LLMToolSpec", + "ProviderMeta", + "ProviderProxy", + "ProviderRequest", + "ProviderType", + "RerankProvider", + "RerankResult", + "STTProvider", + "TTSAudioChunk", + "TTSProvider", + "ToolCallsResult", +] + + +def __getattr__(name: str) -> Any: + if name in {"AgentSpec", "BaseAgentRunner"}: + from .agents import AgentSpec, BaseAgentRunner + + return {"AgentSpec": AgentSpec, "BaseAgentRunner": BaseAgentRunner}[name] + if name in { + "LLMToolSpec", + "ProviderMeta", + "ProviderRequest", + "ProviderType", + "RerankResult", + "ToolCallsResult", + }: + from .entities import ( + LLMToolSpec, + ProviderMeta, + ProviderRequest, + ProviderType, + RerankResult, + ToolCallsResult, + ) + + return { + "LLMToolSpec": LLMToolSpec, + "ProviderMeta": ProviderMeta, + "ProviderRequest": ProviderRequest, + "ProviderType": ProviderType, + "RerankResult": RerankResult, + "ToolCallsResult": ToolCallsResult, + }[name] + if name in { + "EmbeddingProvider", + "ProviderProxy", + "RerankProvider", + "STTProvider", + "TTSAudioChunk", + "TTSProvider", + }: + from .providers import ( + EmbeddingProvider, + ProviderProxy, + RerankProvider, + STTProvider, + TTSAudioChunk, + TTSProvider, + ) + + return { + "EmbeddingProvider": EmbeddingProvider, + "ProviderProxy": ProviderProxy, + "RerankProvider": RerankProvider, + "STTProvider": STTProvider, + "TTSAudioChunk": TTSAudioChunk, + "TTSProvider": TTSProvider, + }[name] + if name == "LLMToolManager": + from .tools import LLMToolManager + + return LLMToolManager + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/astrbot-sdk/src/astrbot_sdk/llm/agents.py b/astrbot-sdk/src/astrbot_sdk/llm/agents.py new file mode 100644 index 0000000000..c2d6b21e62 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/agents.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict, Field + +from .entities import ProviderRequest + +if TYPE_CHECKING: + from ..context import Context + + +class AgentSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + description: str = "" + tool_names: list[str] = Field(default_factory=list) + runner_class: str + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> AgentSpec: + return cls.model_validate(payload) + + +class BaseAgentRunner(ABC): + """agent registration surface. + + only supports agent registration metadata. Actual execution remains + owned by the core tool loop and is not directly callable from SDK plugins. + """ + + @abstractmethod + async def run(self, ctx: Context, request: ProviderRequest) -> Any: + raise NotImplementedError diff --git a/astrbot-sdk/src/astrbot_sdk/llm/entities.py b/astrbot-sdk/src/astrbot_sdk/llm/entities.py new file mode 100644 index 0000000000..ba252db24b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/entities.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import enum +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class _EntityModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class ProviderType(str, enum.Enum): + CHAT_COMPLETION = "chat_completion" + SPEECH_TO_TEXT = "speech_to_text" + TEXT_TO_SPEECH = "text_to_speech" + EMBEDDING = "embedding" + RERANK = "rerank" + + +class ProviderMeta(_EntityModel): + id: str + model: str | None = None + type: str + provider_type: ProviderType = ProviderType.CHAT_COMPLETION + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> ProviderMeta | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ToolCallsResult(_EntityModel): + tool_call_id: str | None = None + tool_name: str + content: str + success: bool = True + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> ToolCallsResult: + return cls.model_validate(payload) + + +class RerankResult(_EntityModel): + index: int + score: float + document: str + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> RerankResult: + return cls.model_validate(payload) + + +class LLMToolSpec(_EntityModel): + name: str + description: str = "" + parameters_schema: dict[str, Any] = Field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + handler_ref: str | None = Field( + default=None, + description="Worker-side handler reference used to resolve the tool callable.", + ) + handler_capability: str | None = Field( + default=None, + description="Optional capability name override for executing this tool handler.", + ) + active: bool = True + + @classmethod + def create( + cls, + *, + name: str, + description: str = "", + parameters_schema: dict[str, Any] | None = None, + handler_ref: str | None = None, + handler_capability: str | None = None, + active: bool = True, + ) -> LLMToolSpec: + # Keep an explicit factory signature so static analyzers do not depend on + # Pydantic's generated __init__ when SDK call sites construct tool specs. + payload: dict[str, Any] = { + "name": name, + "description": description, + "parameters_schema": parameters_schema + if parameters_schema is not None + else {"type": "object", "properties": {}}, + "active": active, + } + if handler_ref is not None: + payload["handler_ref"] = handler_ref + if handler_capability is not None: + payload["handler_capability"] = handler_capability + return cls.from_payload(payload) + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> LLMToolSpec: + return cls.model_validate(payload) + + +class ProviderRequest(_EntityModel): + prompt: str | None = None + system_prompt: str | None = None + session_id: str | None = None + contexts: list[dict[str, Any]] = Field(default_factory=list) + image_urls: list[str] = Field(default_factory=list) + tool_names: list[str] | None = None + tool_calls_result: list[ToolCallsResult] = Field(default_factory=list) + provider_id: str | None = None + model: str | None = None + temperature: float | None = None + max_steps: int | None = None + tool_call_timeout: int | None = None + + def to_payload(self) -> dict[str, Any]: + payload = super().to_payload() + payload["tool_calls_result"] = [ + item.to_payload() for item in self.tool_calls_result + ] + return payload + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> ProviderRequest: + normalized = dict(payload) + raw_results = normalized.get("tool_calls_result") + if isinstance(raw_results, list): + normalized["tool_calls_result"] = [ + ToolCallsResult.from_payload(item) + for item in raw_results + if isinstance(item, dict) + ] + return cls.model_validate(normalized) diff --git a/astrbot-sdk/src/astrbot_sdk/llm/providers.py b/astrbot-sdk/src/astrbot_sdk/llm/providers.py new file mode 100644 index 0000000000..591e1d57d5 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/providers.py @@ -0,0 +1,199 @@ +"""Provider-facing SDK entities and typed proxy helpers.""" + +from __future__ import annotations + +import base64 +from collections.abc import AsyncIterable, AsyncIterator +from dataclasses import dataclass + +from ..clients._proxy import CapabilityProxy +from .entities import ProviderMeta, ProviderType, RerankResult + + +@dataclass(slots=True) +class TTSAudioChunk: + audio: bytes + text: str | None = None + + +class _BaseProviderProxy: + def __init__(self, proxy: CapabilityProxy, meta: ProviderMeta) -> None: + self._proxy = proxy + self._meta = meta + + @property + def id(self) -> str: + return self._meta.id + + @property + def model(self) -> str | None: + return self._meta.model + + @property + def type(self) -> str: + return self._meta.type + + @property + def provider_type(self) -> ProviderType: + return self._meta.provider_type + + def meta(self) -> ProviderMeta: + return self._meta + + +class STTProvider(_BaseProviderProxy): + async def get_text(self, audio_url: str) -> str: + output = await self._proxy.call( + "provider.stt.get_text", + {"provider_id": self.id, "audio_url": str(audio_url)}, + ) + return str(output.get("text", "")) + + +class TTSProvider(_BaseProviderProxy): + def __init__( + self, + proxy: CapabilityProxy, + meta: ProviderMeta, + *, + supports_stream: bool = False, + ) -> None: + super().__init__(proxy, meta) + self._supports_stream = supports_stream + + async def get_audio(self, text: str) -> str: + output = await self._proxy.call( + "provider.tts.get_audio", + {"provider_id": self.id, "text": str(text)}, + ) + return str(output.get("audio_path", "")) + + def support_stream(self) -> bool: + return self._supports_stream + + async def get_audio_stream( + self, + text: str | AsyncIterable[str], + ) -> AsyncIterator[TTSAudioChunk]: + payload = await self._build_stream_payload(text) + async for chunk in self._proxy.stream("provider.tts.get_audio_stream", payload): + audio_base64 = str(chunk.get("audio_base64", "")) + yield TTSAudioChunk( + audio=base64.b64decode(audio_base64) if audio_base64 else b"", + text=( + str(chunk.get("text")) if chunk.get("text") is not None else None + ), + ) + + async def _build_stream_payload( + self, + text: str | AsyncIterable[str], + ) -> dict[str, object]: + payload: dict[str, object] = {"provider_id": self.id} + if isinstance(text, str): + payload["text"] = text + return payload + payload["text_chunks"] = [str(item) async for item in text] + return payload + + +class EmbeddingProvider(_BaseProviderProxy): + async def get_embedding(self, text: str) -> list[float]: + output = await self._proxy.call( + "provider.embedding.get_embedding", + {"provider_id": self.id, "text": str(text)}, + ) + embedding = output.get("embedding") + if not isinstance(embedding, list): + return [] + return [float(item) for item in embedding] + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + output = await self._proxy.call( + "provider.embedding.get_embeddings", + { + "provider_id": self.id, + "texts": [str(item) for item in texts], + }, + ) + embeddings = output.get("embeddings") + if not isinstance(embeddings, list): + return [] + return [ + [float(value) for value in item] + for item in embeddings + if isinstance(item, list) + ] + + async def get_dim(self) -> int: + output = await self._proxy.call( + "provider.embedding.get_dim", + {"provider_id": self.id}, + ) + return int(output.get("dim", 0)) + + +class RerankProvider(_BaseProviderProxy): + async def rerank( + self, + query: str, + documents: list[str], + top_n: int | None = None, + ) -> list[RerankResult]: + output = await self._proxy.call( + "provider.rerank.rerank", + { + "provider_id": self.id, + "query": str(query), + "documents": [str(item) for item in documents], + "top_n": top_n, + }, + ) + results = output.get("results") + if not isinstance(results, list): + return [] + return [ + RerankResult.from_payload(item) + for item in results + if isinstance(item, dict) + ] + + +ProviderProxy = STTProvider | TTSProvider | EmbeddingProvider | RerankProvider + + +def provider_proxy_from_meta( + proxy: CapabilityProxy, + meta: ProviderMeta | None, + *, + tts_supports_stream: bool | None = None, +) -> ProviderProxy | None: + if meta is None: + return None + if meta.provider_type == ProviderType.SPEECH_TO_TEXT: + return STTProvider(proxy, meta) + if meta.provider_type == ProviderType.TEXT_TO_SPEECH: + return TTSProvider( + proxy, + meta, + supports_stream=bool(tts_supports_stream), + ) + if meta.provider_type == ProviderType.EMBEDDING: + return EmbeddingProvider(proxy, meta) + if meta.provider_type == ProviderType.RERANK: + return RerankProvider(proxy, meta) + return None + + +__all__ = [ + "EmbeddingProvider", + "ProviderMeta", + "ProviderProxy", + "ProviderType", + "RerankProvider", + "RerankResult", + "STTProvider", + "TTSAudioChunk", + "TTSProvider", + "provider_proxy_from_meta", +] diff --git a/astrbot-sdk/src/astrbot_sdk/llm/tools.py b/astrbot-sdk/src/astrbot_sdk/llm/tools.py new file mode 100644 index 0000000000..d1a67b30c7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/tools.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .entities import LLMToolSpec + +if TYPE_CHECKING: + from ..clients._proxy import CapabilityProxy + + +class LLMToolManager: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def list_registered(self) -> list[LLMToolSpec]: + output = await self._proxy.call("llm_tool.manager.get", {}) + items = output.get("registered") + if not isinstance(items, list): + return [] + return [ + LLMToolSpec.from_payload(item) for item in items if isinstance(item, dict) + ] + + async def list_active(self) -> list[LLMToolSpec]: + output = await self._proxy.call("llm_tool.manager.get", {}) + items = output.get("active") + if not isinstance(items, list): + return [] + return [ + LLMToolSpec.from_payload(item) for item in items if isinstance(item, dict) + ] + + async def activate(self, name: str) -> bool: + output = await self._proxy.call("llm_tool.manager.activate", {"name": name}) + return bool(output.get("activated", False)) + + async def deactivate(self, name: str) -> bool: + output = await self._proxy.call("llm_tool.manager.deactivate", {"name": name}) + return bool(output.get("deactivated", False)) + + async def add(self, *tools: LLMToolSpec) -> list[str]: + output = await self._proxy.call( + "llm_tool.manager.add", + {"tools": [tool.to_payload() for tool in tools]}, + ) + result = output.get("names") + if not isinstance(result, list): + return [] + return [str(item) for item in result] + + async def remove(self, name: str) -> bool: + output = await self._proxy.call("llm_tool.manager.remove", {"name": name}) + return bool(output.get("removed", False)) + + async def get(self, name: str) -> LLMToolSpec | None: + for tool in await self.list_registered(): + if tool.name == name: + return tool + return None diff --git a/astrbot-sdk/src/astrbot_sdk/message/__init__.py b/astrbot-sdk/src/astrbot_sdk/message/__init__.py new file mode 100644 index 0000000000..4125a0db12 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/__init__.py @@ -0,0 +1,103 @@ +"""Message component, result, and session subpackage.""" + +from .components import ( + At as At, +) +from .components import ( + AtAll as AtAll, +) +from .components import ( + BaseMessageComponent as BaseMessageComponent, +) +from .components import ( + File as File, +) +from .components import ( + Forward as Forward, +) +from .components import ( + Image as Image, +) +from .components import ( + MediaHelper as MediaHelper, +) +from .components import ( + Plain as Plain, +) +from .components import ( + Poke as Poke, +) +from .components import ( + Record as Record, +) +from .components import ( + Reply as Reply, +) +from .components import ( + UnknownComponent as UnknownComponent, +) +from .components import ( + Video as Video, +) +from .components import ( + build_media_component_from_url as build_media_component_from_url, +) +from .components import ( + component_to_payload as component_to_payload, +) +from .components import ( + component_to_payload_sync as component_to_payload_sync, +) +from .components import ( + is_message_component as is_message_component, +) +from .components import ( + payload_to_component as payload_to_component, +) +from .components import ( + payloads_to_components as payloads_to_components, +) +from .result import ( + EventResultType as EventResultType, +) +from .result import ( + MessageBuilder as MessageBuilder, +) +from .result import ( + MessageChain as MessageChain, +) +from .result import ( + MessageEventResult as MessageEventResult, +) +from .result import ( + coerce_message_chain as coerce_message_chain, +) +from .session import MessageSession as MessageSession + +__all__ = [ + "At", + "AtAll", + "BaseMessageComponent", + "EventResultType", + "File", + "Forward", + "Image", + "MediaHelper", + "MessageBuilder", + "MessageChain", + "MessageEventResult", + "MessageSession", + "Plain", + "Poke", + "Record", + "Reply", + "UnknownComponent", + "Video", + "build_media_component_from_url", + "coerce_message_chain", + "component_to_payload", + "component_to_payload_sync", + "is_message_component", + "payload_to_component", + "payloads_to_components", +] diff --git a/astrbot-sdk/src/astrbot_sdk/message/components.py b/astrbot-sdk/src/astrbot_sdk/message/components.py new file mode 100644 index 0000000000..5c5423499d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/components.py @@ -0,0 +1,625 @@ +"""SDK message component compatibility layer. + +该模块有意避免在导入时导入遗留核心组件模块。 +SDK工作线程应该保持轻量级并且不能依赖于主机核心引导程序 +仅用于构造消息对象的路径。 +""" + +from __future__ import annotations + +import asyncio +import base64 +import inspect +import os +import tempfile +import uuid +from collections.abc import Mapping +from pathlib import Path +from typing import Any +from urllib.parse import urlparse +from urllib.request import urlretrieve + +from .._internal.star_runtime import current_runtime_context +from ..errors import AstrBotError + +_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} +_RECORD_SUFFIXES = {".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a"} +_VIDEO_SUFFIXES = {".mp4", ".webm", ".mov", ".mkv", ".avi"} + + +def _temp_path(prefix: str, suffix: str = "") -> Path: + return Path(tempfile.gettempdir()) / f"{prefix}_{uuid.uuid4().hex}{suffix}" + + +def _guess_suffix_from_url(url: str, fallback: str = "") -> str: + suffix = Path(urlparse(url).path).suffix + return suffix or fallback + + +def _download_to_temp(url: str, prefix: str, fallback_suffix: str = "") -> str: + target = _temp_path(prefix, _guess_suffix_from_url(url, fallback_suffix)) + urlretrieve(url, target) + return str(target.resolve()) + + +async def _download_to_temp_async( + url: str, + prefix: str, + fallback_suffix: str = "", +) -> str: + return await asyncio.to_thread( + _download_to_temp, + url, + prefix, + fallback_suffix, + ) + + +def _stringify_mapping(mapping: Mapping[Any, Any]) -> dict[str, Any]: + return {str(key): value for key, value in mapping.items()} + + +async def _register_file_to_service(path: str) -> str: + context = current_runtime_context() + if context is None: + raise RuntimeError("message component file service requires runtime context") + return await context._register_file_url(path) + + +def _reply_chain_payloads_sync(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [component_to_payload_sync(item) for item in value] + + +async def _reply_chain_payloads(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [await component_to_payload(item) for item in value] + + +def _coerce_reply_chain(value: Any) -> list[BaseMessageComponent]: + if not isinstance(value, list): + return [] + if value and all(isinstance(item, BaseMessageComponent) for item in value): + return list(value) + return payloads_to_components(value) + + +def _component_type_name(component: Any) -> str: + raw_type = getattr(component, "type", "unknown") + normalized = getattr(raw_type, "value", raw_type) + return str(normalized or "unknown").lower() + + +def _plain_payload(text: Any) -> dict[str, Any]: + return {"type": "text", "data": {"text": str(text)}} + + +def _reply_payload_data( + component: Any, + *, + chain_payloads: list[dict[str, Any]], +) -> dict[str, Any]: + return { + "id": getattr(component, "id", ""), + "chain": chain_payloads, + "sender_id": getattr(component, "sender_id", 0), + "sender_nickname": getattr(component, "sender_nickname", ""), + "time": getattr(component, "time", 0), + "message_str": getattr(component, "message_str", ""), + "text": getattr(component, "text", ""), + "qq": getattr(component, "qq", 0), + "seq": getattr(component, "seq", 0), + } + + +def _resolve_media_kind(url: str, kind: str = "auto") -> str: + normalized_kind = str(kind).strip().lower() or "auto" + if normalized_kind != "auto": + return normalized_kind + suffix = Path(urlparse(url).path).suffix.lower() + if suffix in _IMAGE_SUFFIXES: + return "image" + if suffix in _RECORD_SUFFIXES: + return "record" + if suffix in _VIDEO_SUFFIXES: + return "video" + return "file" + + +def build_media_component_from_url( + url: str, + *, + kind: str = "auto", +) -> BaseMessageComponent: + url_text = str(url).strip() + if not url_text: + raise AstrBotError.invalid_input( + "MediaHelper.from_url requires a non-empty url" + ) + resolved_kind = _resolve_media_kind(url_text, kind=kind) + if resolved_kind == "image": + return Image.fromURL(url_text) + if resolved_kind in {"record", "audio"}: + return Record.fromURL(url_text) + if resolved_kind == "video": + return Video.fromURL(url_text) + if resolved_kind == "file": + return File(name=_filename_from_url(url_text), url=url_text) + raise AstrBotError.invalid_input( + f"Unsupported media kind: {kind}", + details={"kind": kind, "url": url_text}, + ) + + +def _filename_from_url(url: str) -> str: + name = Path(urlparse(url).path).name + return name or "download" + + +class BaseMessageComponent: + type: str = "unknown" + + def toDict(self) -> dict[str, Any]: + data: dict[str, Any] = {} + for key, value in self.__dict__.items(): + if key == "type" or value is None: + continue + data["type" if key == "_type" else key] = value + return {"type": str(self.type).lower(), "data": data} + + async def to_dict(self) -> dict[str, Any]: + return self.toDict() + + +class Plain(BaseMessageComponent): + type = "plain" + + def __init__(self, text: str, convert: bool = True, **_: Any) -> None: + self.text = text + self.convert = convert + + def toDict(self) -> dict[str, Any]: + return _plain_payload(self.text) + + async def to_dict(self) -> dict[str, Any]: + return _plain_payload(self.text) + + +class At(BaseMessageComponent): + type = "at" + + def __init__(self, qq: int | str, name: str | None = "", **_: Any) -> None: + self.qq = qq + self.name = name or "" + + def toDict(self) -> dict[str, Any]: + return {"type": "at", "data": {"qq": str(self.qq)}} + + +class AtAll(At): + def __init__(self, **_: Any) -> None: + super().__init__(qq="all") + + +class Reply(BaseMessageComponent): + type = "reply" + + def __init__(self, **kwargs: Any) -> None: + self.id = kwargs.get("id", "") + self.chain = _coerce_reply_chain(kwargs.get("chain", [])) + self.sender_id = kwargs.get("sender_id", 0) + self.sender_nickname = kwargs.get("sender_nickname", "") + self.time = kwargs.get("time", 0) + self.message_str = kwargs.get("message_str", "") + self.text = kwargs.get("text", "") + self.qq = kwargs.get("qq", 0) + self.seq = kwargs.get("seq", 0) + + def toDict(self) -> dict[str, Any]: + return { + "type": "reply", + "data": _reply_payload_data( + self, + chain_payloads=_reply_chain_payloads_sync(self.chain), + ), + } + + async def to_dict(self) -> dict[str, Any]: + return { + "type": "reply", + "data": _reply_payload_data( + self, + chain_payloads=await _reply_chain_payloads(self.chain), + ), + } + + +class Image(BaseMessageComponent): + type = "image" + + def __init__(self, file: str | None, **kwargs: Any) -> None: + self.file = file or "" + self._type = kwargs.get("_type", "") + self.subType = kwargs.get("subType", 0) + self.url = kwargs.get("url", "") + self.cache = kwargs.get("cache", True) + self.id = kwargs.get("id", 40000) + self.c = kwargs.get("c", 2) + self.path = kwargs.get("path", "") + self.file_unique = kwargs.get("file_unique", "") + + @staticmethod + def fromURL(url: str, **kwargs: Any) -> Image: + return Image(url, **kwargs) + + @staticmethod + def fromFileSystem(path: str, **kwargs: Any) -> Image: + return Image(f"file:///{os.path.abspath(path)}", path=path, **kwargs) + + @staticmethod + def fromBase64(base64_data: str, **kwargs: Any) -> Image: + return Image(f"base64://{base64_data}", **kwargs) + + async def convert_to_file_path(self) -> str: + url = self.url or self.file + if not url: + raise ValueError("No valid file or URL provided") + if url.startswith("file:///"): + return os.path.abspath(url[8:]) + if url.startswith(("http://", "https://")): + return await _download_to_temp_async(url, "imgseg", ".jpg") + if url.startswith("base64://"): + file_path = _temp_path("imgseg", ".jpg") + file_path.write_bytes(base64.b64decode(url.removeprefix("base64://"))) + return str(file_path.resolve()) + if os.path.exists(url): + return os.path.abspath(url) + raise ValueError(f"not a valid file: {url}") + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.convert_to_file_path()) + + +class Record(BaseMessageComponent): + type = "record" + + def __init__(self, file: str | None, **kwargs: Any) -> None: + self.file = file or "" + self.magic = kwargs.get("magic", False) + self.url = kwargs.get("url", "") + self.cache = kwargs.get("cache", True) + self.proxy = kwargs.get("proxy", True) + self.timeout = kwargs.get("timeout", 0) + self.text = kwargs.get("text") + self.path = kwargs.get("path") + + @staticmethod + def fromFileSystem(path: str, **kwargs: Any) -> Record: + return Record(f"file:///{os.path.abspath(path)}", path=path, **kwargs) + + @staticmethod + def fromURL(url: str, **kwargs: Any) -> Record: + return Record(url, **kwargs) + + async def convert_to_file_path(self) -> str: + if self.file.startswith("file:///"): + return os.path.abspath(self.file[8:]) + if self.file.startswith(("http://", "https://")): + return await _download_to_temp_async(self.file, "recordseg", ".dat") + if self.file.startswith("base64://"): + file_path = _temp_path("recordseg", ".dat") + file_path.write_bytes(base64.b64decode(self.file.removeprefix("base64://"))) + return str(file_path.resolve()) + if os.path.exists(self.file): + return os.path.abspath(self.file) + raise ValueError(f"not a valid file: {self.file}") + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.convert_to_file_path()) + + +class Video(BaseMessageComponent): + type = "video" + + def __init__(self, file: str, **kwargs: Any) -> None: + self.file = file + self.cover = kwargs.get("cover", "") + self.c = kwargs.get("c", 2) + self.path = kwargs.get("path", "") + + @staticmethod + def fromFileSystem(path: str, **kwargs: Any) -> Video: + return Video(f"file:///{os.path.abspath(path)}", path=path, **kwargs) + + @staticmethod + def fromURL(url: str, **kwargs: Any) -> Video: + return Video(url, **kwargs) + + async def convert_to_file_path(self) -> str: + if self.file.startswith("file:///"): + return os.path.abspath(self.file[8:]) + if self.file.startswith(("http://", "https://")): + return await _download_to_temp_async(self.file, "videoseg") + if os.path.exists(self.file): + return os.path.abspath(self.file) + raise ValueError(f"not a valid file: {self.file}") + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.convert_to_file_path()) + + +class File(BaseMessageComponent): + type = "file" + + def __init__(self, name: str, file: str = "", url: str = "") -> None: + self.name = name + self.file_ = file + self.url = url + + @property + def file(self) -> str: + return self.file_ + + @file.setter + def file(self, value: str) -> None: + if value.startswith(("http://", "https://")): + self.url = value + else: + self.file_ = value + + async def get_file(self, allow_return_url: bool = False) -> str: + if allow_return_url and self.url: + return self.url + if self.file_: + path = self.file_ + if path.startswith("file://"): + path = path[7:] + if ( + os.name == "nt" + and len(path) > 2 + and path[0] == "/" + and path[2] == ":" + ): + path = path[1:] + if os.path.exists(path): + return os.path.abspath(path) + if self.url: + suffix = Path(urlparse(self.url).path).suffix + target = await _download_to_temp_async(self.url, "fileseg", suffix) + self.file_ = target + return target + return "" + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.get_file()) + + def toDict(self) -> dict[str, Any]: + payload_file = self.url or self.file_ + return { + "type": "file", + "data": { + "name": self.name, + "file": payload_file, + }, + } + + async def to_dict(self) -> dict[str, Any]: + payload_file = await self.get_file(allow_return_url=True) + return { + "type": "file", + "data": { + "name": self.name, + "file": payload_file, + }, + } + + +class Poke(BaseMessageComponent): + type = "poke" + + def __init__(self, poke_type: str | int | None = None, **kwargs: Any) -> None: + legacy_type = kwargs.pop("type", None) + if poke_type is None: + poke_type = legacy_type + if poke_type in (None, "", "poke", "Poke"): + poke_type = "126" + self._type = str(poke_type) + self.id = kwargs.get("id") + self.qq = kwargs.get("qq", 0) + + def target_id(self) -> str | None: + for value in (self.id, self.qq): + if value is None: + continue + text = str(value).strip() + if text and text != "0": + return text + return None + + def toDict(self) -> dict[str, Any]: + data = {"type": str(self._type or "126")} + target_id = self.target_id() + if target_id: + data["id"] = target_id + return {"type": "poke", "data": data} + + +class Forward(BaseMessageComponent): + type = "forward" + + def __init__(self, id: str, **_: Any) -> None: + self.id = id + + +class UnknownComponent(BaseMessageComponent): + type = "unknown" + + def __init__( + self, + *, + raw_type: str = "unknown", + raw_data: dict[str, Any] | None = None, + ) -> None: + self.raw_type = raw_type + self.raw_data = raw_data or {} + + def toDict(self) -> dict[str, Any]: + return { + "type": self.raw_type or "unknown", + "data": dict(self.raw_data), + } + + +def is_message_component(value: Any) -> bool: + return isinstance(value, BaseMessageComponent) + + +def payload_to_component(payload: Any) -> BaseMessageComponent: + if not isinstance(payload, dict): + return UnknownComponent(raw_data={"value": payload}) + + raw_type = str(payload.get("type", "unknown") or "unknown").lower() + data = payload.get("data") + if not isinstance(data, dict): + data = {} + + if raw_type in {"text", "plain"}: + return Plain(str(data.get("text", "")), convert=False) + if raw_type == "image": + return Image(str(data.get("file") or data.get("url") or "")) + if raw_type == "at": + qq_value = data.get("qq") + if str(qq_value).lower() == "all": + return AtAll() + qq = "" if qq_value is None else str(qq_value) + return At(qq=qq, name=str(data.get("name", ""))) + if raw_type == "reply": + return Reply(**data) + if raw_type == "record": + return Record(str(data.get("file") or data.get("url") or ""), **data) + if raw_type == "video": + return Video(str(data.get("file") or ""), **data) + if raw_type == "file": + file_value = str(data.get("file") or data.get("file_") or "") + if not file_value: + file_value = str(data.get("url") or "") + return File( + str(data.get("name", "")), + file="" if file_value.startswith(("http://", "https://")) else file_value, + url=file_value if file_value.startswith(("http://", "https://")) else "", + ) + if raw_type == "poke": + return Poke( + poke_type=data.get("type"), + id=data.get("id"), + qq=data.get("qq"), + ) + if raw_type == "forward": + return Forward(id=str(data.get("id", ""))) + + return UnknownComponent(raw_type=raw_type, raw_data=_stringify_mapping(data)) + + +def payloads_to_components(payloads: list[Any]) -> list[BaseMessageComponent]: + return [payload_to_component(item) for item in payloads] + + +def component_to_payload_sync(component: Any) -> dict[str, Any]: + if isinstance(component, UnknownComponent): + return component.toDict() + if isinstance(component, Plain): + return _plain_payload(component.text) + if _component_type_name(component) == "reply": + return { + "type": "reply", + "data": _reply_payload_data( + component, + chain_payloads=_reply_chain_payloads_sync( + getattr(component, "chain", []) + ), + ), + } + to_dict = getattr(component, "toDict", None) + if callable(to_dict): + result = to_dict() + if isinstance(result, Mapping): + return _stringify_mapping(result) + return {"type": "unknown", "data": {"value": str(component)}} + + +async def component_to_payload(component: Any) -> dict[str, Any]: + if isinstance(component, (UnknownComponent, Plain)): + return component_to_payload_sync(component) + async_method = getattr(component, "to_dict", None) + if callable(async_method): + payload = async_method() + if inspect.isawaitable(payload): + result = await payload + if isinstance(result, dict): + return result + return component_to_payload_sync(component) + + +class MediaHelper: + @staticmethod + async def from_url( + url: str, + *, + kind: str = "auto", + ) -> BaseMessageComponent: + return build_media_component_from_url(url, kind=kind) + + @staticmethod + async def download(url: str, save_dir: Path) -> Path: + url_text = str(url).strip() + if not url_text: + raise AstrBotError.invalid_input( + "MediaHelper.download requires a non-empty url" + ) + parsed = urlparse(url_text) + if parsed.scheme not in {"http", "https"}: + raise AstrBotError.invalid_input( + "MediaHelper.download only supports http/https urls", + details={"url": url_text}, + ) + target_dir = Path(save_dir) + try: + target_dir.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise AstrBotError.internal_error( + f"Failed to prepare download directory: {target_dir}", + details={"save_dir": str(target_dir)}, + ) from exc + target_path = target_dir / _filename_from_url(url_text) + try: + await asyncio.to_thread(urlretrieve, url_text, target_path) + except Exception as exc: + raise AstrBotError.network_error( + f"Failed to download media from '{url_text}'", + details={"url": url_text}, + ) from exc + return target_path.resolve() + + +__all__ = [ + "At", + "AtAll", + "BaseMessageComponent", + "File", + "Forward", + "Image", + "MediaHelper", + "Plain", + "Poke", + "Record", + "Reply", + "UnknownComponent", + "Video", + "component_to_payload", + "component_to_payload_sync", + "is_message_component", + "payload_to_component", + "payloads_to_components", +] diff --git a/astrbot-sdk/src/astrbot_sdk/message/result.py b/astrbot-sdk/src/astrbot_sdk/message/result.py new file mode 100644 index 0000000000..a38c207099 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/result.py @@ -0,0 +1,174 @@ +"""SDK-local rich message result objects. + +本模块定义消息事件的结果对象,用于构建和返回富文本/多媒体消息。 + +核心类: +- MessageChain: 消息组件列表,支持同步/异步序列化为协议 payload +- MessageEventResult: 事件处理结果,包含类型标记和消息链 +- EventResultType: 结果类型枚举(EMPTY / CHAIN) + +辅助函数: +- coerce_message_chain: 将多种输入格式统一转换为 MessageChain, + 支持 MessageEventResult、MessageChain、单个组件或组件列表 +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from .components import ( + At, + AtAll, + BaseMessageComponent, + File, + Plain, + Reply, + build_media_component_from_url, + component_to_payload, + component_to_payload_sync, + is_message_component, + payloads_to_components, +) + + +class EventResultType(str, Enum): + EMPTY = "empty" + CHAIN = "chain" + + +@dataclass(slots=True) +class MessageChain: + components: list[BaseMessageComponent] = field(default_factory=list) + + def append(self, component: BaseMessageComponent) -> MessageChain: + self.components.append(component) + return self + + def extend(self, components: list[BaseMessageComponent]) -> MessageChain: + self.components.extend(components) + return self + + def __iter__(self) -> Iterator[BaseMessageComponent]: + return iter(self.components) + + def __len__(self) -> int: + return len(self.components) + + def to_payload(self) -> list[dict[str, Any]]: + return [component_to_payload_sync(component) for component in self.components] + + async def to_payload_async(self) -> list[dict[str, Any]]: + return [await component_to_payload(component) for component in self.components] + + def get_plain_text(self, with_other_comps_mark: bool = False) -> str: + texts: list[str] = [] + for component in self.components: + if isinstance(component, Plain): + texts.append(component.text) + elif with_other_comps_mark: + texts.append(f"[{component.__class__.__name__}]") + return " ".join(texts) + + def plain_text(self, with_other_comps_mark: bool = False) -> str: + return self.get_plain_text(with_other_comps_mark=with_other_comps_mark) + + +@dataclass(slots=True) +class MessageEventResult: + type: EventResultType = EventResultType.EMPTY + chain: MessageChain = field(default_factory=MessageChain) + + def to_payload(self) -> dict[str, Any]: + return { + "type": self.type.value, + "chain": self.chain.to_payload(), + } + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> MessageEventResult: + result_type_raw = str(payload.get("type", EventResultType.EMPTY.value)) + try: + result_type = EventResultType(result_type_raw) + except ValueError: + result_type = EventResultType.EMPTY + chain_payload = payload.get("chain") + components = ( + payloads_to_components(chain_payload) + if isinstance(chain_payload, list) + else [] + ) + return cls(type=result_type, chain=MessageChain(components)) + + +@dataclass(slots=True) +class MessageBuilder: + components: list[BaseMessageComponent] = field(default_factory=list) + + def text(self, content: str) -> MessageBuilder: + self.components.append(Plain(content, convert=False)) + return self + + def at(self, user_id: str) -> MessageBuilder: + self.components.append(At(user_id)) + return self + + def at_all(self) -> MessageBuilder: + self.components.append(AtAll()) + return self + + def image(self, url: str) -> MessageBuilder: + self.components.append(build_media_component_from_url(url, kind="image")) + return self + + def record(self, url: str) -> MessageBuilder: + self.components.append(build_media_component_from_url(url, kind="record")) + return self + + def video(self, url: str) -> MessageBuilder: + self.components.append(build_media_component_from_url(url, kind="video")) + return self + + def file(self, name: str, *, file: str = "", url: str = "") -> MessageBuilder: + self.components.append(File(name=name, file=file, url=url)) + return self + + def reply(self, **kwargs: Any) -> MessageBuilder: + self.components.append(Reply(**kwargs)) + return self + + def append(self, component: BaseMessageComponent) -> MessageBuilder: + self.components.append(component) + return self + + def extend(self, components: list[BaseMessageComponent]) -> MessageBuilder: + self.components.extend(components) + return self + + def build(self) -> MessageChain: + return MessageChain(list(self.components)) + + +def coerce_message_chain(value: Any) -> MessageChain | None: + if isinstance(value, MessageEventResult): + return value.chain + if isinstance(value, MessageChain): + return value + if is_message_component(value): + return MessageChain([value]) + if isinstance(value, (list, tuple)) and all( + is_message_component(item) for item in value + ): + return MessageChain(list(value)) + return None + + +__all__ = [ + "EventResultType", + "MessageChain", + "MessageBuilder", + "MessageEventResult", + "coerce_message_chain", +] diff --git a/astrbot-sdk/src/astrbot_sdk/message/session.py b/astrbot-sdk/src/astrbot_sdk/message/session.py new file mode 100644 index 0000000000..951e34d25c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/session.py @@ -0,0 +1,55 @@ +"""SDK-visible message session identifier. + +本模块定义 MessageSession 类,用于统一表示消息会话标识符。 +会话标识符格式为:platform_id:message_type:session_id + +例如: +- qq:group:123456 表示 QQ 群 123456 +- wechat:private:user789 表示微信私聊用户 user789 + +该格式与 AstrBot 核心的 unified_msg_origin 保持兼容, +确保 SDK 与核心之间的会话信息能够正确传递。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .._message_types import normalize_message_type + + +@dataclass(slots=True) +class MessageSession: + """SDK-visible message session identifier. + + The string form stays compatible with AstrBot's unified message origin: + ``platform_id:message_type:session_id``. + """ + + platform_id: str + message_type: str + session_id: str + + def __post_init__(self) -> None: + self.platform_id = str(self.platform_id) + self.message_type = normalize_message_type(self.message_type) + self.session_id = str(self.session_id) + + def __str__(self) -> str: + return f"{self.platform_id}:{self.message_type}:{self.session_id}" + + @classmethod + def from_str(cls, session: str) -> MessageSession: + raw_session = str(session) + parts = raw_session.split(":", 2) + if len(parts) != 3 or any(part == "" for part in parts): + raise ValueError( + "invalid message session format, expected " + "'platform_id:message_type:session_id'" + ) + platform_id, message_type, session_id = parts + return cls( + platform_id=platform_id, + message_type=message_type, + session_id=session_id, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/message_components.py b/astrbot-sdk/src/astrbot_sdk/message_components.py new file mode 100644 index 0000000000..372bd54a67 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message_components.py @@ -0,0 +1,13 @@ +"""Backward-compatible alias for ``astrbot_sdk.message.components``. + +This module intentionally aliases the implementation module instead of re-exporting +names one by one so private helpers keep working with existing monkeypatch sites. +""" + +from __future__ import annotations + +import sys + +from .message import components as _components_module + +sys.modules[__name__] = _components_module diff --git a/astrbot-sdk/src/astrbot_sdk/message_result.py b/astrbot-sdk/src/astrbot_sdk/message_result.py new file mode 100644 index 0000000000..0b575aad5c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message_result.py @@ -0,0 +1,13 @@ +"""Backward-compatible alias for ``astrbot_sdk.message.result``. + +Use a module alias so callers patching helper functions on the legacy module path +still affect ``MessageBuilder`` and other implementation globals. +""" + +from __future__ import annotations + +import sys + +from .message import result as _result_module + +sys.modules[__name__] = _result_module diff --git a/astrbot-sdk/src/astrbot_sdk/message_session.py b/astrbot-sdk/src/astrbot_sdk/message_session.py new file mode 100644 index 0000000000..ec87255555 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message_session.py @@ -0,0 +1,9 @@ +"""Backward-compatible message session exports. + +The canonical implementation moved to ``astrbot_sdk.message.session``. Preserve the +legacy import path to avoid breaking existing plugins. +""" + +from .message.session import MessageSession + +__all__ = ["MessageSession"] diff --git a/astrbot-sdk/src/astrbot_sdk/plugin_kv.py b/astrbot-sdk/src/astrbot_sdk/plugin_kv.py new file mode 100644 index 0000000000..de1922b60b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/plugin_kv.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast + +if TYPE_CHECKING: + from .context import Context + +_VT = TypeVar("_VT") + + +class _HasRuntimeContext(Protocol): + def _require_runtime_context(self) -> Context: ... + + +class PluginKVStoreMixin: + """Plugin-scoped KV helpers backed by the runtime db client.""" + + def _runtime_context(self) -> Context: + owner = cast(_HasRuntimeContext, self) + return owner._require_runtime_context() + + @property + def plugin_id(self) -> str: + ctx = self._runtime_context() + return ctx.plugin_id + + async def put_kv_data(self, key: str, value: Any) -> None: + ctx = self._runtime_context() + await ctx.db.set(str(key), value) + + async def get_kv_data(self, key: str, default: _VT) -> _VT: + ctx = self._runtime_context() + value = await ctx.db.get(str(key)) + return default if value is None else value + + async def delete_kv_data(self, key: str) -> None: + ctx = self._runtime_context() + await ctx.db.delete(str(key)) diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py b/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py new file mode 100644 index 0000000000..f7bf9ba2b6 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py @@ -0,0 +1,160 @@ +"""AstrBot s5r 协议公共入口。 + +这里暴露 s5r 原生协议的消息模型、描述符和解析函数。 + +握手阶段由 `InitializeMessage` 发起,返回值不是另一条 initialize 消息,而是 +`ResultMessage(kind="initialize_result")`,其 `output` 负载可解析为 +`InitializeOutput`。 + +## 插件作者指南:什么时候用什么? + +### CapabilityDescriptor vs BUILTIN_CAPABILITY_SCHEMAS + +**CapabilityDescriptor** 用于**声明**能力: +- 当你的插件想**暴露**一个可被其他插件或核心调用的能力时 +- 例如:你的插件提供了一个翻译功能,想让其他插件调用 + + ```python + from astrbot_sdk.protocol import CapabilityDescriptor + + descriptor = CapabilityDescriptor( + name="my_plugin.translate", # 格式: 插件名.能力名 + description="翻译文本到指定语言", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "要翻译的文本"}, + "target_lang": {"type": "string", "description": "目标语言"}, + }, + "required": ["text", "target_lang"], + }, + output_schema={ + "type": "object", + "properties": { + "translated": {"type": "string"}, + }, + }, + ) + ``` + +**BUILTIN_CAPABILITY_SCHEMAS** 用于**查询**内置能力的参数格式: +- 当你想**调用**核心提供的内置能力时,用它了解参数结构 +- 例如:你想调用 `llm.chat`,但不确定参数格式 + + ```python + from astrbot_sdk.protocol import BUILTIN_CAPABILITY_SCHEMAS + + # 查看 llm.chat 的输入参数格式 + schema = BUILTIN_CAPABILITY_SCHEMAS["llm.chat"] + print(schema["input"]) # 输入参数的 JSON Schema + print(schema["output"]) # 输出结果的 JSON Schema + ``` + +### 命名规范 + +能力名称必须遵循 `{namespace}.{action}` 或 `{namespace}.{sub_namespace}.{action}` 格式: +- `llm.chat` - LLM 对话 +- `db.set` - 数据库写入 +- `llm_tool.manager.activate` - LLM 工具管理 + +**保留命名空间**(插件不可使用): +- `handler.` - 处理器相关 +- `system.` - 系统内部能力 +- `internal.` - 内部实现细节 + +### 常用内置能力速查 + +| 能力名 | 用途 | +|-------|------| +| `llm.chat` | 同步 LLM 对话 | +| `llm.stream_chat` | 流式 LLM 对话 | +| `memory.save` / `memory.get` | 短期记忆存储 | +| `db.set` / `db.get` | 持久化键值存储 | +| `platform.send` | 发送消息 | +| `provider.get_using` | 获取当前 Provider | +""" + +from __future__ import annotations + +from typing import Any + +from . import _builtin_schemas as builtin_schemas +from .descriptors import ( # noqa: F401 + BUILTIN_CAPABILITY_SCHEMAS, + CapabilityDescriptor, + CommandRouteSpec, + CommandTrigger, + CompositeFilterSpec, + EventTrigger, + FilterSpec, + HandlerDescriptor, + LocalFilterRefSpec, + MessageTrigger, + MessageTypeFilterSpec, + ParamSpec, + Permissions, + PlatformFilterSpec, + ScheduleTrigger, + SessionRef, + Trigger, +) +from .messages import ( # noqa: F401 + CancelMessage, + ErrorPayload, + EventMessage, + InitializeMessage, + InitializeOutput, + InvokeMessage, + PeerInfo, + ProtocolMessage, + ResultMessage, + parse_message, +) + +_DIRECT_EXPORTS = [ + "BUILTIN_CAPABILITY_SCHEMAS", + "CapabilityDescriptor", + "CommandRouteSpec", + "CommandTrigger", + "CancelMessage", + "builtin_schemas", + "CompositeFilterSpec", + "ErrorPayload", + "EventTrigger", + "EventMessage", + "FilterSpec", + "HandlerDescriptor", + "InitializeMessage", + "InitializeOutput", + "InvokeMessage", + "LocalFilterRefSpec", + "MessageTrigger", + "MessageTypeFilterSpec", + "ParamSpec", + "PeerInfo", + "PlatformFilterSpec", + "Permissions", + "ProtocolMessage", + "ResultMessage", + "ScheduleTrigger", + "SessionRef", + "Trigger", + "parse_message", +] + +_BUILTIN_SCHEMA_EXPORTS = tuple( + name for name in builtin_schemas.__all__ if name != "BUILTIN_CAPABILITY_SCHEMAS" +) + + +def __getattr__(name: str) -> Any: + if name in _BUILTIN_SCHEMA_EXPORTS: + return getattr(builtin_schemas, name) + raise AttributeError(name) + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(_BUILTIN_SCHEMA_EXPORTS)) + + +__all__ = list(dict.fromkeys([*_DIRECT_EXPORTS, *_BUILTIN_SCHEMA_EXPORTS])) diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py b/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py new file mode 100644 index 0000000000..f1ee985c2b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py @@ -0,0 +1,2470 @@ +"""Builtin protocol schema constants. + +本模块定义了 AstrBot SDK s5r 协议中所有内置能力的 JSON Schema。 +这些 Schema 用于: +1. 验证能力调用的输入参数是否符合预期格式 +2. 生成能力描述文档,供插件开发者参考 +3. 确保跨进程/跨语言调用时的类型安全 + +所有 Schema 遵循 JSON Schema 规范,支持基本类型检查、必填字段、数组元素约束等。 +""" + +from __future__ import annotations + +from typing import Any + +JSONSchema = dict[str, Any] + + +def _object_schema( + *, + required: tuple[str, ...] = (), + **properties: Any, +) -> JSONSchema: + return { + "type": "object", + "properties": properties, + "required": list(required), + } + + +def _nullable(schema: JSONSchema) -> JSONSchema: + return {"anyOf": [schema, {"type": "null"}]} + + +_OPTIONAL_CHAT_PROPERTIES: dict[str, Any] = { + "system": {"type": "string"}, + "history": {"type": "array", "items": {"type": "object"}}, + "contexts": {"type": "array", "items": {"type": "object"}}, + "provider_id": {"type": "string"}, + "tool_calls_result": {"type": "array", "items": {"type": "object"}}, + "model": {"type": "string"}, + "temperature": {"type": "number"}, + "image_urls": {"type": "array", "items": {"type": "string"}}, + "tools": {"type": "array"}, + "max_steps": {"type": "integer"}, +} + +LLM_CHAT_INPUT_SCHEMA = _object_schema( + required=("prompt",), + prompt={"type": "string"}, + **_OPTIONAL_CHAT_PROPERTIES, +) +LLM_CHAT_OUTPUT_SCHEMA = _object_schema(required=("text",), text={"type": "string"}) +LLM_CHAT_RAW_INPUT_SCHEMA = _object_schema( + required=("prompt",), + prompt={"type": "string"}, + **_OPTIONAL_CHAT_PROPERTIES, +) +LLM_CHAT_RAW_OUTPUT_SCHEMA = _object_schema( + required=("text",), + text={"type": "string"}, + usage=_nullable({"type": "object"}), + finish_reason=_nullable({"type": "string"}), + tool_calls={"type": "array", "items": {"type": "object"}}, + role=_nullable({"type": "string"}), + reasoning_content=_nullable({"type": "string"}), + reasoning_signature=_nullable({"type": "string"}), +) +LLM_STREAM_CHAT_INPUT_SCHEMA = _object_schema( + required=("prompt",), + prompt={"type": "string"}, + **_OPTIONAL_CHAT_PROPERTIES, +) +LLM_STREAM_CHAT_OUTPUT_SCHEMA = _object_schema( + required=("text",), text={"type": "string"} +) +MEMORY_SEARCH_INPUT_SCHEMA = _object_schema( + required=("query",), + query={"type": "string"}, + mode={"type": "string", "enum": ["auto", "keyword", "vector", "hybrid"]}, + limit={"type": "integer", "minimum": 1}, + min_score={"type": "number"}, + provider_id={"type": "string"}, + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_SEARCH_OUTPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value", "score", "match_type"), + key={"type": "string"}, + namespace=_nullable({"type": "string"}), + value=_nullable({"type": "object"}), + score={"type": "number"}, + match_type={ + "type": "string", + "enum": ["keyword", "vector", "hybrid"], + }, + ), + }, +) +MEMORY_SAVE_INPUT_SCHEMA = _object_schema( + required=("key", "value"), + key={"type": "string"}, + value={"type": "object"}, + namespace={"type": "string"}, +) +MEMORY_SAVE_OUTPUT_SCHEMA = _object_schema() +MEMORY_GET_INPUT_SCHEMA = _object_schema( + required=("key",), + key={"type": "string"}, + namespace={"type": "string"}, +) +MEMORY_GET_OUTPUT_SCHEMA = _object_schema( + required=("value",), + value=_nullable({"type": "object"}), +) +MEMORY_LIST_KEYS_INPUT_SCHEMA = _object_schema(namespace={"type": "string"}) +MEMORY_LIST_KEYS_OUTPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, +) +MEMORY_EXISTS_INPUT_SCHEMA = _object_schema( + required=("key",), + key={"type": "string"}, + namespace={"type": "string"}, +) +MEMORY_EXISTS_OUTPUT_SCHEMA = _object_schema( + required=("exists",), + exists={"type": "boolean"}, +) +MEMORY_DELETE_INPUT_SCHEMA = _object_schema( + required=("key",), + key={"type": "string"}, + namespace={"type": "string"}, +) +MEMORY_DELETE_OUTPUT_SCHEMA = _object_schema() +MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA = _object_schema( + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA = _object_schema( + required=("key", "value", "ttl_seconds"), + key={"type": "string"}, + value={"type": "object"}, + ttl_seconds={"type": "integer", "minimum": 1}, + namespace={"type": "string"}, +) +MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA = _object_schema() +MEMORY_GET_MANY_INPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, + namespace={"type": "string"}, +) +MEMORY_GET_MANY_OUTPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value"), + key={"type": "string"}, + value=_nullable({"type": "object"}), + ), + }, +) +MEMORY_DELETE_MANY_INPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, + namespace={"type": "string"}, +) +MEMORY_DELETE_MANY_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MEMORY_COUNT_INPUT_SCHEMA = _object_schema( + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_COUNT_OUTPUT_SCHEMA = _object_schema( + required=("count",), + count={"type": "integer"}, +) +MEMORY_STATS_INPUT_SCHEMA = _object_schema( + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_STATS_OUTPUT_SCHEMA = _object_schema( + total_items={"type": "integer"}, + total_bytes=_nullable({"type": "integer"}), + plugin_id=_nullable({"type": "string"}), + ttl_entries=_nullable({"type": "integer"}), + namespace=_nullable({"type": "string"}), + namespace_count=_nullable({"type": "integer"}), + indexed_items=_nullable({"type": "integer"}), + embedded_items=_nullable({"type": "integer"}), + dirty_items=_nullable({"type": "integer"}), + fts_enabled={"type": "boolean"}, + vector_backend=_nullable({"type": "string"}), + vector_indexes={"type": "array", "items": {"type": "object"}}, +) +SYSTEM_GET_DATA_DIR_INPUT_SCHEMA = _object_schema() +SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA = _object_schema( + required=("path",), + path={"type": "string"}, +) +SYSTEM_TEXT_TO_IMAGE_INPUT_SCHEMA = _object_schema( + required=("text",), + text={"type": "string"}, + return_url={"type": "boolean"}, +) +SYSTEM_TEXT_TO_IMAGE_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "string"}, +) +SYSTEM_HTML_RENDER_INPUT_SCHEMA = _object_schema( + required=("tmpl", "data"), + tmpl={"type": "string"}, + data={"type": "object"}, + return_url={"type": "boolean"}, + options=_nullable({"type": "object"}), +) +SYSTEM_HTML_RENDER_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "string"}, +) +SYSTEM_FILE_REGISTER_INPUT_SCHEMA = _object_schema( + required=("path",), + path={"type": "string"}, + timeout=_nullable({"type": "number"}), +) +SYSTEM_FILE_REGISTER_OUTPUT_SCHEMA = _object_schema( + required=("token", "url"), + token={"type": "string"}, + url={"type": "string"}, +) +SYSTEM_FILE_HANDLE_INPUT_SCHEMA = _object_schema( + required=("token",), + token={"type": "string"}, +) +SYSTEM_FILE_HANDLE_OUTPUT_SCHEMA = _object_schema( + required=("path",), + path={"type": "string"}, +) +SYSTEM_SESSION_WAITER_REGISTER_INPUT_SCHEMA = _object_schema( + required=("session_key",), + session_key={"type": "string"}, +) +SYSTEM_SESSION_WAITER_REGISTER_OUTPUT_SCHEMA = _object_schema() +SYSTEM_SESSION_WAITER_UNREGISTER_INPUT_SCHEMA = _object_schema( + required=("session_key",), + session_key={"type": "string"}, +) +SYSTEM_SESSION_WAITER_UNREGISTER_OUTPUT_SCHEMA = _object_schema() +DB_GET_INPUT_SCHEMA = _object_schema(required=("key",), key={"type": "string"}) +DB_GET_OUTPUT_SCHEMA = _object_schema( + required=("value",), + value=_nullable({}), +) +DB_SET_INPUT_SCHEMA = _object_schema( + required=("key", "value"), + key={"type": "string"}, + value={}, +) +DB_SET_OUTPUT_SCHEMA = _object_schema() +DB_DELETE_INPUT_SCHEMA = _object_schema(required=("key",), key={"type": "string"}) +DB_DELETE_OUTPUT_SCHEMA = _object_schema() +DB_LIST_INPUT_SCHEMA = _object_schema(prefix=_nullable({"type": "string"})) +DB_LIST_OUTPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, +) +DB_GET_MANY_INPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, +) +DB_GET_MANY_OUTPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value"), + key={"type": "string"}, + value=_nullable({}), + ), + }, +) +DB_SET_MANY_INPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value"), + key={"type": "string"}, + value={}, + ), + }, +) +DB_SET_MANY_OUTPUT_SCHEMA = _object_schema() +DB_WATCH_INPUT_SCHEMA = _object_schema(prefix=_nullable({"type": "string"})) +DB_WATCH_OUTPUT_SCHEMA = _object_schema() +SESSION_REF_SCHEMA = _object_schema( + required=("conversation_id",), + conversation_id={"type": "string"}, + platform=_nullable({"type": "string"}), + raw=_nullable({"type": "object"}), +) +SYSTEM_EVENT_REACT_INPUT_SCHEMA = _object_schema( + required=("emoji",), + target=_nullable(SESSION_REF_SCHEMA), + emoji={"type": "string"}, +) +SYSTEM_EVENT_REACT_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), + use_fallback={"type": "boolean"}, +) +SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, + stream_id=_nullable({"type": "string"}), +) +SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA = _object_schema( + required=("stream_id", "chain"), + stream_id={"type": "string"}, + chain={"type": "array", "items": {"type": "object"}}, +) +SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA = _object_schema() +SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA = _object_schema( + required=("stream_id",), + stream_id={"type": "string"}, +) +SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA = _object_schema( + required=("should_call_llm", "requested_llm"), + should_call_llm={"type": "boolean"}, + requested_llm={"type": "boolean"}, +) +SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA = _object_schema( + required=("should_call_llm", "requested_llm"), + should_call_llm={"type": "boolean"}, + requested_llm={"type": "boolean"}, +) +SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result=_nullable({"type": "object"}), +) +SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA = _object_schema( + required=("result",), + target=_nullable(SESSION_REF_SCHEMA), + result={"type": "object"}, +) +SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "object"}, +) +SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA = _object_schema() +SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA = _object_schema( + required=("plugin_names",), + plugin_names=_nullable({"type": "array", "items": {"type": "string"}}), +) +SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), + plugin_names=_nullable({"type": "array", "items": {"type": "string"}}), +) +SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA = _object_schema( + required=("plugin_names",), + plugin_names=_nullable({"type": "array", "items": {"type": "string"}}), +) +PLATFORM_SEND_INPUT_SCHEMA = _object_schema( + required=("session", "text"), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), + text={"type": "string"}, +) +PLATFORM_SEND_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_SEND_IMAGE_INPUT_SCHEMA = _object_schema( + required=("session", "image_url"), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), + image_url={"type": "string"}, +) +PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_SEND_CHAIN_INPUT_SCHEMA = _object_schema( + required=("session", "chain"), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), + chain={"type": "array", "items": {"type": "object"}}, +) +PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA = _object_schema( + required=("session", "chain"), + session={"type": "string"}, + chain={"type": "array", "items": {"type": "object"}}, +) +PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_GET_GROUP_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), +) +PLATFORM_GET_GROUP_OUTPUT_SCHEMA = _object_schema( + required=("group",), + group=_nullable({"type": "object"}), +) +PLATFORM_GET_MEMBERS_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), +) +PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA = _object_schema( + required=("members",), + members={"type": "array", "items": {"type": "object"}}, +) +PLATFORM_INSTANCE_SCHEMA = _object_schema( + required=("id", "name", "type", "status"), + id={"type": "string"}, + name={"type": "string"}, + type={"type": "string"}, + status={"type": "string"}, +) +PLATFORM_LIST_INSTANCES_INPUT_SCHEMA = _object_schema() +PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA = _object_schema( + required=("platforms",), + platforms={"type": "array", "items": PLATFORM_INSTANCE_SCHEMA}, +) +PLATFORM_ERROR_SCHEMA = _object_schema( + required=("message", "timestamp"), + message={"type": "string"}, + timestamp={"type": "string"}, + traceback=_nullable({"type": "string"}), +) +PLATFORM_MANAGER_STATE_SCHEMA = _object_schema( + required=("id", "name", "type", "status", "errors", "unified_webhook"), + id={"type": "string"}, + name={"type": "string"}, + type={"type": "string"}, + status={"type": "string"}, + errors={"type": "array", "items": PLATFORM_ERROR_SCHEMA}, + last_error=_nullable(PLATFORM_ERROR_SCHEMA), + unified_webhook={"type": "boolean"}, +) +PLATFORM_STATS_SCHEMA = _object_schema( + required=( + "id", + "type", + "display_name", + "status", + "error_count", + "unified_webhook", + ), + id={"type": "string"}, + type={"type": "string"}, + display_name={"type": "string"}, + status={"type": "string"}, + started_at=_nullable({"type": "string"}), + error_count={"type": "integer"}, + last_error=_nullable(PLATFORM_ERROR_SCHEMA), + unified_webhook={"type": "boolean"}, + meta={"type": "object"}, +) +PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("platform_id",), + platform_id={"type": "string"}, +) +PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("platform",), + platform=_nullable(PLATFORM_MANAGER_STATE_SCHEMA), +) +PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA = _object_schema( + required=("platform_id",), + platform_id={"type": "string"}, +) +PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA = _object_schema() +PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA = _object_schema( + required=("platform_id",), + platform_id={"type": "string"}, +) +PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA = _object_schema( + required=("stats",), + stats=_nullable(PLATFORM_STATS_SCHEMA), +) +PERMISSION_ROLE_SCHEMA = {"type": "string", "enum": ["member", "admin"]} +PERMISSION_CHECK_INPUT_SCHEMA = _object_schema( + required=("user_id",), + user_id={"type": "string"}, + session_id=_nullable({"type": "string"}), +) +PERMISSION_CHECK_RESULT_SCHEMA = _object_schema( + required=("is_admin", "role"), + is_admin={"type": "boolean"}, + role=PERMISSION_ROLE_SCHEMA, +) +PERMISSION_CHECK_OUTPUT_SCHEMA = PERMISSION_CHECK_RESULT_SCHEMA +PERMISSION_GET_ADMINS_INPUT_SCHEMA = _object_schema() +PERMISSION_GET_ADMINS_OUTPUT_SCHEMA = _object_schema( + required=("admins",), + admins={"type": "array", "items": {"type": "string"}}, +) +PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA = _object_schema( + required=("user_id",), + user_id={"type": "string"}, +) +PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA = _object_schema( + required=("changed",), + changed={"type": "boolean"}, +) +PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA = _object_schema( + required=("user_id",), + user_id={"type": "string"}, +) +PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA = _object_schema( + required=("changed",), + changed={"type": "boolean"}, +) +SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA = _object_schema( + required=("session", "plugin_name"), + session={"type": "string"}, + plugin_name={"type": "string"}, +) +SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA = _object_schema( + required=("enabled",), + enabled={"type": "boolean"}, +) +SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA = _object_schema( + required=("session", "handlers"), + session={"type": "string"}, + handlers={"type": "array", "items": {"type": "object"}}, +) +SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA = _object_schema( + required=("handlers",), + handlers={"type": "array", "items": {"type": "object"}}, +) +SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, +) +SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA = _object_schema( + required=("enabled",), + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA = _object_schema( + required=("session", "enabled"), + session={"type": "string"}, + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA = _object_schema() +SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, +) +SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA = _object_schema( + required=("enabled",), + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA = _object_schema( + required=("session", "enabled"), + session={"type": "string"}, + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA = _object_schema() +PERSONA_RECORD_SCHEMA = _object_schema( + required=("persona_id", "system_prompt", "begin_dialogs", "sort_order"), + persona_id={"type": "string"}, + system_prompt={"type": "string"}, + begin_dialogs={"type": "array", "items": {"type": "string"}}, + tools=_nullable({"type": "array", "items": {"type": "string"}}), + skills=_nullable({"type": "array", "items": {"type": "string"}}), + custom_error_message=_nullable({"type": "string"}), + folder_id=_nullable({"type": "string"}), + sort_order={"type": "integer"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), +) +PERSONA_CREATE_SCHEMA = _object_schema( + required=("persona_id", "system_prompt"), + persona_id={"type": "string"}, + system_prompt={"type": "string"}, + begin_dialogs={"type": "array", "items": {"type": "string"}}, + tools=_nullable({"type": "array", "items": {"type": "string"}}), + skills=_nullable({"type": "array", "items": {"type": "string"}}), + custom_error_message=_nullable({"type": "string"}), + folder_id=_nullable({"type": "string"}), + sort_order={"type": "integer"}, +) +PERSONA_UPDATE_SCHEMA = _object_schema( + system_prompt=_nullable({"type": "string"}), + begin_dialogs=_nullable({"type": "array", "items": {"type": "string"}}), + tools=_nullable({"type": "array", "items": {"type": "string"}}), + skills=_nullable({"type": "array", "items": {"type": "string"}}), + custom_error_message=_nullable({"type": "string"}), +) +PERSONA_GET_INPUT_SCHEMA = _object_schema( + required=("persona_id",), + persona_id={"type": "string"}, +) +PERSONA_GET_OUTPUT_SCHEMA = _object_schema( + required=("persona",), + persona=PERSONA_RECORD_SCHEMA, +) +PERSONA_LIST_INPUT_SCHEMA = _object_schema() +PERSONA_LIST_OUTPUT_SCHEMA = _object_schema( + required=("personas",), + personas={"type": "array", "items": PERSONA_RECORD_SCHEMA}, +) +PERSONA_CREATE_INPUT_SCHEMA = _object_schema( + required=("persona",), + persona=PERSONA_CREATE_SCHEMA, +) +PERSONA_CREATE_OUTPUT_SCHEMA = _object_schema( + required=("persona",), + persona=PERSONA_RECORD_SCHEMA, +) +PERSONA_UPDATE_INPUT_SCHEMA = _object_schema( + required=("persona_id", "persona"), + persona_id={"type": "string"}, + persona=PERSONA_UPDATE_SCHEMA, +) +PERSONA_UPDATE_OUTPUT_SCHEMA = _object_schema( + required=("persona",), + persona=_nullable(PERSONA_RECORD_SCHEMA), +) +PERSONA_DELETE_INPUT_SCHEMA = _object_schema( + required=("persona_id",), + persona_id={"type": "string"}, +) +PERSONA_DELETE_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_RECORD_SCHEMA = _object_schema( + required=("conversation_id", "session", "platform_id", "history"), + conversation_id={"type": "string"}, + session={"type": "string"}, + platform_id={"type": "string"}, + history={"type": "array", "items": {"type": "object"}}, + title=_nullable({"type": "string"}), + persona_id=_nullable({"type": "string"}), + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), + token_usage=_nullable({"type": "integer"}), +) +CONVERSATION_CREATE_SCHEMA = _object_schema( + platform_id=_nullable({"type": "string"}), + history=_nullable({"type": "array", "items": {"type": "object"}}), + title=_nullable({"type": "string"}), + persona_id=_nullable({"type": "string"}), +) +CONVERSATION_UPDATE_SCHEMA = _object_schema( + history=_nullable({"type": "array", "items": {"type": "object"}}), + title=_nullable({"type": "string"}), + persona_id=_nullable({"type": "string"}), + token_usage=_nullable({"type": "integer"}), +) +CONVERSATION_NEW_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation=_nullable(CONVERSATION_CREATE_SCHEMA), +) +CONVERSATION_NEW_OUTPUT_SCHEMA = _object_schema( + required=("conversation_id",), + conversation_id={"type": "string"}, +) +CONVERSATION_SWITCH_INPUT_SCHEMA = _object_schema( + required=("session", "conversation_id"), + session={"type": "string"}, + conversation_id={"type": "string"}, +) +CONVERSATION_SWITCH_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_DELETE_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation_id=_nullable({"type": "string"}), +) +CONVERSATION_DELETE_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_GET_INPUT_SCHEMA = _object_schema( + required=("session", "conversation_id"), + session={"type": "string"}, + conversation_id={"type": "string"}, + create_if_not_exists={"type": "boolean"}, +) +CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema( + required=("conversation",), + conversation=_nullable(CONVERSATION_RECORD_SCHEMA), +) +CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + create_if_not_exists={"type": "boolean"}, +) +CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema( + required=("conversation",), + conversation=_nullable(CONVERSATION_RECORD_SCHEMA), +) +CONVERSATION_LIST_INPUT_SCHEMA = _object_schema( + session=_nullable({"type": "string"}), + platform_id=_nullable({"type": "string"}), +) +CONVERSATION_LIST_OUTPUT_SCHEMA = _object_schema( + required=("conversations",), + conversations={"type": "array", "items": CONVERSATION_RECORD_SCHEMA}, +) +CONVERSATION_UPDATE_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation_id=_nullable({"type": "string"}), + conversation=_nullable(CONVERSATION_UPDATE_SCHEMA), +) +CONVERSATION_UPDATE_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation_id=_nullable({"type": "string"}), +) +CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA = _object_schema() +MESSAGE_HISTORY_SESSION_SCHEMA = _object_schema( + required=("platform_id", "message_type", "session_id"), + platform_id={"type": "string"}, + message_type={"type": "string", "enum": ["group", "private", "other"]}, + session_id={"type": "string"}, +) +MESSAGE_HISTORY_SENDER_SCHEMA = _object_schema( + sender_id=_nullable({"type": "string"}), + sender_name=_nullable({"type": "string"}), +) +MESSAGE_HISTORY_RECORD_SCHEMA = _object_schema( + required=("id", "session", "sender", "parts", "metadata"), + id={"type": "integer"}, + session=MESSAGE_HISTORY_SESSION_SCHEMA, + sender=MESSAGE_HISTORY_SENDER_SCHEMA, + parts={"type": "array", "items": {"type": "object"}}, + metadata={"type": "object"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), + idempotency_key=_nullable({"type": "string"}), +) +MESSAGE_HISTORY_PAGE_SCHEMA = _object_schema( + required=("records",), + records={"type": "array", "items": MESSAGE_HISTORY_RECORD_SCHEMA}, + next_cursor=_nullable({"type": "string"}), + total=_nullable({"type": "integer"}), +) +MESSAGE_HISTORY_LIST_INPUT_SCHEMA = _object_schema( + required=("session",), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + cursor=_nullable({"type": "string", "pattern": "^(|[1-9][0-9]*)$"}), + limit={"type": "integer", "minimum": 1}, +) +MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA = _object_schema( + required=("page",), + page=MESSAGE_HISTORY_PAGE_SCHEMA, +) +MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("session", "record_id"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + record_id={"type": "integer", "minimum": 1}, +) +MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("record",), + record=_nullable(MESSAGE_HISTORY_RECORD_SCHEMA), +) +MESSAGE_HISTORY_APPEND_INPUT_SCHEMA = _object_schema( + required=("session", "sender", "parts"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + sender=MESSAGE_HISTORY_SENDER_SCHEMA, + parts={"type": "array", "items": {"type": "object"}}, + metadata=_nullable({"type": "object"}), + idempotency_key=_nullable({"type": "string"}), +) +MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA = _object_schema( + required=("record",), + record=MESSAGE_HISTORY_RECORD_SCHEMA, +) +MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA = _object_schema( + required=("session", "before"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + before={"type": "string"}, +) +MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA = _object_schema( + required=("session", "after"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + after={"type": "string"}, +) +MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA = _object_schema( + required=("session",), + session=MESSAGE_HISTORY_SESSION_SCHEMA, +) +MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MCP_SERVER_SCOPE_SCHEMA = {"type": "string", "enum": ["local", "global"]} +MCP_SERVER_RECORD_SCHEMA = _object_schema( + required=("name", "scope", "active", "running", "config", "tools", "errlogs"), + name={"type": "string"}, + scope=MCP_SERVER_SCOPE_SCHEMA, + active={"type": "boolean"}, + running={"type": "boolean"}, + config={"type": "object"}, + tools={"type": "array", "items": {"type": "string"}}, + errlogs={"type": "array", "items": {"type": "string"}}, + last_error=_nullable({"type": "string"}), +) +MCP_LOCAL_GET_INPUT_SCHEMA = _object_schema(required=("name",), name={"type": "string"}) +MCP_LOCAL_GET_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=_nullable(MCP_SERVER_RECORD_SCHEMA), +) +MCP_LOCAL_LIST_INPUT_SCHEMA = _object_schema() +MCP_LOCAL_LIST_OUTPUT_SCHEMA = _object_schema( + required=("servers",), + servers={"type": "array", "items": MCP_SERVER_RECORD_SCHEMA}, +) +MCP_LOCAL_ENABLE_INPUT_SCHEMA = _object_schema( + required=("name",), name={"type": "string"} +) +MCP_LOCAL_ENABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_LOCAL_DISABLE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +MCP_LOCAL_DISABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_LOCAL_WAIT_UNTIL_READY_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, + timeout={"type": "number"}, +) +MCP_LOCAL_WAIT_UNTIL_READY_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_SESSION_OPEN_INPUT_SCHEMA = _object_schema( + required=("name", "config"), + name={"type": "string"}, + config={"type": "object"}, + timeout={"type": "number"}, +) +MCP_SESSION_OPEN_OUTPUT_SCHEMA = _object_schema( + required=("session_id", "tools"), + session_id={"type": "string"}, + tools={"type": "array", "items": {"type": "string"}}, +) +MCP_SESSION_LIST_TOOLS_INPUT_SCHEMA = _object_schema( + required=("session_id",), + session_id={"type": "string"}, +) +MCP_SESSION_LIST_TOOLS_OUTPUT_SCHEMA = _object_schema( + required=("tools",), + tools={"type": "array", "items": {"type": "string"}}, +) +MCP_SESSION_CALL_TOOL_INPUT_SCHEMA = _object_schema( + required=("session_id", "tool_name", "args"), + session_id={"type": "string"}, + tool_name={"type": "string"}, + args={"type": "object"}, +) +MCP_SESSION_CALL_TOOL_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "object"}, +) +MCP_SESSION_CLOSE_INPUT_SCHEMA = _object_schema( + required=("session_id",), + session_id={"type": "string"}, +) +MCP_SESSION_CLOSE_OUTPUT_SCHEMA = _object_schema() +MCP_GLOBAL_REGISTER_INPUT_SCHEMA = _object_schema( + required=("name", "config"), + name={"type": "string"}, + config={"type": "object"}, + timeout={"type": "number"}, +) +MCP_GLOBAL_REGISTER_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_GLOBAL_GET_INPUT_SCHEMA = _object_schema( + required=("name",), name={"type": "string"} +) +MCP_GLOBAL_GET_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=_nullable(MCP_SERVER_RECORD_SCHEMA), +) +MCP_GLOBAL_LIST_INPUT_SCHEMA = _object_schema() +MCP_GLOBAL_LIST_OUTPUT_SCHEMA = _object_schema( + required=("servers",), + servers={"type": "array", "items": MCP_SERVER_RECORD_SCHEMA}, +) +MCP_GLOBAL_ENABLE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, + timeout={"type": "number"}, +) +MCP_GLOBAL_ENABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_GLOBAL_DISABLE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +MCP_GLOBAL_DISABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_GLOBAL_UNREGISTER_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +MCP_GLOBAL_UNREGISTER_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +INTERNAL_MCP_LOCAL_EXECUTE_INPUT_SCHEMA = _object_schema( + required=("plugin_id", "server_name", "tool_name", "tool_args"), + plugin_id={"type": "string"}, + server_name={"type": "string"}, + tool_name={"type": "string"}, + tool_args={"type": "object"}, +) +INTERNAL_MCP_LOCAL_EXECUTE_OUTPUT_SCHEMA = _object_schema( + required=("content", "success"), + content=_nullable({"type": "string"}), + success={"type": "boolean"}, +) +KNOWLEDGE_BASE_RECORD_SCHEMA = _object_schema( + required=("kb_id", "kb_name", "embedding_provider_id", "doc_count", "chunk_count"), + kb_id={"type": "string"}, + kb_name={"type": "string"}, + description=_nullable({"type": "string"}), + emoji=_nullable({"type": "string"}), + embedding_provider_id={"type": "string"}, + rerank_provider_id=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + top_k_dense=_nullable({"type": "integer"}), + top_k_sparse=_nullable({"type": "integer"}), + top_m_final=_nullable({"type": "integer"}), + doc_count={"type": "integer"}, + chunk_count={"type": "integer"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), +) +KNOWLEDGE_BASE_CREATE_SCHEMA = _object_schema( + required=("kb_name", "embedding_provider_id"), + kb_name={"type": "string"}, + embedding_provider_id={"type": "string"}, + description=_nullable({"type": "string"}), + emoji=_nullable({"type": "string"}), + rerank_provider_id=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + top_k_dense=_nullable({"type": "integer"}), + top_k_sparse=_nullable({"type": "integer"}), + top_m_final=_nullable({"type": "integer"}), +) +KNOWLEDGE_BASE_UPDATE_SCHEMA = _object_schema( + kb_name=_nullable({"type": "string"}), + description=_nullable({"type": "string"}), + emoji=_nullable({"type": "string"}), + embedding_provider_id=_nullable({"type": "string"}), + rerank_provider_id=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + top_k_dense=_nullable({"type": "integer"}), + top_k_sparse=_nullable({"type": "integer"}), + top_m_final=_nullable({"type": "integer"}), +) +KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA = _object_schema( + required=( + "doc_id", + "kb_id", + "doc_name", + "file_type", + "file_size", + "chunk_count", + "media_count", + ), + doc_id={"type": "string"}, + kb_id={"type": "string"}, + doc_name={"type": "string"}, + file_type={"type": "string"}, + file_size={"type": "integer"}, + file_path={"type": "string"}, + chunk_count={"type": "integer"}, + media_count={"type": "integer"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), +) +KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA = _object_schema( + required=( + "chunk_id", + "doc_id", + "kb_id", + "kb_name", + "doc_name", + "chunk_index", + "content", + "score", + "char_count", + ), + chunk_id={"type": "string"}, + doc_id={"type": "string"}, + kb_id={"type": "string"}, + kb_name={"type": "string"}, + doc_name={"type": "string"}, + chunk_index={"type": "integer"}, + content={"type": "string"}, + score={"type": "number"}, + char_count={"type": "integer"}, +) +KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA = _object_schema( + file_token=_nullable({"type": "string"}), + url=_nullable({"type": "string"}), + text=_nullable({"type": "string"}), + file_name=_nullable({"type": "string"}), + file_type=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + batch_size=_nullable({"type": "integer"}), + tasks_limit=_nullable({"type": "integer"}), + max_retries=_nullable({"type": "integer"}), + enable_cleaning=_nullable({"type": "boolean"}), + cleaning_provider_id=_nullable({"type": "string"}), +) +KB_LIST_INPUT_SCHEMA = _object_schema() +KB_LIST_OUTPUT_SCHEMA = _object_schema( + required=("kbs",), + kbs={"type": "array", "items": KNOWLEDGE_BASE_RECORD_SCHEMA}, +) +KB_GET_INPUT_SCHEMA = _object_schema( + required=("kb_id",), + kb_id={"type": "string"}, +) +KB_GET_OUTPUT_SCHEMA = _object_schema( + required=("kb",), + kb=_nullable(KNOWLEDGE_BASE_RECORD_SCHEMA), +) +KB_CREATE_INPUT_SCHEMA = _object_schema( + required=("kb",), + kb=KNOWLEDGE_BASE_CREATE_SCHEMA, +) +KB_CREATE_OUTPUT_SCHEMA = _object_schema( + required=("kb",), + kb=KNOWLEDGE_BASE_RECORD_SCHEMA, +) +KB_UPDATE_INPUT_SCHEMA = _object_schema( + required=("kb_id", "kb"), + kb_id={"type": "string"}, + kb=KNOWLEDGE_BASE_UPDATE_SCHEMA, +) +KB_UPDATE_OUTPUT_SCHEMA = _object_schema( + required=("kb",), + kb=_nullable(KNOWLEDGE_BASE_RECORD_SCHEMA), +) +KB_DELETE_INPUT_SCHEMA = _object_schema( + required=("kb_id",), + kb_id={"type": "string"}, +) +KB_DELETE_OUTPUT_SCHEMA = _object_schema( + required=("deleted",), + deleted={"type": "boolean"}, +) +KB_RETRIEVE_INPUT_SCHEMA = _object_schema( + required=("query",), + query={"type": "string"}, + kb_ids={"type": "array", "items": {"type": "string"}}, + kb_names={"type": "array", "items": {"type": "string"}}, + top_k_fusion={"type": "integer"}, + top_m_final={"type": "integer"}, +) +KB_RETRIEVE_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result=_nullable( + _object_schema( + required=("context_text", "results"), + context_text={"type": "string"}, + results={ + "type": "array", + "items": KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA, + }, + ) + ), +) +KB_DOCUMENT_UPLOAD_INPUT_SCHEMA = _object_schema( + required=("kb_id", "document"), + kb_id={"type": "string"}, + document=KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA, +) +KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA = _object_schema( + required=("document",), + document=KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA, +) +KB_DOCUMENT_LIST_INPUT_SCHEMA = _object_schema( + required=("kb_id",), + kb_id={"type": "string"}, + offset={"type": "integer"}, + limit={"type": "integer"}, +) +KB_DOCUMENT_LIST_OUTPUT_SCHEMA = _object_schema( + required=("documents",), + documents={"type": "array", "items": KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA}, +) +KB_DOCUMENT_GET_INPUT_SCHEMA = _object_schema( + required=("kb_id", "doc_id"), + kb_id={"type": "string"}, + doc_id={"type": "string"}, +) +KB_DOCUMENT_GET_OUTPUT_SCHEMA = _object_schema( + required=("document",), + document=_nullable(KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA), +) +KB_DOCUMENT_DELETE_INPUT_SCHEMA = _object_schema( + required=("kb_id", "doc_id"), + kb_id={"type": "string"}, + doc_id={"type": "string"}, +) +KB_DOCUMENT_DELETE_OUTPUT_SCHEMA = _object_schema( + required=("deleted",), + deleted={"type": "boolean"}, +) +KB_DOCUMENT_REFRESH_INPUT_SCHEMA = _object_schema( + required=("kb_id", "doc_id"), + kb_id={"type": "string"}, + doc_id={"type": "string"}, +) +KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA = _object_schema( + required=("document",), + document=_nullable(KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA), +) +REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA = _object_schema( + required=("command_name", "handler_full_name"), + command_name={"type": "string"}, + handler_full_name={"type": "string"}, + source_event_type={"type": "string"}, + desc={"type": "string"}, + priority={"type": "integer"}, + use_regex={"type": "boolean"}, + ignore_prefix={"type": "boolean"}, +) +REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA = _object_schema() +SKILL_REGISTER_INPUT_SCHEMA = _object_schema( + required=("name", "path"), + name={"type": "string"}, + path={"type": "string"}, + description={"type": "string"}, +) +SKILL_REGISTER_OUTPUT_SCHEMA = _object_schema( + required=("name", "description", "path", "skill_dir"), + name={"type": "string"}, + description={"type": "string"}, + path={"type": "string"}, + skill_dir={"type": "string"}, +) +SKILL_UNREGISTER_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +SKILL_UNREGISTER_OUTPUT_SCHEMA = _object_schema( + required=("removed",), + removed={"type": "boolean"}, +) +SKILL_LIST_INPUT_SCHEMA = _object_schema() +SKILL_LIST_OUTPUT_SCHEMA = _object_schema( + required=("skills",), + skills={ + "type": "array", + "items": SKILL_REGISTER_OUTPUT_SCHEMA, + }, +) +HTTP_REGISTER_API_INPUT_SCHEMA = _object_schema( + required=("route", "methods", "handler_capability"), + route={"type": "string"}, + methods={"type": "array", "items": {"type": "string"}}, + handler_capability={"type": "string"}, + description={"type": "string"}, +) +HTTP_REGISTER_API_OUTPUT_SCHEMA = _object_schema() +HTTP_UNREGISTER_API_INPUT_SCHEMA = _object_schema( + required=("route", "methods"), + route={"type": "string"}, + methods={"type": "array", "items": {"type": "string"}}, +) +HTTP_UNREGISTER_API_OUTPUT_SCHEMA = _object_schema() +HTTP_LIST_APIS_INPUT_SCHEMA = _object_schema() +HTTP_LIST_APIS_OUTPUT_SCHEMA = _object_schema( + required=("apis",), + apis={"type": "array", "items": {"type": "object"}}, +) +METADATA_GET_PLUGIN_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +METADATA_GET_PLUGIN_OUTPUT_SCHEMA = _object_schema( + required=("plugin",), + plugin=_nullable({"type": "object"}), +) +METADATA_LIST_PLUGINS_INPUT_SCHEMA = _object_schema() +METADATA_LIST_PLUGINS_OUTPUT_SCHEMA = _object_schema( + required=("plugins",), + plugins={"type": "array", "items": {"type": "object"}}, +) +METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA = _object_schema( + required=("config",), + config=_nullable({"type": "object"}), +) +METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA = _object_schema( + required=("config",), + config={"type": "object"}, +) +METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA = _object_schema( + required=("config",), + config=_nullable({"type": "object"}), +) +REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA = _object_schema( + required=("event_type",), + event_type={"type": "string"}, +) +REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA = _object_schema( + required=("handlers",), + handlers={"type": "array", "items": {"type": "object"}}, +) +REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA = _object_schema( + required=("full_name",), + full_name={"type": "string"}, +) +REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA = _object_schema( + required=("handler",), + handler=_nullable({"type": "object"}), +) +PROVIDER_META_SCHEMA = _object_schema( + required=("id", "type", "provider_type"), + id={"type": "string"}, + model=_nullable({"type": "string"}), + type={"type": "string"}, + provider_type={"type": "string"}, +) +MANAGED_PROVIDER_RECORD_SCHEMA = _object_schema( + required=("id", "type", "provider_type", "loaded", "enabled"), + id={"type": "string"}, + model=_nullable({"type": "string"}), + type={"type": "string"}, + provider_type={"type": "string"}, + loaded={"type": "boolean"}, + enabled={"type": "boolean"}, + provider_source_id=_nullable({"type": "string"}), +) +PROVIDER_CHANGE_EVENT_SCHEMA = _object_schema( + required=("provider_id", "provider_type"), + provider_id={"type": "string"}, + provider_type={"type": "string"}, + umo=_nullable({"type": "string"}), +) +LLM_TOOL_SPEC_SCHEMA = _object_schema( + required=("name", "description", "parameters_schema", "active"), + name={"type": "string"}, + description={"type": "string"}, + parameters_schema={"type": "object"}, + handler_ref=_nullable({"type": "string"}), + handler_capability=_nullable({"type": "string"}), + active={"type": "boolean"}, +) +AGENT_SPEC_SCHEMA = _object_schema( + required=("name", "description", "tool_names", "runner_class"), + name={"type": "string"}, + description={"type": "string"}, + tool_names={"type": "array", "items": {"type": "string"}}, + runner_class={"type": "string"}, +) +PROVIDER_GET_USING_INPUT_SCHEMA = _object_schema(umo=_nullable({"type": "string"})) +PROVIDER_GET_USING_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(PROVIDER_META_SCHEMA), +) +PROVIDER_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(PROVIDER_META_SCHEMA), +) +PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA = _object_schema( + umo=_nullable({"type": "string"}), +) +PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id=_nullable({"type": "string"}), +) +PROVIDER_LIST_ALL_INPUT_SCHEMA = _object_schema() +PROVIDER_LIST_ALL_OUTPUT_SCHEMA = _object_schema( + required=("providers",), + providers={"type": "array", "items": PROVIDER_META_SCHEMA}, +) +PROVIDER_STT_GET_TEXT_INPUT_SCHEMA = _object_schema( + required=("provider_id", "audio_url"), + provider_id={"type": "string"}, + audio_url={"type": "string"}, +) +PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA = _object_schema( + required=("text",), + text={"type": "string"}, +) +PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA = _object_schema( + required=("provider_id", "text"), + provider_id={"type": "string"}, + text={"type": "string"}, +) +PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA = _object_schema( + required=("audio_path",), + audio_path={"type": "string"}, +) +PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +PROVIDER_TTS_AUDIO_CHUNK_SCHEMA = _object_schema( + required=("audio_base64",), + audio_base64={"type": "string"}, + text=_nullable({"type": "string"}), +) +PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, + text=_nullable({"type": "string"}), + text_chunks={"type": "array", "items": {"type": "string"}}, +) +PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA = PROVIDER_TTS_AUDIO_CHUNK_SCHEMA +PROVIDER_EMBEDDING_GET_INPUT_SCHEMA = _object_schema( + required=("provider_id", "text"), + provider_id={"type": "string"}, + text={"type": "string"}, +) +PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA = _object_schema( + required=("embedding",), + embedding={"type": "array", "items": {"type": "number"}}, +) +PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA = _object_schema( + required=("provider_id", "texts"), + provider_id={"type": "string"}, + texts={"type": "array", "items": {"type": "string"}}, +) +PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA = _object_schema( + required=("embeddings",), + embeddings={ + "type": "array", + "items": {"type": "array", "items": {"type": "number"}}, + }, +) +PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA = _object_schema( + required=("dim",), + dim={"type": "integer"}, +) +PROVIDER_RERANK_RESULT_SCHEMA = _object_schema( + required=("index", "score", "document"), + index={"type": "integer"}, + score={"type": "number"}, + document={"type": "string"}, +) +PROVIDER_RERANK_INPUT_SCHEMA = _object_schema( + required=("provider_id", "query", "documents"), + provider_id={"type": "string"}, + query={"type": "string"}, + documents={"type": "array", "items": {"type": "string"}}, + top_n=_nullable({"type": "integer"}), +) +PROVIDER_RERANK_OUTPUT_SCHEMA = _object_schema( + required=("results",), + results={"type": "array", "items": PROVIDER_RERANK_RESULT_SCHEMA}, +) +PROVIDER_MANAGER_SET_INPUT_SCHEMA = _object_schema( + required=("provider_id", "provider_type"), + provider_id={"type": "string"}, + provider_type={"type": "string"}, + umo=_nullable({"type": "string"}), +) +PROVIDER_MANAGER_SET_OUTPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA = _object_schema( + required=("config",), + config=_nullable({"type": "object"}), +) +PROVIDER_MANAGER_LOAD_INPUT_SCHEMA = _object_schema( + required=("provider_config",), + provider_config={"type": "object"}, +) +PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_CREATE_INPUT_SCHEMA = _object_schema( + required=("provider_config",), + provider_config={"type": "object"}, +) +PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA = _object_schema( + required=("origin_provider_id", "new_config"), + origin_provider_id={"type": "string"}, + new_config={"type": "object"}, +) +PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_DELETE_INPUT_SCHEMA = _object_schema( + provider_id=_nullable({"type": "string"}), + provider_source_id=_nullable({"type": "string"}), +) +PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA = _object_schema( + required=("providers",), + providers={"type": "array", "items": MANAGED_PROVIDER_RECORD_SCHEMA}, +) +PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA = _object_schema( + required=("provider_id", "provider_type"), + provider_id={"type": "string"}, + provider_type={"type": "string"}, + umo=_nullable({"type": "string"}), +) +LLM_TOOL_MANAGER_GET_INPUT_SCHEMA = _object_schema() +LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA = _object_schema( + required=("registered", "active"), + registered={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA}, + active={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA}, +) +LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA = _object_schema( + required=("activated",), + activated={"type": "boolean"}, +) +LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA = _object_schema( + required=("deactivated",), + deactivated={"type": "boolean"}, +) +LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA = _object_schema( + required=("tools",), + tools={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA}, +) +LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA = _object_schema( + required=("names",), + names={"type": "array", "items": {"type": "string"}}, +) +LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA = _object_schema( + required=("removed",), + removed={"type": "boolean"}, +) +AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA = _object_schema( + prompt=_nullable({"type": "string"}), + system_prompt=_nullable({"type": "string"}), + session_id=_nullable({"type": "string"}), + contexts={"type": "array", "items": {"type": "object"}}, + image_urls={"type": "array", "items": {"type": "string"}}, + tool_names=_nullable({"type": "array", "items": {"type": "string"}}), + tool_calls_result={"type": "array", "items": {"type": "object"}}, + provider_id=_nullable({"type": "string"}), + model=_nullable({"type": "string"}), + temperature={"type": "number"}, + max_steps={"type": "integer"}, + tool_call_timeout={"type": "integer"}, +) +AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA = LLM_CHAT_RAW_OUTPUT_SCHEMA +AGENT_REGISTRY_LIST_INPUT_SCHEMA = _object_schema() +AGENT_REGISTRY_LIST_OUTPUT_SCHEMA = _object_schema( + required=("agents",), + agents={"type": "array", "items": AGENT_SPEC_SCHEMA}, +) +AGENT_REGISTRY_GET_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +AGENT_REGISTRY_GET_OUTPUT_SCHEMA = _object_schema( + required=("agent",), + agent=_nullable(AGENT_SPEC_SCHEMA), +) + +BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = { + "llm.chat": {"input": LLM_CHAT_INPUT_SCHEMA, "output": LLM_CHAT_OUTPUT_SCHEMA}, + "llm.chat_raw": { + "input": LLM_CHAT_RAW_INPUT_SCHEMA, + "output": LLM_CHAT_RAW_OUTPUT_SCHEMA, + }, + "llm.stream_chat": { + "input": LLM_STREAM_CHAT_INPUT_SCHEMA, + "output": LLM_STREAM_CHAT_OUTPUT_SCHEMA, + }, + "memory.search": { + "input": MEMORY_SEARCH_INPUT_SCHEMA, + "output": MEMORY_SEARCH_OUTPUT_SCHEMA, + }, + "memory.save": { + "input": MEMORY_SAVE_INPUT_SCHEMA, + "output": MEMORY_SAVE_OUTPUT_SCHEMA, + }, + "memory.get": { + "input": MEMORY_GET_INPUT_SCHEMA, + "output": MEMORY_GET_OUTPUT_SCHEMA, + }, + "memory.list_keys": { + "input": MEMORY_LIST_KEYS_INPUT_SCHEMA, + "output": MEMORY_LIST_KEYS_OUTPUT_SCHEMA, + }, + "memory.exists": { + "input": MEMORY_EXISTS_INPUT_SCHEMA, + "output": MEMORY_EXISTS_OUTPUT_SCHEMA, + }, + "memory.delete": { + "input": MEMORY_DELETE_INPUT_SCHEMA, + "output": MEMORY_DELETE_OUTPUT_SCHEMA, + }, + "memory.clear_namespace": { + "input": MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA, + "output": MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA, + }, + "memory.save_with_ttl": { + "input": MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA, + "output": MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA, + }, + "memory.get_many": { + "input": MEMORY_GET_MANY_INPUT_SCHEMA, + "output": MEMORY_GET_MANY_OUTPUT_SCHEMA, + }, + "memory.delete_many": { + "input": MEMORY_DELETE_MANY_INPUT_SCHEMA, + "output": MEMORY_DELETE_MANY_OUTPUT_SCHEMA, + }, + "memory.count": { + "input": MEMORY_COUNT_INPUT_SCHEMA, + "output": MEMORY_COUNT_OUTPUT_SCHEMA, + }, + "memory.stats": { + "input": MEMORY_STATS_INPUT_SCHEMA, + "output": MEMORY_STATS_OUTPUT_SCHEMA, + }, + "db.get": {"input": DB_GET_INPUT_SCHEMA, "output": DB_GET_OUTPUT_SCHEMA}, + "db.set": {"input": DB_SET_INPUT_SCHEMA, "output": DB_SET_OUTPUT_SCHEMA}, + "db.delete": {"input": DB_DELETE_INPUT_SCHEMA, "output": DB_DELETE_OUTPUT_SCHEMA}, + "db.list": {"input": DB_LIST_INPUT_SCHEMA, "output": DB_LIST_OUTPUT_SCHEMA}, + "db.get_many": { + "input": DB_GET_MANY_INPUT_SCHEMA, + "output": DB_GET_MANY_OUTPUT_SCHEMA, + }, + "db.set_many": { + "input": DB_SET_MANY_INPUT_SCHEMA, + "output": DB_SET_MANY_OUTPUT_SCHEMA, + }, + "db.watch": {"input": DB_WATCH_INPUT_SCHEMA, "output": DB_WATCH_OUTPUT_SCHEMA}, + "platform.send": { + "input": PLATFORM_SEND_INPUT_SCHEMA, + "output": PLATFORM_SEND_OUTPUT_SCHEMA, + }, + "platform.send_image": { + "input": PLATFORM_SEND_IMAGE_INPUT_SCHEMA, + "output": PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA, + }, + "platform.send_chain": { + "input": PLATFORM_SEND_CHAIN_INPUT_SCHEMA, + "output": PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA, + }, + "platform.send_by_session": { + "input": PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA, + "output": PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA, + }, + "platform.get_group": { + "input": PLATFORM_GET_GROUP_INPUT_SCHEMA, + "output": PLATFORM_GET_GROUP_OUTPUT_SCHEMA, + }, + "platform.get_members": { + "input": PLATFORM_GET_MEMBERS_INPUT_SCHEMA, + "output": PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA, + }, + "platform.list_instances": { + "input": PLATFORM_LIST_INSTANCES_INPUT_SCHEMA, + "output": PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA, + }, + "session.plugin.is_enabled": { + "input": SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA, + "output": SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA, + }, + "session.plugin.filter_handlers": { + "input": SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA, + "output": SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA, + }, + "session.service.is_llm_enabled": { + "input": SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA, + "output": SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA, + }, + "session.service.set_llm_status": { + "input": SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA, + "output": SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA, + }, + "session.service.is_tts_enabled": { + "input": SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA, + "output": SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA, + }, + "session.service.set_tts_status": { + "input": SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA, + "output": SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA, + }, + "persona.get": { + "input": PERSONA_GET_INPUT_SCHEMA, + "output": PERSONA_GET_OUTPUT_SCHEMA, + }, + "persona.list": { + "input": PERSONA_LIST_INPUT_SCHEMA, + "output": PERSONA_LIST_OUTPUT_SCHEMA, + }, + "persona.create": { + "input": PERSONA_CREATE_INPUT_SCHEMA, + "output": PERSONA_CREATE_OUTPUT_SCHEMA, + }, + "persona.update": { + "input": PERSONA_UPDATE_INPUT_SCHEMA, + "output": PERSONA_UPDATE_OUTPUT_SCHEMA, + }, + "persona.delete": { + "input": PERSONA_DELETE_INPUT_SCHEMA, + "output": PERSONA_DELETE_OUTPUT_SCHEMA, + }, + "conversation.new": { + "input": CONVERSATION_NEW_INPUT_SCHEMA, + "output": CONVERSATION_NEW_OUTPUT_SCHEMA, + }, + "conversation.switch": { + "input": CONVERSATION_SWITCH_INPUT_SCHEMA, + "output": CONVERSATION_SWITCH_OUTPUT_SCHEMA, + }, + "conversation.delete": { + "input": CONVERSATION_DELETE_INPUT_SCHEMA, + "output": CONVERSATION_DELETE_OUTPUT_SCHEMA, + }, + "conversation.get": { + "input": CONVERSATION_GET_INPUT_SCHEMA, + "output": CONVERSATION_GET_OUTPUT_SCHEMA, + }, + "conversation.get_current": { + "input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA, + "output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA, + }, + "conversation.list": { + "input": CONVERSATION_LIST_INPUT_SCHEMA, + "output": CONVERSATION_LIST_OUTPUT_SCHEMA, + }, + "conversation.update": { + "input": CONVERSATION_UPDATE_INPUT_SCHEMA, + "output": CONVERSATION_UPDATE_OUTPUT_SCHEMA, + }, + "conversation.unset_persona": { + "input": CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA, + "output": CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA, + }, + "message_history.list": { + "input": MESSAGE_HISTORY_LIST_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA, + }, + "message_history.get_by_id": { + "input": MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA, + }, + "message_history.append": { + "input": MESSAGE_HISTORY_APPEND_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA, + }, + "message_history.delete_before": { + "input": MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA, + }, + "message_history.delete_after": { + "input": MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA, + }, + "message_history.delete_all": { + "input": MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA, + }, + "mcp.local.get": { + "input": MCP_LOCAL_GET_INPUT_SCHEMA, + "output": MCP_LOCAL_GET_OUTPUT_SCHEMA, + }, + "mcp.local.list": { + "input": MCP_LOCAL_LIST_INPUT_SCHEMA, + "output": MCP_LOCAL_LIST_OUTPUT_SCHEMA, + }, + "mcp.local.enable": { + "input": MCP_LOCAL_ENABLE_INPUT_SCHEMA, + "output": MCP_LOCAL_ENABLE_OUTPUT_SCHEMA, + }, + "mcp.local.disable": { + "input": MCP_LOCAL_DISABLE_INPUT_SCHEMA, + "output": MCP_LOCAL_DISABLE_OUTPUT_SCHEMA, + }, + "mcp.local.wait_until_ready": { + "input": MCP_LOCAL_WAIT_UNTIL_READY_INPUT_SCHEMA, + "output": MCP_LOCAL_WAIT_UNTIL_READY_OUTPUT_SCHEMA, + }, + "mcp.session.open": { + "input": MCP_SESSION_OPEN_INPUT_SCHEMA, + "output": MCP_SESSION_OPEN_OUTPUT_SCHEMA, + }, + "mcp.session.list_tools": { + "input": MCP_SESSION_LIST_TOOLS_INPUT_SCHEMA, + "output": MCP_SESSION_LIST_TOOLS_OUTPUT_SCHEMA, + }, + "mcp.session.call_tool": { + "input": MCP_SESSION_CALL_TOOL_INPUT_SCHEMA, + "output": MCP_SESSION_CALL_TOOL_OUTPUT_SCHEMA, + }, + "mcp.session.close": { + "input": MCP_SESSION_CLOSE_INPUT_SCHEMA, + "output": MCP_SESSION_CLOSE_OUTPUT_SCHEMA, + }, + "mcp.global.register": { + "input": MCP_GLOBAL_REGISTER_INPUT_SCHEMA, + "output": MCP_GLOBAL_REGISTER_OUTPUT_SCHEMA, + }, + "mcp.global.get": { + "input": MCP_GLOBAL_GET_INPUT_SCHEMA, + "output": MCP_GLOBAL_GET_OUTPUT_SCHEMA, + }, + "mcp.global.list": { + "input": MCP_GLOBAL_LIST_INPUT_SCHEMA, + "output": MCP_GLOBAL_LIST_OUTPUT_SCHEMA, + }, + "mcp.global.enable": { + "input": MCP_GLOBAL_ENABLE_INPUT_SCHEMA, + "output": MCP_GLOBAL_ENABLE_OUTPUT_SCHEMA, + }, + "mcp.global.disable": { + "input": MCP_GLOBAL_DISABLE_INPUT_SCHEMA, + "output": MCP_GLOBAL_DISABLE_OUTPUT_SCHEMA, + }, + "mcp.global.unregister": { + "input": MCP_GLOBAL_UNREGISTER_INPUT_SCHEMA, + "output": MCP_GLOBAL_UNREGISTER_OUTPUT_SCHEMA, + }, + "internal.mcp.local.execute": { + "input": INTERNAL_MCP_LOCAL_EXECUTE_INPUT_SCHEMA, + "output": INTERNAL_MCP_LOCAL_EXECUTE_OUTPUT_SCHEMA, + }, + "kb.list": {"input": KB_LIST_INPUT_SCHEMA, "output": KB_LIST_OUTPUT_SCHEMA}, + "kb.get": {"input": KB_GET_INPUT_SCHEMA, "output": KB_GET_OUTPUT_SCHEMA}, + "kb.create": { + "input": KB_CREATE_INPUT_SCHEMA, + "output": KB_CREATE_OUTPUT_SCHEMA, + }, + "kb.update": { + "input": KB_UPDATE_INPUT_SCHEMA, + "output": KB_UPDATE_OUTPUT_SCHEMA, + }, + "kb.delete": { + "input": KB_DELETE_INPUT_SCHEMA, + "output": KB_DELETE_OUTPUT_SCHEMA, + }, + "kb.retrieve": { + "input": KB_RETRIEVE_INPUT_SCHEMA, + "output": KB_RETRIEVE_OUTPUT_SCHEMA, + }, + "kb.document.upload": { + "input": KB_DOCUMENT_UPLOAD_INPUT_SCHEMA, + "output": KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA, + }, + "kb.document.list": { + "input": KB_DOCUMENT_LIST_INPUT_SCHEMA, + "output": KB_DOCUMENT_LIST_OUTPUT_SCHEMA, + }, + "kb.document.get": { + "input": KB_DOCUMENT_GET_INPUT_SCHEMA, + "output": KB_DOCUMENT_GET_OUTPUT_SCHEMA, + }, + "kb.document.delete": { + "input": KB_DOCUMENT_DELETE_INPUT_SCHEMA, + "output": KB_DOCUMENT_DELETE_OUTPUT_SCHEMA, + }, + "kb.document.refresh": { + "input": KB_DOCUMENT_REFRESH_INPUT_SCHEMA, + "output": KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA, + }, + "registry.command.register": { + "input": REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA, + "output": REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA, + }, + "skill.register": { + "input": SKILL_REGISTER_INPUT_SCHEMA, + "output": SKILL_REGISTER_OUTPUT_SCHEMA, + }, + "skill.unregister": { + "input": SKILL_UNREGISTER_INPUT_SCHEMA, + "output": SKILL_UNREGISTER_OUTPUT_SCHEMA, + }, + "skill.list": { + "input": SKILL_LIST_INPUT_SCHEMA, + "output": SKILL_LIST_OUTPUT_SCHEMA, + }, + "http.register_api": { + "input": HTTP_REGISTER_API_INPUT_SCHEMA, + "output": HTTP_REGISTER_API_OUTPUT_SCHEMA, + }, + "http.unregister_api": { + "input": HTTP_UNREGISTER_API_INPUT_SCHEMA, + "output": HTTP_UNREGISTER_API_OUTPUT_SCHEMA, + }, + "http.list_apis": { + "input": HTTP_LIST_APIS_INPUT_SCHEMA, + "output": HTTP_LIST_APIS_OUTPUT_SCHEMA, + }, + "metadata.get_plugin": { + "input": METADATA_GET_PLUGIN_INPUT_SCHEMA, + "output": METADATA_GET_PLUGIN_OUTPUT_SCHEMA, + }, + "metadata.list_plugins": { + "input": METADATA_LIST_PLUGINS_INPUT_SCHEMA, + "output": METADATA_LIST_PLUGINS_OUTPUT_SCHEMA, + }, + "metadata.get_plugin_config": { + "input": METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA, + "output": METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA, + }, + "metadata.save_plugin_config": { + "input": METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA, + "output": METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA, + }, + "registry.get_handlers_by_event_type": { + "input": REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA, + "output": REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA, + }, + "registry.get_handler_by_full_name": { + "input": REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA, + "output": REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA, + }, + "provider.get_using": { + "input": PROVIDER_GET_USING_INPUT_SCHEMA, + "output": PROVIDER_GET_USING_OUTPUT_SCHEMA, + }, + "provider.get_by_id": { + "input": PROVIDER_GET_BY_ID_INPUT_SCHEMA, + "output": PROVIDER_GET_BY_ID_OUTPUT_SCHEMA, + }, + "provider.get_current_chat_provider_id": { + "input": PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA, + "output": PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA, + }, + "provider.list_all": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_tts": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_stt": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_embedding": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_rerank": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.get_using_tts": { + "input": PROVIDER_GET_USING_INPUT_SCHEMA, + "output": PROVIDER_GET_USING_OUTPUT_SCHEMA, + }, + "provider.get_using_stt": { + "input": PROVIDER_GET_USING_INPUT_SCHEMA, + "output": PROVIDER_GET_USING_OUTPUT_SCHEMA, + }, + "provider.stt.get_text": { + "input": PROVIDER_STT_GET_TEXT_INPUT_SCHEMA, + "output": PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA, + }, + "provider.tts.get_audio": { + "input": PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA, + "output": PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA, + }, + "provider.tts.support_stream": { + "input": PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA, + "output": PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA, + }, + "provider.tts.get_audio_stream": { + "input": PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA, + "output": PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA, + }, + "provider.embedding.get_embedding": { + "input": PROVIDER_EMBEDDING_GET_INPUT_SCHEMA, + "output": PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA, + }, + "provider.embedding.get_embeddings": { + "input": PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA, + "output": PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA, + }, + "provider.embedding.get_dim": { + "input": PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA, + "output": PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA, + }, + "provider.rerank.rerank": { + "input": PROVIDER_RERANK_INPUT_SCHEMA, + "output": PROVIDER_RERANK_OUTPUT_SCHEMA, + }, + "provider.manager.set": { + "input": PROVIDER_MANAGER_SET_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_SET_OUTPUT_SCHEMA, + }, + "provider.manager.get_by_id": { + "input": PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA, + }, + "provider.manager.get_merged_provider_config": { + "input": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA, + }, + "provider.manager.load": { + "input": PROVIDER_MANAGER_LOAD_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA, + }, + "provider.manager.terminate": { + "input": PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA, + }, + "provider.manager.create": { + "input": PROVIDER_MANAGER_CREATE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA, + }, + "provider.manager.update": { + "input": PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA, + }, + "provider.manager.delete": { + "input": PROVIDER_MANAGER_DELETE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA, + }, + "provider.manager.get_insts": { + "input": PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA, + }, + "provider.manager.watch_changes": { + "input": PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA, + }, + "platform.manager.get_by_id": { + "input": PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA, + "output": PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA, + }, + "platform.manager.clear_errors": { + "input": PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA, + "output": PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA, + }, + "platform.manager.get_stats": { + "input": PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA, + "output": PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA, + }, + "permission.check": { + "input": PERMISSION_CHECK_INPUT_SCHEMA, + "output": PERMISSION_CHECK_OUTPUT_SCHEMA, + }, + "permission.get_admins": { + "input": PERMISSION_GET_ADMINS_INPUT_SCHEMA, + "output": PERMISSION_GET_ADMINS_OUTPUT_SCHEMA, + }, + "permission.manager.add_admin": { + "input": PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA, + "output": PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA, + }, + "permission.manager.remove_admin": { + "input": PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA, + "output": PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA, + }, + "llm_tool.manager.get": { + "input": LLM_TOOL_MANAGER_GET_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA, + }, + "llm_tool.manager.activate": { + "input": LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA, + }, + "llm_tool.manager.deactivate": { + "input": LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA, + }, + "llm_tool.manager.add": { + "input": LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA, + }, + "llm_tool.manager.remove": { + "input": LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA, + }, + "agent.tool_loop.run": { + "input": AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA, + "output": AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA, + }, + "agent.registry.list": { + "input": AGENT_REGISTRY_LIST_INPUT_SCHEMA, + "output": AGENT_REGISTRY_LIST_OUTPUT_SCHEMA, + }, + "agent.registry.get": { + "input": AGENT_REGISTRY_GET_INPUT_SCHEMA, + "output": AGENT_REGISTRY_GET_OUTPUT_SCHEMA, + }, + "system.get_data_dir": { + "input": SYSTEM_GET_DATA_DIR_INPUT_SCHEMA, + "output": SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA, + }, + "system.text_to_image": { + "input": SYSTEM_TEXT_TO_IMAGE_INPUT_SCHEMA, + "output": SYSTEM_TEXT_TO_IMAGE_OUTPUT_SCHEMA, + }, + "system.html_render": { + "input": SYSTEM_HTML_RENDER_INPUT_SCHEMA, + "output": SYSTEM_HTML_RENDER_OUTPUT_SCHEMA, + }, + "system.file.register": { + "input": SYSTEM_FILE_REGISTER_INPUT_SCHEMA, + "output": SYSTEM_FILE_REGISTER_OUTPUT_SCHEMA, + }, + "system.file.handle": { + "input": SYSTEM_FILE_HANDLE_INPUT_SCHEMA, + "output": SYSTEM_FILE_HANDLE_OUTPUT_SCHEMA, + }, + "system.session_waiter.register": { + "input": SYSTEM_SESSION_WAITER_REGISTER_INPUT_SCHEMA, + "output": SYSTEM_SESSION_WAITER_REGISTER_OUTPUT_SCHEMA, + }, + "system.session_waiter.unregister": { + "input": SYSTEM_SESSION_WAITER_UNREGISTER_INPUT_SCHEMA, + "output": SYSTEM_SESSION_WAITER_UNREGISTER_OUTPUT_SCHEMA, + }, + "system.event.react": { + "input": SYSTEM_EVENT_REACT_INPUT_SCHEMA, + "output": SYSTEM_EVENT_REACT_OUTPUT_SCHEMA, + }, + "system.event.send_typing": { + "input": SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA, + }, + "system.event.send_streaming": { + "input": SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA, + }, + "system.event.send_streaming_chunk": { + "input": SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA, + }, + "system.event.send_streaming_close": { + "input": SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA, + }, + "system.event.llm.get_state": { + "input": SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA, + "output": SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA, + }, + "system.event.llm.request": { + "input": SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA, + "output": SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA, + }, + "system.event.result.get": { + "input": SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA, + }, + "system.event.result.set": { + "input": SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA, + }, + "system.event.result.clear": { + "input": SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA, + "output": SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA, + }, + "system.event.handler_whitelist.get": { + "input": SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA, + }, + "system.event.handler_whitelist.set": { + "input": SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA, + }, +} + + +__all__ = [ + "BUILTIN_CAPABILITY_SCHEMAS", + "DB_DELETE_INPUT_SCHEMA", + "DB_DELETE_OUTPUT_SCHEMA", + "DB_GET_INPUT_SCHEMA", + "DB_GET_MANY_INPUT_SCHEMA", + "DB_GET_MANY_OUTPUT_SCHEMA", + "DB_GET_OUTPUT_SCHEMA", + "DB_LIST_INPUT_SCHEMA", + "DB_LIST_OUTPUT_SCHEMA", + "DB_SET_INPUT_SCHEMA", + "DB_SET_MANY_INPUT_SCHEMA", + "DB_SET_MANY_OUTPUT_SCHEMA", + "DB_SET_OUTPUT_SCHEMA", + "DB_WATCH_INPUT_SCHEMA", + "DB_WATCH_OUTPUT_SCHEMA", + "HTTP_LIST_APIS_INPUT_SCHEMA", + "HTTP_LIST_APIS_OUTPUT_SCHEMA", + "HTTP_REGISTER_API_INPUT_SCHEMA", + "HTTP_REGISTER_API_OUTPUT_SCHEMA", + "HTTP_UNREGISTER_API_INPUT_SCHEMA", + "HTTP_UNREGISTER_API_OUTPUT_SCHEMA", + "JSONSchema", + "LLM_CHAT_INPUT_SCHEMA", + "LLM_CHAT_OUTPUT_SCHEMA", + "LLM_CHAT_RAW_INPUT_SCHEMA", + "LLM_CHAT_RAW_OUTPUT_SCHEMA", + "LLM_STREAM_CHAT_INPUT_SCHEMA", + "LLM_STREAM_CHAT_OUTPUT_SCHEMA", + "MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA", + "MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA", + "MEMORY_COUNT_INPUT_SCHEMA", + "MEMORY_COUNT_OUTPUT_SCHEMA", + "MEMORY_DELETE_INPUT_SCHEMA", + "MEMORY_DELETE_MANY_INPUT_SCHEMA", + "MEMORY_DELETE_MANY_OUTPUT_SCHEMA", + "MEMORY_DELETE_OUTPUT_SCHEMA", + "MEMORY_EXISTS_INPUT_SCHEMA", + "MEMORY_EXISTS_OUTPUT_SCHEMA", + "MEMORY_GET_INPUT_SCHEMA", + "MEMORY_GET_MANY_INPUT_SCHEMA", + "MEMORY_GET_MANY_OUTPUT_SCHEMA", + "MEMORY_GET_OUTPUT_SCHEMA", + "MEMORY_LIST_KEYS_INPUT_SCHEMA", + "MEMORY_LIST_KEYS_OUTPUT_SCHEMA", + "MEMORY_SAVE_INPUT_SCHEMA", + "MEMORY_SAVE_OUTPUT_SCHEMA", + "MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA", + "MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA", + "MEMORY_SEARCH_INPUT_SCHEMA", + "MEMORY_SEARCH_OUTPUT_SCHEMA", + "MEMORY_STATS_INPUT_SCHEMA", + "MEMORY_STATS_OUTPUT_SCHEMA", + "METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA", + "METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA", + "METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA", + "METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA", + "METADATA_GET_PLUGIN_INPUT_SCHEMA", + "METADATA_GET_PLUGIN_OUTPUT_SCHEMA", + "METADATA_LIST_PLUGINS_INPUT_SCHEMA", + "METADATA_LIST_PLUGINS_OUTPUT_SCHEMA", + "PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA", + "PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA", + "PROVIDER_GET_BY_ID_INPUT_SCHEMA", + "PROVIDER_GET_BY_ID_OUTPUT_SCHEMA", + "PROVIDER_GET_USING_INPUT_SCHEMA", + "PROVIDER_GET_USING_OUTPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_INPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA", + "PROVIDER_CHANGE_EVENT_SCHEMA", + "PROVIDER_LIST_ALL_INPUT_SCHEMA", + "PROVIDER_LIST_ALL_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_CREATE_INPUT_SCHEMA", + "PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_DELETE_INPUT_SCHEMA", + "PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA", + "PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA", + "PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA", + "PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_LOAD_INPUT_SCHEMA", + "PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_SET_INPUT_SCHEMA", + "PROVIDER_MANAGER_SET_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA", + "PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA", + "PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA", + "PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA", + "PROVIDER_META_SCHEMA", + "PROVIDER_RERANK_INPUT_SCHEMA", + "PROVIDER_RERANK_OUTPUT_SCHEMA", + "PROVIDER_RERANK_RESULT_SCHEMA", + "PROVIDER_STT_GET_TEXT_INPUT_SCHEMA", + "PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA", + "PROVIDER_TTS_AUDIO_CHUNK_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA", + "PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA", + "PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_GET_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA", + "LLM_TOOL_SPEC_SCHEMA", + "AGENT_REGISTRY_GET_INPUT_SCHEMA", + "AGENT_REGISTRY_GET_OUTPUT_SCHEMA", + "AGENT_REGISTRY_LIST_INPUT_SCHEMA", + "AGENT_REGISTRY_LIST_OUTPUT_SCHEMA", + "AGENT_SPEC_SCHEMA", + "AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA", + "AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA", + "MANAGED_PROVIDER_RECORD_SCHEMA", + "PLATFORM_ERROR_SCHEMA", + "PLATFORM_GET_MEMBERS_INPUT_SCHEMA", + "PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA", + "PLATFORM_GET_GROUP_INPUT_SCHEMA", + "PLATFORM_GET_GROUP_OUTPUT_SCHEMA", + "PLATFORM_INSTANCE_SCHEMA", + "PLATFORM_LIST_INSTANCES_INPUT_SCHEMA", + "PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA", + "PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA", + "PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA", + "PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_STATE_SCHEMA", + "PERMISSION_CHECK_INPUT_SCHEMA", + "PERMISSION_CHECK_OUTPUT_SCHEMA", + "PERMISSION_CHECK_RESULT_SCHEMA", + "PERMISSION_GET_ADMINS_INPUT_SCHEMA", + "PERMISSION_GET_ADMINS_OUTPUT_SCHEMA", + "PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA", + "PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA", + "PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA", + "PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA", + "PERMISSION_ROLE_SCHEMA", + "PLATFORM_SEND_CHAIN_INPUT_SCHEMA", + "PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA", + "PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA", + "PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA", + "PLATFORM_SEND_IMAGE_INPUT_SCHEMA", + "PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA", + "PLATFORM_SEND_INPUT_SCHEMA", + "PLATFORM_SEND_OUTPUT_SCHEMA", + "PLATFORM_STATS_SCHEMA", + "PERSONA_CREATE_INPUT_SCHEMA", + "PERSONA_CREATE_OUTPUT_SCHEMA", + "PERSONA_CREATE_SCHEMA", + "PERSONA_DELETE_INPUT_SCHEMA", + "PERSONA_DELETE_OUTPUT_SCHEMA", + "PERSONA_GET_INPUT_SCHEMA", + "PERSONA_GET_OUTPUT_SCHEMA", + "PERSONA_LIST_INPUT_SCHEMA", + "PERSONA_LIST_OUTPUT_SCHEMA", + "PERSONA_RECORD_SCHEMA", + "PERSONA_UPDATE_INPUT_SCHEMA", + "PERSONA_UPDATE_OUTPUT_SCHEMA", + "PERSONA_UPDATE_SCHEMA", + "CONVERSATION_CREATE_SCHEMA", + "CONVERSATION_DELETE_INPUT_SCHEMA", + "CONVERSATION_DELETE_OUTPUT_SCHEMA", + "CONVERSATION_GET_CURRENT_INPUT_SCHEMA", + "CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA", + "CONVERSATION_GET_INPUT_SCHEMA", + "CONVERSATION_GET_OUTPUT_SCHEMA", + "CONVERSATION_LIST_INPUT_SCHEMA", + "CONVERSATION_LIST_OUTPUT_SCHEMA", + "CONVERSATION_NEW_INPUT_SCHEMA", + "CONVERSATION_NEW_OUTPUT_SCHEMA", + "CONVERSATION_RECORD_SCHEMA", + "CONVERSATION_SWITCH_INPUT_SCHEMA", + "CONVERSATION_SWITCH_OUTPUT_SCHEMA", + "CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA", + "CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA", + "CONVERSATION_UPDATE_INPUT_SCHEMA", + "CONVERSATION_UPDATE_OUTPUT_SCHEMA", + "CONVERSATION_UPDATE_SCHEMA", + "MESSAGE_HISTORY_APPEND_INPUT_SCHEMA", + "MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA", + "MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_LIST_INPUT_SCHEMA", + "MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_PAGE_SCHEMA", + "MESSAGE_HISTORY_RECORD_SCHEMA", + "MESSAGE_HISTORY_SENDER_SCHEMA", + "MESSAGE_HISTORY_SESSION_SCHEMA", + "KB_CREATE_INPUT_SCHEMA", + "KB_CREATE_OUTPUT_SCHEMA", + "KB_DOCUMENT_DELETE_INPUT_SCHEMA", + "KB_DOCUMENT_DELETE_OUTPUT_SCHEMA", + "KB_DOCUMENT_GET_INPUT_SCHEMA", + "KB_DOCUMENT_GET_OUTPUT_SCHEMA", + "KB_DOCUMENT_LIST_INPUT_SCHEMA", + "KB_DOCUMENT_LIST_OUTPUT_SCHEMA", + "KB_DOCUMENT_REFRESH_INPUT_SCHEMA", + "KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA", + "KB_DOCUMENT_UPLOAD_INPUT_SCHEMA", + "KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA", + "KB_DELETE_INPUT_SCHEMA", + "KB_DELETE_OUTPUT_SCHEMA", + "KB_GET_INPUT_SCHEMA", + "KB_GET_OUTPUT_SCHEMA", + "KB_LIST_INPUT_SCHEMA", + "KB_LIST_OUTPUT_SCHEMA", + "KB_RETRIEVE_INPUT_SCHEMA", + "KB_RETRIEVE_OUTPUT_SCHEMA", + "KB_UPDATE_INPUT_SCHEMA", + "KB_UPDATE_OUTPUT_SCHEMA", + "KNOWLEDGE_BASE_CREATE_SCHEMA", + "KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA", + "KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA", + "KNOWLEDGE_BASE_RECORD_SCHEMA", + "KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA", + "KNOWLEDGE_BASE_UPDATE_SCHEMA", + "REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA", + "REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA", + "SKILL_REGISTER_INPUT_SCHEMA", + "SKILL_REGISTER_OUTPUT_SCHEMA", + "SKILL_UNREGISTER_INPUT_SCHEMA", + "SKILL_UNREGISTER_OUTPUT_SCHEMA", + "SKILL_LIST_INPUT_SCHEMA", + "SKILL_LIST_OUTPUT_SCHEMA", + "REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA", + "REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA", + "REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA", + "REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA", + "SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA", + "SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA", + "SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA", + "SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA", + "SESSION_REF_SCHEMA", + "SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA", + "SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA", + "SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA", + "SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA", + "SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA", + "SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA", + "SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA", + "SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA", + "SYSTEM_EVENT_REACT_INPUT_SCHEMA", + "SYSTEM_EVENT_REACT_OUTPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA", + "SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA", + "SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA", + "SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA", + "SYSTEM_FILE_HANDLE_INPUT_SCHEMA", + "SYSTEM_FILE_HANDLE_OUTPUT_SCHEMA", + "SYSTEM_FILE_REGISTER_INPUT_SCHEMA", + "SYSTEM_FILE_REGISTER_OUTPUT_SCHEMA", +] diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py b/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py new file mode 100644 index 0000000000..abe8b92b2d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py @@ -0,0 +1,413 @@ +"""s5r 协议描述符模型。 + +`protocol` 是 s5r 新引入的协议层抽象,不对应旧树(圣诞树)中的一个同名目录。这里 +定义的是跨进程握手和调度时使用的声明式元数据,而不是运行时的具体处理器/ +能力实现。 +""" + +from __future__ import annotations + +from typing import Annotated, Any, Literal + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from . import _builtin_schemas +from ._builtin_schemas import * # noqa: F403 + +JSONSchema = _builtin_schemas.JSONSchema +RESERVED_CAPABILITY_NAMESPACES = ("handler", "system", "internal") +RESERVED_CAPABILITY_PREFIXES = tuple( + f"{namespace}." for namespace in RESERVED_CAPABILITY_NAMESPACES +) +BUILTIN_CAPABILITY_SCHEMAS = _builtin_schemas.BUILTIN_CAPABILITY_SCHEMAS +_BUILTIN_SCHEMA_EXPORTS = frozenset(_builtin_schemas.__all__) + + +def __getattr__(name: str) -> Any: + if name in _BUILTIN_SCHEMA_EXPORTS: + return getattr(_builtin_schemas, name) + raise AttributeError(name) + + +def __dir__() -> list[str]: + return sorted(set(globals()) | _BUILTIN_SCHEMA_EXPORTS) + + +class _DescriptorBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class Permissions(_DescriptorBase): + """权限配置,控制处理器的访问权限。 + + Attributes: + require_admin: 是否需要管理员权限 + required_role: 处理器要求的最小角色,v1 支持 member/admin + level: 权限等级,数值越高权限越大 + """ + + require_admin: bool = False + required_role: Literal["member", "admin"] | None = None + level: int = 0 + + @model_validator(mode="after") + def normalize_required_role(self) -> Permissions: + if self.require_admin: + if self.required_role not in {None, "admin"}: + raise ValueError( + "permissions.require_admin=True conflicts with required_role=" + f"{self.required_role!r}" + ) + self.required_role = "admin" + return self + if self.required_role == "admin": + self.require_admin = True + return self + + +class SessionRef(_DescriptorBase): + """结构化会话目标。 + + s5r 运行时内部仍然保留 legacy `session` 字符串作为最低兼容层, + 但对外模型允许同时携带平台与原始寻址信息,避免平台发送接口长期 + 只依赖一个不透明字符串。 + """ + + conversation_id: str = Field( + validation_alias=AliasChoices("conversation_id", "session"), + ) + platform: str | None = None + raw: dict[str, Any] | None = None + + @property + def session(self) -> str: + return self.conversation_id + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class CommandTrigger(_DescriptorBase): + """命令触发器,响应特定命令。 + + Attributes: + type: 触发器类型,固定为 "command" + command: 命令名称(不含前缀,如 "help") + aliases: 命令别名列表 + description: 命令描述,用于帮助文档 + platforms: 允许的平台列表,为空表示所有平台 + message_types: 限定的消息类型列表,为空表示不限 + """ + + type: Literal["command"] = "command" + command: str + aliases: list[str] = Field(default_factory=list) + description: str | None = None + platforms: list[str] = Field(default_factory=list) + message_types: list[str] = Field(default_factory=list) + + +class MessageTrigger(_DescriptorBase): + """消息触发器,描述消息类处理器的订阅条件。 + + Attributes: + type: 触发器类型,固定为 "message" + regex: 正则表达式模式,匹配消息文本 + keywords: 关键词列表,消息包含任一关键词即触发 + platforms: 目标平台列表,为空表示所有平台 + message_types: 限定的消息类型列表,为空表示不限 + + Note: + `regex` 和 `keywords` 可以同时为空,此时表示 "任意消息均可触发", + 仅由平台过滤或上层运行时进一步筛选。 + """ + + type: Literal["message"] = "message" + regex: str | None = None + keywords: list[str] = Field(default_factory=list) + platforms: list[str] = Field(default_factory=list) + message_types: list[str] = Field(default_factory=list) + + +class EventTrigger(_DescriptorBase): + """事件触发器,响应特定类型的事件。 + + Attributes: + type: 触发器类型,固定为 "event" + event_type: 事件类型,字符串形式(如 "message"、"notice") + """ + + type: Literal["event"] = "event" + event_type: str + + +class ScheduleTrigger(_DescriptorBase): + """定时触发器,按 cron 表达式或固定间隔执行。 + + Attributes: + type: 触发器类型,固定为 "schedule" + name: 调度任务名称,默认回退为插件 ID 与 handler ID 组合 + cron: cron 表达式(如 "0 9 * * *" 表示每天 9 点) + interval_seconds: 执行间隔(秒) + timezone: IANA 时区名称(如 "Asia/Shanghai") + + Note: + cron 和 interval_seconds 必须且只能有一个非空。 + """ + + type: Literal["schedule"] = "schedule" + name: str | None = None + cron: str | None = Field( + default=None, + validation_alias=AliasChoices("cron", "schedule"), + ) + interval_seconds: int | None = None + timezone: str | None = None + + @property + def schedule(self) -> str | None: + return self.cron + + @model_validator(mode="after") + def validate_schedule(self) -> ScheduleTrigger: + has_cron = self.cron is not None + has_interval = self.interval_seconds is not None + if has_cron == has_interval: + raise ValueError("cron 和 interval_seconds 必须且只能有一个非 null") + return self + + +class PlatformFilterSpec(_DescriptorBase): + kind: Literal["platform"] = "platform" + platforms: list[str] = Field(default_factory=list) + + +class MessageTypeFilterSpec(_DescriptorBase): + kind: Literal["message_type"] = "message_type" + message_types: list[str] = Field(default_factory=list) + + +class LocalFilterRefSpec(_DescriptorBase): + kind: Literal["local"] = "local" + filter_id: str + args: dict[str, Any] = Field(default_factory=dict) + + +class CompositeFilterSpec(_DescriptorBase): + kind: Literal["and", "or"] + children: list[FilterSpec] = Field(default_factory=list) + + +FilterSpec = Annotated[ + PlatformFilterSpec + | MessageTypeFilterSpec + | LocalFilterRefSpec + | CompositeFilterSpec, + Field(discriminator="kind"), +] + + +class ParamSpec(_DescriptorBase): + name: str + type: Literal["str", "int", "float", "bool", "optional", "greedy_str"] + required: bool = True + inner_type: Literal["str", "int", "float", "bool"] | None = None + + +class CommandRouteSpec(_DescriptorBase): + group_path: list[str] = Field(default_factory=list) + display_command: str + group_help: str | None = None + + +CompositeFilterSpec.model_rebuild() + + +Trigger = Annotated[ + CommandTrigger | MessageTrigger | EventTrigger | ScheduleTrigger, + Field(discriminator="type"), +] +"""触发器联合类型,使用 type 字段作为判别器自动解析具体类型。""" + + +class HandlerDescriptor(_DescriptorBase): + """处理器描述符,描述一个事件处理函数的元信息。 + + Attributes: + id: 处理器唯一标识,通常是 "模块.函数名" 格式 + trigger: 触发器配置,决定何时执行该处理器 + kind: 处理器类别,默认普通 handler + contract: 运行时契约名,描述入参/执行语义 + priority: 优先级,数值越大越先执行 + permissions: 权限配置,控制谁可以触发该处理器 + + 使用场景: + HandlerDescriptor 通常由 `@on_command`、`@on_message` 等装饰器自动创建, + 插件作者一般不需要手动实例化。但了解其结构有助于理解插件注册机制。 + + 触发器类型: + - CommandTrigger: 响应特定命令,如 `/help` + - MessageTrigger: 响应消息(正则/关键词匹配) + - EventTrigger: 响应特定事件类型 + - ScheduleTrigger: 定时触发 + + 示例: + 插件作者通常通过装饰器声明处理器,框架会自动生成 HandlerDescriptor: + + ```python + from astrbot_sdk.decorators import on_command, on_message + + # 命令处理器 + @on_command("hello") + async def hello_handler(ctx: Context): + await ctx.reply("Hello!") + + # 消息处理器(正则匹配) + @on_message(regex=r"^test\\s+(.+)$") + async def test_handler(ctx: Context): + await ctx.reply(f"收到: {ctx.match.group(1)}") + ``` + + See Also: + Trigger: 触发器联合类型 + Permissions: 权限配置 + """ + + id: str + trigger: Trigger + kind: Literal["handler", "hook", "tool", "session"] = "handler" + contract: str | None = None + description: str | None = None + priority: int = 0 + permissions: Permissions = Field(default_factory=Permissions) + filters: list[FilterSpec] = Field(default_factory=list) + param_specs: list[ParamSpec] = Field(default_factory=list) + command_route: CommandRouteSpec | None = None + + @model_validator(mode="after") + def validate_contract_defaults(self) -> HandlerDescriptor: + if self.contract is None: + if isinstance(self.trigger, ScheduleTrigger): + self.contract = "schedule" + else: + self.contract = "message_event" + return self + + +class CapabilityDescriptor(_DescriptorBase): + """能力描述符,描述一个可调用的远程能力。 + + 能力命名规范: + - 使用 "namespace.action" 格式,如 "llm.chat"、"db.set" + - 支持多级命名空间,如 "llm_tool.manager.activate" + - 内置能力以 "internal." 开头,如 "internal.legacy.call_context_function" + + 保留命名空间(插件不可使用): + - `handler.` - 处理器相关 + - `system.` - 系统内部能力 + - `internal.` - 内部实现细节 + + Attributes: + name: 能力名称,格式为 "namespace.action" + description: 能力描述,用于文档和调试 + input_schema: 输入参数的 JSON Schema,用于验证 + output_schema: 输出结果的 JSON Schema,用于验证 + supports_stream: 是否支持流式响应 + cancelable: 是否支持取消 + + 使用场景: + 当你的插件需要**暴露**一个可被其他插件调用的能力时,使用此类声明。 + + 示例: + ```python + from astrbot_sdk.protocol import CapabilityDescriptor + + # 声明一个翻译能力 + translate_desc = CapabilityDescriptor( + name="my_plugin.translate", + description="翻译文本到指定语言", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "要翻译的文本"}, + "target_lang": {"type": "string", "description": "目标语言"}, + }, + "required": ["text", "target_lang"], + }, + output_schema={ + "type": "object", + "properties": { + "translated": {"type": "string"}, + }, + }, + ) + + # 声明一个流式数据能力 + stream_desc = CapabilityDescriptor( + name="my_plugin.stream_data", + description="流式返回数据", + supports_stream=True, + cancelable=True, + input_schema={"type": "object", "properties": {"count": {"type": "integer"}}}, + output_schema={"type": "object", "properties": {"items": {"type": "array"}}}, + ) + ``` + + 注意: + 如果你要调用**内置能力**(如 `llm.chat`、`db.set`),不需要手动创建 + CapabilityDescriptor,而是直接通过 `Context.invoke()` 调用,或查阅 + `BUILTIN_CAPABILITY_SCHEMAS` 了解参数格式。 + + See Also: + BUILTIN_CAPABILITY_SCHEMAS: 内置能力的 schema 定义,用于查询参数格式 + """ + + name: str + description: str + input_schema: JSONSchema | None = None + output_schema: JSONSchema | None = None + supports_stream: bool = False + cancelable: bool = False + + @model_validator(mode="after") + def validate_builtin_schema_governance(self) -> CapabilityDescriptor: + builtin_schema = BUILTIN_CAPABILITY_SCHEMAS.get(self.name) + if builtin_schema is None: + return self + if self.input_schema is None or self.output_schema is None: + raise ValueError( + f"内建 capability {self.name} 必须同时提供 input_schema 和 output_schema" + ) + if ( + self.input_schema != builtin_schema["input"] + or self.output_schema != builtin_schema["output"] + ): + raise ValueError( + f"内建 capability {self.name} 的 schema 必须与协议注册表保持一致" + ) + return self + + +__all__ = [ + "Trigger", + "BUILTIN_CAPABILITY_SCHEMAS", + "CapabilityDescriptor", + "CommandRouteSpec", + "CommandTrigger", + "CompositeFilterSpec", + "EventTrigger", + "FilterSpec", + "HandlerDescriptor", + "JSONSchema", + "LocalFilterRefSpec", + "MessageTrigger", + "MessageTypeFilterSpec", + "ParamSpec", + "Permissions", + "PlatformFilterSpec", + "RESERVED_CAPABILITY_NAMESPACES", + "RESERVED_CAPABILITY_PREFIXES", + "ScheduleTrigger", + "SessionRef", +] +__all__ += list(_BUILTIN_SCHEMA_EXPORTS) diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/messages.py b/astrbot-sdk/src/astrbot_sdk/protocol/messages.py new file mode 100644 index 0000000000..c249bf16bd --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/messages.py @@ -0,0 +1,323 @@ +"""s5r 协议消息模型。 + +这些模型描述的是 `Peer` 与 `Peer` 之间的线协议。握手阶段通过 +`InitializeMessage` 发起,再由 `ResultMessage(kind="initialize_result")` +返回 `InitializeOutput`;能力调用阶段则使用 `InvokeMessage` / `ResultMessage` +或 `EventMessage` 序列。 + +TODO: Batch Invoke(协议 v1.1 候选特性) +========================================== + +设计概要: + 新增 BatchInvokeMessage / BatchResultMessage,将多个独立非流式调用 + 打包为单次 IPC 传输,减少序列化和 I/O syscall 开销。 + +约束: + - 只支持非流式子调用(stream=false) + - 结果保序返回,但服务端内部可 asyncio.gather 并发处理 + - 单个子调用失败不拖垮整个 batch,各自返回独立的 success/error + - 仅协议级错误(空 calls、重复 id、子项带 stream=true)整体失败 + - 取消只到 batch 粒度:取消 batch ID → 取消全部未完成子调用 + +改动范围: + - messages.py : 加 BatchInvokeMessage / BatchResultMessage + - peer.py : 加 invoke_batch() 和 _handle_batch_invoke() + - clients/_proxy.py : 加 call_batch() + - transport.py : 不动(batch 仍然是一行 JSON) + +暂不实现的原因(2026-03-28): + 1. SDK 集成(feat/sdk-integration)尚在主干开发期,协议层应保持简单稳定 + 2. 现有 pipelining(asyncio.gather + 多行 InvokeMessage)已覆盖并发场景, + 单次 stdio IPC 延迟在微秒级,实测中不构成瓶颈 + 3. peer.py 已 776 行,是协议栈核心文件,batch 会引入子调用生命周期管理、 + 超时聚合等额外复杂度 + 4. 目前无真实插件在单次 handler 中发出 10+ 独立 capability 调用, + 缺乏可测量的性能收益数据 + +触发条件(何时重新评估): + - 有插件在单次 handler 中 gather 10+ 独立 capability 调用 + - IPC 序列化/解析耗时经 profile 确认占总延迟 >5% + - 需要 WebSocket 传输场景下的带宽优化 +""" + +from __future__ import annotations + +import json +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from .descriptors import CapabilityDescriptor, HandlerDescriptor + + +class _MessageBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class ErrorPayload(_MessageBase): + """错误载荷,用于 ResultMessage 和 EventMessage 中传递错误信息。 + + Attributes: + code: 错误码,字符串类型,便于语义化错误分类 + message: 错误消息,人类可读的错误描述 + hint: 错误提示,可选的解决方案或建议 + retryable: 是否可重试,标识该错误是否可通过重试解决 + docs_url: 可选的文档链接,帮助调用方定位更多说明 + details: 可选的结构化细节,便于调试和日志展示 + """ + + code: str + message: str + hint: str = "" + retryable: bool = False + docs_url: str = "" + details: dict[str, Any] | None = None + + +class PeerInfo(_MessageBase): + """对等节点信息,标识消息发送方的身份。 + + Attributes: + name: 节点名称,通常是插件 ID 或核心标识 + role: 节点角色,"plugin" 或 "core" + version: 节点版本号,可选 + """ + + name: str + role: Literal["plugin", "core"] + version: str | None = None + + +class InitializeMessage(_MessageBase): + """初始化消息,用于建立连接时交换信息。 + + Attributes: + type: 消息类型,固定为 "initialize" + id: 消息 ID,用于关联响应 + protocol_version: 协议版本号 + peer: 发送方节点信息 + handlers: 注册的处理器描述符列表 + provided_capabilities: 发送方对外暴露的能力描述符列表 + metadata: 扩展元数据,可存储插件配置等信息 + """ + + type: Literal["initialize"] = "initialize" + id: str + protocol_version: str + peer: PeerInfo + handlers: list[HandlerDescriptor] = Field(default_factory=list) + provided_capabilities: list[CapabilityDescriptor] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class InitializeOutput(_MessageBase): + """初始化输出,作为 InitializeMessage 的响应数据。 + + Attributes: + peer: 接收方(核心)节点信息 + protocol_version: 协商后的协议版本;未协商时可为空 + capabilities: 核心提供的能力描述符列表 + metadata: 扩展元数据 + """ + + peer: PeerInfo + protocol_version: str | None = None + capabilities: list[CapabilityDescriptor] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ResultMessage(_MessageBase): + """结果消息,用于返回能力调用的结果。 + + Attributes: + type: 消息类型,固定为 "result" + id: 关联的请求 ID + kind: 结果类型,可选,如 "initialize_result" 标识初始化结果 + success: 是否成功 + output: 成功时的输出数据 + error: 失败时的错误信息 + """ + + type: Literal["result"] = "result" + id: str + kind: str | None = None + success: bool + output: dict[str, Any] = Field(default_factory=dict) + error: ErrorPayload | None = None + + @model_validator(mode="after") + def validate_result_state(self) -> ResultMessage: + """约束 success / output / error 的组合状态。""" + if self.success: + if self.error is not None: + raise ValueError("success=true 时 error 必须为空") + return self + if self.error is None: + raise ValueError("success=false 时必须提供 error") + if self.output: + raise ValueError("success=false 时 output 必须为空") + return self + + +class InvokeMessage(_MessageBase): + """调用消息,用于请求执行远程能力。 + + Attributes: + type: 消息类型,固定为 "invoke" + id: 请求 ID,用于关联响应 + capability: 目标能力名称,格式为 "namespace.action" + input: 调用输入参数 + stream: 是否期望流式响应,若为 True 将收到 EventMessage 序列 + caller_plugin_id: 运行时透传的调用方插件 ID,不属于业务 payload + """ + + type: Literal["invoke"] = "invoke" + id: str + capability: str + input: dict[str, Any] = Field(default_factory=dict) + stream: bool = False + caller_plugin_id: str | None = None + + +class EventMessage(_MessageBase): + """事件消息,用于流式调用的状态通知。 + + 流式调用生命周期: + 1. started: 调用开始,所有字段为空 + 2. delta: 数据增量更新,包含 data 字段 + 3. completed: 调用完成,包含 output 字段 + 4. failed: 调用失败,包含 error 字段 + + Attributes: + type: 消息类型,固定为 "event" + id: 关联的请求 ID + phase: 事件阶段,started/delta/completed/failed + data: 增量数据,仅 delta 阶段有效 + output: 最终输出,仅 completed 阶段有效 + error: 错误信息,仅 failed 阶段有效 + """ + + type: Literal["event"] = "event" + id: str + phase: Literal["started", "delta", "completed", "failed"] + data: dict[str, Any] = Field(default_factory=dict) + output: dict[str, Any] = Field(default_factory=dict) + error: ErrorPayload | None = None + + @model_validator(mode="after") + def validate_phase_constraints(self) -> EventMessage: + """验证各 phase 的字段约束。 + + - started: 所有字段必须为空 + - delta: 必须有 data,output/error 必须为空 + - completed: 必须有 output,data/error 必须为空 + - failed: 必须有 error,data/output 必须为空 + """ + phase = self.phase + if phase == "started": + if self.data or self.output or self.error: + raise ValueError("started phase 必须所有字段为空") + elif phase == "delta": + if not self.data: + raise ValueError("delta phase 需要 data") + if self.output or self.error: + raise ValueError("delta phase 的 output/error 必须为空") + elif phase == "completed": + if not self.output: + raise ValueError("completed phase 需要 output") + if self.data or self.error: + raise ValueError("completed phase 的 data/error 必须为空") + elif phase == "failed": + if self.error is None: + raise ValueError("failed phase 需要 error") + if self.data or self.output: + raise ValueError("failed phase 的 data/output 必须为空") + return self + + +class CancelMessage(_MessageBase): + """取消消息,用于取消正在进行的调用。 + + Attributes: + type: 消息类型,固定为 "cancel" + id: 要取消的请求 ID + reason: 取消原因,默认为 "user_cancelled" + """ + + type: Literal["cancel"] = "cancel" + id: str + reason: str = "user_cancelled" + + +ProtocolMessage = ( + InitializeMessage | ResultMessage | InvokeMessage | EventMessage | CancelMessage +) +"""协议消息联合类型,所有有效消息类型的联合。""" + +_PROTOCOL_MESSAGE_MODELS = { + "initialize": InitializeMessage, + "result": ResultMessage, + "invoke": InvokeMessage, + "event": EventMessage, + "cancel": CancelMessage, +} + + +def parse_message( + payload: ProtocolMessage | str | bytes | dict[str, Any], +) -> ProtocolMessage: + """解析协议消息。 + + 从原始载荷(字符串、字节或字典)解析为对应的 ProtocolMessage 类型。 + 根据 "type" 字段自动识别消息类型并验证。 + + Args: + payload: 原始消息载荷,支持已解析模型、JSON 字符串、字节或字典 + + Returns: + 解析后的协议消息对象 + + Raises: + ValueError: 未知的消息类型 + + Example: + >>> msg = parse_message('{"type": "invoke", "id": "1", "capability": "test"}') + >>> isinstance(msg, InvokeMessage) + True + """ + if isinstance( + payload, + ( + InitializeMessage, + ResultMessage, + InvokeMessage, + EventMessage, + CancelMessage, + ), + ): + return payload + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + if isinstance(payload, str): + payload = json.loads(payload) + if not isinstance(payload, dict): + raise ValueError("协议消息必须是 JSON object") + message_type = payload.get("type") + model = _PROTOCOL_MESSAGE_MODELS.get(str(message_type)) + if model is not None: + return model.model_validate(payload) + raise ValueError(f"未知消息类型:{message_type}") + + +__all__ = [ + "CancelMessage", + "ErrorPayload", + "EventMessage", + "InitializeMessage", + "InitializeOutput", + "InvokeMessage", + "PeerInfo", + "ProtocolMessage", + "ResultMessage", + "parse_message", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py new file mode 100644 index 0000000000..7601f745c2 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py @@ -0,0 +1,63 @@ +"""AstrBot SDK runtime public exports. + +本模块提供运行时核心组件的公共导出,包括: +- CapabilityRouter: 能力路由器,处理能力调用的分发和路由 +- HandlerDispatcher: 事件处理器分发器,将事件分发到注册的 handler +- Peer: 与 AstrBot 核心通信的对等端抽象 +- Transport 系列: 进程间通信传输层实现(stdio/websocket) + +延迟加载策略: +为避免导入时触发 websocket/aiohttp 等重型依赖,采用 __getattr__ 实现按需加载。 +这样轻量级导入(如仅使用类型提示)不会产生不必要的依赖开销。 +""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .capability_router import CapabilityRouter, StreamExecution + from .handler_dispatcher import HandlerDispatcher + from .peer import Peer + from .transport import ( + MessageHandler, + StdioTransport, + Transport, + WebSocketClientTransport, + WebSocketServerTransport, + ) + +__all__ = [ + "CapabilityRouter", + "HandlerDispatcher", + "MessageHandler", + "Peer", + "StdioTransport", + "StreamExecution", + "Transport", + "WebSocketClientTransport", + "WebSocketServerTransport", +] + + +def __getattr__(name: str) -> Any: + if name in {"CapabilityRouter", "StreamExecution"}: + module = import_module(".capability_router", __name__) + return getattr(module, name) + if name == "HandlerDispatcher": + module = import_module(".handler_dispatcher", __name__) + return getattr(module, name) + if name == "Peer": + module = import_module(".peer", __name__) + return getattr(module, name) + if name in { + "MessageHandler", + "StdioTransport", + "Transport", + "WebSocketClientTransport", + "WebSocketServerTransport", + }: + module = import_module(".transport", __name__) + return getattr(module, name) + raise AttributeError(name) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py new file mode 100644 index 0000000000..b0af66d417 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from .bridge_base import CapabilityRouterBridgeBase +from .capabilities import ( + ConversationCapabilityMixin, + DBCapabilityMixin, + HttpCapabilityMixin, + KnowledgeBaseCapabilityMixin, + LLMCapabilityMixin, + McpCapabilityMixin, + MemoryCapabilityMixin, + MessageHistoryCapabilityMixin, + MetadataCapabilityMixin, + PermissionCapabilityMixin, + PersonaCapabilityMixin, + PlatformCapabilityMixin, + ProviderCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + SystemCapabilityMixin, +) + + +class BuiltinCapabilityRouterMixin( + LLMCapabilityMixin, + MemoryCapabilityMixin, + DBCapabilityMixin, + PlatformCapabilityMixin, + HttpCapabilityMixin, + MetadataCapabilityMixin, + PermissionCapabilityMixin, + ProviderCapabilityMixin, + McpCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + PersonaCapabilityMixin, + ConversationCapabilityMixin, + MessageHistoryCapabilityMixin, + KnowledgeBaseCapabilityMixin, + SystemCapabilityMixin, + CapabilityRouterBridgeBase, +): + def _register_builtin_capabilities(self) -> None: + self._register_llm_capabilities() + self._register_memory_capabilities() + self._register_db_capabilities() + self._register_platform_capabilities() + self._register_http_capabilities() + self._register_metadata_capabilities() + self._register_permission_capabilities() + self._register_provider_capabilities() + self._register_agent_tool_capabilities() + self._register_mcp_capabilities() + self._register_session_capabilities() + self._register_skill_capabilities() + self._register_persona_capabilities() + self._register_conversation_capabilities() + self._register_message_history_capabilities() + self._register_kb_capabilities() + self._register_provider_manager_capabilities() + self._register_platform_manager_capabilities() + self._register_system_capabilities() + + +__all__ = ["BuiltinCapabilityRouterMixin"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py new file mode 100644 index 0000000000..6d31ba6f2c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Any + +from ...protocol.descriptors import CapabilityDescriptor + + +class CapabilityRouterHost: + memory_store: dict[str, dict[str, Any]] + _memory_backends: dict[str, Any] + _memory_index: dict[str, dict[str, Any]] + _memory_dirty_keys: set[str] + _memory_expires_at: dict[str, datetime | None] + db_store: dict[str, Any] + sent_messages: list[dict[str, Any]] + event_actions: list[dict[str, Any]] + http_api_store: list[dict[str, Any]] + _event_streams: dict[str, dict[str, Any]] + _plugins: dict[str, Any] + _request_overlays: dict[str, dict[str, Any]] + _provider_catalog: dict[str, list[dict[str, Any]]] + _provider_configs: dict[str, dict[str, Any]] + _active_provider_ids: dict[str, str | None] + _provider_change_subscriptions: dict[str, asyncio.Queue[dict[str, Any]]] + _system_data_root: Path + _session_waiters: dict[str, set[str]] + _session_plugin_configs: dict[str, dict[str, Any]] + _session_service_configs: dict[str, dict[str, Any]] + _db_watch_subscriptions: dict[str, tuple[str | None, asyncio.Queue[dict[str, Any]]]] + _dynamic_command_routes: dict[str, list[dict[str, Any]]] + _file_token_store: dict[str, str] + _platform_instances: list[dict[str, Any]] + _persona_store: dict[str, dict[str, Any]] + _conversation_store: dict[str, dict[str, Any]] + _session_current_conversation_ids: dict[str, str] + _kb_store: dict[str, dict[str, Any]] + _kb_document_store: dict[str, dict[str, dict[str, Any]]] + _kb_document_content_store: dict[str, str] + + def register( + self, + descriptor: CapabilityDescriptor, + *, + call_handler=None, + stream_handler=None, + finalize=None, + exposed: bool = True, + ) -> None: + raise NotImplementedError + + def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None: + raise NotImplementedError + + @staticmethod + def _require_caller_plugin_id(capability_name: str) -> str: + raise NotImplementedError + + @staticmethod + def _validated_plugin_id(plugin_id: str, *, capability_name: str) -> str: + raise NotImplementedError + + def _plugin_data_dir(self, plugin_id: str, *, capability_name: str) -> Path: + raise NotImplementedError + + def register_dynamic_command_route( + self, + *, + plugin_id: str, + command_name: str, + handler_full_name: str, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ) -> None: + raise NotImplementedError + + def get_platform_instances(self) -> list[dict[str, Any]]: + raise NotImplementedError + + @staticmethod + def _normalize_platform_name(value: Any) -> str: + raise NotImplementedError + + @classmethod + def _normalized_platform_names(cls, values: Any) -> set[str]: + raise NotImplementedError + + def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + raise NotImplementedError + + def _platform_name_from_id(self, platform_id: str) -> str: + raise NotImplementedError + + def _session_platform_name(self, session: str) -> str: + raise NotImplementedError + + def _require_platform_support_for_session( + self, + capability_name: str, + session: str, + ) -> str: + raise NotImplementedError + + def _register_agent_tool_capabilities(self) -> None: + raise NotImplementedError + + def _provider_entry( + self, + payload: dict[str, Any], + capability_name: str, + expected_kind: str | None = None, + ) -> dict[str, Any]: + raise NotImplementedError + + async def _provider_embedding_get_embedding( + self, request_id: str, payload: dict[str, Any], token + ) -> dict[str, Any]: + raise NotImplementedError + + async def _provider_embedding_get_embeddings( + self, request_id: str, payload: dict[str, Any], token + ) -> dict[str, Any]: + raise NotImplementedError diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py new file mode 100644 index 0000000000..f1e36516fe --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import copy +import hashlib +import math +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from ..._internal.plugin_ids import resolve_plugin_data_dir, validate_plugin_id +from ...errors import AstrBotError +from ...protocol.descriptors import ( + BUILTIN_CAPABILITY_SCHEMAS, + CapabilityDescriptor, + SessionRef, +) +from ._host import CapabilityRouterHost + + +def _clone_target_payload(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + return {str(key): item for key, item in value.items()} + + +def _clone_chain_payload(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [ + {str(key): item for key, item in chunk.items()} + for chunk in value + if isinstance(chunk, dict) + ] + + +_MOCK_EMBEDDING_DIM = 24 + + +def _embedding_terms(text: str) -> list[str]: + """Build stable tokens for the mock embedding implementation.""" + normalized = re.sub(r"\s+", " ", str(text).strip().casefold()) + compact = normalized.replace(" ", "") + if not normalized: + return [] + + terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word] + if compact: + if len(compact) == 1: + terms.append(compact) + else: + terms.extend( + compact[index : index + 2] for index in range(len(compact) - 1) + ) + terms.append(compact) + return terms or [normalized] + + +def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]: + """Generate a deterministic normalized mock embedding vector.""" + values = [0.0] * _MOCK_EMBEDDING_DIM + for term in _embedding_terms(text): + digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest() + index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM + values[index] += 1.0 + min(len(term), 8) * 0.05 + norm = math.sqrt(sum(value * value for value in values)) + if norm <= 0: + return values + return [value / norm for value in values] + + +class CapabilityRouterBridgeBase(CapabilityRouterHost): + _memory_backends: dict[str, Any] + + @staticmethod + def _normalize_platform_name(value: Any) -> str: + return str(value or "").strip().lower() + + @classmethod + def _normalized_platform_names(cls, values: Any) -> set[str]: + if not isinstance(values, list): + return set() + return { + cls._normalize_platform_name(item) + for item in values + if cls._normalize_platform_name(item) + } + + @staticmethod + def _validated_plugin_id(plugin_id: str, *, capability_name: str) -> str: + try: + return validate_plugin_id(plugin_id) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires a safe plugin_id: {exc}" + ) from exc + + def _plugin_data_dir(self, plugin_id: str, *, capability_name: str) -> Path: + try: + return resolve_plugin_data_dir(self._system_data_root, plugin_id) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires a safe plugin_id: {exc}" + ) from exc + + def _builtin_descriptor( + self, + name: str, + description: str, + *, + supports_stream: bool = False, + cancelable: bool = False, + ) -> CapabilityDescriptor: + schema = BUILTIN_CAPABILITY_SCHEMAS[name] + return CapabilityDescriptor( + name=name, + description=description, + input_schema=copy.deepcopy(schema["input"]), + output_schema=copy.deepcopy(schema["output"]), + supports_stream=supports_stream, + cancelable=cancelable, + ) + + def _resolve_target( + self, payload: dict[str, Any] + ) -> tuple[str, dict[str, Any] | None]: + target_payload = payload.get("target") + if isinstance(target_payload, dict): + target = SessionRef.model_validate(target_payload) + return target.session, target.to_payload() + return str(payload.get("session", "")), None + + @staticmethod + def _is_group_session(session: str) -> bool: + normalized = str(session).lower() + return ":group:" in normalized or ":groupmessage:" in normalized + + @staticmethod + def _mock_group_payload(session: str) -> dict[str, Any] | None: + if not CapabilityRouterBridgeBase._is_group_session(session): + return None + members = [ + { + "user_id": f"{session}:member-1", + "nickname": "Member 1", + "role": "member", + }, + { + "user_id": f"{session}:member-2", + "nickname": "Member 2", + "role": "admin", + }, + ] + return { + "group_id": session.rsplit(":", maxsplit=1)[-1], + "group_name": f"Mock Group {session.rsplit(':', maxsplit=1)[-1]}", + "group_avatar": "", + "group_owner": members[0]["user_id"], + "group_admins": [members[1]["user_id"]], + "members": members, + } + + def _session_plugin_config(self, session: str) -> dict[str, Any]: + config = self._session_plugin_configs.get(str(session), {}) + return dict(config) if isinstance(config, dict) else {} + + def _session_service_config(self, session: str) -> dict[str, Any]: + config = self._session_service_configs.get(str(session), {}) + return dict(config) if isinstance(config, dict) else {} + + @staticmethod + def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + @staticmethod + def _session_platform_id(session: str) -> str: + parts = str(session).split(":", maxsplit=1) + if parts and parts[0].strip(): + return parts[0].strip() + return "unknown" + + def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + normalized_platform = self._normalize_platform_name(platform_name) + if not normalized_platform: + return True + plugin = self._plugins.get(str(plugin_id)) + if plugin is None: + return True + metadata = getattr(plugin, "metadata", None) + if not isinstance(metadata, dict): + return True + supported = self._normalized_platform_names(metadata.get("support_platforms")) + if not supported: + return True + return normalized_platform in supported + + def _platform_name_from_id(self, platform_id: str) -> str: + normalized_platform_id = str(platform_id).strip() + if not normalized_platform_id: + return "" + for item in self.get_platform_instances(): + if not isinstance(item, dict): + continue + if str(item.get("id", "")).strip() != normalized_platform_id: + continue + return self._normalize_platform_name(item.get("type")) + return "" + + def _session_platform_name(self, session: str) -> str: + return self._platform_name_from_id(self._session_platform_id(session)) + + def _require_platform_support_for_session( + self, + capability_name: str, + session: str, + ) -> str: + plugin_id = self._require_caller_plugin_id(capability_name) + platform_name = self._session_platform_name(session) + if not platform_name or self._plugin_supports_platform( + plugin_id, platform_name + ): + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} does not support platform '{platform_name}' for plugin '{plugin_id}'" + ) + + @staticmethod + def _normalize_history_payload(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [dict(item) for item in value if isinstance(item, dict)] + + @staticmethod + def _normalize_persona_dialogs_payload(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item) for item in value if isinstance(item, str)] + + @staticmethod + def _optional_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py new file mode 100644 index 0000000000..1b765697d7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py @@ -0,0 +1,35 @@ +from .conversation import ConversationCapabilityMixin +from .db import DBCapabilityMixin +from .http import HttpCapabilityMixin +from .kb import KnowledgeBaseCapabilityMixin +from .llm import LLMCapabilityMixin +from .mcp import McpCapabilityMixin +from .memory import MemoryCapabilityMixin +from .message_history import MessageHistoryCapabilityMixin +from .metadata import MetadataCapabilityMixin +from .permission import PermissionCapabilityMixin +from .persona import PersonaCapabilityMixin +from .platform import PlatformCapabilityMixin +from .provider import ProviderCapabilityMixin +from .session import SessionCapabilityMixin +from .skill import SkillCapabilityMixin +from .system import SystemCapabilityMixin + +__all__ = [ + "ConversationCapabilityMixin", + "DBCapabilityMixin", + "HttpCapabilityMixin", + "KnowledgeBaseCapabilityMixin", + "LLMCapabilityMixin", + "McpCapabilityMixin", + "MemoryCapabilityMixin", + "MessageHistoryCapabilityMixin", + "MetadataCapabilityMixin", + "PermissionCapabilityMixin", + "PersonaCapabilityMixin", + "PlatformCapabilityMixin", + "ProviderCapabilityMixin", + "SessionCapabilityMixin", + "SkillCapabilityMixin", + "SystemCapabilityMixin", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py new file mode 100644 index 0000000000..a250f43e5a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class ConversationCapabilityMixin(CapabilityRouterBridgeBase): + async def _conversation_new( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + if not session: + raise AstrBotError.invalid_input("conversation.new requires session") + raw_conversation = payload.get("conversation") + if raw_conversation is None: + raw_conversation = {} + if not isinstance(raw_conversation, dict): + raise AstrBotError.invalid_input( + "conversation.new requires conversation object" + ) + conversation_id = uuid.uuid4().hex + now = self._now_iso() + record = { + "conversation_id": conversation_id, + "session": session, + "platform_id": ( + str(raw_conversation.get("platform_id")) + if raw_conversation.get("platform_id") is not None + else self._session_platform_id(session) + ), + "history": self._normalize_history_payload(raw_conversation.get("history")), + "title": ( + str(raw_conversation.get("title")) + if raw_conversation.get("title") is not None + else None + ), + "persona_id": ( + str(raw_conversation.get("persona_id")) + if raw_conversation.get("persona_id") is not None + else None + ), + "created_at": now, + "updated_at": now, + "token_usage": None, + } + self._conversation_store[conversation_id] = record + self._session_current_conversation_ids[session] = conversation_id + return {"conversation_id": conversation_id} + + async def _conversation_switch( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = str(payload.get("conversation_id", "")).strip() + record = self._conversation_store.get(conversation_id) + if record is None or str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.switch requires a conversation in the same session" + ) + self._session_current_conversation_ids[session] = conversation_id + return {} + + async def _conversation_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = payload.get("conversation_id") + normalized_conversation_id = ( + str(conversation_id).strip() if conversation_id is not None else "" + ) + if not normalized_conversation_id: + normalized_conversation_id = self._session_current_conversation_ids.get( + session, "" + ) + if not normalized_conversation_id: + return {} + record = self._conversation_store.get(normalized_conversation_id) + if record is None: + return {} + if str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.delete requires a conversation in the same session" + ) + del self._conversation_store[normalized_conversation_id] + current_conversation_id = self._session_current_conversation_ids.get(session) + if current_conversation_id == normalized_conversation_id: + replacement = next( + ( + conversation_id + for conversation_id, item in self._conversation_store.items() + if str(item.get("session", "")) == session + ), + None, + ) + if replacement is None: + self._session_current_conversation_ids.pop(session, None) + else: + self._session_current_conversation_ids[session] = replacement + return {} + + async def _conversation_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = str(payload.get("conversation_id", "")).strip() + record = self._conversation_store.get(conversation_id) + if record is None and bool(payload.get("create_if_not_exists", False)): + created = await self._conversation_new( + _request_id, + {"session": session, "conversation": {}}, + _token, + ) + record = self._conversation_store.get( + str(created.get("conversation_id", "")).strip() + ) + if record is None: + return {"conversation": None} + if str(record.get("session", "")) != session: + return {"conversation": None} + return {"conversation": dict(record)} + + async def _conversation_get_current( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = self._session_current_conversation_ids.get(session, "") + if not conversation_id and bool(payload.get("create_if_not_exists", False)): + created = await self._conversation_new( + _request_id, + {"session": session, "conversation": {}}, + _token, + ) + conversation_id = str(created.get("conversation_id", "")).strip() + if not conversation_id: + return {"conversation": None} + record = self._conversation_store.get(conversation_id) + if record is None or str(record.get("session", "")) != session: + return {"conversation": None} + return {"conversation": dict(record)} + + async def _conversation_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = payload.get("session") + platform_id = payload.get("platform_id") + conversations = [] + for conversation_id in sorted(self._conversation_store.keys()): + item = self._conversation_store[conversation_id] + if session is not None and str(item.get("session", "")) != str(session): + continue + if platform_id is not None and str(item.get("platform_id", "")) != str( + platform_id + ): + continue + conversations.append(dict(item)) + return {"conversations": conversations} + + async def _conversation_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = payload.get("conversation_id") + normalized_conversation_id = ( + str(conversation_id).strip() if conversation_id is not None else "" + ) + if not normalized_conversation_id: + normalized_conversation_id = self._session_current_conversation_ids.get( + session, "" + ) + if not normalized_conversation_id: + return {} + record = self._conversation_store.get(normalized_conversation_id) + if record is None: + return {} + if str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.update requires a conversation in the same session" + ) + raw_conversation = payload.get("conversation") + if not isinstance(raw_conversation, dict): + raw_conversation = {} + if "history" in raw_conversation: + history = raw_conversation.get("history") + record["history"] = ( + self._normalize_history_payload(history) if history is not None else [] + ) + if "title" in raw_conversation: + title = raw_conversation.get("title") + record["title"] = str(title) if title is not None else None + if "persona_id" in raw_conversation: + persona_id = raw_conversation.get("persona_id") + record["persona_id"] = str(persona_id) if persona_id is not None else None + if "token_usage" in raw_conversation: + token_usage = raw_conversation.get("token_usage") + record["token_usage"] = ( + int(token_usage) if token_usage is not None else None + ) + record["updated_at"] = self._now_iso() + return {} + + async def _conversation_unset_persona( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = payload.get("conversation_id") + normalized_conversation_id = ( + str(conversation_id).strip() if conversation_id is not None else "" + ) + if not normalized_conversation_id: + normalized_conversation_id = self._session_current_conversation_ids.get( + session, "" + ) + if not normalized_conversation_id: + return {} + record = self._conversation_store.get(normalized_conversation_id) + if record is None: + return {} + if str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.unset_persona requires a conversation in the same session" + ) + record["persona_id"] = None + record["updated_at"] = self._now_iso() + return {} + + def _register_conversation_capabilities(self) -> None: + self.register( + self._builtin_descriptor("conversation.new", "新建对话"), + call_handler=self._conversation_new, + ) + self.register( + self._builtin_descriptor("conversation.switch", "切换对话"), + call_handler=self._conversation_switch, + ) + self.register( + self._builtin_descriptor("conversation.delete", "删除对话"), + call_handler=self._conversation_delete, + ) + self.register( + self._builtin_descriptor("conversation.get", "获取对话"), + call_handler=self._conversation_get, + ) + self.register( + self._builtin_descriptor("conversation.get_current", "获取当前对话"), + call_handler=self._conversation_get_current, + ) + self.register( + self._builtin_descriptor("conversation.list", "列出对话"), + call_handler=self._conversation_list, + ) + self.register( + self._builtin_descriptor("conversation.update", "更新对话"), + call_handler=self._conversation_update, + ) + self.register( + self._builtin_descriptor("conversation.unset_persona", "清空对话人格"), + call_handler=self._conversation_unset_persona, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py new file mode 100644 index 0000000000..f8bdfedf9a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from ....errors import AstrBotError +from ..._streaming import StreamExecution +from ..bridge_base import CapabilityRouterBridgeBase + + +class DBCapabilityMixin(CapabilityRouterBridgeBase): + def _db_scoped_key(self, plugin_id: str, key: str) -> str: + """将用户提供的 key 加上插件命名空间前缀,防止跨插件越权访问。""" + return f"{plugin_id}:{key}" + + def _db_strip_scope(self, plugin_id: str, scoped_key: str) -> str: + """去掉命名空间前缀,返回插件视角的原始 key。""" + prefix = f"{plugin_id}:" + return ( + scoped_key[len(prefix) :] if scoped_key.startswith(prefix) else scoped_key + ) + + def _db_public_event( + self, plugin_id: str, raw_event: dict[str, Any] + ) -> dict[str, Any]: + """将内部事件转换回插件可见的 key 视图。""" + event = dict(raw_event) + key = event.get("key") + if isinstance(key, str): + event["key"] = self._db_strip_scope(plugin_id, key) + return event + + async def _db_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.get") + key = self._db_scoped_key(plugin_id, str(payload.get("key", ""))) + return {"value": self.db_store.get(key)} + + async def _db_set( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.set") + key = self._db_scoped_key(plugin_id, str(payload.get("key", ""))) + value = payload.get("value") + self.db_store[key] = value + self._emit_db_change(op="set", key=key, value=value) + return {} + + async def _db_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.delete") + key = self._db_scoped_key(plugin_id, str(payload.get("key", ""))) + self.db_store.pop(key, None) + self._emit_db_change(op="delete", key=key, value=None) + return {} + + async def _db_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.list") + ns_prefix = f"{plugin_id}:" + # 只列出属于当前插件命名空间的 key,并去掉命名空间前缀返回给插件 + user_prefix = payload.get("prefix") + all_keys = sorted( + key for key in self.db_store.keys() if key.startswith(ns_prefix) + ) + stripped = [self._db_strip_scope(plugin_id, k) for k in all_keys] + if isinstance(user_prefix, str): + stripped = [k for k in stripped if k.startswith(user_prefix)] + return {"keys": stripped} + + async def _db_get_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.get_many") + keys_payload = payload.get("keys") + if not isinstance(keys_payload, (list, tuple)): + raise AstrBotError.invalid_input("db.get_many 的 keys 必须是数组") + items = [ + { + "key": str(k), + "value": self.db_store.get(self._db_scoped_key(plugin_id, str(k))), + } + for k in keys_payload + ] + return {"items": items} + + async def _db_set_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.set_many") + items_payload = payload.get("items") + if not isinstance(items_payload, (list, tuple)): + raise AstrBotError.invalid_input("db.set_many 的 items 必须是数组") + for entry in items_payload: + if not isinstance(entry, dict): + raise AstrBotError.invalid_input( + "db.set_many 的 items 必须是 object 数组" + ) + key = self._db_scoped_key(plugin_id, str(entry.get("key", ""))) + value = entry.get("value") + self.db_store[key] = value + self._emit_db_change(op="set", key=key, value=value) + return {} + + async def _db_watch( + self, request_id: str, payload: dict[str, Any], _token + ) -> StreamExecution: + plugin_id = self._require_caller_plugin_id("db.watch") + prefix = payload.get("prefix") + prefix_value: str | None + if isinstance(prefix, str): + # 将用户传入的前缀也加上命名空间,只监听本插件的 key 变更 + prefix_value = self._db_scoped_key(plugin_id, prefix) + elif prefix is None: + # 无前缀时默认监听整个命名空间 + prefix_value = f"{plugin_id}:" + else: + raise AstrBotError.invalid_input("db.watch 的 prefix 必须是 string 或 null") + + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + self._db_watch_subscriptions[request_id] = (prefix_value, queue) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + yield self._db_public_event(plugin_id, await queue.get()) + finally: + self._db_watch_subscriptions.pop(request_id, None) + + return StreamExecution( + iterator=iterator(), + finalize=lambda _chunks: {}, + collect_chunks=False, + ) + + def _register_db_capabilities(self) -> None: + self.register( + self._builtin_descriptor("db.get", "读取 KV"), call_handler=self._db_get + ) + self.register( + self._builtin_descriptor("db.set", "写入 KV"), call_handler=self._db_set + ) + self.register( + self._builtin_descriptor("db.delete", "删除 KV"), + call_handler=self._db_delete, + ) + self.register( + self._builtin_descriptor("db.list", "列出 KV"), call_handler=self._db_list + ) + self.register( + self._builtin_descriptor("db.get_many", "批量读取 KV"), + call_handler=self._db_get_many, + ) + self.register( + self._builtin_descriptor("db.set_many", "批量写入 KV"), + call_handler=self._db_set_many, + ) + self.register( + self._builtin_descriptor( + "db.watch", + "订阅 KV 变更", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._db_watch, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py new file mode 100644 index 0000000000..d884c4d9cf --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import re +from typing import Any + +from ...._internal.plugin_ids import ( + capability_belongs_to_plugin, + http_route_belongs_to_plugin, + plugin_capability_prefix, + plugin_http_route_root, +) +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + +# 路由只允许字母、数字、/, -, _, . 以及路径参数 {param},且必须以 / 开头。 +# 参数段必须完整地形如 {param},同时禁止空段(例如连续斜杠)。 +_ROUTE_SEGMENT_RE = re.compile(r"^(?:[\w\-._]+|\{[\w\-._]+\})$") + + +def _validate_route(route: str, capability_name: str) -> None: + """校验 HTTP 路由路径格式,阻止路径遍历和非法字符。""" + if ".." in route: + raise AstrBotError.invalid_input(f"{capability_name}: 路由路径不允许包含 '..'") + if not route.startswith("/"): + raise AstrBotError.invalid_input( + f"{capability_name}: 路由路径格式非法,只允许字母/数字/-/_/./{{param}} 段," + "且必须以 / 开头,如 /foo/bar" + ) + if route == "/": + return + segments = route.split("/")[1:] + if any( + not segment or not _ROUTE_SEGMENT_RE.fullmatch(segment) for segment in segments + ): + raise AstrBotError.invalid_input( + f"{capability_name}: 路由路径格式非法,只允许字母/数字/-/_/./{{param}} 段," + "禁止连续斜杠,且必须以 / 开头,如 /foo/bar" + ) + + +def _validate_plugin_route_namespace(route: str, plugin_id: str) -> None: + if http_route_belongs_to_plugin(route, plugin_id): + return + route_root = plugin_http_route_root(plugin_id) + raise AstrBotError.invalid_input( + "http.register_api 要求 route 使用当前插件的公开命名空间前缀:" + f" route={route!r}, plugin_id={plugin_id!r}, expected={route_root!r} " + f"或 {route_root + '/...'}" + ) + + +def _validate_handler_capability_namespace( + handler_capability: str, + plugin_id: str, +) -> None: + if capability_belongs_to_plugin(handler_capability, plugin_id): + return + expected_prefix = plugin_capability_prefix(plugin_id) + raise AstrBotError.invalid_input( + "http.register_api 要求 handler_capability 属于当前插件:" + f" capability={handler_capability!r}, plugin_id={plugin_id!r}, " + f"expected_prefix={expected_prefix!r}" + ) + + +class HttpCapabilityMixin(CapabilityRouterBridgeBase): + async def _http_register_api( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + methods_payload = payload.get("methods") + if not isinstance(methods_payload, list) or not all( + isinstance(item, str) for item in methods_payload + ): + raise AstrBotError.invalid_input( + "http.register_api 的 methods 必须是 string 数组" + ) + route = str(payload.get("route", "")).strip() + handler_capability = str(payload.get("handler_capability", "")).strip() + if not route or not handler_capability: + raise AstrBotError.invalid_input( + "http.register_api 需要 route 和 handler_capability" + ) + _validate_route(route, "http.register_api") + plugin_name = self._require_caller_plugin_id("http.register_api") + _validate_plugin_route_namespace(route, plugin_name) + _validate_handler_capability_namespace(handler_capability, plugin_name) + methods = sorted({method.upper() for method in methods_payload if method}) + entry: dict[str, Any] = { + "route": route, + "methods": methods, + "handler_capability": handler_capability, + "description": str(payload.get("description", "")), + "plugin_id": plugin_name, + } + self.http_api_store = [ + item + for item in self.http_api_store + if not ( + item.get("route") == route + and item.get("plugin_id") == entry["plugin_id"] + and item.get("methods") == methods + ) + ] + self.http_api_store.append(entry) + return {} + + async def _http_unregister_api( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + route = str(payload.get("route", "")).strip() + methods_payload = payload.get("methods") + if not isinstance(methods_payload, list) or not all( + isinstance(item, str) for item in methods_payload + ): + raise AstrBotError.invalid_input( + "http.unregister_api 的 methods 必须是 string 数组" + ) + plugin_name = self._require_caller_plugin_id("http.unregister_api") + methods = {method.upper() for method in methods_payload if method} + updated: list[dict[str, Any]] = [] + for entry in self.http_api_store: + if entry.get("route") != route: + updated.append(entry) + continue + if entry.get("plugin_id") != plugin_name: + updated.append(entry) + continue + if not methods: + # `HTTPClient.unregister_api(methods=None)` 会归一化为空列表, + # 公开语义就是“移除当前插件在该 route 下注册的全部方法”。 + continue + remaining_methods = [ + method for method in entry.get("methods", []) if method not in methods + ] + if remaining_methods: + updated.append({**entry, "methods": remaining_methods}) + self.http_api_store = updated + return {} + + async def _http_list_apis( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_name = self._require_caller_plugin_id("http.list_apis") + apis = [ + dict(entry) + for entry in self.http_api_store + if entry.get("plugin_id") == plugin_name + ] + return {"apis": apis} + + def _register_http_capabilities(self) -> None: + self.register( + self._builtin_descriptor("http.register_api", "注册 HTTP 路由"), + call_handler=self._http_register_api, + ) + self.register( + self._builtin_descriptor("http.unregister_api", "注销 HTTP 路由"), + call_handler=self._http_unregister_api, + ) + self.register( + self._builtin_descriptor("http.list_apis", "列出 HTTP 路由"), + call_handler=self._http_list_apis, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py new file mode 100644 index 0000000000..77a03d86c7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py @@ -0,0 +1,427 @@ +from __future__ import annotations + +import math +import uuid +from pathlib import Path +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +def _term_set(text: str) -> set[str]: + normalized = " ".join(str(text).strip().casefold().split()) + compact = normalized.replace(" ", "") + if not normalized: + return set() + terms = {item for item in normalized.split(" ") if item} + if compact: + terms.add(compact) + if len(compact) > 1: + terms.update( + compact[index : index + 2] for index in range(len(compact) - 1) + ) + return terms + + +class KnowledgeBaseCapabilityMixin(CapabilityRouterBridgeBase): + def _kb_documents(self, kb_id: str) -> dict[str, dict[str, Any]]: + return self._kb_document_store.setdefault(kb_id, {}) + + def _refresh_mock_kb_stats(self, kb_id: str) -> None: + kb = self._kb_store.get(kb_id) + if not isinstance(kb, dict): + return + documents = self._kb_documents(kb_id) + kb["doc_count"] = len(documents) + kb["chunk_count"] = sum( + int(document.get("chunk_count", 0) or 0) for document in documents.values() + ) + kb["updated_at"] = self._now_iso() + + def _resolve_mock_kb_ids(self, payload: dict[str, Any]) -> list[str]: + kb_ids = [ + str(item).strip() for item in payload.get("kb_ids", []) if str(item).strip() + ] + if kb_ids: + return [kb_id for kb_id in kb_ids if kb_id in self._kb_store] + + kb_names = [ + str(item).strip() + for item in payload.get("kb_names", []) + if str(item).strip() + ] + if not kb_names: + return [] + name_set = set(kb_names) + return [ + kb_id + for kb_id, kb in self._kb_store.items() + if str(kb.get("kb_name", "")).strip() in name_set + ] + + @staticmethod + def _score_mock_document(query: str, content: str) -> float: + query_terms = _term_set(query) + content_terms = _term_set(content) + if not query_terms or not content_terms: + return 0.0 + overlap = len(query_terms & content_terms) + if overlap <= 0: + return 0.0 + score = overlap / len(query_terms) + if query.strip().casefold() in str(content).casefold(): + score += 0.25 + return min(score, 1.0) + + @staticmethod + def _build_mock_context_text(results: list[dict[str, Any]]) -> str: + lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"] + for index, item in enumerate(results, start=1): + lines.append(f"【知识 {index}】") + lines.append(f"来源: {item['kb_name']} / {item['doc_name']}") + lines.append(f"内容: {item['content']}") + lines.append(f"相关度: {float(item['score']):.2f}") + lines.append("") + return "\n".join(lines) + + async def _kb_list( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return { + "kbs": [ + dict(record) + for record in sorted( + self._kb_store.values(), + key=lambda item: str(item.get("created_at", "")), + ) + ] + } + + async def _kb_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + record = self._kb_store.get(kb_id) + return {"kb": dict(record) if isinstance(record, dict) else None} + + async def _kb_create( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.create requires kb object") + embedding_provider_id = str(raw_kb.get("embedding_provider_id", "")).strip() + if not embedding_provider_id: + raise AstrBotError.invalid_input("kb.create requires embedding_provider_id") + kb_id = uuid.uuid4().hex + now = self._now_iso() + record = { + "kb_id": kb_id, + "kb_name": str(raw_kb.get("kb_name", "")), + "description": ( + str(raw_kb.get("description")) + if raw_kb.get("description") is not None + else None + ), + "emoji": ( + str(raw_kb.get("emoji")) if raw_kb.get("emoji") is not None else None + ), + "embedding_provider_id": embedding_provider_id, + "rerank_provider_id": ( + str(raw_kb.get("rerank_provider_id")) + if raw_kb.get("rerank_provider_id") is not None + else None + ), + "chunk_size": self._optional_int(raw_kb.get("chunk_size")), + "chunk_overlap": self._optional_int(raw_kb.get("chunk_overlap")), + "top_k_dense": self._optional_int(raw_kb.get("top_k_dense")), + "top_k_sparse": self._optional_int(raw_kb.get("top_k_sparse")), + "top_m_final": self._optional_int(raw_kb.get("top_m_final")), + "doc_count": 0, + "chunk_count": 0, + "created_at": now, + "updated_at": now, + } + self._kb_store[kb_id] = record + self._kb_document_store[kb_id] = {} + return {"kb": dict(record)} + + async def _kb_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.update requires kb object") + record = self._kb_store.get(kb_id) + if not isinstance(record, dict): + return {"kb": None} + + for field_name in ( + "kb_name", + "description", + "emoji", + "embedding_provider_id", + "rerank_provider_id", + ): + if field_name in raw_kb: + value = raw_kb.get(field_name) + record[field_name] = str(value) if value is not None else None + for field_name in ( + "chunk_size", + "chunk_overlap", + "top_k_dense", + "top_k_sparse", + "top_m_final", + ): + if field_name in raw_kb: + record[field_name] = self._optional_int(raw_kb.get(field_name)) + record["updated_at"] = self._now_iso() + return {"kb": dict(record)} + + async def _kb_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + documents = self._kb_document_store.pop(kb_id, {}) + for document in documents.values(): + doc_id = str(document.get("doc_id", "")).strip() + if doc_id: + self._kb_document_content_store.pop(doc_id, None) + deleted = self._kb_store.pop(kb_id, None) is not None + return {"deleted": deleted} + + async def _kb_retrieve( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + query = str(payload.get("query", "")).strip() + if not query: + raise AstrBotError.invalid_input("kb.retrieve requires query") + kb_ids = self._resolve_mock_kb_ids(payload) + if not kb_ids: + raise AstrBotError.invalid_input("kb.retrieve requires kb_ids or kb_names") + + top_m_final = self._optional_int(payload.get("top_m_final")) or 5 + results: list[dict[str, Any]] = [] + for kb_id in kb_ids: + kb = self._kb_store.get(kb_id) + if not isinstance(kb, dict): + continue + for document in self._kb_documents(kb_id).values(): + doc_id = str(document.get("doc_id", "")).strip() + if not doc_id: + continue + content = self._kb_document_content_store.get(doc_id, "") + score = self._score_mock_document(query, content) + if score <= 0: + continue + results.append( + { + "chunk_id": f"{doc_id}:0", + "doc_id": doc_id, + "kb_id": kb_id, + "kb_name": str(kb.get("kb_name", "")), + "doc_name": str(document.get("doc_name", "")), + "chunk_index": 0, + "content": content, + "score": score, + "char_count": len(content), + } + ) + results.sort(key=lambda item: float(item["score"]), reverse=True) + results = results[:top_m_final] + if not results: + return {"result": None} + return { + "result": { + "context_text": self._build_mock_context_text(results), + "results": results, + } + } + + async def _kb_document_upload( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + kb = self._kb_store.get(kb_id) + if not isinstance(kb, dict): + raise AstrBotError.invalid_input(f"Unknown knowledge base: {kb_id}") + raw_document = payload.get("document") + if not isinstance(raw_document, dict): + raise AstrBotError.invalid_input( + "kb.document.upload requires document object" + ) + + file_name = str(raw_document.get("file_name", "")).strip() + file_type = str(raw_document.get("file_type", "")).strip() + file_path = "" + content_text = "" + file_size = 0 + + text_value = raw_document.get("text") + url_value = raw_document.get("url") + file_token = str(raw_document.get("file_token", "")).strip() + + if isinstance(text_value, str) and text_value.strip(): + content_text = text_value + if not file_name: + file_name = "document.txt" + if not file_type: + file_type = "txt" + file_size = len(content_text.encode("utf-8")) + elif isinstance(url_value, str) and url_value.strip(): + url_text = url_value.strip() + content_text = f"Imported from {url_text}" + if not file_name: + file_name = ( + Path(url_text.split("?", maxsplit=1)[0]).name or "document.url" + ) + if not file_type: + suffix = Path(file_name).suffix.lstrip(".") + file_type = suffix or "url" + file_path = url_text + file_size = len(content_text.encode("utf-8")) + elif file_token: + file_path = self._file_token_store.pop(file_token, "") + if not file_path: + raise AstrBotError.invalid_input(f"Unknown file token: {file_token}") + path = Path(file_path) + if not path.exists(): + raise AstrBotError.invalid_input(f"File does not exist: {file_path}") + raw_bytes = path.read_bytes() + content_text = raw_bytes.decode("utf-8", errors="ignore") + if not file_name: + file_name = path.name + if not file_type: + file_type = path.suffix.lstrip(".") + if not file_type: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_type when the file has no suffix" + ) + file_size = len(raw_bytes) + else: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_token, url, or text" + ) + + chunk_size = self._optional_int(raw_document.get("chunk_size")) + if chunk_size is None or chunk_size <= 0: + chunk_size = self._optional_int(kb.get("chunk_size")) or 512 + chunk_count = max(1, math.ceil(max(len(content_text), 1) / chunk_size)) + doc_id = uuid.uuid4().hex + now = self._now_iso() + document = { + "doc_id": doc_id, + "kb_id": kb_id, + "doc_name": file_name, + "file_type": file_type, + "file_size": file_size, + "file_path": file_path, + "chunk_count": chunk_count, + "media_count": 0, + "created_at": now, + "updated_at": now, + } + self._kb_documents(kb_id)[doc_id] = document + self._kb_document_content_store[doc_id] = content_text + self._refresh_mock_kb_stats(kb_id) + return {"document": dict(document)} + + async def _kb_document_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + offset = max(self._optional_int(payload.get("offset")) or 0, 0) + limit = max(self._optional_int(payload.get("limit")) or 100, 0) + documents = list(self._kb_documents(kb_id).values()) + documents.sort(key=lambda item: str(item.get("created_at", ""))) + return { + "documents": [dict(item) for item in documents[offset : offset + limit]] + } + + async def _kb_document_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + doc_id = str(payload.get("doc_id", "")).strip() + document = self._kb_documents(kb_id).get(doc_id) + return {"document": dict(document) if isinstance(document, dict) else None} + + async def _kb_document_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + doc_id = str(payload.get("doc_id", "")).strip() + deleted = self._kb_documents(kb_id).pop(doc_id, None) is not None + if deleted: + self._kb_document_content_store.pop(doc_id, None) + self._refresh_mock_kb_stats(kb_id) + return {"deleted": deleted} + + async def _kb_document_refresh( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + doc_id = str(payload.get("doc_id", "")).strip() + document = self._kb_documents(kb_id).get(doc_id) + if not isinstance(document, dict): + return {"document": None} + kb = self._kb_store.get(kb_id, {}) + chunk_size = self._optional_int(kb.get("chunk_size")) or 512 + content_text = self._kb_document_content_store.get(doc_id, "") + document["chunk_count"] = max( + 1, math.ceil(max(len(content_text), 1) / chunk_size) + ) + document["updated_at"] = self._now_iso() + self._refresh_mock_kb_stats(kb_id) + return {"document": dict(document)} + + def _register_kb_capabilities(self) -> None: + self.register( + self._builtin_descriptor("kb.list", "列出知识库"), + call_handler=self._kb_list, + ) + self.register( + self._builtin_descriptor("kb.get", "获取知识库"), + call_handler=self._kb_get, + ) + self.register( + self._builtin_descriptor("kb.create", "创建知识库"), + call_handler=self._kb_create, + ) + self.register( + self._builtin_descriptor("kb.update", "更新知识库"), + call_handler=self._kb_update, + ) + self.register( + self._builtin_descriptor("kb.delete", "删除知识库"), + call_handler=self._kb_delete, + ) + self.register( + self._builtin_descriptor("kb.retrieve", "检索知识库"), + call_handler=self._kb_retrieve, + ) + self.register( + self._builtin_descriptor("kb.document.upload", "上传知识库文档"), + call_handler=self._kb_document_upload, + ) + self.register( + self._builtin_descriptor("kb.document.list", "列出知识库文档"), + call_handler=self._kb_document_list, + ) + self.register( + self._builtin_descriptor("kb.document.get", "获取知识库文档"), + call_handler=self._kb_document_get, + ) + self.register( + self._builtin_descriptor("kb.document.delete", "删除知识库文档"), + call_handler=self._kb_document_delete, + ) + self.register( + self._builtin_descriptor("kb.document.refresh", "刷新知识库文档"), + call_handler=self._kb_document_refresh, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py new file mode 100644 index 0000000000..daf1621128 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from ..bridge_base import CapabilityRouterBridgeBase + + +class LLMCapabilityMixin(CapabilityRouterBridgeBase): + async def _llm_chat( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + prompt = str(payload.get("prompt", "")) + return {"text": f"Echo: {prompt}"} + + async def _llm_chat_raw( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + prompt = str(payload.get("prompt", "")) + text = f"Echo: {prompt}" + return { + "text": text, + "usage": { + "input_tokens": len(prompt), + "output_tokens": len(text), + }, + "finish_reason": "stop", + "tool_calls": [], + } + + async def _llm_stream( + self, + _request_id: str, + payload: dict[str, Any], + token, + ) -> AsyncIterator[dict[str, Any]]: + text = f"Echo: {str(payload.get('prompt', ''))}" + for char in text: + token.raise_if_cancelled() + await asyncio.sleep(0) + yield {"text": char} + + def _register_llm_capabilities(self) -> None: + self.register( + self._builtin_descriptor("llm.chat", "发送对话请求,返回文本"), + call_handler=self._llm_chat, + ) + self.register( + self._builtin_descriptor("llm.chat_raw", "发送对话请求,返回完整响应"), + call_handler=self._llm_chat_raw, + ) + self.register( + self._builtin_descriptor( + "llm.stream_chat", + "流式对话", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._llm_stream, + finalize=lambda chunks: { + "text": "".join(item.get("text", "") for item in chunks) + }, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/mcp.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/mcp.py new file mode 100644 index 0000000000..33582f5b44 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/mcp.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import asyncio +import uuid +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +def _mock_tools_from_config(name: str, config: dict[str, Any]) -> list[str]: + configured = config.get("mock_tools") + if isinstance(configured, list): + tools = [str(item) for item in configured if str(item).strip()] + if tools: + return tools + return [f"{name}_tool"] + + +def _mock_server_record( + *, + name: str, + scope: str, + active: bool, + running: bool, + config: dict[str, Any], + tools: list[str], + errlogs: list[str] | None = None, + last_error: str | None = None, +) -> dict[str, Any]: + return { + "name": name, + "scope": scope, + "active": bool(active), + "running": bool(running), + "config": dict(config), + "tools": list(tools), + "errlogs": list(errlogs or []), + "last_error": last_error, + } + + +class McpCapabilityMixin(CapabilityRouterBridgeBase): + def _plugin_local_mcp_servers(self, plugin_id: str) -> dict[str, dict[str, Any]]: + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + return plugin.local_mcp_servers + + @staticmethod + def _require_server_name(payload: dict[str, Any], capability_name: str) -> str: + name = str(payload.get("name", "")).strip() + if not name: + raise AstrBotError.invalid_input(f"{capability_name} requires name") + return name + + @staticmethod + def _normalized_timeout(payload: dict[str, Any], default: float = 30.0) -> float: + raw_value = payload.get("timeout", default) + try: + timeout = float(raw_value) + except (TypeError, ValueError) as exc: + raise AstrBotError.invalid_input("timeout must be numeric") from exc + if timeout <= 0: + raise AstrBotError.invalid_input("timeout must be greater than 0") + return timeout + + def _mock_connect_outcome( + self, + *, + name: str, + config: dict[str, Any], + scope: str, + ) -> dict[str, Any]: + if bool(config.get("mock_fail", False)): + last_error = str(config.get("mock_error") or f"{name} failed") + return _mock_server_record( + name=name, + scope=scope, + active=bool(config.get("active", True)), + running=False, + config=config, + tools=[], + errlogs=[last_error], + last_error=last_error, + ) + return _mock_server_record( + name=name, + scope=scope, + active=bool(config.get("active", True)), + running=True, + config=config, + tools=_mock_tools_from_config(name, config), + errlogs=[], + last_error=None, + ) + + async def _mcp_local_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.get") + name = self._require_server_name(payload, "mcp.local.get") + return { + "server": self._plugin_local_mcp_servers(plugin_id).get(name), + } + + async def _mcp_local_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.list") + servers = sorted( + self._plugin_local_mcp_servers(plugin_id).values(), + key=lambda item: str(item.get("name", "")), + ) + return {"servers": [dict(item) for item in servers]} + + async def _mcp_local_enable( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.enable") + name = self._require_server_name(payload, "mcp.local.enable") + servers = self._plugin_local_mcp_servers(plugin_id) + server = servers.get(name) + if server is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if bool(server.get("active", False)) and bool(server.get("running", False)): + return {"server": dict(server)} + updated = self._mock_connect_outcome( + name=name, + config=dict(server.get("config", {})), + scope="local", + ) + updated["active"] = True + servers[name] = updated + return {"server": dict(updated)} + + async def _mcp_local_disable( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.disable") + name = self._require_server_name(payload, "mcp.local.disable") + servers = self._plugin_local_mcp_servers(plugin_id) + server = servers.get(name) + if server is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if not bool(server.get("active", False)) and not bool( + server.get("running", False) + ): + return {"server": dict(server)} + updated = dict(server) + updated["active"] = False + updated["running"] = False + servers[name] = updated + return {"server": updated} + + async def _mcp_local_wait_until_ready( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.wait_until_ready") + name = self._require_server_name(payload, "mcp.local.wait_until_ready") + timeout = self._normalized_timeout(payload) + server = self._plugin_local_mcp_servers(plugin_id).get(name) + if server is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if bool(server.get("running", False)): + return {"server": dict(server)} + delay = float(server.get("config", {}).get("mock_connect_delay", 0.0) or 0.0) + if delay > timeout: + raise TimeoutError( + f"Local MCP server '{name}' did not become ready in time" + ) + if delay > 0: + await asyncio.sleep(delay) + if bool(server.get("active", False)) and not bool( + server.get("config", {}).get("mock_fail", False) + ): + refreshed = self._mock_connect_outcome( + name=name, + config=dict(server.get("config", {})), + scope="local", + ) + refreshed["active"] = bool(server.get("active", False)) + self._plugin_local_mcp_servers(plugin_id)[name] = refreshed + refreshed = self._plugin_local_mcp_servers(plugin_id).get(name) + if refreshed is None or not bool(refreshed.get("running", False)): + raise TimeoutError( + f"Local MCP server '{name}' did not become ready in time" + ) + return {"server": dict(refreshed)} + + async def _mcp_session_open( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.session.open") + name = self._require_server_name(payload, "mcp.session.open") + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input("mcp.session.open requires config object") + timeout = self._normalized_timeout(payload) + delay = float(config.get("mock_connect_delay", 0.0) or 0.0) + if bool(config.get("mock_fail", False)) or delay > timeout: + raise TimeoutError(f"MCP session '{name}' failed to connect in time") + if delay > 0: + await asyncio.sleep(delay) + session_id = f"{plugin_id}:{uuid.uuid4().hex}" + tools = _mock_tools_from_config(name, dict(config)) + self._mcp_session_store[session_id] = { + "plugin_id": plugin_id, + "name": name, + "config": dict(config), + "tools": tools, + "tool_results": dict(config.get("mock_tool_results", {})) + if isinstance(config.get("mock_tool_results"), dict) + else {}, + } + return {"session_id": session_id, "tools": list(tools)} + + async def _mcp_session_list_tools( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session_id = str(payload.get("session_id", "")).strip() + session = self._mcp_session_store.get(session_id) + if session is None: + raise AstrBotError.invalid_input("Unknown MCP session") + return {"tools": list(session.get("tools", []))} + + async def _mcp_session_call_tool( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session_id = str(payload.get("session_id", "")).strip() + session = self._mcp_session_store.get(session_id) + if session is None: + raise AstrBotError.invalid_input("Unknown MCP session") + tool_name = str(payload.get("tool_name", "")).strip() + if not tool_name: + raise AstrBotError.invalid_input("mcp.session.call_tool requires tool_name") + args = payload.get("args") + if not isinstance(args, dict): + raise AstrBotError.invalid_input( + "mcp.session.call_tool requires args object" + ) + tool_results = session.get("tool_results", {}) + if isinstance(tool_results, dict) and tool_name in tool_results: + result = tool_results[tool_name] + return { + "result": dict(result) + if isinstance(result, dict) + else {"value": result} + } + return { + "result": { + "tool_name": tool_name, + "arguments": dict(args), + "content": f"mock:{session['name']}:{tool_name}", + } + } + + async def _mcp_session_close( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session_id = str(payload.get("session_id", "")).strip() + self._mcp_session_store.pop(session_id, None) + return {} + + def _require_global_mcp_risk_ack( + self, + plugin_id: str, + capability_name: str, + ) -> None: + plugin = self._plugins.get(plugin_id) + metadata = plugin.metadata if plugin is not None else {} + if bool(metadata.get("acknowledge_global_mcp_risk", False)): + return + raise PermissionError( + f"{capability_name} requires @acknowledge_global_mcp_risk" + ) + + def _audit_global_mcp_mutation( + self, + *, + plugin_id: str, + action: str, + server_name: str, + request_id: str, + ) -> None: + self._mcp_audit_logs.append( + { + "plugin_id": plugin_id, + "action": action, + "server_name": server_name, + "request_id": request_id, + } + ) + + async def _mcp_global_register( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.register") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.register") + name = self._require_server_name(payload, "mcp.global.register") + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input( + "mcp.global.register requires config object" + ) + if name in self._mcp_global_servers: + raise AstrBotError.invalid_input( + f"Global MCP server already exists: {name}" + ) + record = self._mock_connect_outcome( + name=name, + config=dict(config), + scope="global", + ) + self._mcp_global_servers[name] = record + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="register", + server_name=name, + request_id=request_id, + ) + return {"server": dict(record)} + + async def _mcp_global_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.get") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.get") + name = self._require_server_name(payload, "mcp.global.get") + return {"server": self._mcp_global_servers.get(name)} + + async def _mcp_global_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.list") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.list") + servers = sorted( + self._mcp_global_servers.values(), + key=lambda item: str(item.get("name", "")), + ) + return {"servers": [dict(item) for item in servers]} + + async def _mcp_global_enable( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.enable") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.enable") + name = self._require_server_name(payload, "mcp.global.enable") + record = self._mcp_global_servers.get(name) + if record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + updated = self._mock_connect_outcome( + name=name, + config=dict(record.get("config", {})), + scope="global", + ) + updated["active"] = True + self._mcp_global_servers[name] = updated + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="enable", + server_name=name, + request_id=request_id, + ) + return {"server": dict(updated)} + + async def _mcp_global_disable( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.disable") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.disable") + name = self._require_server_name(payload, "mcp.global.disable") + record = self._mcp_global_servers.get(name) + if record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + updated = dict(record) + updated["active"] = False + updated["running"] = False + self._mcp_global_servers[name] = updated + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="disable", + server_name=name, + request_id=request_id, + ) + return {"server": dict(updated)} + + async def _mcp_global_unregister( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.unregister") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.unregister") + name = self._require_server_name(payload, "mcp.global.unregister") + record = self._mcp_global_servers.pop(name, None) + if record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="unregister", + server_name=name, + request_id=request_id, + ) + return {"server": dict(record)} + + async def _internal_mcp_local_execute( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = str(payload.get("plugin_id", "")).strip() + server_name = str(payload.get("server_name", "")).strip() + tool_name = str(payload.get("tool_name", "")).strip() + tool_args = payload.get("tool_args") + if not plugin_id or not server_name or not tool_name: + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires plugin_id, server_name, and tool_name" + ) + if not isinstance(tool_args, dict): + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires tool_args object" + ) + plugin = self._plugins.get(plugin_id) + server = ( + plugin.local_mcp_servers.get(server_name) if plugin is not None else None + ) + if server is None or not bool(server.get("running", False)): + return { + "content": f"Local MCP server unavailable: {server_name}", + "success": False, + } + if tool_name not in server.get("tools", []): + return { + "content": f"Local MCP tool not found: {server_name}.{tool_name}", + "success": False, + } + return { + "content": f"mock:{server_name}:{tool_name}:{tool_args}", + "success": True, + } + + def _register_mcp_capabilities(self) -> None: + self.register( + self._builtin_descriptor("mcp.local.get", "Get local MCP server"), + call_handler=self._mcp_local_get, + ) + self.register( + self._builtin_descriptor("mcp.local.list", "List local MCP servers"), + call_handler=self._mcp_local_list, + ) + self.register( + self._builtin_descriptor("mcp.local.enable", "Enable local MCP server"), + call_handler=self._mcp_local_enable, + ) + self.register( + self._builtin_descriptor("mcp.local.disable", "Disable local MCP server"), + call_handler=self._mcp_local_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.local.wait_until_ready", + "Wait until local MCP server is ready", + ), + call_handler=self._mcp_local_wait_until_ready, + ) + self.register( + self._builtin_descriptor("mcp.session.open", "Open temporary MCP session"), + call_handler=self._mcp_session_open, + ) + self.register( + self._builtin_descriptor( + "mcp.session.list_tools", + "List tools in temporary MCP session", + ), + call_handler=self._mcp_session_list_tools, + ) + self.register( + self._builtin_descriptor( + "mcp.session.call_tool", + "Call tool in temporary MCP session", + ), + call_handler=self._mcp_session_call_tool, + ) + self.register( + self._builtin_descriptor( + "mcp.session.close", "Close temporary MCP session" + ), + call_handler=self._mcp_session_close, + ) + self.register( + self._builtin_descriptor( + "mcp.global.register", + "Register global MCP server", + ), + call_handler=self._mcp_global_register, + ) + self.register( + self._builtin_descriptor("mcp.global.get", "Get global MCP server"), + call_handler=self._mcp_global_get, + ) + self.register( + self._builtin_descriptor("mcp.global.list", "List global MCP servers"), + call_handler=self._mcp_global_list, + ) + self.register( + self._builtin_descriptor("mcp.global.enable", "Enable global MCP server"), + call_handler=self._mcp_global_enable, + ) + self.register( + self._builtin_descriptor( + "mcp.global.disable", + "Disable global MCP server", + ), + call_handler=self._mcp_global_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.global.unregister", + "Unregister global MCP server", + ), + call_handler=self._mcp_global_unregister, + ) + self.register( + self._builtin_descriptor( + "internal.mcp.local.execute", + "Execute local MCP tool", + ), + call_handler=self._internal_mcp_local_execute, + exposed=False, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py new file mode 100644 index 0000000000..f55ef7ccf0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py @@ -0,0 +1,655 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ...._internal.invocation_context import current_caller_plugin_id +from ...._internal.memory_utils import ( + cosine_similarity, + extract_memory_text, + is_ttl_memory_entry, + memory_expiration_from_ttl, + memory_index_entry, + memory_keyword_score, + memory_value_for_search, +) +from ...._memory_backend import PluginMemoryBackend +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class MemoryCapabilityMixin(CapabilityRouterBridgeBase): + def _memory_plugin_id(self) -> str: + plugin_id = current_caller_plugin_id() + return self._validated_plugin_id( + str(plugin_id).strip() or "__anonymous__", + capability_name="memory.*", + ) + + def _memory_backend_for_plugin(self, plugin_id: str) -> PluginMemoryBackend: + backend = self._memory_backends.get(plugin_id) + if backend is None: + backend = PluginMemoryBackend( + self._plugin_data_dir(plugin_id, capability_name="memory.*") + ) + self._memory_backends[plugin_id] = backend + return backend + + @staticmethod + def _is_ttl_memory_entry(value: Any) -> bool: + """判断存储值是否使用了 TTL 包装结构。 + + Args: + value: 待检查的存储值。 + + Returns: + bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。 + """ + return is_ttl_memory_entry(value) + + @classmethod + def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None: + """提取用于检索的原始 memory payload。 + + Args: + stored: memory_store 中保存的原始值。 + + Returns: + dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。 + """ + return memory_value_for_search(stored) + + @classmethod + def _extract_memory_text(cls, stored: Any) -> str: + """提取用于检索索引的首选文本。 + + Args: + stored: memory_store 中保存的原始值。 + + Returns: + str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。 + """ + return extract_memory_text(stored) + + @staticmethod + def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None: + """将 TTL 秒数转换为 UTC 过期时间。 + + Args: + ttl_seconds: TTL 秒数。 + + Returns: + datetime | None: 绝对过期时间;当输入无效时返回 ``None``。 + """ + return memory_expiration_from_ttl(ttl_seconds) + + @staticmethod + def _memory_keyword_score(query: str, key: str, text: str) -> float: + """计算关键词匹配分数。 + + Args: + query: 查询文本。 + key: memory 条目的键。 + text: 已索引的检索文本。 + + Returns: + float: 基于键名和文本命中的粗粒度关键词分数。 + """ + return memory_keyword_score(query, key, text) + + @staticmethod + def _cosine_similarity(left: list[float], right: list[float]) -> float: + """计算两个向量之间的余弦相似度。 + + Args: + left: 左侧向量。 + right: 右侧向量。 + + Returns: + float: 余弦相似度;输入不合法时返回 ``0.0``。 + """ + return cosine_similarity(left, right) + + def _resolve_memory_embedding_provider_id( + self, + provider_id: Any, + *, + required: bool, + ) -> str | None: + """解析 memory.search 要使用的 embedding provider。 + + Args: + provider_id: 调用方显式传入的 provider 标识。 + required: 当前检索模式是否强制要求 embedding provider。 + + Returns: + str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。 + """ + normalized = str(provider_id).strip() if provider_id is not None else "" + if normalized: + self._provider_entry( + {"provider_id": normalized}, + "memory.search", + "embedding", + ) + return normalized + active_id = self._active_provider_ids.get("embedding") + if active_id is not None: + normalized_active = str(active_id).strip() + if normalized_active: + self._provider_entry( + {"provider_id": normalized_active}, + "memory.search", + "embedding", + ) + return normalized_active + if required: + raise AstrBotError.invalid_input( + "memory.search requires an embedding provider", + ) + return None + + @staticmethod + def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]: + """将原始索引项规范化为内部统一结构。 + + Args: + entry: 当前索引表中的原始项。 + text: 当前条目的索引文本。 + + Returns: + dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。 + """ + return memory_index_entry(entry, text=text) + + def _clear_memory_sidecars(self, key: str) -> None: + """清理指定 memory 键对应的所有 sidecar 状态。 + + Args: + key: memory 条目的键。 + + Returns: + None + """ + self._memory_index.pop(key, None) + self._memory_expires_at.pop(key, None) + self._memory_dirty_keys.discard(key) + + def _delete_memory_entry(self, key: str) -> bool: + """删除 memory 条目并同步清理 sidecar 状态。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 条目存在并删除成功时返回 ``True``。 + """ + deleted = self.memory_store.pop(key, None) is not None + self._clear_memory_sidecars(key) + return deleted + + def _upsert_memory_sidecars( + self, + key: str, + stored: dict[str, Any], + *, + expires_at: datetime | None = None, + ) -> None: + """创建或更新单条 memory 的 sidecar 索引状态。 + + Args: + key: memory 条目的键。 + stored: 需要建立索引的原始存储值。 + expires_at: 可选的绝对过期时间。 + + Returns: + None + """ + self._memory_index[key] = { + "text": self._extract_memory_text(stored), + "embedding": None, + "provider_id": None, + } + if expires_at is None: + self._memory_expires_at.pop(key, None) + else: + self._memory_expires_at[key] = expires_at + self._memory_dirty_keys.add(key) + + def _ensure_memory_sidecars(self, key: str, stored: Any) -> None: + """确保 sidecar 状态与当前存储值保持一致。 + + Args: + key: memory 条目的键。 + stored: memory_store 中的当前存储值。 + + Returns: + None + """ + if not isinstance(stored, dict): + return + text = self._extract_memory_text(stored) + existed = key in self._memory_index + entry = self._memory_index_entry(self._memory_index.get(key), text=text) + if entry["text"] != text: + entry["text"] = text + entry["embedding"] = None + entry["provider_id"] = None + self._memory_dirty_keys.add(key) + self._memory_index[key] = entry + if not existed: + self._memory_dirty_keys.add(key) + + def _is_memory_expired(self, key: str) -> bool: + """判断 memory 条目是否已过期。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果当前时间已超过记录的过期时间则返回 ``True``。 + """ + expires_at = self._memory_expires_at.get(key) + return expires_at is not None and expires_at <= datetime.now(timezone.utc) + + def _purge_expired_memory_entry(self, key: str) -> bool: + """在单条 memory 已过期时立即清理它。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果条目已过期并被成功清理则返回 ``True``。 + """ + if not self._is_memory_expired(key): + return False + self._delete_memory_entry(key) + return True + + def _purge_expired_memory_entries(self) -> None: + """批量清理所有已跟踪的过期 TTL 条目。 + + Returns: + None + """ + for key in list(self._memory_expires_at): + self._purge_expired_memory_entry(key) + + async def _embedding_for_text( + self, + *, + provider_id: str, + text: str, + ) -> list[float]: + """通过 embedding capability 获取单条文本向量。 + + Args: + provider_id: 使用的 embedding provider 标识。 + text: 待向量化的文本。 + + Returns: + list[float]: provider 返回的向量;异常场景下返回空列表。 + """ + output = await self._provider_embedding_get_embedding( + "", + {"provider_id": provider_id, "text": text}, + None, + ) + embedding = output.get("embedding") + if not isinstance(embedding, list): + return [] + return [float(item) for item in embedding] + + async def _embeddings_for_texts( + self, + *, + provider_id: str, + texts: list[str], + ) -> list[list[float]]: + """批量获取多条文本的 embedding 向量。 + + Args: + provider_id: 使用的 embedding provider 标识。 + texts: 待向量化的文本列表。 + + Returns: + list[list[float]]: 与输入顺序对应的向量列表。 + """ + if not texts: + return [] + output = await self._provider_embedding_get_embeddings( + "", + {"provider_id": provider_id, "texts": texts}, + None, + ) + embeddings = output.get("embeddings") + if not isinstance(embeddings, list): + return [] + return [ + [float(value) for value in item] + for item in embeddings + if isinstance(item, list) + ] + + async def _refresh_memory_embeddings(self, *, provider_id: str) -> None: + """刷新当前 provider 下脏或过期的 memory 向量索引。 + + Args: + provider_id: 当前使用的 embedding provider 标识。 + + Returns: + None + """ + keys_to_refresh: list[str] = [] + texts_to_refresh: list[str] = [] + for key, stored in self.memory_store.items(): + self._ensure_memory_sidecars(key, stored) + entry = self._memory_index_entry( + self._memory_index.get(key), + text=self._extract_memory_text(stored), + ) + should_refresh = ( + key in self._memory_dirty_keys + or entry["embedding"] is None + or entry["provider_id"] != provider_id + ) + self._memory_index[key] = entry + if should_refresh: + keys_to_refresh.append(key) + texts_to_refresh.append(str(entry["text"])) + # 分批请求,避免单次 payload 过大导致 OOM 或 413 + _BATCH_SIZE = 64 + embeddings: list[list[float]] = [] + for batch_start in range(0, len(texts_to_refresh), _BATCH_SIZE): + batch = texts_to_refresh[batch_start : batch_start + _BATCH_SIZE] + embeddings.extend( + await self._embeddings_for_texts( + provider_id=provider_id, + texts=batch, + ) + ) + for index, key in enumerate(keys_to_refresh): + entry = self._memory_index_entry( + self._memory_index.get(key), + text=str(texts_to_refresh[index]), + ) + entry["embedding"] = embeddings[index] if index < len(embeddings) else [] + entry["provider_id"] = provider_id + self._memory_index[key] = entry + self._memory_dirty_keys.discard(key) + + async def _memory_search( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + query = str(payload.get("query", "")) + mode = str(payload.get("mode", "auto")).strip().lower() or "auto" + limit = self._optional_int(payload.get("limit")) + raw_min_score = payload.get("min_score") + min_score = float(raw_min_score) if raw_min_score is not None else None + namespace = payload.get("namespace") + include_descendants = bool(payload.get("include_descendants", True)) + provider_id = self._resolve_memory_embedding_provider_id( + payload.get("provider_id"), + required=mode in {"vector", "hybrid"}, + ) + effective_mode = mode + if effective_mode == "auto": + effective_mode = "hybrid" if provider_id is not None else "keyword" + backend = self._memory_backend_for_plugin(plugin_id) + items = await backend.search( + query, + namespace=str(namespace) if namespace is not None else None, + include_descendants=include_descendants, + mode=effective_mode, + limit=limit, + min_score=min_score, + provider_id=provider_id, + embed_one=( + ( + lambda text: self._embedding_for_text( + provider_id=provider_id, text=text + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + embed_many=( + ( + lambda texts: self._embeddings_for_texts( + provider_id=provider_id, + texts=texts, + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + ) + return {"items": items} + + async def _memory_save( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + key = str(payload.get("key", "")) + value = payload.get("value") + if not isinstance(value, dict): + raise AstrBotError.invalid_input("memory.save 的 value 必须是 object") + await self._memory_backend_for_plugin(plugin_id).save( + key, + value, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + key = str(payload.get("key", "")) + value = await self._memory_backend_for_plugin(plugin_id).get( + key, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"value": value} + + async def _memory_list_keys( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + keys = await self._memory_backend_for_plugin(plugin_id).list_keys( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"keys": keys} + + async def _memory_exists( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + exists = await self._memory_backend_for_plugin(plugin_id).exists( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"exists": exists} + + async def _memory_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + await self._memory_backend_for_plugin(plugin_id).delete( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_clear_namespace( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + deleted_count = await self._memory_backend_for_plugin( + plugin_id + ).clear_namespace( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"deleted_count": deleted_count} + + async def _memory_save_with_ttl( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + key = str(payload.get("key", "")) + value = payload.get("value") + ttl_seconds = payload.get("ttl_seconds", 0) + if not isinstance(value, dict): + raise AstrBotError.invalid_input( + "memory.save_with_ttl 的 value 必须是 object" + ) + await self._memory_backend_for_plugin(plugin_id).save_with_ttl( + key, + value, + int(ttl_seconds), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + keys_payload = payload.get("keys") + if not isinstance(keys_payload, (list, tuple)): + raise AstrBotError.invalid_input("memory.get_many 的 keys 必须是数组") + items = await self._memory_backend_for_plugin(plugin_id).get_many( + [str(item) for item in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"items": items} + + async def _memory_delete_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + keys_payload = payload.get("keys") + if not isinstance(keys_payload, (list, tuple)): + raise AstrBotError.invalid_input("memory.delete_many 的 keys 必须是数组") + deleted_count = await self._memory_backend_for_plugin(plugin_id).delete_many( + [str(item) for item in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"deleted_count": deleted_count} + + async def _memory_count( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + count = await self._memory_backend_for_plugin(plugin_id).count( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"count": count} + + async def _memory_stats( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + stats = await self._memory_backend_for_plugin(plugin_id).stats( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", True)), + ) + stats["plugin_id"] = plugin_id + return stats + + def _register_memory_capabilities(self) -> None: + self.register( + self._builtin_descriptor("memory.search", "搜索记忆"), + call_handler=self._memory_search, + ) + self.register( + self._builtin_descriptor("memory.save", "保存记忆"), + call_handler=self._memory_save, + ) + self.register( + self._builtin_descriptor("memory.get", "读取单条记忆"), + call_handler=self._memory_get, + ) + self.register( + self._builtin_descriptor("memory.list_keys", "列出命名空间内的记忆键"), + call_handler=self._memory_list_keys, + ) + self.register( + self._builtin_descriptor("memory.exists", "检查记忆键是否存在"), + call_handler=self._memory_exists, + ) + self.register( + self._builtin_descriptor("memory.delete", "删除记忆"), + call_handler=self._memory_delete, + ) + self.register( + self._builtin_descriptor("memory.clear_namespace", "清理记忆命名空间"), + call_handler=self._memory_clear_namespace, + ) + self.register( + self._builtin_descriptor("memory.save_with_ttl", "保存带过期时间的记忆"), + call_handler=self._memory_save_with_ttl, + ) + self.register( + self._builtin_descriptor("memory.get_many", "批量获取记忆"), + call_handler=self._memory_get_many, + ) + self.register( + self._builtin_descriptor("memory.delete_many", "批量删除记忆"), + call_handler=self._memory_delete_many, + ) + self.register( + self._builtin_descriptor("memory.count", "统计命名空间内的记忆数量"), + call_handler=self._memory_count, + ) + self.register( + self._builtin_descriptor("memory.stats", "获取记忆统计信息"), + call_handler=self._memory_stats, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py new file mode 100644 index 0000000000..3e2b6666bc --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ....errors import AstrBotError +from ....message.session import MessageSession +from ..bridge_base import CapabilityRouterBridgeBase + + +def _session_payload(session: MessageSession) -> dict[str, str]: + return { + "platform_id": str(session.platform_id), + "message_type": str(session.message_type), + "session_id": str(session.session_id), + } + + +class MessageHistoryCapabilityMixin(CapabilityRouterBridgeBase): + @staticmethod + def _normalize_timestamp(raw_value: Any) -> datetime: + normalized = str(raw_value or "").strip() + if normalized.endswith("Z"): + normalized = f"{normalized[:-1]}+00:00" + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + @staticmethod + def _typed_session_from_payload(payload: Any) -> MessageSession: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + "message_history capabilities require a session object" + ) + platform_id = str(payload.get("platform_id", "")).strip() + message_type = str(payload.get("message_type", "")).strip() + session_id = str(payload.get("session_id", "")).strip() + if not platform_id or not message_type or not session_id: + raise AstrBotError.invalid_input( + "message_history session requires platform_id, message_type, and session_id" + ) + return MessageSession( + platform_id=platform_id, + message_type=message_type, + session_id=session_id, + ) + + @staticmethod + def _typed_key(session: MessageSession) -> str: + return ( + f"{str(session.platform_id)}:{str(session.message_type).lower()}:" + f"{str(session.session_id)}" + ) + + def _message_history_records(self, session: MessageSession) -> list[dict[str, Any]]: + key = self._typed_key(session) + records = self._message_history_store.get(key) + if records is None: + records = [] + self._message_history_store[key] = records + return records + + def _next_message_history_id(self) -> int: + next_id = int(self._message_history_next_id) + self._message_history_next_id += 1 + return next_id + + def _create_message_history_record( + self, + *, + session: MessageSession, + sender_payload: dict[str, Any], + parts_payload: list[dict[str, Any]], + metadata: dict[str, Any], + idempotency_key: str | None, + ) -> dict[str, Any]: + now = self._now_iso() + return { + "id": self._next_message_history_id(), + "session": _session_payload(session), + "sender": { + "sender_id": ( + str(sender_payload.get("sender_id")) + if sender_payload.get("sender_id") is not None + else None + ), + "sender_name": ( + str(sender_payload.get("sender_name")) + if sender_payload.get("sender_name") is not None + else None + ), + }, + "parts": [dict(item) for item in parts_payload if isinstance(item, dict)], + "metadata": dict(metadata), + "created_at": now, + "updated_at": now, + "idempotency_key": idempotency_key, + } + + @staticmethod + def _serialize_record(record: dict[str, Any]) -> dict[str, Any]: + return { + "id": int(record.get("id", 0) or 0), + "session": ( + dict(record.get("session")) + if isinstance(record.get("session"), dict) + else {} + ), + "sender": ( + dict(record.get("sender")) + if isinstance(record.get("sender"), dict) + else {} + ), + "parts": ( + [ + dict(item) + for item in record.get("parts", []) + if isinstance(item, dict) + ] + if isinstance(record.get("parts"), list) + else [] + ), + "metadata": ( + dict(record.get("metadata")) + if isinstance(record.get("metadata"), dict) + else {} + ), + "created_at": record.get("created_at"), + "updated_at": record.get("updated_at"), + "idempotency_key": ( + str(record.get("idempotency_key")) + if record.get("idempotency_key") is not None + else None + ), + } + + @staticmethod + def _parse_boundary(raw_value: Any, field_name: str) -> datetime: + text = str(raw_value or "").strip() + if not text: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires {field_name}" + ) + try: + return MessageHistoryCapabilityMixin._normalize_timestamp(text) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires an ISO datetime string" + ) from exc + + async def _message_history_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + raw_limit = self._optional_int(payload.get("limit")) + limit = 50 if raw_limit is None else raw_limit + if limit < 1: + raise AstrBotError.invalid_input("message_history.list requires limit >= 1") + raw_cursor = payload.get("cursor") + cursor_id = ( + self._optional_int(raw_cursor) if raw_cursor not in (None, "") else None + ) + if raw_cursor not in (None, "") and (cursor_id is None or cursor_id < 1): + raise AstrBotError.invalid_input( + "message_history.list requires cursor to be a positive integer string" + ) + records = list(reversed(self._message_history_records(session))) + total = len(records) + if cursor_id is not None: + records = [ + record for record in records if int(record.get("id", 0)) < cursor_id + ] + page_records = records[:limit] + next_cursor = ( + str(page_records[-1]["id"]) + if len(records) > limit and page_records + else None + ) + return { + "page": { + "records": [self._serialize_record(record) for record in page_records], + "next_cursor": next_cursor, + "total": total, + } + } + + async def _message_history_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + record_id = self._optional_int(payload.get("record_id")) + if record_id is None or record_id < 1: + raise AstrBotError.invalid_input( + "message_history.get_by_id requires record_id >= 1" + ) + record = next( + ( + item + for item in self._message_history_records(session) + if int(item.get("id", 0) or 0) == record_id + ), + None, + ) + return { + "record": self._serialize_record(record) if record is not None else None + } + + async def _message_history_append( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + sender_payload = payload.get("sender") + if not isinstance(sender_payload, dict): + raise AstrBotError.invalid_input( + "message_history.append requires sender object" + ) + parts_payload = payload.get("parts") + if not isinstance(parts_payload, list) or any( + not isinstance(item, dict) for item in parts_payload + ): + raise AstrBotError.invalid_input( + "message_history.append requires parts array" + ) + metadata = payload.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + raise AstrBotError.invalid_input( + "message_history.append requires metadata object when provided" + ) + idempotency_key = ( + str(payload.get("idempotency_key")) + if payload.get("idempotency_key") is not None + else None + ) + records = self._message_history_records(session) + if idempotency_key: + existing = next( + ( + record + for record in records + if str(record.get("idempotency_key") or "") == idempotency_key + ), + None, + ) + if existing is not None: + return {"record": self._serialize_record(existing)} + record = self._create_message_history_record( + session=session, + sender_payload=sender_payload, + parts_payload=parts_payload, + metadata=dict(metadata or {}), + idempotency_key=idempotency_key, + ) + records.append(record) + return {"record": self._serialize_record(record)} + + async def _message_history_delete_before( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + before = self._parse_boundary(payload.get("before"), "delete_before") + records = self._message_history_records(session) + retained: list[dict[str, Any]] = [] + deleted_count = 0 + for record in records: + created_at = self._normalize_timestamp(record.get("created_at")) + if created_at < before: + deleted_count += 1 + continue + retained.append(record) + self._message_history_store[self._typed_key(session)] = retained + return {"deleted_count": deleted_count} + + async def _message_history_delete_after( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + after = self._parse_boundary(payload.get("after"), "delete_after") + records = self._message_history_records(session) + retained: list[dict[str, Any]] = [] + deleted_count = 0 + for record in records: + created_at = self._normalize_timestamp(record.get("created_at")) + if created_at > after: + deleted_count += 1 + continue + retained.append(record) + self._message_history_store[self._typed_key(session)] = retained + return {"deleted_count": deleted_count} + + async def _message_history_delete_all( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + key = self._typed_key(session) + deleted_count = len(self._message_history_store.get(key, [])) + self._message_history_store[key] = [] + return {"deleted_count": deleted_count} + + def _register_message_history_capabilities(self) -> None: + self.register( + self._builtin_descriptor("message_history.list", "List message history"), + call_handler=self._message_history_list, + ) + self.register( + self._builtin_descriptor( + "message_history.get_by_id", + "Get message history by id", + ), + call_handler=self._message_history_get_by_id, + ) + self.register( + self._builtin_descriptor( + "message_history.append", "Append message history" + ), + call_handler=self._message_history_append, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_before", + "Delete message history before timestamp", + ), + call_handler=self._message_history_delete_before, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_after", + "Delete message history after timestamp", + ), + call_handler=self._message_history_delete_after, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_all", + "Delete all message history in session", + ), + call_handler=self._message_history_delete_all, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py new file mode 100644 index 0000000000..787f63369b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any + +from ..bridge_base import CapabilityRouterBridgeBase + + +class MetadataCapabilityMixin(CapabilityRouterBridgeBase): + async def _metadata_get_plugin( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + name = str(payload.get("name", "")).strip() + plugin = self._plugins.get(name) + if plugin is None: + return {"plugin": None} + return {"plugin": dict(plugin.metadata)} + + async def _metadata_list_plugins( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugins = [ + dict(self._plugins[name].metadata) for name in sorted(self._plugins.keys()) + ] + return {"plugins": plugins} + + async def _metadata_get_plugin_config( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + name = str(payload.get("name", "")).strip() + caller_plugin_id = self._require_caller_plugin_id("metadata.get_plugin_config") + if name != caller_plugin_id: + return {"config": None} + plugin = self._plugins.get(name) + if plugin is None: + return {"config": None} + return {"config": dict(plugin.config)} + + async def _metadata_save_plugin_config( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + caller_plugin_id = self._require_caller_plugin_id("metadata.save_plugin_config") + plugin = self._plugins.get(caller_plugin_id) + if plugin is None: + return {"config": None} + config = payload.get("config") + if not isinstance(config, dict): + return {"config": dict(plugin.config)} + plugin.config = dict(config) + return {"config": dict(plugin.config)} + + def _register_metadata_capabilities(self) -> None: + self.register( + self._builtin_descriptor("metadata.get_plugin", "获取单个插件元数据"), + call_handler=self._metadata_get_plugin, + ) + self.register( + self._builtin_descriptor("metadata.list_plugins", "列出插件元数据"), + call_handler=self._metadata_list_plugins, + ) + self.register( + self._builtin_descriptor( + "metadata.get_plugin_config", + "获取插件配置", + ), + call_handler=self._metadata_get_plugin_config, + ) + self.register( + self._builtin_descriptor( + "metadata.save_plugin_config", + "保存当前插件配置", + ), + call_handler=self._metadata_save_plugin_config, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py new file mode 100644 index 0000000000..063ab840c9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class PermissionCapabilityMixin(CapabilityRouterBridgeBase): + def _register_permission_capabilities(self) -> None: + self.register( + self._builtin_descriptor("permission.check", "查询用户权限角色"), + call_handler=self._permission_check, + ) + self.register( + self._builtin_descriptor("permission.get_admins", "列出管理员 ID"), + call_handler=self._permission_get_admins, + ) + self.register( + self._builtin_descriptor( + "permission.manager.add_admin", + "添加管理员 ID", + ), + call_handler=self._permission_manager_add_admin, + ) + self.register( + self._builtin_descriptor( + "permission.manager.remove_admin", + "移除管理员 ID", + ), + call_handler=self._permission_manager_remove_admin, + ) + + @staticmethod + def _normalize_admin_ids(values: Any) -> list[str]: + if not isinstance(values, list): + return [] + normalized: list[str] = [] + for item in values: + user_id = str(item).strip() + if user_id: + normalized.append(user_id) + return normalized + + def _admin_ids_snapshot(self) -> list[str]: + normalized = self._normalize_admin_ids( + getattr(self, "_permission_admin_ids", []) + ) + self._permission_admin_ids = list(normalized) + return normalized + + @staticmethod + def _required_user_id(payload: dict[str, Any], capability_name: str) -> str: + user_id = str(payload.get("user_id", "")).strip() + if not user_id: + raise AstrBotError.invalid_input(f"{capability_name} requires user_id") + return user_id + + def _require_reserved_plugin(self, capability_name: str) -> str: + plugin_id = self._require_caller_plugin_id(capability_name) + plugin = self._plugins.get(plugin_id) + if plugin is not None and bool(plugin.metadata.get("reserved", False)): + return plugin_id + if plugin_id in {"system", "__system__"}: + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} is restricted to reserved/system plugins" + ) + + @staticmethod + def _require_admin_event_context( + payload: dict[str, Any], + capability_name: str, + ) -> None: + if bool(payload.get("_caller_is_admin", False)): + return + raise AstrBotError.invalid_input( + f"{capability_name} requires an active admin event context" + ) + + async def _permission_check( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + user_id = self._required_user_id(payload, "permission.check") + admins = self._admin_ids_snapshot() + is_admin = user_id in admins + return { + "is_admin": is_admin, + "role": "admin" if is_admin else "member", + } + + async def _permission_get_admins( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return {"admins": self._admin_ids_snapshot()} + + async def _permission_manager_add_admin( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin("permission.manager.add_admin") + self._require_admin_event_context(payload, "permission.manager.add_admin") + user_id = self._required_user_id(payload, "permission.manager.add_admin") + admins = self._admin_ids_snapshot() + if user_id in admins: + return {"changed": False} + admins.append(user_id) + self._permission_admin_ids = admins + return {"changed": True} + + async def _permission_manager_remove_admin( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin("permission.manager.remove_admin") + self._require_admin_event_context(payload, "permission.manager.remove_admin") + user_id = self._required_user_id(payload, "permission.manager.remove_admin") + admins = self._admin_ids_snapshot() + if user_id not in admins: + return {"changed": False} + admins.remove(user_id) + self._permission_admin_ids = admins + return {"changed": True} diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py new file mode 100644 index 0000000000..6d7b3b3531 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class PersonaCapabilityMixin(CapabilityRouterBridgeBase): + async def _persona_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + persona_id = str(payload.get("persona_id", "")).strip() + record = self._persona_store.get(persona_id) + if record is None: + raise AstrBotError.invalid_input(f"persona not found: {persona_id}") + return {"persona": dict(record)} + + async def _persona_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + personas = [ + dict(self._persona_store[persona_id]) + for persona_id in sorted(self._persona_store.keys()) + ] + return {"personas": personas} + + async def _persona_create( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.create requires persona object") + persona_id = str(raw_persona.get("persona_id", "")).strip() + if not persona_id: + raise AstrBotError.invalid_input("persona.create requires persona_id") + if persona_id in self._persona_store: + raise AstrBotError.invalid_input(f"persona already exists: {persona_id}") + now = self._now_iso() + record = { + "persona_id": persona_id, + "system_prompt": str(raw_persona.get("system_prompt", "")), + "begin_dialogs": self._normalize_persona_dialogs_payload( + raw_persona.get("begin_dialogs") + ), + "tools": ( + [str(item) for item in raw_persona.get("tools", [])] + if isinstance(raw_persona.get("tools"), list) + else None + ), + "skills": ( + [str(item) for item in raw_persona.get("skills", [])] + if isinstance(raw_persona.get("skills"), list) + else None + ), + "custom_error_message": ( + str(raw_persona.get("custom_error_message")) + if raw_persona.get("custom_error_message") is not None + else None + ), + "folder_id": ( + str(raw_persona.get("folder_id")) + if raw_persona.get("folder_id") is not None + else None + ), + "sort_order": int(raw_persona.get("sort_order", 0)), + "created_at": now, + "updated_at": now, + } + self._persona_store[persona_id] = record + return {"persona": dict(record)} + + async def _persona_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + persona_id = str(payload.get("persona_id", "")).strip() + record = self._persona_store.get(persona_id) + if record is None: + return {"persona": None} + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.update requires persona object") + if ( + "system_prompt" in raw_persona + and raw_persona.get("system_prompt") is not None + ): + record["system_prompt"] = str(raw_persona.get("system_prompt", "")) + if "begin_dialogs" in raw_persona: + begin_dialogs = raw_persona.get("begin_dialogs") + record["begin_dialogs"] = ( + self._normalize_persona_dialogs_payload(begin_dialogs) + if begin_dialogs is not None + else [] + ) + if "tools" in raw_persona: + tools = raw_persona.get("tools") + record["tools"] = ( + [str(item) for item in tools] if isinstance(tools, list) else None + ) + if "skills" in raw_persona: + skills = raw_persona.get("skills") + record["skills"] = ( + [str(item) for item in skills] if isinstance(skills, list) else None + ) + if "custom_error_message" in raw_persona: + custom_error_message = raw_persona.get("custom_error_message") + record["custom_error_message"] = ( + str(custom_error_message) if custom_error_message is not None else None + ) + record["updated_at"] = self._now_iso() + return {"persona": dict(record)} + + async def _persona_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + persona_id = str(payload.get("persona_id", "")).strip() + if persona_id not in self._persona_store: + raise AstrBotError.invalid_input(f"persona not found: {persona_id}") + del self._persona_store[persona_id] + return {} + + def _register_persona_capabilities(self) -> None: + self.register( + self._builtin_descriptor("persona.get", "获取人格"), + call_handler=self._persona_get, + ) + self.register( + self._builtin_descriptor("persona.list", "列出人格"), + call_handler=self._persona_list, + ) + self.register( + self._builtin_descriptor("persona.create", "创建人格"), + call_handler=self._persona_create, + ) + self.register( + self._builtin_descriptor("persona.update", "更新人格"), + call_handler=self._persona_update, + ) + self.register( + self._builtin_descriptor("persona.delete", "删除人格"), + call_handler=self._persona_delete, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py new file mode 100644 index 0000000000..dbc565a013 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class PlatformCapabilityMixin(CapabilityRouterBridgeBase): + async def _platform_send( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, target = self._resolve_target(payload) + self._require_platform_support_for_session("platform.send", session) + text = str(payload.get("text", "")) + message_id = f"msg_{len(self.sent_messages) + 1}" + sent: dict[str, Any] = { + "message_id": message_id, + "session": session, + "text": text, + } + if target is not None: + sent["target"] = target + self.sent_messages.append(sent) + return {"message_id": message_id} + + async def _platform_send_image( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, target = self._resolve_target(payload) + self._require_platform_support_for_session("platform.send_image", session) + image_url = str(payload.get("image_url", "")) + message_id = f"img_{len(self.sent_messages) + 1}" + sent: dict[str, Any] = { + "message_id": message_id, + "session": session, + "image_url": image_url, + } + if target is not None: + sent["target"] = target + self.sent_messages.append(sent) + return {"message_id": message_id} + + async def _platform_send_chain( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, target = self._resolve_target(payload) + self._require_platform_support_for_session("platform.send_chain", session) + chain = payload.get("chain") + if not isinstance(chain, list) or not all( + isinstance(item, dict) for item in chain + ): + raise AstrBotError.invalid_input( + "platform.send_chain 的 chain 必须是 object 数组" + ) + message_id = f"chain_{len(self.sent_messages) + 1}" + sent: dict[str, Any] = { + "message_id": message_id, + "session": session, + "chain": [dict(item) for item in chain], + } + if target is not None: + sent["target"] = target + self.sent_messages.append(sent) + return {"message_id": message_id} + + async def _platform_send_by_session( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + chain = payload.get("chain") + if not isinstance(chain, list) or not all( + isinstance(item, dict) for item in chain + ): + raise AstrBotError.invalid_input( + "platform.send_by_session 的 chain 必须是 object 数组" + ) + session = str(payload.get("session", "")) + self._require_platform_support_for_session("platform.send_by_session", session) + message_id = f"proactive_{len(self.sent_messages) + 1}" + self.sent_messages.append( + { + "message_id": message_id, + "session": session, + "chain": [dict(item) for item in chain], + } + ) + return {"message_id": message_id} + + async def _platform_get_group( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, _target = self._resolve_target(payload) + return {"group": self._mock_group_payload(session)} + + async def _platform_get_members( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, _target = self._resolve_target(payload) + group = self._mock_group_payload(session) + if group is None: + return {"members": []} + return {"members": list(group.get("members", []))} + + async def _platform_list_instances( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("platform.list_instances") + return { + "platforms": [ + { + "id": str(item.get("id", "")), + "name": str(item.get("name", "")), + "type": str(item.get("type", "")), + "status": str(item.get("status", "unknown")), + } + for item in self.get_platform_instances() + if isinstance(item, dict) + and self._plugin_supports_platform(plugin_id, str(item.get("type", ""))) + ] + } + + def _register_platform_capabilities(self) -> None: + self.register( + self._builtin_descriptor("platform.send", "发送消息"), + call_handler=self._platform_send, + ) + self.register( + self._builtin_descriptor("platform.send_image", "发送图片"), + call_handler=self._platform_send_image, + ) + self.register( + self._builtin_descriptor("platform.send_chain", "发送消息链"), + call_handler=self._platform_send_chain, + ) + self.register( + self._builtin_descriptor( + "platform.send_by_session", "按会话主动发送消息链" + ), + call_handler=self._platform_send_by_session, + ) + self.register( + self._builtin_descriptor("platform.get_group", "获取当前群信息"), + call_handler=self._platform_get_group, + ) + self.register( + self._builtin_descriptor("platform.get_members", "获取群成员"), + call_handler=self._platform_get_members, + ) + self.register( + self._builtin_descriptor("platform.list_instances", "列出平台实例元信息"), + call_handler=self._platform_list_instances, + ) + + async def _platform_manager_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_by_id") + platform_id = str(payload.get("platform_id", "")).strip() + platform = next( + ( + dict(item) + for item in self._platform_instances + if str(item.get("id", "")) == platform_id + ), + None, + ) + return {"platform": platform} + + async def _platform_manager_clear_errors( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.clear_errors") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + item["errors"] = [] + item["last_error"] = None + if str(item.get("status", "")) == "error": + item["status"] = "running" + break + return {} + + async def _platform_manager_get_stats( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_stats") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + stats = item.get("stats") + if isinstance(stats, dict): + return {"stats": dict(stats)} + errors = item.get("errors") + last_error = item.get("last_error") + meta = item.get("meta") + return { + "stats": { + "id": platform_id, + "type": str(item.get("type", "")), + "display_name": str(item.get("name", platform_id)), + "status": str(item.get("status", "pending")), + "started_at": item.get("started_at"), + "error_count": len(errors) if isinstance(errors, list) else 0, + "last_error": dict(last_error) + if isinstance(last_error, dict) + else None, + "unified_webhook": bool(item.get("unified_webhook", False)), + "meta": dict(meta) if isinstance(meta, dict) else {}, + } + } + return {"stats": None} + + def _register_platform_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "platform.manager.get_by_id", + "按 ID 获取平台管理快照", + ), + call_handler=self._platform_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "platform.manager.clear_errors", + "清除平台错误", + ), + call_handler=self._platform_manager_clear_errors, + ) + self.register( + self._builtin_descriptor( + "platform.manager.get_stats", + "获取平台统计信息", + ), + call_handler=self._platform_manager_get_stats, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py new file mode 100644 index 0000000000..937373a0a0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py @@ -0,0 +1,1080 @@ +from __future__ import annotations + +import asyncio +import base64 +from collections.abc import AsyncIterator +from typing import Any + +from ....errors import AstrBotError +from ..._streaming import StreamExecution +from ..bridge_base import ( + _MOCK_EMBEDDING_DIM, + CapabilityRouterBridgeBase, + _mock_embedding_vector, +) + + +class ProviderCapabilityMixin(CapabilityRouterBridgeBase): + @staticmethod + def _active_local_mcp_tool_names(plugin: Any | None) -> list[str]: + if plugin is None: + return [] + local_tools: list[str] = [] + for server in plugin.local_mcp_servers.values(): + if not bool(server.get("active", False)): + continue + if not bool(server.get("running", False)): + continue + server_name = str(server.get("name", "")).strip() + if not server_name: + continue + for tool_name in server.get("tools", []): + if not isinstance(tool_name, str) or not tool_name.strip(): + continue + local_tools.append(f"mcp.{server_name}.{tool_name}") + return local_tools + + def _provider_payload( + self, kind: str, provider_id: str | None + ) -> dict[str, Any] | None: + if not provider_id: + return None + for item in self._provider_catalog.get(kind, []): + if str(item.get("id", "")) == provider_id: + return dict(item) + return None + + def _provider_payload_by_id(self, provider_id: str) -> dict[str, Any] | None: + normalized = str(provider_id).strip() + if not normalized: + return None + for items in self._provider_catalog.values(): + for item in items: + if str(item.get("id", "")) == normalized: + return dict(item) + return None + + @staticmethod + def _provider_kind_from_type(provider_type: str) -> str: + mapping = { + "chat_completion": "chat", + "text_to_speech": "tts", + "speech_to_text": "stt", + "embedding": "embedding", + "rerank": "rerank", + } + normalized = str(provider_type).strip().lower() + if normalized not in mapping: + raise AstrBotError.invalid_input(f"unknown provider_type: {provider_type}") + return mapping[normalized] + + def _provider_config_by_id(self, provider_id: str) -> dict[str, Any] | None: + record = self._provider_configs.get(str(provider_id).strip()) + return dict(record) if isinstance(record, dict) else None + + @staticmethod + def _managed_provider_record( + payload: dict[str, Any], + *, + loaded: bool, + ) -> dict[str, Any]: + return { + "id": str(payload.get("id", "")), + "model": ( + str(payload.get("model")) if payload.get("model") is not None else None + ), + "type": str(payload.get("type", "")), + "provider_type": str(payload.get("provider_type", "chat_completion")), + "loaded": bool(loaded), + "enabled": bool(payload.get("enable", True)), + "provider_source_id": ( + str(payload.get("provider_source_id")) + if payload.get("provider_source_id") is not None + else None + ), + } + + def _managed_provider_record_by_id(self, provider_id: str) -> dict[str, Any] | None: + provider = self._provider_payload_by_id(provider_id) + if provider is not None: + config = self._provider_config_by_id(provider_id) or provider + merged = dict(provider) + merged.update( + { + "enable": config.get("enable", True), + "provider_source_id": config.get("provider_source_id"), + } + ) + return self._managed_provider_record(merged, loaded=True) + config = self._provider_config_by_id(provider_id) + if config is None: + return None + return self._managed_provider_record(config, loaded=False) + + def _emit_provider_change( + self, + provider_id: str, + provider_type: str, + umo: str | None, + ) -> None: + event = { + "provider_id": str(provider_id), + "provider_type": str(provider_type), + "umo": str(umo) if umo is not None else None, + } + for queue in list(self._provider_change_subscriptions.values()): + queue.put_nowait(dict(event)) + + def _require_reserved_plugin(self, capability_name: str) -> str: + plugin_id = self._require_caller_plugin_id(capability_name) + plugin = self._plugins.get(plugin_id) + if plugin is not None and bool(plugin.metadata.get("reserved", False)): + return plugin_id + if plugin_id in {"system", "__system__"}: + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} is restricted to reserved/system plugins" + ) + + def _provider_entry( + self, + payload: dict[str, Any], + capability_name: str, + expected_kind: str | None = None, + ) -> dict[str, Any]: + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + f"{capability_name} requires provider_id", + ) + provider = self._provider_payload_by_id(provider_id) + if provider is None: + raise AstrBotError.invalid_input( + f"{capability_name} unknown provider_id: {provider_id}", + ) + if ( + expected_kind is not None + and str(provider.get("provider_type")) != expected_kind + ): + raise AstrBotError.invalid_input( + f"{capability_name} requires a {expected_kind} provider", + ) + return provider + + async def _provider_get_using( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider_id = self._active_provider_ids.get("chat") + return {"provider": self._provider_payload("chat", provider_id)} + + async def _provider_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + return { + "provider": self._provider_payload_by_id( + str(payload.get("provider_id", "")) + ) + } + + async def _provider_get_current_chat_provider_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + return {"provider_id": self._active_provider_ids.get("chat")} + + def _provider_list_payload(self, kind: str) -> dict[str, Any]: + return { + "providers": [dict(item) for item in self._provider_catalog.get(kind, [])] + } + + async def _provider_list_all( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("chat") + + async def _provider_list_all_tts( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("tts") + + async def _provider_list_all_stt( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("stt") + + async def _provider_list_all_embedding( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("embedding") + + async def _provider_list_all_rerank( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("rerank") + + async def _provider_get_using_tts( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider_id = self._active_provider_ids.get("tts") + return {"provider": self._provider_payload("tts", provider_id)} + + async def _provider_get_using_stt( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider_id = self._active_provider_ids.get("stt") + return {"provider": self._provider_payload("stt", provider_id)} + + async def _provider_stt_get_text( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._provider_entry( + payload, + "provider.stt.get_text", + "speech_to_text", + ) + return {"text": f"Mock transcript: {str(payload.get('audio_url', ''))}"} + + async def _provider_tts_get_audio( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.tts.get_audio", + "text_to_speech", + ) + return { + "audio_path": ( + f"mock://tts/{provider.get('id', '')}/{str(payload.get('text', ''))}" + ) + } + + async def _provider_tts_support_stream( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.tts.support_stream", + "text_to_speech", + ) + return {"supported": bool(provider.get("support_stream", True))} + + async def _provider_tts_get_audio_stream( + self, + _request_id: str, + payload: dict[str, Any], + token, + ) -> StreamExecution: + self._provider_entry( + payload, + "provider.tts.get_audio_stream", + "text_to_speech", + ) + text = payload.get("text") + text_chunks = payload.get("text_chunks") + if isinstance(text, str): + chunks = [text] + elif isinstance(text_chunks, list) and text_chunks: + chunks = [str(item) for item in text_chunks] + else: + raise AstrBotError.invalid_input( + "provider.tts.get_audio_stream requires text or text_chunks" + ) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + for chunk in chunks: + token.raise_if_cancelled() + await asyncio.sleep(0) + yield { + "audio_base64": base64.b64encode( + f"mock-audio:{chunk}".encode() + ).decode("ascii"), + "text": chunk, + } + + return StreamExecution( + iterator=iterator(), + finalize=lambda items: ( + items[-1] if items else {"audio_base64": "", "text": None} + ), + ) + + async def _provider_embedding_get_embedding( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.embedding.get_embedding", + "embedding", + ) + return { + "embedding": _mock_embedding_vector( + str(payload.get("text", "")), + provider_id=str(provider.get("id", "")), + ) + } + + async def _provider_embedding_get_embeddings( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.embedding.get_embeddings", + "embedding", + ) + texts = payload.get("texts") + if not isinstance(texts, list): + raise AstrBotError.invalid_input( + "provider.embedding.get_embeddings requires texts", + ) + return { + "embeddings": [ + _mock_embedding_vector( + str(text), + provider_id=str(provider.get("id", "")), + ) + for text in texts + ], + } + + async def _provider_embedding_get_dim( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._provider_entry( + payload, + "provider.embedding.get_dim", + "embedding", + ) + return {"dim": _MOCK_EMBEDDING_DIM} + + async def _provider_rerank_rerank( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._provider_entry( + payload, + "provider.rerank.rerank", + "rerank", + ) + documents = payload.get("documents") + if not isinstance(documents, list): + raise AstrBotError.invalid_input( + "provider.rerank.rerank requires documents", + ) + scored = [ + { + "index": index, + "score": 1.0, + "document": str(raw_document), + } + for index, raw_document in enumerate(documents) + ] + top_n = payload.get("top_n") + if top_n is not None: + scored = scored[: max(int(top_n), 0)] + return {"results": scored} + + async def _provider_manager_set( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.set") + provider_id = str(payload.get("provider_id", "")).strip() + provider_type = str(payload.get("provider_type", "")).strip() + kind = self._provider_kind_from_type(provider_type) + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.set requires provider_id" + ) + if self._provider_payload(kind, provider_id) is None: + raise AstrBotError.invalid_input( + f"provider.manager.set unknown provider_id: {provider_id}" + ) + self._active_provider_ids[kind] = provider_id + self._emit_provider_change( + provider_id, + provider_type, + str(payload.get("umo")) if payload.get("umo") is not None else None, + ) + return {} + + async def _provider_manager_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.get_by_id") + return { + "provider": self._managed_provider_record_by_id( + str(payload.get("provider_id", "")) + ) + } + + async def _provider_manager_get_merged_provider_config( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.get_merged_provider_config") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config requires provider_id" + ) + provider = self._provider_payload_by_id(provider_id) + config = self._provider_config_by_id(provider_id) + if provider is None and config is None: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config " + f"unknown provider_id: {provider_id}" + ) + if provider is None: + return {"config": dict(config) if isinstance(config, dict) else config} + if config is None: + return {"config": dict(provider)} + merged_config = dict(provider) + merged_config.update(config) + return {"config": merged_config} + + @staticmethod + def _normalize_provider_config_object( + payload: Any, + capability_name: str, + field_name: str, + ) -> dict[str, Any]: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + f"{capability_name} requires {field_name} object" + ) + return dict(payload) + + async def _provider_manager_load( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.load") + provider_config = self._normalize_provider_config_object( + payload.get("provider_config"), + "provider.manager.load", + "provider_config", + ) + provider_id = str(provider_config.get("id", "")).strip() + provider_type = str(provider_config.get("provider_type", "")).strip() + kind = self._provider_kind_from_type(provider_type) + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.load requires provider id" + ) + if bool(provider_config.get("enable", True)): + record = { + "id": provider_id, + "model": ( + str(provider_config.get("model")) + if provider_config.get("model") is not None + else None + ), + "type": str(provider_config.get("type", "")), + "provider_type": provider_type, + } + self._provider_catalog[kind] = [ + item + for item in self._provider_catalog.get(kind, []) + if str(item.get("id", "")) != provider_id + ] + self._provider_catalog[kind].append(record) + self._emit_provider_change(provider_id, provider_type, None) + return { + "provider": self._managed_provider_record( + provider_config, + loaded=bool(provider_config.get("enable", True)), + ) + } + + async def _provider_manager_terminate( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.terminate") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.terminate requires provider_id" + ) + managed = self._managed_provider_record_by_id(provider_id) + if managed is None: + raise AstrBotError.invalid_input( + f"provider.manager.terminate unknown provider_id: {provider_id}" + ) + kind = self._provider_kind_from_type(str(managed.get("provider_type", ""))) + self._provider_catalog[kind] = [ + item + for item in self._provider_catalog.get(kind, []) + if str(item.get("id", "")) != provider_id + ] + if self._active_provider_ids.get(kind) == provider_id: + catalog = self._provider_catalog.get(kind, []) + self._active_provider_ids[kind] = ( + str(catalog[0].get("id")) if catalog else None + ) + self._emit_provider_change( + provider_id, str(managed.get("provider_type", "")), None + ) + return {} + + async def _provider_manager_create( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.create") + provider_config = self._normalize_provider_config_object( + payload.get("provider_config"), + "provider.manager.create", + "provider_config", + ) + provider_id = str(provider_config.get("id", "")).strip() + provider_type = str(provider_config.get("provider_type", "")).strip() + kind = self._provider_kind_from_type(provider_type) + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.create requires provider id" + ) + self._provider_configs[provider_id] = dict(provider_config) + if bool(provider_config.get("enable", True)): + self._provider_catalog[kind] = [ + item + for item in self._provider_catalog.get(kind, []) + if str(item.get("id", "")) != provider_id + ] + self._provider_catalog[kind].append( + { + "id": provider_id, + "model": ( + str(provider_config.get("model")) + if provider_config.get("model") is not None + else None + ), + "type": str(provider_config.get("type", "")), + "provider_type": provider_type, + } + ) + self._emit_provider_change(provider_id, provider_type, None) + return {"provider": self._managed_provider_record_by_id(provider_id)} + + async def _provider_manager_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.update") + origin_provider_id = str(payload.get("origin_provider_id", "")).strip() + new_config = self._normalize_provider_config_object( + payload.get("new_config"), + "provider.manager.update", + "new_config", + ) + if not origin_provider_id: + raise AstrBotError.invalid_input( + "provider.manager.update requires origin_provider_id" + ) + current = self._provider_config_by_id(origin_provider_id) + if current is None: + current = self._managed_provider_record_by_id(origin_provider_id) + if current is None: + raise AstrBotError.invalid_input( + f"provider.manager.update unknown provider_id: {origin_provider_id}" + ) + target_provider_id = str(new_config.get("id") or origin_provider_id).strip() + provider_type = str( + new_config.get("provider_type") or current.get("provider_type", "") + ).strip() + kind = self._provider_kind_from_type(provider_type) + self._provider_configs.pop(origin_provider_id, None) + merged = dict(current) + merged.update(new_config) + merged["id"] = target_provider_id + merged["provider_type"] = provider_type + self._provider_configs[target_provider_id] = merged + for catalog_kind, items in list(self._provider_catalog.items()): + self._provider_catalog[catalog_kind] = [ + item for item in items if str(item.get("id", "")) != origin_provider_id + ] + if bool(merged.get("enable", True)): + self._provider_catalog[kind].append( + { + "id": target_provider_id, + "model": ( + str(merged.get("model")) + if merged.get("model") is not None + else None + ), + "type": str(merged.get("type", "")), + "provider_type": provider_type, + } + ) + for active_kind, active_id in list(self._active_provider_ids.items()): + if active_id == origin_provider_id: + self._active_provider_ids[active_kind] = ( + target_provider_id if active_kind == kind else None + ) + self._emit_provider_change(target_provider_id, provider_type, None) + return {"provider": self._managed_provider_record_by_id(target_provider_id)} + + async def _provider_manager_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.delete") + provider_id = ( + str(payload.get("provider_id")).strip() + if payload.get("provider_id") is not None + else None + ) + provider_source_id = ( + str(payload.get("provider_source_id")).strip() + if payload.get("provider_source_id") is not None + else None + ) + if not provider_id and not provider_source_id: + raise AstrBotError.invalid_input( + "provider.manager.delete requires provider_id or provider_source_id" + ) + deleted: list[dict[str, Any]] = [] + if provider_id: + record = self._managed_provider_record_by_id(provider_id) + if record is not None: + deleted.append(record) + self._provider_configs.pop(provider_id, None) + else: + for record_id, record in list(self._provider_configs.items()): + if ( + str(record.get("provider_source_id", "")).strip() + != provider_source_id + ): + continue + deleted_record = self._managed_provider_record_by_id(record_id) + if deleted_record is not None: + deleted.append(deleted_record) + self._provider_configs.pop(record_id, None) + deleted_ids = {str(item.get("id", "")) for item in deleted} + for kind, items in list(self._provider_catalog.items()): + self._provider_catalog[kind] = [ + item for item in items if str(item.get("id", "")) not in deleted_ids + ] + if self._active_provider_ids.get(kind) in deleted_ids: + catalog = self._provider_catalog.get(kind, []) + self._active_provider_ids[kind] = ( + str(catalog[0].get("id")) if catalog else None + ) + for record in deleted: + self._emit_provider_change( + str(record.get("id", "")), + str(record.get("provider_type", "")), + None, + ) + return {} + + async def _provider_manager_get_insts( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.get_insts") + return { + "providers": [ + self._managed_provider_record(item, loaded=True) + for item in self._provider_catalog.get("chat", []) + ] + } + + async def _provider_manager_watch_changes( + self, request_id: str, _payload: dict[str, Any], _token + ) -> StreamExecution: + self._require_reserved_plugin("provider.manager.watch_changes") + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + self._provider_change_subscriptions[request_id] = queue + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + yield await queue.get() + finally: + self._provider_change_subscriptions.pop(request_id, None) + + return StreamExecution( + iterator=iterator(), + finalize=lambda _chunks: {}, + collect_chunks=False, + ) + + async def _platform_manager_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_by_id") + platform_id = str(payload.get("platform_id", "")).strip() + platform = next( + ( + dict(item) + for item in self._platform_instances + if str(item.get("id", "")) == platform_id + ), + None, + ) + return {"platform": platform} + + async def _platform_manager_clear_errors( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.clear_errors") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + item["errors"] = [] + item["last_error"] = None + if str(item.get("status", "")) == "error": + item["status"] = "running" + break + return {} + + async def _platform_manager_get_stats( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_stats") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + stats = item.get("stats") + if isinstance(stats, dict): + return {"stats": dict(stats)} + errors = item.get("errors") + last_error = item.get("last_error") + meta = item.get("meta") + return { + "stats": { + "id": platform_id, + "type": str(item.get("type", "")), + "display_name": str(item.get("name", platform_id)), + "status": str(item.get("status", "pending")), + "started_at": item.get("started_at"), + "error_count": len(errors) if isinstance(errors, list) else 0, + "last_error": dict(last_error) + if isinstance(last_error, dict) + else None, + "unified_webhook": bool(item.get("unified_webhook", False)), + "meta": dict(meta) if isinstance(meta, dict) else {}, + } + } + return {"stats": None} + + async def _llm_tool_manager_get( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.get") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"registered": [], "active": []} + registered = [dict(item) for item in plugin.llm_tools.values()] + active = [ + dict(item) + for name, item in plugin.llm_tools.items() + if name in plugin.active_llm_tools + ] + return {"registered": registered, "active": active} + + async def _llm_tool_manager_activate( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.activate") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"activated": False} + name = str(payload.get("name", "")) + spec = plugin.llm_tools.get(name) + if spec is None: + return {"activated": False} + spec["active"] = True + plugin.active_llm_tools.add(name) + return {"activated": True} + + async def _llm_tool_manager_deactivate( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.deactivate") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"deactivated": False} + name = str(payload.get("name", "")) + spec = plugin.llm_tools.get(name) + if spec is None: + return {"deactivated": False} + spec["active"] = False + plugin.active_llm_tools.discard(name) + return {"deactivated": True} + + async def _llm_tool_manager_add( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.add") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"names": []} + tools_payload = payload.get("tools") + if not isinstance(tools_payload, list): + raise AstrBotError.invalid_input("llm_tool.manager.add 的 tools 必须是数组") + names: list[str] = [] + for item in tools_payload: + if not isinstance(item, dict): + continue + name = str(item.get("name", "")).strip() + if not name: + continue + plugin.llm_tools[name] = dict(item) + if bool(item.get("active", True)): + plugin.active_llm_tools.add(name) + else: + plugin.active_llm_tools.discard(name) + names.append(name) + return {"names": names} + + async def _llm_tool_manager_remove( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.remove") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"removed": False} + name = str(payload.get("name", "")).strip() + removed = plugin.llm_tools.pop(name, None) is not None + plugin.active_llm_tools.discard(name) + return {"removed": removed} + + async def _agent_registry_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("agent.registry.list") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"agents": []} + return {"agents": [dict(item) for item in plugin.agents.values()]} + + async def _agent_registry_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("agent.registry.get") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"agent": None} + agent = plugin.agents.get(str(payload.get("name", ""))) + return {"agent": dict(agent) if isinstance(agent, dict) else None} + + async def _agent_tool_loop_run( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("agent.tool_loop.run") + plugin = self._plugins.get(plugin_id) + requested_tools = payload.get("tool_names") + active_tools: list[str] = [] + if plugin is not None: + local_tools = self._active_local_mcp_tool_names(plugin) + if isinstance(requested_tools, list) and requested_tools: + active_tools = [ + name + for name in (str(item) for item in requested_tools) + if name in plugin.active_llm_tools or name in local_tools + ] + else: + active_tools = sorted([*plugin.active_llm_tools, *local_tools]) + prompt = str(payload.get("prompt", "") or "") + suffix = "" + if active_tools: + suffix = f" tools={','.join(active_tools)}" + return { + "text": f"Mock tool loop: {prompt}{suffix}".strip(), + "usage": { + "input_tokens": len(prompt), + "output_tokens": len(prompt) + len(suffix), + }, + "finish_reason": "stop", + "tool_calls": [], + "role": "assistant", + "reasoning_content": None, + "reasoning_signature": None, + } + + def _register_provider_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.get_using", "获取当前聊天 Provider"), + call_handler=self._provider_get_using, + ) + self.register( + self._builtin_descriptor("provider.get_by_id", "按 ID 获取 Provider"), + call_handler=self._provider_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.get_current_chat_provider_id", + "获取当前聊天 Provider ID", + ), + call_handler=self._provider_get_current_chat_provider_id, + ) + self.register( + self._builtin_descriptor("provider.list_all", "列出聊天 Providers"), + call_handler=self._provider_list_all, + ) + self.register( + self._builtin_descriptor("provider.list_all_tts", "列出 TTS Providers"), + call_handler=self._provider_list_all_tts, + ) + self.register( + self._builtin_descriptor("provider.list_all_stt", "列出 STT Providers"), + call_handler=self._provider_list_all_stt, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_embedding", + "列出 Embedding Providers", + ), + call_handler=self._provider_list_all_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_rerank", + "列出 Rerank Providers", + ), + call_handler=self._provider_list_all_rerank, + ) + self.register( + self._builtin_descriptor("provider.get_using_tts", "获取当前 TTS Provider"), + call_handler=self._provider_get_using_tts, + ) + self.register( + self._builtin_descriptor("provider.get_using_stt", "获取当前 STT Provider"), + call_handler=self._provider_get_using_stt, + ) + self.register( + self._builtin_descriptor("provider.stt.get_text", "STT 转写"), + call_handler=self._provider_stt_get_text, + ) + self.register( + self._builtin_descriptor("provider.tts.get_audio", "TTS 合成音频"), + call_handler=self._provider_tts_get_audio, + ) + self.register( + self._builtin_descriptor( + "provider.tts.support_stream", + "检查 TTS 流式支持", + ), + call_handler=self._provider_tts_support_stream, + ) + self.register( + self._builtin_descriptor( + "provider.tts.get_audio_stream", + "流式 TTS 音频输出", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_tts_get_audio_stream, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embedding", + "获取单条向量", + ), + call_handler=self._provider_embedding_get_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embeddings", + "批量获取向量", + ), + call_handler=self._provider_embedding_get_embeddings, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_dim", + "获取向量维度", + ), + call_handler=self._provider_embedding_get_dim, + ) + self.register( + self._builtin_descriptor("provider.rerank.rerank", "文档重排序"), + call_handler=self._provider_rerank_rerank, + ) + + def _register_provider_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.manager.set", "设置当前 Provider"), + call_handler=self._provider_manager_set, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_by_id", + "按 ID 获取 Provider 管理记录", + ), + call_handler=self._provider_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_merged_provider_config", + "获取 Provider 合并配置", + ), + call_handler=self._provider_manager_get_merged_provider_config, + ) + self.register( + self._builtin_descriptor("provider.manager.load", "运行时加载 Provider"), + call_handler=self._provider_manager_load, + ) + self.register( + self._builtin_descriptor( + "provider.manager.terminate", + "终止已加载的 Provider", + ), + call_handler=self._provider_manager_terminate, + ) + self.register( + self._builtin_descriptor("provider.manager.create", "创建 Provider"), + call_handler=self._provider_manager_create, + ) + self.register( + self._builtin_descriptor("provider.manager.update", "更新 Provider"), + call_handler=self._provider_manager_update, + ) + self.register( + self._builtin_descriptor("provider.manager.delete", "删除 Provider"), + call_handler=self._provider_manager_delete, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_insts", + "列出已加载聊天 Provider", + ), + call_handler=self._provider_manager_get_insts, + ) + self.register( + self._builtin_descriptor( + "provider.manager.watch_changes", + "订阅 Provider 变更", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_manager_watch_changes, + ) + + def _register_agent_tool_capabilities(self) -> None: + self.register( + self._builtin_descriptor("llm_tool.manager.get", "获取 LLM 工具状态"), + call_handler=self._llm_tool_manager_get, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.activate", "激活 LLM 工具"), + call_handler=self._llm_tool_manager_activate, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.deactivate", "停用 LLM 工具"), + call_handler=self._llm_tool_manager_deactivate, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.add", "动态添加 LLM 工具"), + call_handler=self._llm_tool_manager_add, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.remove", "动态移除 LLM 工具"), + call_handler=self._llm_tool_manager_remove, + ) + self.register( + self._builtin_descriptor("agent.tool_loop.run", "运行 mock tool loop"), + call_handler=self._agent_tool_loop_run, + ) + self.register( + self._builtin_descriptor("agent.registry.list", "列出 Agent 元数据"), + call_handler=self._agent_registry_list, + ) + self.register( + self._builtin_descriptor("agent.registry.get", "获取 Agent 元数据"), + call_handler=self._agent_registry_get, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py new file mode 100644 index 0000000000..e56f979e9e --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class SessionCapabilityMixin(CapabilityRouterBridgeBase): + async def _session_plugin_is_enabled( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + plugin_name = str(payload.get("plugin_name", "")) + config = self._session_plugin_config(session) + enabled_plugins = { + str(item) for item in config.get("enabled_plugins", []) if str(item).strip() + } + disabled_plugins = { + str(item) + for item in config.get("disabled_plugins", []) + if str(item).strip() + } + if plugin_name in enabled_plugins: + return {"enabled": True} + return {"enabled": plugin_name not in disabled_plugins} + + async def _session_plugin_filter_handlers( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + handlers = payload.get("handlers") + if not isinstance(handlers, list): + raise AstrBotError.invalid_input( + "session.plugin.filter_handlers 的 handlers 必须是 object 数组" + ) + disabled_plugins = { + str(item) + for item in self._session_plugin_config(session).get("disabled_plugins", []) + if str(item).strip() + } + reserved_plugins = { + str(plugin.metadata.get("name", "")) + for plugin in self._plugins.values() + if bool(plugin.metadata.get("reserved", False)) + } + filtered = [] + for item in handlers: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_name", "")) + if ( + plugin_name + and plugin_name in disabled_plugins + and plugin_name not in reserved_plugins + ): + continue + filtered.append(dict(item)) + return {"handlers": filtered} + + async def _session_service_is_llm_enabled( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + return {"enabled": bool(config.get("llm_enabled", True))} + + async def _session_service_set_llm_status( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + config["llm_enabled"] = bool(payload.get("enabled", False)) + self._session_service_configs[session] = config + return {} + + async def _session_service_is_tts_enabled( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + return {"enabled": bool(config.get("tts_enabled", True))} + + async def _session_service_set_tts_status( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + config["tts_enabled"] = bool(payload.get("enabled", False)) + self._session_service_configs[session] = config + return {} + + def _register_session_capabilities(self) -> None: + self.register( + self._builtin_descriptor("session.plugin.is_enabled", "获取会话级插件开关"), + call_handler=self._session_plugin_is_enabled, + ) + self.register( + self._builtin_descriptor( + "session.plugin.filter_handlers", + "按会话过滤 handler 元数据", + ), + call_handler=self._session_plugin_filter_handlers, + ) + self.register( + self._builtin_descriptor( + "session.service.is_llm_enabled", + "获取会话级 LLM 开关", + ), + call_handler=self._session_service_is_llm_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_llm_status", + "写入会话级 LLM 开关", + ), + call_handler=self._session_service_set_llm_status, + ) + self.register( + self._builtin_descriptor( + "session.service.is_tts_enabled", + "获取会话级 TTS 开关", + ), + call_handler=self._session_service_is_tts_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_tts_status", + "写入会话级 TTS 开关", + ), + call_handler=self._session_service_set_tts_status, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py new file mode 100644 index 0000000000..942f696989 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class SkillCapabilityMixin(CapabilityRouterBridgeBase): + def _register_skill_capabilities(self) -> None: + self.register( + self._builtin_descriptor("skill.register", "注册插件 skill"), + call_handler=self._skill_register, + ) + self.register( + self._builtin_descriptor("skill.unregister", "注销插件 skill"), + call_handler=self._skill_unregister, + ) + self.register( + self._builtin_descriptor("skill.list", "列出插件 skill"), + call_handler=self._skill_list, + ) + + async def _skill_register( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, str]: + plugin_id = self._require_caller_plugin_id("skill.register") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + + skill_name = str(payload.get("name", "")).strip() + if not skill_name: + raise AstrBotError.invalid_input("skill.register requires name") + skill_path = str(payload.get("path", "")).strip() + if not skill_path: + raise AstrBotError.invalid_input("skill.register requires path") + + path_obj = Path(skill_path) + skill_dir = path_obj.parent if path_obj.name == "SKILL.md" else path_obj + + entry = { + "name": skill_name, + "description": str(payload.get("description", "") or ""), + "path": skill_path, + "skill_dir": str(skill_dir), + } + plugin.skills[skill_name] = entry + return dict(entry) + + async def _skill_unregister( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, bool]: + plugin_id = self._require_caller_plugin_id("skill.unregister") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + removed = ( + plugin.skills.pop(str(payload.get("name", "")).strip(), None) is not None + ) + return {"removed": removed} + + async def _skill_list( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, list[dict[str, str]]]: + plugin_id = self._require_caller_plugin_id("skill.list") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + return { + "skills": [ + dict(plugin.skills[name]) for name in sorted(plugin.skills.keys()) + ] + } diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py new file mode 100644 index 0000000000..12012e5699 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import json +import uuid +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import ( + CapabilityRouterBridgeBase, + _clone_chain_payload, + _clone_target_payload, +) + + +class SystemCapabilityMixin(CapabilityRouterBridgeBase): + @staticmethod + def _overlay_request_id(request_id: str, payload: dict[str, Any]) -> str: + scope_request_id = payload.get("_request_scope_id") + if isinstance(scope_request_id, str) and scope_request_id.strip(): + return scope_request_id + return request_id + + def _register_system_capabilities(self) -> None: + self.register( + self._builtin_descriptor("system.get_data_dir", "获取插件数据目录"), + call_handler=self._system_get_data_dir, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.text_to_image", "文本转图片"), + call_handler=self._system_text_to_image, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.html_render", "渲染 HTML 模板"), + call_handler=self._system_html_render, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.register", "注册文件令牌"), + call_handler=self._system_file_register, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.handle", "解析文件令牌"), + call_handler=self._system_file_handle, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.register", + "注册会话等待器", + ), + call_handler=self._system_session_waiter_register, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.unregister", + "注销会话等待器", + ), + call_handler=self._system_session_waiter_unregister, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.react", "发送事件表情回应"), + call_handler=self._system_event_react, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.send_typing", "发送输入中状态"), + call_handler=self._system_event_send_typing, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming", + "发送事件流式消息", + ), + call_handler=self._system_event_send_streaming, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_chunk", + "推送事件流式消息分片", + ), + call_handler=self._system_event_send_streaming_chunk, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_close", + "关闭事件流式消息会话", + ), + call_handler=self._system_event_send_streaming_close, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.get_state", + "读取当前请求的默认 LLM 状态", + ), + call_handler=self._system_event_llm_get_state, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.request", + "请求当前事件继续进入默认 LLM 链路", + ), + call_handler=self._system_event_llm_request, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.result.get", "读取当前请求结果"), + call_handler=self._system_event_result_get, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.result.set", "写入当前请求结果"), + call_handler=self._system_event_result_set, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.result.clear", "清理当前请求结果"), + call_handler=self._system_event_result_clear, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.get", + "读取当前请求 handler 白名单", + ), + call_handler=self._system_event_handler_whitelist_get, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.set", + "写入当前请求 handler 白名单", + ), + call_handler=self._system_event_handler_whitelist_set, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "registry.get_handlers_by_event_type", + "按事件类型列出 handler 元数据", + ), + call_handler=self._registry_get_handlers_by_event_type, + ) + self.register( + self._builtin_descriptor( + "registry.get_handler_by_full_name", + "按 full name 查询 handler 元数据", + ), + call_handler=self._registry_get_handler_by_full_name, + ) + self.register( + self._builtin_descriptor( + "registry.command.register", + "注册动态命令路由", + ), + call_handler=self._registry_command_register, + ) + + def _ensure_request_overlay(self, request_id: str) -> dict[str, Any]: + overlay = self._request_overlays.get(request_id) + if overlay is None: + overlay = { + "should_call_llm": False, + "requested_llm": False, + "result": None, + "handler_whitelist": None, + } + self._request_overlays[request_id] = overlay + return overlay + + async def _system_get_data_dir( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("system.get_data_dir") + data_dir = self._plugin_data_dir( + plugin_id, + capability_name="system.get_data_dir", + ) + data_dir.mkdir(parents=True, exist_ok=True) + return {"path": str(data_dir)} + + async def _system_text_to_image( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + text = str(payload.get("text", "")) + if bool(payload.get("return_url", True)): + return {"result": f"mock://text_to_image/{text}"} + return {"result": f"{text}"} + + async def _system_html_render( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + tmpl = str(payload.get("tmpl", "")) + data = payload.get("data") + if not isinstance(data, dict): + raise AstrBotError.invalid_input("system.html_render requires object data") + if bool(payload.get("return_url", True)): + return {"result": f"mock://html_render/{tmpl}"} + return {"result": json.dumps({"tmpl": tmpl, "data": data}, ensure_ascii=False)} + + async def _system_file_register( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + path = str(payload.get("path", "")).strip() + if not path: + raise AstrBotError.invalid_input("system.file.register requires path") + file_token = uuid.uuid4().hex + self._file_token_store[file_token] = path + return {"token": file_token, "url": f"mock://file/{file_token}"} + + async def _system_file_handle( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + file_token = str(payload.get("token", "")).strip() + if not file_token: + raise AstrBotError.invalid_input("system.file.handle requires token") + path = self._file_token_store.pop(file_token, None) + if path is None: + raise AstrBotError.invalid_input(f"Unknown file token: {file_token}") + return {"path": path} + + async def _system_event_llm_get_state( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + return { + "should_call_llm": bool(overlay["should_call_llm"]), + "requested_llm": bool(overlay["requested_llm"]), + } + + async def _system_event_llm_request( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + overlay = self._ensure_request_overlay(overlay_request_id) + overlay["requested_llm"] = True + overlay["should_call_llm"] = True + return await self._system_event_llm_get_state( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _system_event_result_get( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + result = overlay.get("result") + return {"result": dict(result) if isinstance(result, dict) else None} + + async def _system_event_result_set( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + result = payload.get("result") + if not isinstance(result, dict): + raise AstrBotError.invalid_input( + "system.event.result.set 的 result 必须是 object" + ) + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + overlay["result"] = dict(result) + return {"result": dict(result)} + + async def _system_event_result_clear( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + overlay["result"] = None + return {} + + async def _system_event_handler_whitelist_get( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + whitelist = overlay.get("handler_whitelist") + if whitelist is None: + return {"plugin_names": None} + return {"plugin_names": sorted(str(item) for item in whitelist)} + + async def _system_event_handler_whitelist_set( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + overlay = self._ensure_request_overlay(overlay_request_id) + plugin_names_payload = payload.get("plugin_names") + if plugin_names_payload is None: + overlay["handler_whitelist"] = None + elif isinstance(plugin_names_payload, list): + overlay["handler_whitelist"] = { + str(item) for item in plugin_names_payload if str(item).strip() + } + else: + raise AstrBotError.invalid_input( + "system.event.handler_whitelist.set 的 plugin_names 必须是数组或 null" + ) + return await self._system_event_handler_whitelist_get( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _registry_get_handlers_by_event_type( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + event_type = str(payload.get("event_type", "")).strip() + handlers: list[dict[str, Any]] = [] + for plugin in self._plugins.values(): + handlers.extend( + [ + dict(handler) + for handler in plugin.handlers + if event_type in handler.get("event_types", []) + ] + ) + if event_type == "message": + for plugin_name, routes in self._dynamic_command_routes.items(): + for route in routes: + if not isinstance(route, dict): + continue + handlers.append( + { + "plugin_name": str(route.get("plugin_name", plugin_name)), + "handler_full_name": str( + route.get("handler_full_name", "") + ), + "trigger_type": ( + "message" + if bool(route.get("use_regex", False)) + else "command" + ), + "description": ( + None + if route.get("desc") is None + else str(route.get("desc", "")).strip() or None + ), + "event_types": ["message"], + "enabled": True, + "group_path": [], + "priority": int(route.get("priority", 0) or 0), + "kind": "handler", + "require_admin": False, + "required_role": None, + } + ) + return {"handlers": handlers} + + async def _registry_get_handler_by_full_name( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + full_name = str(payload.get("full_name", "")).strip() + for plugin in self._plugins.values(): + for handler in plugin.handlers: + if handler.get("handler_full_name") == full_name: + return {"handler": dict(handler)} + return {"handler": None} + + async def _registry_command_register( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + source_event_type = str(payload.get("source_event_type", "")).strip() + if source_event_type not in {"astrbot_loaded", "platform_loaded"}: + raise AstrBotError.invalid_input( + "register_commands is only available in astrbot_loaded/platform_loaded events" + ) + if bool(payload.get("ignore_prefix", False)): + raise AstrBotError.invalid_input( + "register_commands(ignore_prefix=True) is unsupported in SDK runtime" + ) + priority_value = payload.get("priority", 0) + if isinstance(priority_value, bool) or not isinstance(priority_value, int): + raise AstrBotError.invalid_input( + "registry.command.register 的 priority 必须是 integer" + ) + plugin_id = self._require_caller_plugin_id("registry.command.register") + self.register_dynamic_command_route( + plugin_id=plugin_id, + command_name=str(payload.get("command_name", "")), + handler_full_name=str(payload.get("handler_full_name", "")), + desc=str(payload.get("desc", "")), + priority=priority_value, + use_regex=bool(payload.get("use_regex", False)), + ) + return {} + + async def _system_session_waiter_register( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("system.session_waiter.register") + session_key = str(payload.get("session_key", "")).strip() + if not session_key: + raise AstrBotError.invalid_input( + "system.session_waiter.register requires session_key" + ) + self._session_waiters.setdefault(plugin_id, set()).add(session_key) + return {} + + async def _system_session_waiter_unregister( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("system.session_waiter.unregister") + session_key = str(payload.get("session_key", "")).strip() + plugin_waiters = self._session_waiters.get(plugin_id) + if plugin_waiters is None: + return {} + plugin_waiters.discard(session_key) + if not plugin_waiters: + self._session_waiters.pop(plugin_id, None) + return {} + + async def _system_event_react( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self.event_actions.append( + { + "action": "react", + "emoji": str(payload.get("emoji", "")), + "target": _clone_target_payload(payload.get("target")), + } + ) + return {"supported": True} + + async def _system_event_send_typing( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self.event_actions.append( + { + "action": "send_typing", + "target": _clone_target_payload(payload.get("target")), + } + ) + return {"supported": True} + + async def _system_event_send_streaming( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + stream_id = f"mock-stream-{len(self._event_streams) + 1}" + stream_state: dict[str, Any] = { + "target": _clone_target_payload(payload.get("target")), + "chunks": [], + "use_fallback": bool(payload.get("use_fallback", False)), + } + self._event_streams[stream_id] = stream_state + return {"supported": True, "stream_id": stream_id} + + async def _system_event_send_streaming_chunk( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + stream = self._event_streams.get(str(payload.get("stream_id", ""))) + if stream is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + chain = payload.get("chain") + if not isinstance(chain, list): + raise AstrBotError.invalid_input( + "system.event.send_streaming_chunk requires a chain array" + ) + stream["chunks"].append({"chain": _clone_chain_payload(chain)}) + return {} + + async def _system_event_send_streaming_close( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + stream_id = str(payload.get("stream_id", "")) + stream = self._event_streams.pop(stream_id, None) + if stream is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + self.event_actions.append( + { + "action": "send_streaming", + "target": stream["target"], + "chunks": list(stream["chunks"]), + "use_fallback": bool(stream["use_fallback"]), + } + ) + return {"supported": True} diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py b/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py new file mode 100644 index 0000000000..cb8ba44c2a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import re +import shlex +from collections.abc import Sequence +from typing import Any + +from ..protocol.descriptors import ParamSpec + + +def normalize_command_invocation(text: str) -> str: + normalized = re.sub(r"\s+", " ", str(text).strip()) + if not normalized: + return "" + normalized = re.sub(r"^/\s*", "", normalized) + return normalized.strip() + + +def command_root_name(text: str) -> str: + normalized = normalize_command_invocation(text) + if not normalized: + return "" + return normalized.split(" ", 1)[0] + + +def match_command_name(text: str, command_name: str) -> str | None: + normalized_command = normalize_command_invocation(command_name) + if not normalized_command: + return None + command_tokens = [re.escape(token) for token in normalized_command.split()] + command_pattern = r"\s+".join(command_tokens) + pattern = rf"^\s*/?\s*{command_pattern}(?:\s+(?P.*))?\s*$" + match = re.match(pattern, text) + if match is None: + return None + remainder = match.group("remainder") + if remainder is None: + return "" + return remainder.strip() + + +def build_command_args( + param_specs: Sequence[ParamSpec], remainder: str +) -> dict[str, Any]: + if not param_specs or not remainder: + return {} + if len(param_specs) == 1: + return {param_specs[0].name: remainder} + parts = split_command_remainder(remainder) + values: dict[str, Any] = {} + for index, spec in enumerate(param_specs): + if index >= len(parts): + break + if spec.type == "greedy_str": + values[spec.name] = " ".join(parts[index:]) + break + values[spec.name] = parts[index] + return values + + +def build_regex_args( + param_specs: Sequence[ParamSpec], match: re.Match[str] +) -> dict[str, Any]: + named = { + key: value for key, value in match.groupdict().items() if value is not None + } + names = [spec.name for spec in param_specs if spec.name not in named] + positional = [value for value in match.groups() if value is not None] + for index, value in enumerate(positional): + if index >= len(names): + break + named[names[index]] = value + return named + + +def split_command_remainder(remainder: str) -> list[str]: + if not remainder: + return [] + try: + return shlex.split(remainder) + except ValueError: + return remainder.split() diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py b/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py new file mode 100644 index 0000000000..40d162d355 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py @@ -0,0 +1,156 @@ +"""Support helpers for runtime loader reflection and signature validation. + +本模块提供运行时加载器所需的反射和签名验证工具函数,主要用于: +1. 解析 handler/capability 函数签名,提取参数类型信息 +2. 识别需要注入的框架对象(如 Context、MessageEvent、ScheduleContext) +3. 构建参数规格 (ParamSpec) 供协议层使用 +4. 验证 schedule handler 的签名合法性 + +关键函数: +- build_param_specs: 从 handler 签名构建参数规格列表 +- is_injected_parameter: 判断参数是否应由框架注入而非从命令行解析 +- validate_schedule_signature: 确保 schedule handler 只接受允许的注入参数 +""" + +from __future__ import annotations + +import inspect +import typing +from typing import Any, Literal, TypeAlias, cast + +from .._internal.injected_params import is_framework_injected_parameter +from .._internal.typing_utils import unwrap_optional +from ..decorators import get_capability_meta, get_handler_meta +from ..protocol.descriptors import ParamSpec +from ..types import GreedyStr + +ParamTypeName: TypeAlias = Literal[ + "str", "int", "float", "bool", "optional", "greedy_str" +] +OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None + + +def is_injected_parameter(annotation: Any, parameter_name: str) -> bool: + return is_framework_injected_parameter(parameter_name, annotation) + + +def param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]: + normalized, is_optional = unwrap_optional(annotation) + if normalized is GreedyStr: + return "greedy_str", None, False + if normalized in {int, float, bool, str}: + normalized_name = cast( + Literal["str", "int", "float", "bool"], normalized.__name__ + ) + if is_optional: + return "optional", normalized_name, False + return normalized_name, None, True + if is_optional: + return "optional", "str", False + return "str", None, True + + +def build_param_specs(handler: Any) -> list[ParamSpec]: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + type_hints = typing.get_type_hints(handler) + except Exception: + type_hints = {} + + specs: list[ParamSpec] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + annotation = type_hints.get(parameter.name) + if is_injected_parameter(annotation, parameter.name): + continue + param_type, inner_type, required = param_type_name(annotation) + if parameter.default is not inspect.Parameter.empty: + required = False + specs.append( + ParamSpec( + name=parameter.name, + type=param_type, + required=required, + inner_type=inner_type, + ) + ) + + greedy_indexes = [ + index for index, spec in enumerate(specs) if spec.type == "greedy_str" + ] + if greedy_indexes and greedy_indexes[-1] != len(specs) - 1: + greedy_spec = specs[greedy_indexes[-1]] + raise ValueError(f"参数 '{greedy_spec.name}' (GreedyStr) 必须是最后一个参数。") + return specs + + +def validate_schedule_signature(handler: Any) -> None: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return + allowed_names = {"ctx", "context", "sched", "schedule"} + invalid = [ + parameter.name + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.name not in allowed_names + ] + if invalid: + raise ValueError( + "Schedule handler 只允许注入 ctx/context 和 sched/schedule 参数。" + ) + + +def resolve_handler_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + for candidate in candidates: + meta = get_handler_meta(candidate) + if meta is not None and meta.trigger is not None: + return getattr(instance, name), meta + return None + + +def resolve_capability_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + for candidate in candidates: + meta = get_capability_meta(candidate) + if meta is not None: + return getattr(instance, name), meta + return None + + +__all__ = [ + "build_param_specs", + "is_injected_parameter", + "param_type_name", + "resolve_capability_candidate", + "resolve_handler_candidate", + "unwrap_optional", + "validate_schedule_signature", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py b/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py new file mode 100644 index 0000000000..29d2671caa --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py @@ -0,0 +1,28 @@ +"""Shared stream execution primitives for runtime internals. + +本模块定义流式执行的通用数据结构 StreamExecution,用于: +1. 封装异步生成器迭代器,支持逐块返回数据 +2. 提供收集完成后的聚合回调 (finalize) +3. 控制是否需要在内存中累积所有分块 + +使用场景: +- LLM 流式对话返回逐字输出 +- DB watch 监听键值变更流 +- 任何需要分块返回而非一次性返回的能力调用 +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Callable +from dataclasses import dataclass +from typing import Any + + +@dataclass(slots=True) +class StreamExecution: + iterator: AsyncIterator[dict[str, Any]] + finalize: Callable[[list[dict[str, Any]]], dict[str, Any]] + collect_chunks: bool = True + + +__all__ = ["StreamExecution"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py b/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py new file mode 100644 index 0000000000..d735caae9c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py @@ -0,0 +1,171 @@ +"""启动引导入口。 + +对外提供三个顶层启动函数: + +- ``run_supervisor``: 启动 Supervisor 进程 +- ``run_plugin_worker``: 启动单插件或组 Worker 进程 +- ``run_websocket_server``: 以 WebSocket 方式启动 Worker + +运行时核心类分布在同目录的子模块: + +- ``runtime.supervisor``: ``SupervisorRuntime`` / ``WorkerSession`` +- ``runtime.worker``: ``PluginWorkerRuntime`` / ``GroupWorkerRuntime`` +""" + +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path +from typing import IO + +from .loader import PluginEnvironmentManager +from .supervisor import ( + SupervisorRuntime, + WorkerSession, + _install_signal_handlers, + _prepare_stdio_transport, + _sdk_source_dir, + _wait_for_shutdown, +) +from .transport import ( + StdioTransport, + WebSocketServerTransport, + build_websocket_server_ssl_context, +) +from .worker import GroupWorkerRuntime, PluginWorkerRuntime, _load_plugin_specs + +__all__ = [ + "GroupWorkerRuntime", + "PluginWorkerRuntime", + "SupervisorRuntime", + "WorkerSession", + "_install_signal_handlers", + "_prepare_stdio_transport", + "_sdk_source_dir", + "_wait_for_shutdown", + "run_supervisor", + "run_plugin_worker", + "run_websocket_server", +] + + +async def run_supervisor( + *, + plugins_dir: Path = Path("plugins"), + stdin: IO[str] | None = None, + stdout: IO[str] | None = None, + env_manager: PluginEnvironmentManager | None = None, + workers_manifest: Path | None = None, +) -> None: + transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport( + stdin, + stdout, + ) + transport = StdioTransport(stdin=transport_stdin, stdout=transport_stdout) + runtime = SupervisorRuntime( + transport=transport, + plugins_dir=plugins_dir, + env_manager=env_manager, + workers_manifest=workers_manifest, + ) + + try: + await runtime.start() + stop_event = asyncio.Event() + _install_signal_handlers(stop_event) + await _wait_for_shutdown(runtime.peer, stop_event) + finally: + await runtime.stop() + if original_stdout is not None: + sys.stdout = original_stdout + + +async def run_plugin_worker( + *, + plugin_dir: Path | None = None, + group_metadata: Path | None = None, + stdin: IO[str] | None = None, + stdout: IO[str] | None = None, +) -> None: + if plugin_dir is None and group_metadata is None: + raise ValueError("plugin_dir or group_metadata is required") + if plugin_dir is not None and group_metadata is not None: + raise ValueError("plugin_dir and group_metadata are mutually exclusive") + + transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport( + stdin, + stdout, + ) + transport = StdioTransport(stdin=transport_stdin, stdout=transport_stdout) + if group_metadata is not None: + runtime = GroupWorkerRuntime( + group_metadata_path=group_metadata, + transport=transport, + ) + else: + # 前置互斥校验已保证单插件模式下 plugin_dir 一定存在;这里显式收窄, + # 避免把入口层的 Optional 继续传播到单插件运行时。 + assert plugin_dir is not None + runtime = PluginWorkerRuntime(plugin_dir=plugin_dir, transport=transport) + try: + await runtime.start() + stop_event = asyncio.Event() + _install_signal_handlers(stop_event) + await _wait_for_shutdown(runtime.peer, stop_event) + finally: + await runtime.stop() + if original_stdout is not None: + sys.stdout = original_stdout + + +async def run_websocket_server( + *, + worker_id: str | None = None, + host: str = "127.0.0.1", + port: int = 8765, + path: str = "/", + plugin_dirs: list[Path] | None = None, + tls_ca_file: Path | None = None, + tls_cert_file: Path | None = None, + tls_key_file: Path | None = None, +) -> None: + resolved_plugin_dirs = [path.resolve() for path in (plugin_dirs or [Path.cwd()])] + if tls_ca_file is None or tls_cert_file is None or tls_key_file is None: + raise ValueError( + "tls_ca_file, tls_cert_file, and tls_key_file are required for websocket workers" + ) + transport = WebSocketServerTransport( + host=host, + port=port, + path=path, + ssl_context=build_websocket_server_ssl_context( + ca_file=tls_ca_file, + cert_file=tls_cert_file, + key_file=tls_key_file, + ), + ) + resolved_worker_id = worker_id + if resolved_worker_id is None and len(resolved_plugin_dirs) == 1: + resolved_worker_id = _load_plugin_specs([resolved_plugin_dirs[0]])[0].name + if len(resolved_plugin_dirs) == 1: + runtime = PluginWorkerRuntime( + plugin_dir=resolved_plugin_dirs[0], + worker_id=resolved_worker_id, + transport=transport, + ) + else: + if resolved_worker_id is None: + raise ValueError("worker_id is required when serving multiple plugins") + runtime = GroupWorkerRuntime( + plugin_dirs=resolved_plugin_dirs, + worker_id=resolved_worker_id, + transport=transport, + ) + try: + await runtime.start() + stop_event = asyncio.Event() + _install_signal_handlers(stop_event) + await _wait_for_shutdown(runtime.peer, stop_event) + finally: + await runtime.stop() diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py b/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py new file mode 100644 index 0000000000..1e149413a1 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py @@ -0,0 +1,515 @@ +"""Capability invocation dispatcher. + +本模块实现能力调用的分发器,负责: +1. 接收能力调用请求,定位对应的已注册能力 +2. 构建调用上下文 (Context),注入必要的依赖 +3. 支持同步和流式两种调用模式 +4. 管理活跃调用任务的生命周期和取消 + +参数注入策略: +按类型注入 Context / CancelToken / dict,或按参数名注入 +ctx / context / payload / input / data / cancel_token / token。 +若无法匹配则抛出详细的错误信息,帮助开发者定位问题。 +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import typing +from collections.abc import AsyncIterator, Sequence +from typing import Any, cast, get_type_hints + +from .._internal.invocation_context import caller_plugin_scope +from .._internal.plugin_logger import PluginLogger +from .._internal.sdk_logger import logger +from .._internal.star_runtime import bind_star_runtime +from .._internal.typing_utils import unwrap_optional +from ..context import CancelToken, Context +from ..errors import AstrBotError +from ..events import MessageEvent +from ..star import Star +from ._streaming import StreamExecution +from .loader import LoadedCapability, LoadedLLMTool + + +class CapabilityDispatcher: + def __init__( + self, + *, + plugin_id: str, + peer, + capabilities: Sequence[LoadedCapability], + llm_tools: Sequence[LoadedLLMTool] | None = None, + ) -> None: + self._plugin_id = plugin_id + self._peer = peer + self._capabilities = {item.descriptor.name: item for item in capabilities} + self._llm_tools: dict[tuple[str, str], LoadedLLMTool] = {} + try: + setattr(peer, "_sdk_capability_dispatcher", self) + except AttributeError: + logger.warning( + f"Failed to attach _sdk_capability_dispatcher to peer {peer}, " + "dynamic LLM tool registration may not work" + ) + for item in llm_tools or []: + self._register_llm_tool(item, item.plugin_id or plugin_id) + self._active: dict[str, tuple[asyncio.Task[Any], CancelToken]] = {} + + def _register_llm_tool( + self, + loaded: LoadedLLMTool, + owner_plugin: str, + ) -> None: + self._llm_tools[(owner_plugin, loaded.spec.name)] = loaded + if loaded.spec.handler_ref and loaded.spec.handler_ref != loaded.spec.name: + self._llm_tools[(owner_plugin, loaded.spec.handler_ref)] = loaded + + def add_dynamic_llm_tool( + self, + *, + plugin_id: str, + spec, + callable_obj, + owner: Any | None = None, + ) -> None: + self.remove_llm_tool(plugin_id, spec.name) + loaded = LoadedLLMTool( + spec=spec.model_copy(deep=True), + callable=callable_obj, + owner=owner, + plugin_id=plugin_id, + ) + self._register_llm_tool(loaded, plugin_id) + + def remove_llm_tool(self, plugin_id: str, name: str) -> bool: + removed = False + for key, value in list(self._llm_tools.items()): + if key[0] != plugin_id: + continue + spec_name = str(getattr(value.spec, "name", "")).strip() + handler_ref = str(getattr(value.spec, "handler_ref", "") or "").strip() + if name not in {spec_name, handler_ref}: + continue + self._llm_tools.pop(key, None) + removed = True + return removed + + async def invoke( + self, + message, + cancel_token: CancelToken, + ) -> dict[str, Any] | StreamExecution: + if message.capability == "internal.llm_tool.execute": + return await self._invoke_registered_llm_tool(message, cancel_token) + + loaded = self._capabilities.get(message.capability) + if loaded is None: + raise LookupError(f"capability not found: {message.capability}") + + plugin_id = self._resolve_plugin_id(loaded) + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + ) + bound_logger = cast(PluginLogger, ctx.logger).bind( + plugin_id=plugin_id, + request_id=message.id, + capability=message.capability, + session_id=self._logger_session_id(dict(message.input)), + event_type=self._logger_event_type(dict(message.input)), + ) + ctx.logger = bound_logger + + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._run_capability( + loaded, + payload=dict(message.input), + ctx=ctx, + cancel_token=cancel_token, + stream=bool(message.stream), + ) + ) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + async def _invoke_registered_llm_tool( + self, + message, + cancel_token: CancelToken, + ) -> dict[str, Any]: + payload = dict(message.input) + plugin_id = str(payload.get("plugin_id") or self._plugin_id) + tool_name = str(payload.get("tool_name", "")) + handler_ref = str(payload.get("handler_ref") or tool_name) + loaded = self._llm_tools.get((plugin_id, handler_ref)) + if loaded is None: + loaded = self._llm_tools.get((plugin_id, tool_name)) + if loaded is None: + raise LookupError(f"llm tool not found: {plugin_id}:{tool_name}") + + event_payload = payload.get("event") + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + bound_logger = cast(PluginLogger, ctx.logger).bind( + plugin_id=plugin_id, + request_id=message.id, + capability="internal.llm_tool.execute", + session_id=self._logger_session_id(payload), + event_type=self._logger_event_type(payload), + ) + ctx.logger = bound_logger + event = MessageEvent.from_payload( + event_payload if isinstance(event_payload, dict) else {}, + context=ctx, + ) + self._bind_event_reply_handler(ctx, event) + tool_args = payload.get("tool_args") + normalized_args = dict(tool_args) if isinstance(tool_args, dict) else {} + + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._run_registered_llm_tool(loaded, event, ctx, normalized_args) + ) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + def _bind_event_reply_handler(self, ctx: Context, event: MessageEvent) -> None: + async def reply(text: str) -> None: + try: + await ctx.platform.send(event.session_ref or event.session_id, text) + except TypeError: + send = getattr(self._peer, "send", None) + if not callable(send): + raise + result = send(event.session_id, text) + if inspect.isawaitable(result): + await result + + event.bind_reply_handler(reply) + + async def _run_registered_llm_tool( + self, + loaded: LoadedLLMTool, + event: MessageEvent, + ctx: Context, + tool_args: dict[str, Any], + ) -> dict[str, Any]: + owner = loaded.owner if isinstance(loaded.owner, Star) else None + with bind_star_runtime(owner, ctx): + result = loaded.callable( + *self._build_tool_args( + loaded.callable, + event, + ctx, + tool_args, + ) + ) + if inspect.isasyncgen(result): + raise AstrBotError.protocol_error( + "SDK LLM tool must return awaitable result, async generator is unsupported" + ) + if inspect.isawaitable(result): + result = await result + if result is None: + # content=None means the tool completed successfully but produced no + # textual payload. The core bridge preserves this as a real None. + return {"content": None, "success": True} + if isinstance(result, dict): + return { + "content": json.dumps(result, ensure_ascii=False, default=str), + "success": True, + } + return {"content": str(result), "success": True} + + def _build_tool_args( + self, + handler, + event: MessageEvent, + ctx: Context, + tool_args: dict[str, Any], + ) -> list[Any]: + signature = inspect.signature(handler) + args: list[Any] = [] + type_hints: dict[str, Any] = {} + try: + type_hints = get_type_hints(handler) + except Exception: + type_hints = {} + + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + + injected = None + param_type = type_hints.get(parameter.name) + if param_type is not None: + injected = self._inject_tool_by_type(param_type, event, ctx) + if injected is None: + if parameter.name == "event": + injected = event + elif parameter.name in {"ctx", "context"}: + injected = ctx + elif parameter.name in tool_args: + injected = tool_args[parameter.name] + if injected is None: + if parameter.default is not parameter.empty: + continue + raise TypeError( + f"SDK LLM tool '{getattr(handler, '__name__', repr(handler))}' missing required argument '{parameter.name}'" + ) + args.append(injected) + return args + + def _inject_tool_by_type( + self, + param_type: Any, + event: MessageEvent, + ctx: Context, + ) -> Any: + param_type, _is_optional = unwrap_optional(param_type) + + if param_type is Context or ( + isinstance(param_type, type) and issubclass(param_type, Context) + ): + return ctx + if param_type is MessageEvent or ( + isinstance(param_type, type) and issubclass(param_type, MessageEvent) + ): + return event + return None + + def _resolve_plugin_id(self, loaded: LoadedCapability) -> str: + if loaded.plugin_id: + return loaded.plugin_id + return self._plugin_id + + @staticmethod + def _logger_session_id(payload: dict[str, Any]) -> str: + if isinstance(payload.get("event"), dict): + return str(payload["event"].get("session_id", "")) + return str(payload.get("session", "")) + + @staticmethod + def _logger_event_type(payload: dict[str, Any]) -> str: + if isinstance(payload.get("event"), dict): + event_payload = payload["event"] + return str( + event_payload.get("event_type") + or event_payload.get("type") + or event_payload.get("message_type") + or "message" + ) + if payload.get("session") is not None: + return "capability" + return "capability" + + async def cancel(self, request_id: str) -> None: + active = self._active.get(request_id) + if active is None: + return + task, cancel_token = active + cancel_token.cancel() + task.cancel() + + async def _run_capability( + self, + loaded: LoadedCapability, + *, + payload: dict[str, Any], + ctx: Context, + cancel_token: CancelToken, + stream: bool, + ) -> dict[str, Any] | StreamExecution: + result = loaded.callable( + *self._build_args( + loaded.callable, + payload, + ctx, + cancel_token, + plugin_id=self._resolve_plugin_id(loaded), + capability_name=loaded.descriptor.name, + ) + ) + if stream: + if inspect.isasyncgen(result): + return StreamExecution( + iterator=self._iterate_generator(result), + finalize=lambda chunks: {"items": chunks}, + ) + if inspect.isawaitable(result): + result = await result + if inspect.isasyncgen(result): + return StreamExecution( + iterator=self._iterate_generator(result), + finalize=lambda chunks: {"items": chunks}, + ) + if isinstance(result, StreamExecution): + return result + raise AstrBotError.protocol_error( + "stream=true 的插件 capability 必须返回 async generator 或 StreamExecution" + ) + + if inspect.isasyncgen(result): + raise AstrBotError.protocol_error( + "stream=false 的插件 capability 不能返回 async generator" + ) + if inspect.isawaitable(result): + result = await result + return self._normalize_output(result) + + def _build_args( + self, + handler, + payload: dict[str, Any], + ctx: Context, + cancel_token: CancelToken, + *, + plugin_id: str | None = None, + capability_name: str | None = None, + ) -> list[Any]: + signature = inspect.signature(handler) + args: list[Any] = [] + + type_hints: dict[str, Any] = {} + try: + type_hints = get_type_hints(handler) + except Exception: + pass + + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + + injected = None + param_type = type_hints.get(parameter.name) + if param_type is not None: + injected = self._inject_by_type(param_type, payload, ctx, cancel_token) + + if injected is None: + if parameter.name in {"ctx", "context"}: + injected = ctx + elif parameter.name in {"payload", "input", "data"}: + injected = payload + elif parameter.name in {"cancel_token", "token"}: + injected = cancel_token + + if injected is None: + if parameter.default is not parameter.empty: + continue + raise TypeError( + self._format_capability_injection_error( + handler=handler, + parameter_name=parameter.name, + plugin_id=plugin_id, + capability_name=capability_name, + payload=payload, + ) + ) + args.append(injected) + + return args + + def _inject_by_type( + self, + param_type: Any, + payload: dict[str, Any], + ctx: Context, + cancel_token: CancelToken, + ) -> Any: + param_type, _is_optional = unwrap_optional(param_type) + origin = typing.get_origin(param_type) + + if param_type is Context or ( + isinstance(param_type, type) and issubclass(param_type, Context) + ): + return ctx + if param_type is CancelToken or ( + isinstance(param_type, type) and issubclass(param_type, CancelToken) + ): + return cancel_token + if param_type is dict or origin is dict: + return payload + return None + + def _format_capability_injection_error( + self, + *, + handler, + parameter_name: str, + plugin_id: str | None, + capability_name: str | None, + payload: dict[str, Any], + ) -> str: + plugin_text = plugin_id or self._plugin_id + target = capability_name or getattr(handler, "__name__", "") + payload_keys = sorted(str(key) for key in payload.keys()) + payload_keys_text = ", ".join(payload_keys) if payload_keys else "" + return ( + f"插件 '{plugin_text}' 的 capability '{target}' 参数注入失败:" + f"必填参数 '{parameter_name}' 无法注入。" + f"签名: {getattr(handler, '__name__', '')}" + f"{self._callable_signature(handler)}。" + "当前支持按类型注入 Context / CancelToken / dict," + "按参数名注入 ctx / context / payload / input / data / cancel_token / token," + f"以及 payload 中现有键:{payload_keys_text}。" + ) + + async def _iterate_generator( + self, + generator: AsyncIterator[Any], + ) -> AsyncIterator[dict[str, Any]]: + async for item in generator: + yield self._normalize_chunk(item) + + def _normalize_chunk(self, item: Any) -> dict[str, Any]: + output = self._normalize_output(item) + if output: + return output + return {"ok": True} + + def _normalize_output(self, result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, dict): + return result + model_dump = getattr(result, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + return dumped + raise AstrBotError.invalid_input("插件 capability 必须返回 dict 或可序列化对象") + + @staticmethod + def _callable_signature(handler) -> str: + try: + return str(inspect.signature(handler)) + except (TypeError, ValueError): + return "(?)" + + +__all__ = ["CapabilityDispatcher"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py b/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py new file mode 100644 index 0000000000..bd0fa68d61 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py @@ -0,0 +1,990 @@ +"""能力路由模块。 + +定义 CapabilityRouter 类,负责能力的注册、发现和执行路由。 +能力是核心侧提供给插件侧调用的功能,如 LLM 聊天、存储、消息发送等。 + +核心概念: + CapabilityDescriptor: 能力描述符,声明能力名称、输入输出 Schema 等 + CallHandler: 同步调用处理器,签名 (request_id, payload, cancel_token) -> dict + StreamHandler: 流式调用处理器,签名 (request_id, payload, cancel_token) -> AsyncIterator + FinalizeHandler: 流式结果聚合器,签名 (chunks) -> dict + +内置能力: + LLM: + llm.chat: 同步 LLM 聊天 + llm.chat_raw: 同步 LLM 聊天(完整响应) + llm.stream_chat: 流式 LLM 聊天 + Memory: + memory.search: 搜索记忆 + memory.save: 保存记忆 + memory.save_with_ttl: 保存带过期时间的记忆 + memory.get: 读取单条记忆 + memory.list_keys: 列出命名空间中的记忆键 + memory.exists: 检查记忆键是否存在 + memory.get_many: 批量获取多条记忆 + memory.delete: 删除记忆 + memory.clear_namespace: 清理命名空间中的记忆 + memory.delete_many: 批量删除多条记忆 + memory.count: 统计命名空间中的记忆数量 + memory.stats: 获取记忆统计信息 + DB: + db.get: 读取 KV 存储 + db.set: 写入 KV 存储 + db.delete: 删除 KV 存储 + db.list: 列出 KV 键 + db.get_many: 批量读取多个 KV 键 + db.set_many: 批量写入多个 KV 键 + db.watch: 订阅 KV 变更事件 + Platform: + platform.send: 发送消息 + platform.send_image: 发送图片 + platform.send_chain: 发送消息链 + platform.send_by_session: 主动按会话发送消息链 + platform.get_group: 获取当前群信息 + platform.get_members: 获取群成员 + Permission: + permission.check: 查询用户权限角色 + permission.get_admins: 列出管理员 ID + permission.manager.add_admin: 添加管理员 ID + permission.manager.remove_admin: 移除管理员 ID + HTTP: + http.register_api: 注册 HTTP 路由到插件 capability + http.unregister_api: 注销 HTTP 路由 + http.list_apis: 查询已注册的 HTTP 路由 + Metadata: + metadata.get_plugin: 获取单个插件元数据 + metadata.list_plugins: 列出所有插件元数据 + metadata.get_plugin_config: 获取当前调用插件自己的配置 + Provider: + provider.get_using: 获取当前聊天 Provider + provider.get_current_chat_provider_id: 获取当前聊天 Provider ID + provider.list_all: 列出聊天 Providers + provider.list_all_tts: 列出 TTS Providers + provider.list_all_stt: 列出 STT Providers + provider.list_all_embedding: 列出 Embedding Providers + provider.list_all_rerank: 列出 Rerank Providers + provider.get_using_tts: 获取当前 TTS Provider + provider.get_using_stt: 获取当前 STT Provider + provider.get_by_id: 按 ID 获取 Provider + provider.stt.get_text: STT 转写 + provider.tts.get_audio: TTS 合成音频 + provider.tts.support_stream: 检查 TTS 原生流式支持 + provider.tts.get_audio_stream: 流式 TTS 音频输出 + provider.embedding.get_embedding: 获取单条向量 + provider.embedding.get_embeddings: 批量获取向量 + provider.embedding.get_dim: 获取向量维度 + provider.rerank.rerank: 文档重排序 + provider.manager.set: 设置当前 Provider + provider.manager.get_by_id: 按 ID 获取 Provider 管理记录 + provider.manager.get_merged_provider_config: 获取 Provider 合并配置 + provider.manager.load: 运行时加载 Provider + provider.manager.terminate: 终止已加载的 Provider + provider.manager.create: 创建 Provider + provider.manager.update: 更新 Provider + provider.manager.delete: 删除 Provider + provider.manager.get_insts: 列出已加载聊天 Provider + provider.manager.watch_changes: 订阅 Provider 变更(流式) + Platform Manager: + platform.manager.get_by_id: 按 ID 获取平台管理快照 + platform.manager.clear_errors: 清除平台错误 + platform.manager.get_stats: 获取平台统计信息 + LLM Tool: + llm_tool.manager.get: 获取 LLM 工具状态 + llm_tool.manager.activate: 激活 LLM 工具 + llm_tool.manager.deactivate: 停用 LLM 工具 + llm_tool.manager.add: 动态添加 LLM 工具 + llm_tool.manager.remove: 动态移除 LLM 工具 + Agent: + agent.tool_loop.run: 运行 tool loop + agent.registry.list: 列出 Agent 元数据 + agent.registry.get: 获取 Agent 元数据 + Registry: + registry.get_handlers_by_event_type: 按事件类型列出 handler 元数据 + registry.get_handler_by_full_name: 按 full name 查询 handler 元数据 + Session: + session.plugin.is_enabled: 获取会话级插件开关 + session.plugin.filter_handlers: 按会话过滤 handler 元数据 + session.service.is_llm_enabled: 获取会话级 LLM 开关 + session.service.set_llm_status: 写入会话级 LLM 开关 + session.service.is_tts_enabled: 获取会话级 TTS 开关 + session.service.set_tts_status: 写入会话级 TTS 开关 + Managers: + persona.get / persona.list / persona.create / persona.update / persona.delete + conversation.new / conversation.switch / conversation.delete + conversation.get / conversation.list / conversation.update + kb.list / kb.get / kb.create / kb.update / kb.delete / kb.retrieve + kb.document.upload / kb.document.list / kb.document.get + kb.document.delete / kb.document.refresh + System (内部使用): + system.get_data_dir: 获取插件数据目录 + system.text_to_image: 文本转图片 + system.html_render: 渲染 HTML 模板 + system.file.register: 注册文件令牌 + system.file.handle: 解析文件令牌 + system.session_waiter.register: 注册会话等待器 + system.session_waiter.unregister: 注销会话等待器 + system.event.react: 发送事件表情回应 + system.event.send_typing: 发送输入中状态 + system.event.send_streaming: 发送事件流式消息 + system.event.send_streaming_chunk: 推送事件流式消息分片 + system.dynamic_command.register: 注册动态命令路由 + system.dynamic_command.list: 列出动态命令路由 + system.dynamic_command.remove: 移除动态命令路由 + +能力命名规范: + - 格式: {namespace}.{action} 或 {namespace}.{sub_namespace}.{action} + - 内置能力命名空间: llm, memory, db, platform, permission, http, metadata, provider, llm_tool, agent, registry + - 保留命名空间前缀: handler., system., internal. + +使用示例: + router = CapabilityRouter() + + # 注册同步能力 + router.register( + CapabilityDescriptor( + name="my_plugin.calculate", + description="执行计算", + input_schema={"type": "object", "properties": {"x": {"type": "number"}}}, + output_schema={"type": "object", "properties": {"result": {"type": "number"}}}, + ), + call_handler=my_calculate, + ) + + # 注册流式能力 + async def stream_data(request_id, payload, token): + for i in range(10): + yield {"index": i} + + router.register( + CapabilityDescriptor( + name="my_plugin.stream", + description="流式数据", + supports_stream=True, + cancelable=True, + ), + stream_handler=stream_data, + finalize=lambda chunks: {"count": len(chunks)}, + ) + + # 执行能力 + result = await router.execute("my_plugin.calculate", {"x": 42}, stream=False, ...) + stream_result = await router.execute("my_plugin.stream", {}, stream=True, ...) +""" + +from __future__ import annotations + +import asyncio +import inspect +import re +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from .._internal.invocation_context import current_caller_plugin_id +from ..errors import AstrBotError +from ..protocol.descriptors import ( + RESERVED_CAPABILITY_PREFIXES, + CapabilityDescriptor, +) +from ._capability_router_builtins import BuiltinCapabilityRouterMixin +from ._streaming import StreamExecution + +CallHandler = Callable[[str, dict[str, Any], object], Awaitable[dict[str, Any]]] +FinalizeHandler = Callable[[list[dict[str, Any]]], dict[str, Any]] +CAPABILITY_NAME_PATTERN = re.compile(r"^[a-z][a-z0-9_]*(?:\.[a-z][a-z0-9_]*)+$") + + +StreamHandler = Callable[ + [str, dict[str, Any], object], + AsyncIterator[dict[str, Any]] + | StreamExecution + | Awaitable[AsyncIterator[dict[str, Any]] | StreamExecution], +] + + +@dataclass(slots=True) +class _CapabilityRegistration: + descriptor: CapabilityDescriptor + call_handler: CallHandler | None = None + stream_handler: StreamHandler | None = None + finalize: FinalizeHandler | None = None + exposed: bool = True + + +@dataclass(slots=True) +class _RegisteredPlugin: + metadata: dict[str, Any] + config: dict[str, Any] + handlers: list[dict[str, Any]] + llm_tools: dict[str, dict[str, Any]] = field(default_factory=dict) + active_llm_tools: set[str] = field(default_factory=set) + local_mcp_servers: dict[str, dict[str, Any]] = field(default_factory=dict) + agents: dict[str, dict[str, Any]] = field(default_factory=dict) + skills: dict[str, dict[str, str]] = field(default_factory=dict) + + +class CapabilityRouter(BuiltinCapabilityRouterMixin): + def __init__(self) -> None: + self._registrations: dict[str, _CapabilityRegistration] = {} + self.db_store: dict[str, Any] = {} + self.memory_store: dict[str, dict[str, Any]] = {} + self._memory_backends: dict[str, Any] = {} + self._memory_index: dict[str, dict[str, Any]] = {} + self._memory_dirty_keys: set[str] = set() + self._memory_expires_at: dict[str, datetime | None] = {} + self.sent_messages: list[dict[str, Any]] = [] + self.event_actions: list[dict[str, Any]] = [] + self._event_streams: dict[str, dict[str, Any]] = {} + self.http_api_store: list[dict[str, Any]] = [] + self._plugins: dict[str, _RegisteredPlugin] = {} + self._request_overlays: dict[str, dict[str, Any]] = {} + self._provider_catalog: dict[str, list[dict[str, Any]]] = { + "chat": [ + { + "id": "mock-chat-provider", + "model": "mock-chat-model", + "type": "mock", + "provider_type": "chat_completion", + } + ], + "tts": [ + { + "id": "mock-tts-provider", + "model": "mock-tts-model", + "type": "mock", + "provider_type": "text_to_speech", + } + ], + "stt": [ + { + "id": "mock-stt-provider", + "model": "mock-stt-model", + "type": "mock", + "provider_type": "speech_to_text", + } + ], + "embedding": [ + { + "id": "mock-embedding-provider", + "model": "mock-embedding-model", + "type": "mock", + "provider_type": "embedding", + } + ], + "rerank": [ + { + "id": "mock-rerank-provider", + "model": "mock-rerank-model", + "type": "mock", + "provider_type": "rerank", + } + ], + } + self._provider_configs: dict[str, dict[str, Any]] = { + str(item["id"]): {**item, "enable": True} + for providers in self._provider_catalog.values() + for item in providers + } + self._active_provider_ids: dict[str, str | None] = { + kind: providers[0]["id"] if providers else None + for kind, providers in self._provider_catalog.items() + } + self._provider_change_subscriptions: dict[ + str, asyncio.Queue[dict[str, Any]] + ] = {} + self._system_data_root = Path.cwd() / ".astrbot_sdk_testing" / "plugin_data" + self._session_waiters: dict[str, set[str]] = {} + self._db_watch_subscriptions: dict[ + str, tuple[str | None, asyncio.Queue[dict[str, Any]]] + ] = {} + self._session_plugin_configs: dict[str, dict[str, Any]] = {} + self._session_service_configs: dict[str, dict[str, Any]] = {} + self._dynamic_command_routes: dict[str, list[dict[str, Any]]] = {} + self._file_token_store: dict[str, str] = {} + self._persona_store: dict[str, dict[str, Any]] = {} + self._conversation_store: dict[str, dict[str, Any]] = {} + self._session_current_conversation_ids: dict[str, str] = {} + self._message_history_store: dict[str, list[dict[str, Any]]] = {} + self._message_history_next_id = 1 + self._mcp_session_store: dict[str, dict[str, Any]] = {} + self._mcp_global_servers: dict[str, dict[str, Any]] = {} + self._mcp_audit_logs: list[dict[str, str]] = [] + self._kb_store: dict[str, dict[str, Any]] = {} + self._kb_document_store: dict[str, dict[str, dict[str, Any]]] = {} + self._kb_document_content_store: dict[str, str] = {} + self._platform_instances: list[dict[str, Any]] = [ + { + "id": "mock-platform", + "name": "Mock Platform", + "type": "mock", + "status": "running", + } + ] + self._permission_admin_ids: list[str] = ["astrbot"] + self._register_builtin_capabilities() + + def upsert_plugin( + self, + *, + metadata: dict[str, Any], + config: dict[str, Any] | None = None, + ) -> None: + name = str(metadata.get("name", "")).strip() + if not name: + raise ValueError("plugin metadata must include a non-empty name") + normalized_metadata = dict(metadata) + normalized_metadata.setdefault("display_name", name) + normalized_metadata.setdefault("description", "") + normalized_metadata.setdefault("repo", "") + normalized_metadata.setdefault("author", "") + normalized_metadata.setdefault("version", "0.0.0") + normalized_metadata.setdefault("enabled", True) + normalized_metadata.setdefault("reserved", False) + normalized_metadata.setdefault("acknowledge_global_mcp_risk", False) + normalized_metadata.setdefault("support_platforms", []) + normalized_metadata.setdefault("astrbot_version", None) + local_mcp_servers = normalized_metadata.pop("local_mcp_servers", {}) + normalized_servers = ( + { + str(server_name): dict(server_payload) + for server_name, server_payload in local_mcp_servers.items() + if str(server_name).strip() and isinstance(server_payload, dict) + } + if isinstance(local_mcp_servers, dict) + else {} + ) + existing = self._plugins.get(name) + if existing is not None: + existing.metadata = normalized_metadata + existing.config = dict(config or {}) + existing.local_mcp_servers = normalized_servers + return + self._plugins[name] = _RegisteredPlugin( + metadata=normalized_metadata, + config=dict(config or {}), + handlers=[], + local_mcp_servers=normalized_servers, + ) + + def set_plugin_handlers( + self, + name: str, + handlers: list[dict[str, Any]], + ) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.handlers = [dict(item) for item in handlers] + valid_handlers = { + str(item.get("handler_full_name", "")).strip() + for item in plugin.handlers + if isinstance(item, dict) + } + if not valid_handlers: + self._dynamic_command_routes.pop(name, None) + return + routes = self._dynamic_command_routes.get(name) + if routes is None: + return + self._dynamic_command_routes[name] = [ + dict(item) + for item in routes + if str(item.get("handler_full_name", "")).strip() in valid_handlers + ] + if not self._dynamic_command_routes[name]: + self._dynamic_command_routes.pop(name, None) + + def set_plugin_enabled(self, name: str, enabled: bool) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.metadata["enabled"] = enabled + + def register_dynamic_command_route( + self, + *, + plugin_id: str, + command_name: str, + handler_full_name: str, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ) -> None: + command_text = str(command_name).strip() + if not command_text: + raise AstrBotError.invalid_input("command_name must not be empty") + handler_text = str(handler_full_name).strip() + if not handler_text: + raise AstrBotError.invalid_input("handler_full_name must not be empty") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + if not self._plugin_has_handler(plugin_id, handler_text): + raise AstrBotError.invalid_input( + "handler_full_name must belong to the caller plugin and exist" + ) + route = { + "plugin_name": plugin_id, + "command_name": command_text, + "handler_full_name": handler_text, + "desc": str(desc), + "priority": int(priority), + "use_regex": bool(use_regex), + } + routes = [ + item + for item in self._dynamic_command_routes.get(plugin_id, []) + if str(item.get("command_name", "")).strip() != command_text + or bool(item.get("use_regex", False)) != bool(use_regex) + ] + routes.append(route) + self._dynamic_command_routes[plugin_id] = routes + + def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]: + return [dict(item) for item in self._dynamic_command_routes.get(plugin_id, [])] + + def remove_dynamic_command_routes_for_plugin(self, plugin_id: str) -> None: + self._dynamic_command_routes.pop(plugin_id, None) + + def set_platform_instances(self, instances: list[dict[str, Any]]) -> None: + normalized: list[dict[str, Any]] = [] + for item in instances: + if not isinstance(item, dict): + continue + platform_id = str(item.get("id", "")).strip() + platform_type = str(item.get("type", "")).strip() + if not platform_id or not platform_type: + continue + errors = item.get("errors") + last_error = item.get("last_error") + stats = item.get("stats") + meta = item.get("meta") + normalized.append( + { + "id": platform_id, + "name": str(item.get("name", platform_id)), + "type": platform_type, + "status": str(item.get("status", "unknown")), + "errors": [ + dict(error) for error in errors if isinstance(error, dict) + ] + if isinstance(errors, list) + else [], + "last_error": ( + dict(last_error) if isinstance(last_error, dict) else None + ), + "unified_webhook": bool(item.get("unified_webhook", False)), + "stats": dict(stats) if isinstance(stats, dict) else None, + "meta": dict(meta) if isinstance(meta, dict) else {}, + "started_at": item.get("started_at"), + } + ) + self._platform_instances = normalized + + def get_platform_instances(self) -> list[dict[str, Any]]: + return [dict(item) for item in self._platform_instances] + + def set_admin_ids(self, admin_ids: list[str]) -> None: + self._permission_admin_ids = [ + user_id for user_id in (str(item).strip() for item in admin_ids) if user_id + ] + + def _plugin_has_handler(self, plugin_id: str, handler_full_name: str) -> bool: + plugin = self._plugins.get(plugin_id) + if plugin is None: + return False + handler_name = str(handler_full_name).strip() + if not handler_name: + return False + for handler in plugin.handlers: + if not isinstance(handler, dict): + continue + if str(handler.get("handler_full_name", "")).strip() == handler_name: + return True + return False + + def set_plugin_llm_tools( + self, + name: str, + tools: list[dict[str, Any]], + ) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.llm_tools = { + str(item.get("name", "")): dict(item) + for item in tools + if isinstance(item, dict) and str(item.get("name", "")).strip() + } + plugin.active_llm_tools = { + tool_name + for tool_name, item in plugin.llm_tools.items() + if bool(item.get("active", True)) + } + + def set_plugin_agents( + self, + name: str, + agents: list[dict[str, Any]], + ) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.agents = { + str(item.get("name", "")): dict(item) + for item in agents + if isinstance(item, dict) and str(item.get("name", "")).strip() + } + + def set_provider_catalog( + self, + kind: str, + providers: list[dict[str, Any]], + *, + active_id: str | None = None, + ) -> None: + self._provider_catalog[kind] = [ + dict(item) + for item in providers + if isinstance(item, dict) and str(item.get("id", "")).strip() + ] + for item in self._provider_catalog[kind]: + provider_id = str(item.get("id", "")).strip() + if not provider_id: + continue + self._provider_configs[provider_id] = {**item, "enable": True} + if active_id is not None: + self._active_provider_ids[kind] = active_id + else: + catalog = self._provider_catalog[kind] + self._active_provider_ids[kind] = catalog[0]["id"] if catalog else None + + def emit_provider_change( + self, + provider_id: str, + provider_type: str, + umo: str | None = None, + ) -> None: + event = { + "provider_id": str(provider_id), + "provider_type": str(provider_type), + "umo": str(umo) if umo is not None else None, + } + for queue in list(self._provider_change_subscriptions.values()): + queue.put_nowait(dict(event)) + + def record_platform_error( + self, + platform_id: str, + message: str, + *, + traceback: str | None = None, + ) -> None: + for item in self._platform_instances: + if str(item.get("id", "")) != str(platform_id): + continue + error = { + "message": str(message), + "timestamp": datetime.now(timezone.utc).isoformat(), + "traceback": str(traceback) if traceback is not None else None, + } + errors = item.setdefault("errors", []) + if isinstance(errors, list): + errors.append(error) + item["last_error"] = error + item["status"] = "error" + return + + def set_platform_stats(self, platform_id: str, stats: dict[str, Any]) -> None: + for item in self._platform_instances: + if str(item.get("id", "")) != str(platform_id): + continue + item["stats"] = dict(stats) + return + + def set_session_plugin_config( + self, + session_id: str, + *, + enabled_plugins: list[str] | None = None, + disabled_plugins: list[str] | None = None, + ) -> None: + config: dict[str, Any] = {} + if enabled_plugins is not None: + config["enabled_plugins"] = [str(item) for item in enabled_plugins] + if disabled_plugins is not None: + config["disabled_plugins"] = [str(item) for item in disabled_plugins] + self._session_plugin_configs[str(session_id)] = config + + def set_session_service_config( + self, + session_id: str, + *, + llm_enabled: bool | None = None, + tts_enabled: bool | None = None, + ) -> None: + config: dict[str, Any] = {} + if llm_enabled is not None: + config["llm_enabled"] = bool(llm_enabled) + if tts_enabled is not None: + config["tts_enabled"] = bool(tts_enabled) + self._session_service_configs[str(session_id)] = config + + def remove_http_apis_for_plugin(self, plugin_id: str) -> None: + self.http_api_store = [ + entry + for entry in self.http_api_store + if entry.get("plugin_id") != plugin_id + ] + + @staticmethod + def _require_caller_plugin_id(capability_name: str) -> str: + caller_plugin_id = current_caller_plugin_id() + if caller_plugin_id: + return caller_plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} 只能在插件运行时上下文中调用" + ) + + def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None: + event = {"op": op, "key": key, "value": value} + for prefix, queue in list(self._db_watch_subscriptions.values()): + if prefix is not None and not key.startswith(prefix): + continue + queue.put_nowait(event) + + def descriptors(self) -> list[CapabilityDescriptor]: + return [ + entry.descriptor for entry in self._registrations.values() if entry.exposed + ] + + def all_descriptors(self) -> list[CapabilityDescriptor]: + return [entry.descriptor for entry in self._registrations.values()] + + def contains(self, name: str) -> bool: + return name in self._registrations + + def unregister(self, name: str) -> None: + self._registrations.pop(name, None) + + def register( + self, + descriptor: CapabilityDescriptor, + *, + call_handler: CallHandler | None = None, + stream_handler: StreamHandler | None = None, + finalize: FinalizeHandler | None = None, + exposed: bool = True, + ) -> None: + is_internal_reserved = not exposed and descriptor.name.startswith( + RESERVED_CAPABILITY_PREFIXES + ) + if ( + not CAPABILITY_NAME_PATTERN.fullmatch(descriptor.name) + and not is_internal_reserved + ): + raise ValueError( + f"capability 名称必须匹配 {{namespace}}.{{method}}:{descriptor.name}" + ) + if exposed and descriptor.name.startswith(RESERVED_CAPABILITY_PREFIXES): + raise ValueError( + f"保留 capability 命名空间仅供框架内部使用:{descriptor.name}" + ) + self._registrations[descriptor.name] = _CapabilityRegistration( + descriptor=descriptor, + call_handler=call_handler, + stream_handler=stream_handler, + finalize=finalize, + exposed=exposed, + ) + + async def execute( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool, + cancel_token, + request_id: str, + ) -> dict[str, Any] | StreamExecution: + registration = self._registrations.get(capability) + if registration is None: + raise AstrBotError.capability_not_found(capability) + + self._validate_schema_with_context( + capability=capability, + phase="输入", + schema=registration.descriptor.input_schema, + payload=payload, + ) + if stream: + if registration.stream_handler is None: + raise AstrBotError.invalid_input(f"{capability} 不支持 stream=true") + raw_execution = registration.stream_handler( + request_id, payload, cancel_token + ) + if inspect.isawaitable(raw_execution): + raw_execution = await raw_execution + if isinstance(raw_execution, StreamExecution): + return self._wrap_stream_execution( + registration.descriptor, + raw_execution, + ) + finalize = registration.finalize or (lambda chunks: {"items": chunks}) + return self._wrap_stream_execution( + registration.descriptor, + StreamExecution( + iterator=raw_execution, + finalize=finalize, + ), + ) + + if registration.call_handler is None: + raise AstrBotError.invalid_input( + f"{capability} 只能以 stream=true 调用,registration.call_handler 为 None" + ) + output = await registration.call_handler(request_id, payload, cancel_token) + self._validate_schema_with_context( + capability=capability, + phase="输出", + schema=registration.descriptor.output_schema, + payload=output, + ) + return output + + def _wrap_stream_execution( + self, + descriptor: CapabilityDescriptor, + execution: StreamExecution, + ) -> StreamExecution: + def validated_finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]: + output = execution.finalize(chunks) + self._validate_schema_with_context( + capability=descriptor.name, + phase="输出", + schema=descriptor.output_schema, + payload=output, + ) + return output + + return StreamExecution( + iterator=execution.iterator, + finalize=validated_finalize, + collect_chunks=execution.collect_chunks, + ) + + # ------------------------------------------------------------------ + # Schema validation + # ------------------------------------------------------------------ + + def _validate_schema( + self, + schema: dict[str, Any] | None, + payload: Any, + ) -> None: + if not isinstance(schema, dict) or not schema: + return + self._validate_value(schema, payload, path="") + + def _validate_schema_with_context( + self, + *, + capability: str, + phase: str, + schema: dict[str, Any] | None, + payload: Any, + ) -> None: + try: + self._validate_schema(schema, payload) + except AstrBotError as exc: + if exc.code != "invalid_input": + raise + raise AstrBotError.invalid_input( + f"capability '{capability}' 的{phase}校验失败:{exc.message}", + hint=( + f"请检查 capability '{capability}' 的{phase.lower()}是否符合声明的 schema" + ), + ) from exc + + def _validate_value( + self, + schema: dict[str, Any], + value: Any, + *, + path: str, + ) -> None: + any_of = schema.get("anyOf") + if isinstance(any_of, list): + for candidate in any_of: + if not isinstance(candidate, dict): + continue + try: + self._validate_value(candidate, value, path=path) + return + except AstrBotError: + continue + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 不符合允许的 schema 约束," + f"实际收到 {self._value_type_name(value)}" + ) + + enum = schema.get("enum") + if isinstance(enum, list) and value not in enum: + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 {enum},实际收到 {value!r}" + ) + + schema_type = schema.get("type") + if schema_type == "object": + if not isinstance(value, dict): + if not path: + raise AstrBotError.invalid_input( + f"输入必须是 object,实际收到 {self._value_type_name(value)}" + ) + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 object," + f"实际收到 {self._value_type_name(value)}" + ) + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + for field_name in required_fields: + field_path = self._join_path(path, str(field_name)) + if field_name not in value: + raise AstrBotError.invalid_input(f"缺少必填字段:{field_path}") + field_schema = self._property_schema(properties, field_name) + if value[field_name] is None and not self._schema_allows_null( + field_schema + ): + raise AstrBotError.invalid_input(f"缺少必填字段:{field_path}") + self._validate_value( + field_schema, + value[field_name], + path=field_path, + ) + for field_name, field_value in value.items(): + field_schema = properties.get(field_name) + if isinstance(field_schema, dict): + self._validate_value( + field_schema, + field_value, + path=self._join_path(path, str(field_name)), + ) + return + + if schema_type == "array": + if not isinstance(value, list): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 array," + f"实际收到 {self._value_type_name(value)}" + ) + item_schema = schema.get("items") + if isinstance(item_schema, dict): + for index, item in enumerate(value): + self._validate_value( + item_schema, + item, + path=self._index_path(path, index), + ) + return + + if schema_type == "string": + if not isinstance(value, str): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 string," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "integer": + if not isinstance(value, int) or isinstance(value, bool): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 integer," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "number": + if not isinstance(value, (int, float)) or isinstance(value, bool): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 number," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "boolean": + if not isinstance(value, bool): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 boolean," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "null": + if value is not None: + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 null," + f"实际收到 {self._value_type_name(value)}" + ) + return + + @staticmethod + def _field_label(path: str) -> str: + if not path: + return "输入" + return f"字段 {path}" + + @staticmethod + def _join_path(path: str, field_name: str) -> str: + if not path: + return field_name + return f"{path}.{field_name}" + + @staticmethod + def _index_path(path: str, index: int) -> str: + return f"{path}[{index}]" if path else f"[{index}]" + + @staticmethod + def _property_schema( + properties: Any, + field_name: str, + ) -> dict[str, Any]: + if not isinstance(properties, dict): + return {} + field_schema = properties.get(field_name) + if isinstance(field_schema, dict): + return field_schema + return {} + + @staticmethod + def _schema_allows_null(field_schema: Any) -> bool: + if not isinstance(field_schema, dict): + return False + if field_schema.get("type") == "null": + return True + any_of = field_schema.get("anyOf") + if not isinstance(any_of, list): + return False + return any( + isinstance(candidate, dict) and candidate.get("type") == "null" + for candidate in any_of + ) + + @staticmethod + def _value_type_name(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "boolean" + if isinstance(value, int): + return "integer" + if isinstance(value, float): + return "number" + if isinstance(value, str): + return "string" + if isinstance(value, list): + return "array" + if isinstance(value, dict): + return "object" + return type(value).__name__ diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py b/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py new file mode 100644 index 0000000000..6503cb842d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py @@ -0,0 +1,675 @@ +"""astrbot-sdk runtime 的插件共享环境规划模块。 + +这个模块负责“多个插件,共享较少数量 Python 环境”的策略。核心约束是: + +- 插件仍然独立发现、独立加载 +- Worker 运行时既可以是一插件一进程,也可以由 GroupWorkerRuntime 在同一进程承载多个插件 +- 只有在依赖兼容时才共享 Python 环境 + +整体流程如下: + +1. 先按插件声明的 `runtime.python` 分桶 +2. 再按依赖兼容性构建候选分组 +3. 为每个分组在 `.astrbot/` 下落地 source、lock、metadata 和 venv 路径 +4. 在 worker 启动前准备或同步该分组的共享环境 + +当前阶段优先保证兼容性,因此仍保留 `--system-site-packages`,也不改变 +现有插件 manifest 语义。 +""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .loader import PluginSpec + +GROUP_STATE_FILE_NAME = ".group-venv-state.json" + +_EXACT_PIN_PATTERN = re.compile(r"^([A-Za-z0-9_.-]+)==([^\s;]+)$") +_NORMALIZE_PATTERN = re.compile(r"[-_.]+") +_PYVENV_VERSION_PATTERN = re.compile( + r"^(?:version|version_info)\s*=\s*(\d+\.\d+)(?:\.\d+)?\s*$", + re.IGNORECASE | re.MULTILINE, +) + + +def _require_uv_binary(uv_binary: str | None) -> str: + if not uv_binary: + raise RuntimeError("uv executable not found") + return uv_binary + + +def _venv_python_path(venv_path: Path) -> Path: + if os.name == "nt": + return venv_path / "Scripts" / "python.exe" + return venv_path / "bin" / "python" + + +def _normalize_package_name(name: str) -> str: + return _NORMALIZE_PATTERN.sub("-", name).lower() + + +def _read_pyvenv_major_minor(pyvenv_cfg: Path) -> str | None: + if not pyvenv_cfg.exists(): + return None + try: + content = pyvenv_cfg.read_text(encoding="utf-8") + except OSError: + return None + match = _PYVENV_VERSION_PATTERN.search(content) + if match is None: + return None + return match.group(1) + + +def _requirement_lines(plugin: PluginSpec) -> list[str]: + if not plugin.requirements_path.exists(): + return [] + + lines: list[str] = [] + for raw_line in plugin.requirements_path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + lines.append(line) + return lines + + +@dataclass(slots=True) +class EnvironmentGroup: + """一个或多个兼容插件最终共享的环境描述。 + + 分组是环境复用的最小单位。`plugins` 中的所有插件都会使用同一个 + `python_path`、lockfile 和 venv 目录,但运行时仍然各自启动独立的 + worker 进程。 + """ + + id: str + python_version: str + plugins: list[PluginSpec] + source_path: Path + lockfile_path: Path + metadata_path: Path + venv_path: Path + python_path: Path + environment_fingerprint: str + + +@dataclass(slots=True) +class EnvironmentPlanResult: + """一次完整规划得到的结果。 + + `plugins` 只包含成功完成规划的插件。 + `skipped_plugins` 记录规划失败的插件及原因,这类插件即使单独成组也没 + 有得到可用的共享环境。 + """ + + groups: list[EnvironmentGroup] = field(default_factory=list) + plugins: list[PluginSpec] = field(default_factory=list) + plugin_to_group: dict[str, EnvironmentGroup] = field(default_factory=dict) + skipped_plugins: dict[str, str] = field(default_factory=dict) + + +class EnvironmentPlanner: + """负责共享环境规划和分组工件落地。 + + 对 supervisor 启动来说,这个类主要回答两个问题: + + - 哪些插件可以共享一个环境 + - 这个共享环境应该对应哪份 lockfile 和哪个 venv 路径 + + 它本身不负责真正创建或同步 venv,这部分在规划结束后交给 + `GroupEnvironmentManager` 处理。 + """ + + def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None: + self.repo_root = repo_root.resolve() + self.uv_binary = uv_binary or shutil.which("uv") + self.cache_dir = self.repo_root / ".uv-cache" + self.artifacts_dir = self.repo_root / ".astrbot" + self.group_dir = self.artifacts_dir / "groups" + self.lock_dir = self.artifacts_dir / "locks" + self.env_dir = self.artifacts_dir / "envs" + self._compatibility_cache: dict[str, bool] = {} + + def plan(self, plugins: list[PluginSpec]) -> EnvironmentPlanResult: + """为当前插件集合生成稳定的共享环境规划。 + + 之所以在 worker 启动前完成规划,是为了让 supervisor 能够: + + - 只跳过依赖无法满足的那部分插件 + - 在兼容插件之间复用同一个环境 + - 清理旧规划遗留的 `.astrbot` 工件 + """ + if not plugins: + self.cleanup_artifacts([]) + return EnvironmentPlanResult() + _require_uv_binary(self.uv_binary) + + candidate_groups = self._build_candidate_groups(plugins) + planned_groups: list[EnvironmentGroup] = [] + skipped_plugins: dict[str, str] = {} + for group_plugins in candidate_groups: + materialized, skipped = self._materialize_candidate_group(group_plugins) + planned_groups.extend(materialized) + skipped_plugins.update(skipped) + + planned_groups.sort(key=lambda group: (group.python_version, group.id)) + self.cleanup_artifacts(planned_groups) + + plugin_to_group = { + plugin.name: group for group in planned_groups for plugin in group.plugins + } + planned_plugins = [ + plugin for plugin in plugins if plugin.name in plugin_to_group + ] + return EnvironmentPlanResult( + groups=planned_groups, + plugins=planned_plugins, + plugin_to_group=plugin_to_group, + skipped_plugins=skipped_plugins, + ) + + def _build_candidate_groups( + self, plugins: list[PluginSpec] + ) -> list[list[PluginSpec]]: + """用贪心方式把插件装入兼容性候选组。 + + 分组过程保持确定性,规则是: + + - Python 版本是第一层硬边界 + - `requirements.txt` 约束更多的插件优先落位 + - 若仍相同,则按插件名排序 + """ + buckets: dict[str, list[PluginSpec]] = {} + for plugin in plugins: + buckets.setdefault(plugin.python_version, []).append(plugin) + + planned_groups: list[list[PluginSpec]] = [] + for python_version in sorted(buckets): + python_groups: list[list[PluginSpec]] = [] + for plugin in self._sort_plugins(buckets[python_version]): + placed = False + for group_plugins in python_groups: + if self._is_compatible([*group_plugins, plugin]): + group_plugins.append(plugin) + placed = True + break + if not placed: + python_groups.append([plugin]) + planned_groups.extend(python_groups) + return planned_groups + + @staticmethod + def _sort_plugins(plugins: list[PluginSpec]) -> list[PluginSpec]: + return sorted( + plugins, + key=lambda plugin: (-len(_requirement_lines(plugin)), plugin.name), + ) + + def _is_compatible(self, plugins: list[PluginSpec]) -> bool: + """判断一组插件是否可以共享一个环境。 + + 兼容性判断先走一个便宜的快速路径: + + - 如果每条 requirement 都是 `pkg==1.2.3` 这种精确版本锁定 + - 且归一化后的包名之间没有解析出冲突版本 + - 那么无需调用求解器,直接认为这一组兼容 + + 更复杂的情况则回退到 `uv pip compile`,以它的求解结果作为最终依 + 赖兼容性的判断依据。 + """ + cache_key = self._compatibility_cache_key(plugins) + cached = self._compatibility_cache.get(cache_key) + if cached is not None: + return cached + + requirement_lines = self._collect_requirement_lines(plugins) + if not requirement_lines: + self._compatibility_cache[cache_key] = True + return True + + if self._merge_exact_requirements(requirement_lines) is not None: + self._compatibility_cache[cache_key] = True + return True + + with tempfile.TemporaryDirectory( + prefix="astrbot-env-plan-", + dir=self.repo_root, + ) as temp_dir: + source_path = Path(temp_dir) / "compat.in" + output_path = Path(temp_dir) / "compat.txt" + self._write_source_file(source_path, plugins) + try: + self._compile_lockfile( + source_path=source_path, + output_path=output_path, + python_version=plugins[0].python_version, + ) + except RuntimeError: + self._compatibility_cache[cache_key] = False + return False + + self._compatibility_cache[cache_key] = True + return True + + def _materialize_candidate_group( + self, + plugins: list[PluginSpec], + ) -> tuple[list[EnvironmentGroup], dict[str, str]]: + """为一个候选组创建工件,失败时自动拆分。 + + 如果整组插件无法生成 lockfile,规划器会退回到“一插件一组”继续尝 + 试,避免单个坏插件阻塞整批插件启动。 + """ + try: + return [self._materialize_group(plugins)], {} + except RuntimeError as exc: + if len(plugins) == 1: + return [], {plugins[0].name: str(exc)} + + materialized: list[EnvironmentGroup] = [] + skipped: dict[str, str] = {} + for plugin in plugins: + groups, child_skipped = self._materialize_candidate_group([plugin]) + materialized.extend(groups) + skipped.update(child_skipped) + return materialized, skipped + + def _materialize_group(self, plugins: list[PluginSpec]) -> EnvironmentGroup: + """落地定义一个共享环境所需的全部文件。 + + 分组身份由 Python 版本和插件集合共同决定。 + 环境指纹则会进一步包含编译后的 lockfile 内容,这样当依赖解析结果 + 变化时,已有环境就可以走增量同步而不是盲目重建。 + """ + group_id = self._group_identity(plugins)[:16] + python_version = plugins[0].python_version + source_path = self.group_dir / f"{group_id}.in" + lockfile_path = self.lock_dir / f"{group_id}.txt" + metadata_path = self.group_dir / f"{group_id}.json" + venv_path = self.env_dir / group_id + python_path = _venv_python_path(venv_path) + + source_path.parent.mkdir(parents=True, exist_ok=True) + lockfile_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + venv_path.parent.mkdir(parents=True, exist_ok=True) + + self._write_source_file(source_path, plugins) + self._write_lockfile( + lockfile_path=lockfile_path, + source_path=source_path, + plugins=plugins, + python_version=python_version, + ) + environment_fingerprint = self._environment_fingerprint( + plugins=plugins, + python_version=python_version, + lockfile_path=lockfile_path, + ) + metadata_path.write_text( + json.dumps( + { + "group_id": group_id, + "python_version": python_version, + "plugins": [plugin.name for plugin in plugins], + "plugin_entries": [ + { + "name": plugin.name, + "plugin_dir": str(plugin.plugin_dir), + } + for plugin in plugins + ], + "source_path": str(source_path), + "lockfile_path": str(lockfile_path), + "venv_path": str(venv_path), + "environment_fingerprint": environment_fingerprint, + }, + ensure_ascii=True, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + return EnvironmentGroup( + id=group_id, + python_version=python_version, + plugins=list(plugins), + source_path=source_path, + lockfile_path=lockfile_path, + metadata_path=metadata_path, + venv_path=venv_path, + python_path=python_path, + environment_fingerprint=environment_fingerprint, + ) + + def _write_source_file(self, source_path: Path, plugins: list[PluginSpec]) -> None: + """写入供 lockfile 生成使用的分组 requirements 输入文件。""" + lines: list[str] = [] + for plugin in sorted(plugins, key=lambda item: item.name): + requirements = _requirement_lines(plugin) + if not requirements: + continue + lines.append(f"# {plugin.name}") + lines.extend(requirements) + lines.append("") + + content = "\n".join(lines).rstrip() + if content: + content += "\n" + source_path.write_text(content, encoding="utf-8") + + def _write_lockfile( + self, + *, + lockfile_path: Path, + source_path: Path, + plugins: list[PluginSpec], + python_version: str, + ) -> None: + """为一个分组生成 lockfile。 + + 即使依赖集合为空,也会故意生成空 lockfile,这样整个共享环境流水 + 线的处理方式可以保持一致。 + """ + if not self._collect_requirement_lines(plugins): + lockfile_path.write_text("", encoding="utf-8") + return + + self._compile_lockfile( + source_path=source_path, + output_path=lockfile_path, + python_version=python_version, + ) + + def _compile_lockfile( + self, + *, + source_path: Path, + output_path: Path, + python_version: str, + ) -> None: + """把依赖求解委托给 `uv pip compile`。""" + uv_binary = _require_uv_binary(self.uv_binary) + self._run_command( + [ + uv_binary, + "pip", + "compile", + "--python-version", + python_version, + "--no-managed-python", + "--no-python-downloads", + "--quiet", + str(source_path), + "-o", + str(output_path), + ], + cwd=self.repo_root, + command_name=f"compile lockfile for {source_path.name}", + ) + + def _run_command(self, command: list[str], *, cwd: Path, command_name: str) -> None: + process = subprocess.run( + command, + cwd=str(cwd), + env={**os.environ, "UV_CACHE_DIR": str(self.cache_dir)}, + capture_output=True, + text=True, + check=False, + ) + if process.returncode != 0: + raise RuntimeError( + f"{command_name} failed with exit code {process.returncode}: " + f"{process.stderr.strip() or process.stdout.strip()}" + ) + + def cleanup_artifacts(self, groups: list[EnvironmentGroup]) -> None: + """清理不再被当前规划引用的 `.astrbot` 工件。 + + 清理范围只覆盖规划器自己维护的共享环境工件,不会碰旧式插件目录下 + 的本地 `.venv`。 + """ + active_group_ids = {group.id for group in groups} + self._cleanup_group_artifacts(active_group_ids) + self._cleanup_lockfiles(active_group_ids) + self._cleanup_envs(active_group_ids) + + def _cleanup_group_artifacts(self, active_group_ids: set[str]) -> None: + if not self.group_dir.exists(): + return + for entry in self.group_dir.iterdir(): + if entry.suffix not in {".in", ".json"}: + continue + if entry.stem in active_group_ids: + continue + entry.unlink(missing_ok=True) + + def _cleanup_lockfiles(self, active_group_ids: set[str]) -> None: + if not self.lock_dir.exists(): + return + for entry in self.lock_dir.iterdir(): + if entry.suffix != ".txt": + continue + if entry.stem in active_group_ids: + continue + entry.unlink(missing_ok=True) + + def _cleanup_envs(self, active_group_ids: set[str]) -> None: + if not self.env_dir.exists(): + return + for entry in self.env_dir.iterdir(): + if entry.name in active_group_ids: + continue + if entry.is_dir(): + shutil.rmtree(entry) + else: + entry.unlink(missing_ok=True) + + def _compatibility_cache_key(self, plugins: list[PluginSpec]) -> str: + payload = { + "python_version": plugins[0].python_version if plugins else "", + "plugins": [ + { + "name": plugin.name, + "requirements": _requirement_lines(plugin), + } + for plugin in sorted(plugins, key=lambda item: item.name) + ], + } + encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _group_identity(plugins: list[PluginSpec]) -> str: + payload = { + "python_version": plugins[0].python_version if plugins else "", + "plugins": sorted(plugin.name for plugin in plugins), + } + encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _environment_fingerprint( + *, + plugins: list[PluginSpec], + python_version: str, + lockfile_path: Path, + ) -> str: + payload = { + "python_version": python_version, + "plugins": sorted(plugin.name for plugin in plugins), + "lockfile": lockfile_path.read_text(encoding="utf-8"), + } + encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _collect_requirement_lines(plugins: list[PluginSpec]) -> list[str]: + lines: list[str] = [] + for plugin in plugins: + lines.extend(_requirement_lines(plugin)) + return lines + + @staticmethod + def _merge_exact_requirements(requirement_lines: list[str]) -> list[str] | None: + merged: dict[str, str] = {} + for line in requirement_lines: + match = _EXACT_PIN_PATTERN.fullmatch(line) + if match is None: + return None + package_name = _normalize_package_name(match.group(1)) + existing = merged.get(package_name) + if existing is not None and existing != line: + return None + merged[package_name] = line + return [merged[name] for name in sorted(merged)] + + +class GroupEnvironmentManager: + """负责创建、校验和同步一个已经规划好的共享环境。""" + + def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None: + self.repo_root = repo_root.resolve() + self.uv_binary = uv_binary or shutil.which("uv") + self.cache_dir = self.repo_root / ".uv-cache" + + def prepare(self, group: EnvironmentGroup) -> Path: + """确保分组对应的解释器路径已经可以用于 worker 启动。 + + 行为概括如下: + + - 环境缺失、Python 版本不对、lockfile 丢失:重建 + - 环境结构还在但指纹变化:执行 `uv pip sync` + - 否则:直接复用现有解释器路径 + """ + _require_uv_binary(self.uv_binary) + + state_path = group.venv_path / GROUP_STATE_FILE_NAME + state = self._load_state(state_path) + if ( + not group.python_path.exists() + or not self._matches_python_version(group.venv_path, group.python_version) + or not group.lockfile_path.exists() + ): + self._rebuild(group) + self._write_state(state_path, group) + elif not self._state_matches_group(state, group): + self._sync_existing(group) + self._write_state(state_path, group) + return group.python_path + + def _rebuild(self, group: EnvironmentGroup) -> None: + if group.venv_path.exists(): + shutil.rmtree(group.venv_path) + self._create_venv(group) + self._sync_lockfile(group) + + def _sync_existing(self, group: EnvironmentGroup) -> None: + self._sync_lockfile(group) + + def _sync_lockfile(self, group: EnvironmentGroup) -> None: + """让已安装包与该分组的 lockfile 精确对齐。""" + uv_binary = _require_uv_binary(self.uv_binary) + self._run_command( + [ + uv_binary, + "pip", + "sync", + "--python", + str(group.python_path), + "--allow-empty-requirements", + str(group.lockfile_path), + ], + cwd=self.repo_root, + command_name=f"sync group env {group.id}", + ) + + def _create_venv(self, group: EnvironmentGroup) -> None: + """为一个分组创建共享 venv。 + + 当前迁移阶段仍保留 `--system-site-packages`,以兼容那些仍然隐式依 + 赖宿主环境包的旧插件。 + """ + uv_binary = _require_uv_binary(self.uv_binary) + self._run_command( + [ + uv_binary, + "venv", + "--python", + group.python_version, + "--system-site-packages", + "--no-python-downloads", + "--no-managed-python", + str(group.venv_path), + ], + cwd=self.repo_root, + command_name=f"create group venv {group.id}", + ) + + def _run_command(self, command: list[str], *, cwd: Path, command_name: str) -> None: + process = subprocess.run( + command, + cwd=str(cwd), + env={**os.environ, "UV_CACHE_DIR": str(self.cache_dir)}, + capture_output=True, + text=True, + check=False, + ) + if process.returncode != 0: + raise RuntimeError( + f"{command_name} failed with exit code {process.returncode}: " + f"{process.stderr.strip() or process.stdout.strip()}" + ) + + @staticmethod + def _matches_python_version(venv_path: Path, version: str) -> bool: + return _read_pyvenv_major_minor(venv_path / "pyvenv.cfg") == version + + @staticmethod + def _load_state(state_path: Path) -> dict[str, object]: + if not state_path.exists(): + return {} + try: + data = json.loads(state_path.read_text(encoding="utf-8")) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + @staticmethod + def _write_state(state_path: Path, group: EnvironmentGroup) -> None: + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text( + json.dumps( + { + "group_id": group.id, + "python_version": group.python_version, + "environment_fingerprint": group.environment_fingerprint, + "plugins": [plugin.name for plugin in group.plugins], + }, + ensure_ascii=True, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + @staticmethod + def _state_matches_group(state: dict[str, object], group: EnvironmentGroup) -> bool: + return ( + state.get("group_id") == group.id + and state.get("python_version") == group.python_version + and state.get("environment_fingerprint") == group.environment_fingerprint + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py b/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py new file mode 100644 index 0000000000..f92b296398 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py @@ -0,0 +1,990 @@ +"""处理器分发模块。 + +定义 HandlerDispatcher 类,负责将能力调用分发到具体的处理器函数。 +支持参数注入、流式执行、错误处理。 + +核心职责: + - 根据处理器 ID 查找处理器 + - 构建处理器参数(支持类型注解注入) + - 执行处理器并处理结果 + - 处理异步生成器流式结果 + - 统一的错误处理 + +参数注入优先级: + 1. 按类型注解注入(支持 Optional[Type]) + 2. 按参数名注入(兼容无类型注解) + 3. 从 args 注入(命令参数等) + +支持的注入类型: + - MessageEvent: 消息事件 + - Context: 运行时上下文 +""" + +from __future__ import annotations + +import asyncio +import inspect +import re +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, cast, get_type_hints + +from .._internal.command_model import ( + parse_command_model_remainder, + resolve_command_model_param, +) +from .._internal.injected_params import legacy_arg_parameter_names +from .._internal.invocation_context import caller_plugin_scope +from .._internal.plugin_logger import PluginLogger +from .._internal.sdk_logger import logger +from .._internal.star_runtime import bind_star_runtime +from .._internal.typing_utils import unwrap_optional +from ..clients.llm import LLMResponse +from ..context import CancelToken, Context +from ..conversation import ( + DEFAULT_BUSY_MESSAGE, + ConversationClosed, + ConversationReplaced, + ConversationSession, + ConversationState, +) +from ..events import MessageEvent +from ..filters import LocalFilterBinding +from ..llm.entities import ProviderRequest +from ..message.components import BaseMessageComponent +from ..message.result import ( + MessageChain, + MessageEventResult, + coerce_message_chain, +) +from ..protocol.descriptors import ( + CommandTrigger, + MessageTrigger, + ParamSpec, + ScheduleTrigger, +) +from ..schedule import ScheduleContext +from ..session_waiter import ( + SessionWaiterManager, + _mark_session_waiter_handler_task, + _unmark_session_waiter_handler_task, +) +from ..star import Star +from ._command_matching import ( + build_command_args, + build_regex_args, + match_command_name, +) +from .capability_dispatcher import CapabilityDispatcher +from .limiter import LimiterEngine +from .loader import LoadedHandler + + +@dataclass(slots=True) +class _ActiveConversation: + session: ConversationSession + task: asyncio.Task[Any] + + +@dataclass(slots=True) +class _InjectedEventPayloads: + provider_request: ProviderRequest | None = None + llm_response: LLMResponse | None = None + event_result: MessageEventResult | None = None + + +class HandlerDispatcher: + def __init__( + self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler] + ) -> None: + self._plugin_id = plugin_id + self._peer = peer + self._handlers = {item.descriptor.id: item for item in handlers} + self._active: dict[str, tuple[asyncio.Task[Any], CancelToken]] = {} + self._session_waiters = SessionWaiterManager(plugin_id=plugin_id, peer=peer) + self._limiter = LimiterEngine() + self._conversations: dict[str, _ActiveConversation] = {} + try: + setattr(peer, "_session_waiter_manager", self._session_waiters) + except AttributeError: + logger.warning( + f"Failed to attach _session_waiter_manager to peer {peer}, " + "some features may not work as expected" + ) + + def has_active_waiter(self, event: MessageEvent) -> bool: + return self._session_waiters.has_active_waiter(event) + + async def invoke(self, message, cancel_token: CancelToken) -> dict[str, Any]: + handler_id = str(message.input.get("handler_id", "")) + if handler_id == "__sdk_session_waiter__": + event_payload = message.input.get("event", {}) + requested_plugin_id = str(message.input.get("plugin_id") or "").strip() + ctx = Context( + peer=self._peer, + plugin_id=requested_plugin_id or self._plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + event = MessageEvent.from_payload(event_payload, context=ctx) + session_key = event.unified_msg_origin + if requested_plugin_id: + plugin_id = requested_plugin_id + else: + plugin_ids = self._session_waiters.get_waiter_plugin_ids(session_key) + if len(plugin_ids) > 1: + raise LookupError( + "multiple active session_waiters found for session; " + "dispatch requires explicit plugin identity" + ) + plugin_id = plugin_ids[0] if plugin_ids else self._plugin_id + if plugin_id != ctx.plugin_id: + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + event = MessageEvent.from_payload(event_payload, context=ctx) + event.bind_reply_handler(self._create_reply_handler(ctx, event)) + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._session_waiters.dispatch(event, plugin_id=plugin_id) + ) + _mark_session_waiter_handler_task(task) + task.add_done_callback(_unmark_session_waiter_handler_task) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + loaded = self._handlers.get(handler_id) + if loaded is None: + raise LookupError(f"handler not found: {handler_id}") + + plugin_id = self._resolve_plugin_id(loaded) + event_payload = message.input.get("event", {}) + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + event = MessageEvent.from_payload(event_payload, context=ctx) + bound_logger = cast(PluginLogger, ctx.logger).bind( + plugin_id=plugin_id, + request_id=message.id, + handler_ref=handler_id, + session_id=event.session_id, + event_type=str( + event_payload.get("event_type") + or event_payload.get("type") + or event.message_type + ), + ) + ctx.logger = bound_logger + event.bind_reply_handler(self._create_reply_handler(ctx, event)) + schedule_context = self._build_schedule_context(loaded, event_payload) + + # 提取 args 用于兼容 handler 签名 + raw_args = message.input.get("args") or {} + args = dict(raw_args) if isinstance(raw_args, dict) else {} + if not args: + args = self._derive_args(loaded, event) + + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._run_handler( + loaded, + event, + ctx, + args, + schedule_context=schedule_context, + ) + ) + _mark_session_waiter_handler_task(task) + task.add_done_callback(_unmark_session_waiter_handler_task) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + def _resolve_plugin_id(self, loaded: LoadedHandler) -> str: + if loaded.plugin_id: + return loaded.plugin_id + handler_id = getattr(loaded.descriptor, "id", "") + if isinstance(handler_id, str) and ":" in handler_id: + return handler_id.split(":", 1)[0] + return self._plugin_id + + def _create_reply_handler(self, ctx: Context, event: MessageEvent): + async def reply(text: str) -> None: + try: + await ctx.platform.send(event.session_ref or event.session_id, text) + except TypeError: + send = getattr(self._peer, "send", None) + if not callable(send): + raise + result = send(event.session_id, text) + if inspect.isawaitable(result): + await result + + return reply + + async def cancel(self, request_id: str) -> None: + active = self._active.get(request_id) + if active is None: + return + task, cancel_token = active + cancel_token.cancel() + task.cancel() + + async def _run_handler( + self, + loaded: LoadedHandler, + event: MessageEvent, + ctx: Context, + args: dict[str, Any] | None = None, + *, + schedule_context: ScheduleContext | None = None, + ) -> dict[str, Any]: + summary = {"sent_message": False, "stop": False, "call_llm": False} + injected_payloads = _InjectedEventPayloads() + event_type = self._event_type_name(event) + try: + limiter = loaded.limiter + if limiter is not None: + decision = self._limiter.evaluate( + plugin_id=self._resolve_plugin_id(loaded), + handler_id=loaded.descriptor.id, + limiter=limiter, + event=event, + ) + if not decision.allowed: + if decision.error is not None: + raise decision.error + if decision.hint: + await event.reply(decision.hint) + summary["sent_message"] = True + return summary + if not self._run_local_filters( + loaded.local_filters, + event=event, + ctx=ctx, + ): + return summary + parsed_args, help_text = self._prepare_handler_args( + loaded, + args or {}, + ) + if help_text is not None: + await event.reply(help_text) + summary["sent_message"] = True + return summary + if loaded.conversation is not None: + return await self._start_conversation( + loaded, + event, + ctx, + parsed_args, + schedule_context=schedule_context, + ) + owner = loaded.owner if isinstance(loaded.owner, Star) else None + with bind_star_runtime(owner, ctx): + result = loaded.callable( + *self._build_args( + loaded.callable, + event, + ctx, + parsed_args, + plugin_id=self._resolve_plugin_id(loaded), + handler_ref=loaded.descriptor.id, + schedule_context=schedule_context, + injected_payloads=injected_payloads, + ) + ) + if inspect.isasyncgen(result): + async for item in result: + self._merge_handler_summary( + summary, + await self._handle_result_item(item, event, ctx), + ) + summary["stop"] = bool(summary.get("stop")) or event.is_stopped() + self._append_injected_payloads( + summary, + injected_payloads, + event=event, + event_type=event_type, + ) + return summary + if inspect.isawaitable(result): + result = await result + if result is not None: + self._merge_handler_summary( + summary, + await self._handle_result_item(result, event, ctx), + ) + summary["stop"] = bool(summary.get("stop")) or event.is_stopped() + self._append_injected_payloads( + summary, + injected_payloads, + event=event, + event_type=event_type, + ) + return summary + except Exception as exc: + await self._handle_error( + loaded.owner, + exc, + event, + ctx, + handler_name=loaded.callable.__name__, + plugin_id=self._resolve_plugin_id(loaded), + ) + raise + + def _derive_args( + self, + loaded: LoadedHandler, + event: MessageEvent, + ) -> dict[str, Any]: + trigger = loaded.descriptor.trigger + if isinstance(trigger, CommandTrigger): + param_specs = loaded.descriptor.param_specs + for command_name in [trigger.command, *trigger.aliases]: + remainder = match_command_name(event.text, command_name) + if remainder is not None: + model_param = resolve_command_model_param(loaded.callable) + if model_param is not None: + return { + "__command_model_remainder__": remainder, + "__command_name__": command_name, + } + if param_specs: + return build_command_args(param_specs, remainder) + return build_command_args( + [ + ParamSpec(name=name, type="str") + for name in legacy_arg_parameter_names(loaded.callable) + ], + remainder, + ) + return {} + if isinstance(trigger, MessageTrigger) and trigger.regex: + match = re.search(trigger.regex, event.text) + if match is None: + return {} + if loaded.descriptor.param_specs: + return build_regex_args(loaded.descriptor.param_specs, match) + return build_regex_args( + [ + ParamSpec(name=name, type="str") + for name in legacy_arg_parameter_names(loaded.callable) + ], + match, + ) + return {} + + def _build_args( + self, + handler, + event: MessageEvent, + ctx: Context, + args: dict[str, Any] | None = None, + *, + plugin_id: str | None = None, + handler_ref: str | None = None, + schedule_context: ScheduleContext | None = None, + conversation_session: ConversationSession | None = None, + injected_payloads: _InjectedEventPayloads | None = None, + ) -> list[Any]: + """构建 handler 参数列表。""" + from .._internal.sdk_logger import logger + + signature = inspect.signature(handler) + injected_args: list[Any] = [] + args = args or {} + + type_hints: dict[str, Any] = {} + try: + type_hints = get_type_hints(handler) + except Exception: + pass + + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + + injected = None + + # 1. 优先按类型注解注入 + param_type = type_hints.get(parameter.name) + if param_type is not None: + injected = self._inject_by_type( + param_type, + event, + ctx, + schedule_context, + conversation_session, + injected_payloads=injected_payloads, + ) + + # 2. Fallback 按名字注入 + if injected is None: + if parameter.name == "event": + injected = event + elif parameter.name in {"ctx", "context"}: + injected = ctx + elif parameter.name in {"sched", "schedule"}: + injected = schedule_context + elif parameter.name in {"conversation", "conv"}: + injected = conversation_session + elif parameter.name in args: + injected = args[parameter.name] + + # 3. 检查是否有默认值 + if injected is None: + if parameter.default is not parameter.empty: + continue + logger.error( + "Handler '{}' 的必填参数 '{}' 无法注入", + handler.__name__, + parameter.name, + ) + raise TypeError( + self._format_handler_injection_error( + handler=handler, + parameter_name=parameter.name, + plugin_id=plugin_id, + handler_ref=handler_ref, + args=args, + ) + ) + else: + injected_args.append(injected) + + return injected_args + + def _prepare_handler_args( + self, + loaded: LoadedHandler, + args: dict[str, Any], + ) -> tuple[dict[str, Any], str | None]: + parsed_args = ( + self._parse_handler_args(loaded.descriptor.param_specs, args) + if loaded.descriptor.param_specs + else { + key: value + for key, value in dict(args).items() + if not str(key).startswith("__command_") + } + ) + if not isinstance(loaded.descriptor.trigger, CommandTrigger): + return parsed_args, None + model_param = resolve_command_model_param(loaded.callable) + if model_param is None: + return parsed_args, None + if "__command_model_remainder__" not in args: + return parsed_args, None + trigger = loaded.descriptor.trigger + command_name = str(args.get("__command_name__", "")) or ( + trigger.command + if isinstance(trigger, CommandTrigger) + else loaded.descriptor.id.rsplit(".", 1)[-1] + ) + result = parse_command_model_remainder( + remainder=str(args.get("__command_model_remainder__", "")), + model_param=model_param, + command_name=command_name, + ) + if result.help_text is not None: + return parsed_args, result.help_text + if result.model is not None: + parsed_args[model_param.name] = result.model + return parsed_args, None + + async def _start_conversation( + self, + loaded: LoadedHandler, + event: MessageEvent, + ctx: Context, + parsed_args: dict[str, Any], + *, + schedule_context: ScheduleContext | None, + ) -> dict[str, Any]: + assert loaded.conversation is not None + conversation_meta = loaded.conversation + summary = {"sent_message": False, "stop": True, "call_llm": False} + key = f"{self._resolve_plugin_id(loaded)}:{event.session_id}" + active = self._conversations.get(key) + if active is not None and not active.task.done(): + if conversation_meta.mode == "reject": + await event.reply( + conversation_meta.busy_message or DEFAULT_BUSY_MESSAGE + ) + summary["sent_message"] = True + return summary + active.session.mark_replaced() + await self._session_waiters.fail( + active.session.session_key, + ConversationReplaced("conversation replaced by a newer session"), + ) + await asyncio.sleep(0) + active.task.cancel() + try: + await asyncio.wait_for( + asyncio.shield(active.task), + timeout=conversation_meta.grace_period, + ) + except asyncio.TimeoutError: + cast(PluginLogger, ctx.logger).warning( + "Conversation replacement grace period exceeded for handler {}", + loaded.descriptor.id, + ) + except asyncio.CancelledError: + pass + except Exception: + pass + finally: + if self._conversations.get(key) is active: + self._conversations.pop(key, None) + + conversation = ConversationSession( + ctx=ctx, + event=event, + waiter_manager=self._session_waiters, + timeout=conversation_meta.timeout, + ) + + async def _runner() -> None: + try: + await self._run_conversation_task( + loaded, + event, + ctx, + parsed_args, + conversation, + schedule_context=schedule_context, + ) + finally: + if conversation.state == ConversationState.ACTIVE: + conversation.close(ConversationState.COMPLETED) + current = self._conversations.get(key) + if current is not None and current.session is conversation: + self._conversations.pop(key, None) + + task = await ctx.register_task( + _runner(), + f"conversation:{loaded.descriptor.id}", + ) + conversation.bind_owner_task(task) + self._conversations[key] = _ActiveConversation( + session=conversation, + task=task, + ) + return summary + + async def _run_conversation_task( + self, + loaded: LoadedHandler, + event: MessageEvent, + ctx: Context, + parsed_args: dict[str, Any], + conversation: ConversationSession, + *, + schedule_context: ScheduleContext | None, + ) -> None: + owner = loaded.owner if isinstance(loaded.owner, Star) else None + args_with_conversation = dict(parsed_args) + args_with_conversation.setdefault("conversation", conversation) + try: + with bind_star_runtime(owner, ctx): + result = loaded.callable( + *self._build_args( + loaded.callable, + event, + ctx, + args_with_conversation, + plugin_id=self._resolve_plugin_id(loaded), + handler_ref=loaded.descriptor.id, + schedule_context=schedule_context, + conversation_session=conversation, + ) + ) + if inspect.isasyncgen(result): + async for item in result: + await self._handle_result_item(item, event, ctx) + return + if inspect.isawaitable(result): + result = await result + if result is not None: + await self._handle_result_item(result, event, ctx) + except asyncio.CancelledError: + if conversation.state == ConversationState.ACTIVE: + conversation.close(ConversationState.CANCELLED) + raise + except (ConversationReplaced, ConversationClosed): + return + except Exception as exc: + await self._handle_error( + loaded.owner, + exc, + event, + ctx, + handler_name=loaded.callable.__name__, + plugin_id=self._resolve_plugin_id(loaded), + ) + + def _inject_by_type( + self, + param_type: Any, + event: MessageEvent, + ctx: Context, + schedule_context: ScheduleContext | None, + conversation_session: ConversationSession | None, + *, + injected_payloads: _InjectedEventPayloads | None = None, + ) -> Any: + """根据类型注解注入参数。""" + param_type, _is_optional = unwrap_optional(param_type) + + # 注入 MessageEvent 及其子类 + if param_type is MessageEvent: + return event + if isinstance(param_type, type) and issubclass(param_type, MessageEvent): + if isinstance(event, param_type): + return event + factory = getattr(param_type, "from_message_event", None) + if callable(factory): + return factory(event) + return event + + # 注入 Context 及其子类 + if param_type is Context or ( + isinstance(param_type, type) and issubclass(param_type, Context) + ): + return ctx + if param_type is ScheduleContext or ( + isinstance(param_type, type) and issubclass(param_type, ScheduleContext) + ): + return schedule_context + if param_type is ConversationSession or ( + isinstance(param_type, type) and issubclass(param_type, ConversationSession) + ): + return conversation_session + if param_type is ProviderRequest or ( + isinstance(param_type, type) and issubclass(param_type, ProviderRequest) + ): + return self._inject_provider_request(event, injected_payloads) + if param_type is LLMResponse or ( + isinstance(param_type, type) and issubclass(param_type, LLMResponse) + ): + return self._inject_llm_response(event, injected_payloads) + if param_type is MessageEventResult or ( + isinstance(param_type, type) and issubclass(param_type, MessageEventResult) + ): + return self._inject_event_result(event, injected_payloads) + + return None + + @staticmethod + def _event_type_name(event: MessageEvent) -> str: + raw = event.raw if isinstance(event.raw, dict) else {} + value = raw.get("event_type") or raw.get("type") + return str(value or "") + + @staticmethod + def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None: + raw = event.raw if isinstance(event.raw, dict) else {} + payload = raw.get(key) + if isinstance(payload, dict): + return payload + nested_raw = raw.get("raw") + if isinstance(nested_raw, dict): + nested_payload = nested_raw.get(key) + if isinstance(nested_payload, dict): + return nested_payload + return None + + def _inject_provider_request( + self, + event: MessageEvent, + injected_payloads: _InjectedEventPayloads | None, + ) -> ProviderRequest | None: + if injected_payloads is None: + payload = self._payload_from_event(event, "provider_request") + return ( + ProviderRequest.from_payload(payload) if payload is not None else None + ) + if injected_payloads.provider_request is None: + payload = self._payload_from_event(event, "provider_request") + if payload is None: + return None + injected_payloads.provider_request = ProviderRequest.from_payload(payload) + return injected_payloads.provider_request + + def _inject_llm_response( + self, + event: MessageEvent, + injected_payloads: _InjectedEventPayloads | None, + ) -> LLMResponse | None: + if injected_payloads is None: + payload = self._payload_from_event(event, "llm_response") + return LLMResponse.model_validate(payload) if payload is not None else None + if injected_payloads.llm_response is None: + payload = self._payload_from_event(event, "llm_response") + if payload is None: + return None + injected_payloads.llm_response = LLMResponse.model_validate(payload) + return injected_payloads.llm_response + + def _inject_event_result( + self, + event: MessageEvent, + injected_payloads: _InjectedEventPayloads | None, + ) -> MessageEventResult | None: + if injected_payloads is None: + payload = self._payload_from_event(event, "event_result") + return ( + MessageEventResult.from_payload(payload) + if payload is not None + else None + ) + if injected_payloads.event_result is None: + payload = self._payload_from_event(event, "event_result") + if payload is None: + return None + injected_payloads.event_result = MessageEventResult.from_payload(payload) + return injected_payloads.event_result + + @staticmethod + def _append_injected_payloads( + summary: dict[str, Any], + injected_payloads: _InjectedEventPayloads, + *, + event: MessageEvent, + event_type: str, + ) -> None: + if ( + event_type == "llm_request" + and injected_payloads.provider_request is not None + ): + summary["provider_request"] = ( + injected_payloads.provider_request.to_payload() + ) + elif ( + event_type in {"llm_response", "agent_done"} + and injected_payloads.llm_response is not None + ): + summary["llm_response"] = injected_payloads.llm_response.model_dump( + exclude_none=True + ) + elif ( + event_type in {"decorating_result", "streaming_delta"} + and injected_payloads.event_result is not None + ): + summary["event_result"] = injected_payloads.event_result.to_payload() + if event._should_serialize_sdk_local_extras(): # noqa: SLF001 + summary["sdk_local_extras"] = event._sdk_local_extras_payload() # noqa: SLF001 + + def _format_handler_injection_error( + self, + *, + handler, + parameter_name: str, + plugin_id: str | None, + handler_ref: str | None, + args: dict[str, Any], + ) -> str: + plugin_text = plugin_id or self._plugin_id + target = handler_ref or getattr(handler, "__name__", "") + arg_keys = sorted(str(key) for key in args.keys()) + arg_keys_text = ", ".join(arg_keys) if arg_keys else "" + return ( + f"插件 '{plugin_text}' 的 handler '{target}' 参数注入失败:" + f"必填参数 '{parameter_name}' 无法注入。" + f"签名: {getattr(handler, '__name__', '')}" + f"{self._callable_signature(handler)}。" + "当前支持按类型注入 MessageEvent / Context," + "按参数名注入 event / ctx / context," + f"以及 args 中现有键:{arg_keys_text}。" + ) + + @staticmethod + def _callable_signature(handler) -> str: + try: + return str(inspect.signature(handler)) + except (TypeError, ValueError): + return "(...)" + + async def _handle_result_item( + self, + item: Any, + event: MessageEvent, + ctx: Context | None = None, + ) -> dict[str, Any]: + sent_message = await self._send_result(item, event, ctx) + if isinstance(item, dict): + return { + "sent_message": sent_message, + "stop": bool(item.get("stop", False)), + "call_llm": bool(item.get("call_llm", False)), + } + return { + "sent_message": sent_message, + "stop": False, + "call_llm": False, + } + + @staticmethod + def _merge_handler_summary( + target: dict[str, Any], + source: dict[str, Any], + ) -> None: + target["sent_message"] = bool(target.get("sent_message")) or bool( + source.get("sent_message") + ) + target["stop"] = bool(target.get("stop")) or bool(source.get("stop")) + target["call_llm"] = bool(target.get("call_llm")) or bool( + source.get("call_llm") + ) + + async def _send_result( + self, + item: Any, + event: MessageEvent, + ctx: Context | None = None, + ) -> bool: + """发送处理器结果。""" + if isinstance(item, str): + await event.reply(item) + return True + if isinstance(item, dict) and "text" in item: + await event.reply(str(item["text"])) + return True + if isinstance(item, MessageEventResult): + chain = item.chain + if chain.components: + await event.reply_chain(chain) + return True + return False + chain = coerce_message_chain(item) + if chain is not None: + if chain.components: + await event.reply_chain(chain) + return True + return False + if isinstance(item, list) and all( + isinstance(component, BaseMessageComponent) for component in item + ): + await event.reply_chain(MessageChain(list(item))) + return True + # 支持带 text 属性的对象 + text = getattr(item, "text", None) + if isinstance(text, str): + await event.reply(text) + return True + return False + + @staticmethod + def _parse_handler_args( + param_specs: Sequence[ParamSpec], + args: dict[str, Any], + ) -> dict[str, Any]: + parsed: dict[str, Any] = {} + for spec in param_specs: + if spec.name not in args: + if spec.type == "optional": + parsed[spec.name] = None + continue + if spec.required: + raise TypeError(f"缺少参数: {spec.name}") + continue + parsed[spec.name] = HandlerDispatcher._convert_param(spec, args[spec.name]) + return parsed + + @staticmethod + def _convert_param(spec: ParamSpec, value: Any) -> Any: + if spec.type in {"str", "greedy_str"}: + return str(value) + if spec.type == "int": + return int(str(value)) + if spec.type == "float": + return float(str(value)) + if spec.type == "bool": + normalized = str(value).strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False + raise TypeError(f"无法解析布尔参数 {spec.name}: {value!r}") + if spec.type == "optional": + if value is None: + return None + inner = ParamSpec( + name=spec.name, + type=spec.inner_type or "str", + required=False, + ) + return HandlerDispatcher._convert_param(inner, value) + return value + + @staticmethod + def _run_local_filters( + bindings: list[LocalFilterBinding], + *, + event: MessageEvent, + ctx: Context, + ) -> bool: + for binding in bindings: + if not binding.evaluate(event=event, ctx=ctx): + return False + return True + + @staticmethod + def _build_schedule_context( + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> ScheduleContext | None: + if not isinstance(loaded.descriptor.trigger, ScheduleTrigger): + return None + try: + return ScheduleContext.from_payload(event_payload) + except Exception: + return None + + async def _handle_error( + self, + owner: Any, + exc: Exception, + event: MessageEvent, + ctx: Context, + *, + handler_name: str = "", + plugin_id: str | None = None, + ) -> None: + if hasattr(owner, "on_error") and callable(owner.on_error): + bound_owner = owner if isinstance(owner, Star) else None + with bind_star_runtime(bound_owner, ctx): + result = owner.on_error(exc, event, ctx) + if inspect.isawaitable(result): + await result + return + await Star.default_on_error(exc, event, ctx) + + +__all__ = ["CapabilityDispatcher", "HandlerDispatcher"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py b/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py new file mode 100644 index 0000000000..b32fe6e2da --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import time +from collections import deque +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from ..decorators import LimiterMeta +from ..errors import AstrBotError + +DEFAULT_RATE_LIMIT_MESSAGE = "操作过于频繁,请稍后再试。" +DEFAULT_COOLDOWN_MESSAGE = "冷却中,请在 {remaining_seconds}s 后重试。" + + +@dataclass(slots=True) +class LimiterDecision: + allowed: bool + error: AstrBotError | None = None + hint: str | None = None + + +class LimiterEngine: + def __init__(self, *, clock: Callable[[], float] | None = None) -> None: + self._clock = clock or time.monotonic + self._windows: dict[str, deque[float]] = {} + + def evaluate( + self, + *, + plugin_id: str, + handler_id: str, + limiter: LimiterMeta, + event: Any, + ) -> LimiterDecision: + now = float(self._clock()) + key = self._make_key( + plugin_id=plugin_id, + handler_id=handler_id, + scope=limiter.scope, + event=event, + ) + bucket = self._windows.setdefault(key, deque()) + threshold = now - limiter.window + while bucket and bucket[0] <= threshold: + bucket.popleft() + + if len(bucket) < limiter.limit: + bucket.append(now) + return LimiterDecision(allowed=True) + + remaining = 0.0 + if bucket: + remaining = max(0.0, limiter.window - (now - bucket[0])) + hint = self._hint_text(limiter, remaining) + details = { + "scope": limiter.scope, + "handler_id": handler_id, + "remaining_seconds": round(remaining, 3), + } + if limiter.behavior == "silent": + return LimiterDecision(allowed=False) + if limiter.behavior == "error": + if limiter.kind == "cooldown": + return LimiterDecision( + allowed=False, + error=AstrBotError.cooldown_active(hint=hint, details=details), + ) + return LimiterDecision( + allowed=False, + error=AstrBotError.rate_limited(hint=hint, details=details), + ) + return LimiterDecision(allowed=False, hint=hint) + + @staticmethod + def _make_key( + *, + plugin_id: str, + handler_id: str, + scope: str, + event: Any, + ) -> str: + prefix = f"{plugin_id}:{handler_id}" + if scope == "global": + return prefix + if scope == "session": + return f"{prefix}:{getattr(event, 'session_id', '')}" + if scope == "user": + return ( + f"{prefix}:{getattr(event, 'platform_id', '')}" + f":{getattr(event, 'user_id', '')}" + ) + if scope == "group": + return ( + f"{prefix}:{getattr(event, 'platform_id', '')}" + f":{getattr(event, 'group_id', '')}" + ) + return prefix + + @staticmethod + def _hint_text(limiter: LimiterMeta, remaining: float) -> str: + if limiter.message: + return limiter.message.format( + remaining_seconds=max(1, int(remaining + 0.999)) + ) + if limiter.kind == "cooldown": + return DEFAULT_COOLDOWN_MESSAGE.format( + remaining_seconds=max(1, int(remaining + 0.999)) + ) + return DEFAULT_RATE_LIMIT_MESSAGE + + +__all__ = [ + "DEFAULT_COOLDOWN_MESSAGE", + "DEFAULT_RATE_LIMIT_MESSAGE", + "LimiterDecision", + "LimiterEngine", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/loader.py b/astrbot-sdk/src/astrbot_sdk/runtime/loader.py new file mode 100644 index 0000000000..9422b68a95 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/loader.py @@ -0,0 +1,1556 @@ +"""插件加载模块。 + +定义插件发现、环境管理和加载的核心逻辑。 +仅支持 astrbot-sdk 新版 Star 组件。 + +核心概念: + PluginSpec: 插件规范,描述插件的基本信息 + PluginDiscoveryResult: 插件发现结果,包含成功和跳过的插件 + PluginEnvironmentManager: 插件虚拟环境管理器 + LoadedHandler: 加载后的处理器,包含描述符和可调用对象 + LoadedPlugin: 加载后的插件,包含处理器和实例 + +插件发现流程: + 1. 扫描 plugins_dir 下的子目录 + 2. 检查 plugin.yaml 和 requirements.txt + 3. 解析 manifest_data 获取插件信息 + 4. 验证必要字段(name, components, runtime.python) + 5. 返回 PluginDiscoveryResult + +环境管理流程: + 1. 对插件集合做共享环境规划 + 2. 按 Python 版本和依赖兼容性构建环境分组 + 3. 为每个分组生成 lock/source/metadata 工件 + 4. 必要时重建或同步分组虚拟环境 + 5. 将单个插件映射到所属分组环境 + +插件加载流程: + 1. 将插件目录添加到 sys.path + 2. 遍历 components 列表 + 3. 动态导入组件类 + 4. 直接实例化(无参构造函数) + 5. 扫描处理器方法 + 6. 构建 HandlerDescriptor + +plugin.yaml 格式: + name: my_plugin + author: author_name + repo: my_plugin + desc: Plugin description + version: 1.0.0 + runtime: + python: "3.11" + components: + - class: my_plugin.main:MyComponent + +`loader` 是 runtime 与插件代码之间的边界层,负责三件事: + +- 从 `plugin.yaml` 解析出可运行的 `PluginSpec` +- 用 `uv` 为插件准备独立环境 +- 把组件实例和 handler 元数据整理成 `LoadedPlugin` +""" + +from __future__ import annotations + +import builtins +import contextlib +import copy +import hashlib +import importlib +import importlib.abc +import inspect +import json +import os +import re +import shutil +import sys +import threading +import types +import typing +from dataclasses import dataclass, field +from importlib import import_module +from pathlib import Path +from typing import Any, Literal, TypeAlias, cast + +import yaml + +from .._internal.command_model import resolve_command_model_param +from .._internal.injected_params import is_framework_injected_parameter +from .._internal.invocation_context import caller_plugin_scope, current_caller_plugin_id +from .._internal.plugin_ids import ( + capability_belongs_to_plugin, + plugin_capability_prefix, + validate_plugin_id, +) +from .._internal.sdk_logger import logger +from .._internal.typing_utils import unwrap_optional +from ..decorators import ( + ConversationMeta, + LimiterMeta, + get_agent_meta, + get_capability_meta, + get_handler_meta, + get_llm_tool_meta, +) +from ..llm.agents import AgentSpec +from ..llm.entities import LLMToolSpec +from ..protocol.descriptors import ( + CapabilityDescriptor, + HandlerDescriptor, + ParamSpec, + ScheduleTrigger, +) +from ..types import GreedyStr +from .environment_groups import ( + EnvironmentGroup, + EnvironmentPlanner, + EnvironmentPlanResult, + GroupEnvironmentManager, +) + +PLUGIN_MANIFEST_FILE = "plugin.yaml" +STATE_FILE_NAME = ".astrbot-worker-state.json" +CONFIG_SCHEMA_FILE = "_conf_schema.json" +PLUGIN_METADATA_ATTR = "__astrbot_plugin_metadata__" +ParamTypeName: TypeAlias = Literal[ + "str", "int", "float", "bool", "optional", "greedy_str" +] +OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None +HandlerKind: TypeAlias = Literal["handler", "hook", "tool", "session"] +DiscoverySeverity: TypeAlias = Literal["warning", "error"] +DiscoveryPhase: TypeAlias = Literal["discovery", "load", "lifecycle", "reload"] +_PLUGIN_IMPORT_LOCK = threading.RLock() +_VALID_HANDLER_KINDS: tuple[HandlerKind, ...] = ("handler", "hook", "tool", "session") +_PLUGIN_PACKAGE_PREFIX = "astrbot_ext_" +_GITHUB_REPO_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_GITHUB_REPO_SLUG_RE = re.compile(r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+$") +_GITHUB_REPO_URL_RE = re.compile( + r"^https://github\.com/[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+/?$", + re.IGNORECASE, +) +_PLUGIN_IMPORT_NAMESPACES: dict[str, _PluginImportNamespace] = {} +_ORIGINAL_BUILTIN_IMPORT = builtins.__import__ +_PLUGIN_IMPORT_HOOK_INSTALLED = False +_PLUGIN_IMPORT_META_FINDER: _PluginScopedMetaPathFinder | None = None +_PLUGIN_IMPORT_ALIAS_STATE = threading.local() + + +def _default_python_version() -> str: + return f"{sys.version_info.major}.{sys.version_info.minor}" + + +def _is_valid_github_repo_ref(value: str) -> bool: + normalized = value.strip() + if not normalized: + return False + return bool( + _GITHUB_REPO_NAME_RE.fullmatch(normalized) + or _GITHUB_REPO_SLUG_RE.fullmatch(normalized) + or _GITHUB_REPO_URL_RE.fullmatch(normalized) + ) + + +def _venv_python_path(venv_dir: Path) -> Path: + if os.name == "nt": + return venv_dir / "Scripts" / "python.exe" + return venv_dir / "bin" / "python" + + +@dataclass(slots=True) +class PluginSpec: + name: str + plugin_dir: Path + manifest_path: Path + requirements_path: Path + python_version: str + manifest_data: dict[str, Any] + + +@dataclass(slots=True) +class PluginDiscoveryResult: + plugins: list[PluginSpec] + skipped_plugins: dict[str, str] + issues: list[PluginDiscoveryIssue] = field(default_factory=list) + + +@dataclass(slots=True) +class PluginDiscoveryIssue: + severity: DiscoverySeverity + phase: DiscoveryPhase + plugin_id: str + message: str + details: str = "" + hint: str = "" + + def to_payload(self) -> dict[str, str]: + return { + "severity": self.severity, + "phase": self.phase, + "plugin_id": self.plugin_id, + "message": self.message, + "details": self.details, + "hint": self.hint, + } + + +@dataclass(slots=True) +class LoadedHandler: + descriptor: HandlerDescriptor + callable: Any + owner: Any + plugin_id: str = "" + local_filters: list[Any] = field(default_factory=list) + limiter: LimiterMeta | None = None + conversation: ConversationMeta | None = None + + +@dataclass(slots=True) +class LoadedCapability: + descriptor: CapabilityDescriptor + callable: Any + owner: Any + plugin_id: str = "" + + +@dataclass(slots=True) +class LoadedLLMTool: + spec: LLMToolSpec + callable: Any + owner: Any + plugin_id: str = "" + + +@dataclass(slots=True) +class LoadedAgent: + spec: AgentSpec + runner_class: type[Any] + owner: Any | None = None + plugin_id: str = "" + + +@dataclass(slots=True) +class LoadedPlugin: + plugin: PluginSpec + handlers: list[LoadedHandler] + capabilities: list[LoadedCapability] = field(default_factory=list) + llm_tools: list[LoadedLLMTool] = field(default_factory=list) + agents: list[LoadedAgent] = field(default_factory=list) + instances: list[Any] = field(default_factory=list) + + +@dataclass(slots=True) +class _ResolvedComponent: + cls: type[Any] + class_path: str + index: int + + +@dataclass(slots=True) +class _PluginImportNamespace: + plugin_id: str + plugin_dir: Path + package_name: str + + +@dataclass(slots=True) +class _ParamTypeInfo: + type_name: ParamTypeName + inner_type: OptionalInnerType + required: bool + + +class _PluginScopedAliasLoader(importlib.abc.Loader): + def __init__(self, *, alias_name: str, target_name: str) -> None: + self.alias_name = alias_name + self.target_name = target_name + + def create_module(self, spec: importlib.machinery.ModuleSpec) -> types.ModuleType: + del spec + module = sys.modules.get(self.target_name) + if not isinstance(module, types.ModuleType): + module = import_module(self.target_name) + _record_plugin_import_alias(self.alias_name) + return module + + def exec_module(self, module: types.ModuleType) -> None: + del module + + +class _PluginScopedMetaPathFinder(importlib.abc.MetaPathFinder): + def find_spec( + self, + fullname: str, + path: list[str] | None = None, + target: types.ModuleType | None = None, + ) -> importlib.machinery.ModuleSpec | None: + del path, target + namespace = _plugin_import_namespace_for_current_caller() + if namespace is None: + return None + rewritten_name = _rewrite_plugin_import_name(namespace, fullname) + if rewritten_name is None: + return None + parent_name, _, _ = rewritten_name.rpartition(".") + parent_search_path = None + if parent_name: + parent_module = sys.modules.get(parent_name) + if not isinstance(parent_module, types.ModuleType): + parent_module = import_module(parent_name) + parent_search_path = getattr(parent_module, "__path__", None) + target_spec = importlib.machinery.PathFinder.find_spec( + rewritten_name, + parent_search_path, + ) + if target_spec is None: + return None + alias_spec = importlib.machinery.ModuleSpec( + fullname, + _PluginScopedAliasLoader( + alias_name=fullname, + target_name=rewritten_name, + ), + is_package=target_spec.submodule_search_locations is not None, + ) + alias_spec.origin = target_spec.origin + alias_spec.cached = target_spec.cached + alias_spec.has_location = target_spec.has_location + if target_spec.submodule_search_locations is not None: + alias_spec.submodule_search_locations = list( + target_spec.submodule_search_locations + ) + return alias_spec + + +def _sanitize_package_component(plugin_id: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9_]+", "_", plugin_id).strip("_") + return sanitized or "plugin" + + +def _plugin_package_name(plugin_id: str) -> str: + digest = hashlib.sha256(plugin_id.encode("utf-8")).hexdigest()[:8] + return f"{_PLUGIN_PACKAGE_PREFIX}{_sanitize_package_component(plugin_id)}_{digest}" + + +def _plugin_module_name(package_name: str, module_name: str) -> str: + normalized = module_name.strip() + return f"{package_name}.{normalized}" if normalized else package_name + + +def _iter_handler_names(instance: Any) -> list[str]: + handler_names = getattr(instance.__class__, "__handlers__", ()) + if handler_names: + return list(handler_names) + return list(dir(instance)) + + +def _iter_discoverable_names(instance: Any) -> list[str]: + handler_names = list(dict.fromkeys(_iter_handler_names(instance))) + known_names = set(handler_names) + extra_names = sorted(name for name in dir(instance) if name not in known_names) + return [*handler_names, *extra_names] + + +def _validate_loaded_capability_namespace( + plugin: PluginSpec, + *, + resolved_component: _ResolvedComponent, + attribute_name: str, + capability_name: str, +) -> None: + if capability_belongs_to_plugin(capability_name, plugin.name): + return + expected_prefix = plugin_capability_prefix(plugin.name) + raise ValueError( + f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} " + f"方法 {attribute_name!r} 导出的 capability {capability_name!r} 必须使用当前插件名前缀 " + f"{expected_prefix!r},例如 {expected_prefix}" + ) + + +def _register_loaded_capability_name( + seen_capability_sources: dict[str, str], + *, + capability_name: str, + source_ref: str, +) -> None: + existing_source = seen_capability_sources.get(capability_name) + if existing_source is not None: + raise ValueError( + f"capability {capability_name!r} 重复定义:{existing_source} 与 {source_ref}" + ) + seen_capability_sources[capability_name] = source_ref + + +def _is_injected_parameter(annotation: Any, parameter_name: str) -> bool: + return is_framework_injected_parameter(parameter_name, annotation) + + +def _param_type_name(annotation: Any) -> _ParamTypeInfo: + normalized, is_optional = unwrap_optional(annotation) + if normalized is GreedyStr: + return _ParamTypeInfo("greedy_str", None, False) + if normalized in {int, float, bool, str}: + normalized_name = cast( + Literal["str", "int", "float", "bool"], normalized.__name__ + ) + if is_optional: + return _ParamTypeInfo("optional", normalized_name, False) + return _ParamTypeInfo(normalized_name, None, True) + if is_optional: + return _ParamTypeInfo("optional", "str", False) + return _ParamTypeInfo("str", None, True) + + +def _build_param_specs(handler: Any) -> list[ParamSpec]: + model_param = resolve_command_model_param(handler) + if model_param is not None: + return [] + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + type_hints = typing.get_type_hints(handler) + except Exception as exc: + logger.warning( + "Failed to resolve type hints for handler {}: {}", + getattr(handler, "__qualname__", repr(handler)), + exc, + ) + type_hints = {} + + specs: list[ParamSpec] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + annotation = type_hints.get(parameter.name) + if _is_injected_parameter(annotation, parameter.name): + continue + type_info = _param_type_name(annotation) + required = type_info.required + if parameter.default is not inspect.Parameter.empty: + required = False + specs.append( + ParamSpec( + name=parameter.name, + type=type_info.type_name, + required=required, + inner_type=type_info.inner_type, + ) + ) + + greedy_indexes = [ + index for index, spec in enumerate(specs) if spec.type == "greedy_str" + ] + if greedy_indexes and greedy_indexes[-1] != len(specs) - 1: + greedy_spec = specs[greedy_indexes[-1]] + raise ValueError(f"参数 '{greedy_spec.name}' (GreedyStr) 必须是最后一个参数。") + return specs + + +def _validate_schedule_signature(handler: Any) -> None: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return + allowed_names = {"ctx", "context", "sched", "schedule"} + invalid = [ + parameter.name + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.name not in allowed_names + ] + if invalid: + raise ValueError( + "Schedule handler 只允许注入 ctx/context 和 sched/schedule 参数。" + ) + + +def _plugin_context(plugin: PluginSpec) -> str: + return f"插件 '{plugin.name}'({plugin.manifest_path})" + + +def _component_context(plugin: PluginSpec, *, class_path: str, index: int) -> str: + return f"{_plugin_context(plugin)} 的 components[{index}].class='{class_path}'" + + +def _resolve_candidate( + instance: Any, + name: str, + meta_getter: typing.Callable[[Any], Any | None], + *, + predicate: typing.Callable[[Any], bool] | None = None, +) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + + for candidate in candidates: + meta = meta_getter(candidate) + if meta is None: + continue + if predicate is not None and not predicate(meta): + continue + try: + return getattr(instance, name), meta + except AttributeError: + return None + return None + + +def _resolve_handler_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + """Resolve handler candidates without triggering unrelated descriptor side effects.""" + return _resolve_candidate( + instance, + name, + get_handler_meta, + predicate=lambda meta: meta.trigger is not None, + ) + + +def _resolve_capability_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + return _resolve_candidate(instance, name, get_capability_meta) + + +def _resolve_llm_tool_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + return _resolve_candidate(instance, name, get_llm_tool_meta) + + +def _iter_agent_candidates(component_cls: type[Any]) -> list[tuple[type[Any], Any]]: + module = import_module(component_cls.__module__) + seen: set[str] = set() + resolved: list[tuple[type[Any], Any]] = [] + + def _collect(candidate: Any) -> None: + if not inspect.isclass(candidate): + return + meta = get_agent_meta(candidate) + if meta is None: + return + key = f"{candidate.__module__}.{candidate.__qualname__}" + if key in seen: + return + seen.add(key) + resolved.append((candidate, meta)) + + for candidate in vars(module).values(): + _collect(candidate) + for candidate in vars(component_cls).values(): + _collect(candidate) + return resolved + + +def _read_yaml(path: Path) -> dict[str, Any]: + data = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + return data if isinstance(data, dict) else {} + + +def _read_requirements_text(path: Path) -> str: + if not path.exists(): + return "" + return path.read_text(encoding="utf-8") + + +def _plugin_config_dir(plugin_dir: Path) -> Path: + if plugin_dir.parent.name == "plugins" and plugin_dir.parent.parent.exists(): + return plugin_dir.parent.parent / "config" + return plugin_dir / "data" / "config" + + +def _plugin_config_path(plugin_dir: Path, plugin_name: str) -> Path: + return _plugin_config_dir(plugin_dir) / f"{plugin_name}_config.json" + + +def _schema_default(field_schema: dict[str, Any]) -> Any: + if "default" in field_schema: + return copy.deepcopy(field_schema["default"]) + + field_type = str(field_schema.get("type") or "string") + if field_type == "object": + items = field_schema.get("items") + if isinstance(items, dict): + return { + key: _normalize_config_value(child_schema, None) + for key, child_schema in items.items() + if isinstance(child_schema, dict) + } + return {} + if field_type in {"list", "template_list", "file"}: + return [] + if field_type == "dict": + return {} + if field_type == "int": + return 0 + if field_type == "float": + return 0.0 + if field_type == "bool": + return False + return "" + + +def _normalize_config_value(field_schema: dict[str, Any], value: Any) -> Any: + field_type = str(field_schema.get("type") or "string") + default_value = _schema_default(field_schema) + + if field_type == "object": + items = field_schema.get("items") + if not isinstance(items, dict): + return default_value + current = value if isinstance(value, dict) else {} + return { + key: _normalize_config_value(child_schema, current.get(key)) + for key, child_schema in items.items() + if isinstance(child_schema, dict) + } + if field_type in {"list", "template_list", "file"}: + return copy.deepcopy(value) if isinstance(value, list) else default_value + if field_type == "dict": + return copy.deepcopy(value) if isinstance(value, dict) else default_value + if field_type == "int": + return ( + value + if isinstance(value, int) and not isinstance(value, bool) + else default_value + ) + if field_type == "float": + return ( + value + if isinstance(value, (int, float)) and not isinstance(value, bool) + else default_value + ) + if field_type == "bool": + return value if isinstance(value, bool) else default_value + if field_type in {"string", "text"}: + return value if isinstance(value, str) else default_value + return copy.deepcopy(value) if value is not None else default_value + + +def load_plugin_config_schema(plugin: PluginSpec) -> dict[str, Any]: + """加载插件配置 schema,解析失败时记录日志并返回空对象。""" + schema_path = plugin.plugin_dir / CONFIG_SCHEMA_FILE + if not schema_path.exists(): + return {} + + try: + schema_payload = json.loads(schema_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + logger.warning( + "Failed to parse SDK plugin config schema {}: {}", + schema_path, + exc, + ) + return {} + except OSError as exc: + logger.warning( + "Failed to read SDK plugin config schema {}: {}", + schema_path, + exc, + ) + return {} + if not isinstance(schema_payload, dict): + logger.warning( + "SDK plugin config schema {} must be a JSON object, got {}", + schema_path, + type(schema_payload).__name__, + ) + return {} + return schema_payload + + +def save_plugin_config( + plugin: PluginSpec, + payload: dict[str, Any], + *, + schema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """按 schema 归一化并写回插件配置。""" + active_schema = ( + load_plugin_config_schema(plugin) if schema is None else dict(schema) + ) + normalized = { + key: _normalize_config_value(field_schema, payload.get(key)) + for key, field_schema in active_schema.items() + if isinstance(field_schema, dict) + } + + config_path = _plugin_config_path(plugin.plugin_dir, plugin.name) + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text( + json.dumps(normalized, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return normalized + + +def load_plugin_config( + plugin: PluginSpec, + *, + schema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """加载插件配置,返回普通字典。""" + active_schema = ( + load_plugin_config_schema(plugin) if schema is None else dict(schema) + ) + if not active_schema: + return {} + + config_path = _plugin_config_path(plugin.plugin_dir, plugin.name) + try: + existing_payload = ( + json.loads(config_path.read_text(encoding="utf-8")) + if config_path.exists() + else {} + ) + except json.JSONDecodeError as exc: + logger.warning( + "Failed to parse SDK plugin config {}: {}", + config_path, + exc, + ) + existing_payload = {} + except OSError as exc: + logger.warning( + "Failed to read SDK plugin config {}: {}", + config_path, + exc, + ) + existing_payload = {} + existing = existing_payload if isinstance(existing_payload, dict) else {} + normalized = { + key: _normalize_config_value(field_schema, existing.get(key)) + for key, field_schema in active_schema.items() + if isinstance(field_schema, dict) + } + + if not config_path.exists() or normalized != existing: + save_plugin_config(plugin, normalized, schema=active_schema) + return normalized + + +def _is_new_star_component(cls: type[Any]) -> bool: + """检查组件类是否为 astrbot-sdk 新版 Star。""" + return bool(getattr(cls, "__astrbot_is_new_star__", False)) + + +def _plugin_component_classes(plugin: PluginSpec) -> list[_ResolvedComponent]: + """解析插件组件类列表。""" + components = plugin.manifest_data.get("components") or [] + if not isinstance(components, list): + return [] + + classes: list[_ResolvedComponent] = [] + for index, component in enumerate(components): + if not isinstance(component, dict): + raise ValueError( + f"{_plugin_context(plugin)} 的 components[{index}] 必须是 object。" + ) + class_path = component.get("class") + if not isinstance(class_path, str) or ":" not in class_path: + raise ValueError( + f"{_plugin_context(plugin)} 的 components[{index}].class " + "必须是 ':'。" + ) + try: + cls = _import_plugin_string(class_path, plugin) + except Exception as exc: + raise ValueError( + f"{_component_context(plugin, class_path=class_path, index=index)} " + f"加载失败:{exc}" + ) from exc + if not isinstance(cls, type): + raise ValueError( + f"{_component_context(plugin, class_path=class_path, index=index)} " + "解析结果不是类,请检查导出名称。" + ) + classes.append( + _ResolvedComponent( + cls=cls, + class_path=class_path, + index=index, + ) + ) + if not classes: + raise ValueError( + f"{_plugin_context(plugin)} 未声明任何可加载组件。" + "请检查 plugin.yaml 中的 components 配置。" + ) + return classes + + +def load_plugin_spec(plugin_dir: Path) -> PluginSpec: + """从插件目录加载插件规范。""" + plugin_dir = plugin_dir.resolve() + manifest_path = plugin_dir / PLUGIN_MANIFEST_FILE + requirements_path = plugin_dir / "requirements.txt" + + if not manifest_path.exists(): + raise ValueError(f"插件目录 '{plugin_dir}' 缺少 {PLUGIN_MANIFEST_FILE}。") + + manifest_data = _read_yaml(manifest_path) + runtime = manifest_data.get("runtime") or {} + python_version = runtime.get("python") or _default_python_version() + + return PluginSpec( + name=str(manifest_data.get("name") or plugin_dir.name), + plugin_dir=plugin_dir, + manifest_path=manifest_path, + requirements_path=requirements_path, + python_version=str(python_version), + manifest_data=manifest_data, + ) + + +def validate_plugin_spec(plugin: PluginSpec) -> None: + """校验单个插件规范,供 CLI 和发现流程复用。""" + manifest_data = plugin.manifest_data + manifest_label = f"插件 '{plugin.name}'({plugin.manifest_path})" + + raw_name = manifest_data.get("name") + if not isinstance(raw_name, str) or not raw_name: + raise ValueError(f"{manifest_label} 缺少 name。") + try: + validate_plugin_id(raw_name) + except ValueError as exc: + raise ValueError(f"{manifest_label} 的 name 不合法:{exc}") from exc + + raw_runtime = manifest_data.get("runtime") or {} + raw_python = raw_runtime.get("python") + if not isinstance(raw_python, str) or not raw_python: + raise ValueError(f"{manifest_label} 缺少 runtime.python。") + + raw_author = manifest_data.get("author") + if not isinstance(raw_author, str) or not raw_author.strip(): + raise ValueError(f"{manifest_label} 缺少 author。") + + raw_repo = manifest_data.get("repo") + if not isinstance(raw_repo, str) or not raw_repo.strip(): + raise ValueError(f"{manifest_label} 缺少 repo。") + if not _is_valid_github_repo_ref(raw_repo): + raise ValueError( + f"{manifest_label} 的 repo 不合法:" + "请填写 GitHub 仓库名(repo)、owner/repo,或 https://github.com/owner/repo。" + ) + + components = manifest_data.get("components") + if not isinstance(components, list): + raise ValueError(f"{manifest_label} 的 components 必须是数组。") + + for index, component in enumerate(components): + if not isinstance(component, dict): + raise ValueError(f"{manifest_label} 的 components[{index}] 必须是 object。") + class_path = component.get("class") + if not isinstance(class_path, str) or ":" not in class_path: + raise ValueError( + f"{manifest_label} 的 components[{index}].class " + "必须是 ':'。" + ) + + +# TODO: 不能保证插件和命令冲突消失,真有那么一天我们sdk小团体也是好起来了 +def discover_plugins(plugins_dir: Path) -> PluginDiscoveryResult: + """扫描目录发现所有插件。""" + plugins_root = plugins_dir.resolve() + skipped_plugins: dict[str, str] = {} + issues: list[PluginDiscoveryIssue] = [] + plugins: list[PluginSpec] = [] + # TODO: 改用 dict 记录 name -> plugin_dir 映射,以便在重复时报错时显示冲突路径 + seen_name_sources: dict[str, Path] = {} # plugin_name -> plugin_dir + + if not plugins_root.exists(): + return PluginDiscoveryResult([], {}, []) + + for entry in sorted(plugins_root.iterdir()): + if not entry.is_dir() or entry.name.startswith("."): + continue + manifest_path = entry / PLUGIN_MANIFEST_FILE + if not manifest_path.exists(): + continue + + plugin: PluginSpec | None = None + try: + plugin = load_plugin_spec(entry) + validate_plugin_spec(plugin) + except Exception as exc: + skip_key = entry.name + if plugin is not None: + raw_name = plugin.manifest_data.get("name") + if isinstance(raw_name, str) and raw_name: + skip_key = raw_name + details = str(exc) + skipped_plugins[skip_key] = f"failed to parse plugin manifest: {details}" + issues.append( + PluginDiscoveryIssue( + severity="error", + phase="discovery", + plugin_id=skip_key, + message="插件发现失败", + details=details, + ) + ) + continue + + plugin_name = plugin.name + if not isinstance(plugin_name, str) or not plugin_name: + skipped_plugins[entry.name] = "plugin name is required" + issues.append( + PluginDiscoveryIssue( + severity="error", + phase="discovery", + plugin_id=entry.name, + message="插件缺少名称", + details="plugin name is required", + ) + ) + continue + if plugin_name in seen_name_sources: + existing_source = seen_name_sources.get(plugin_name, Path("")) + skipped_plugins[plugin_name] = "duplicate plugin name" + issues.append( + PluginDiscoveryIssue( + severity="error", + phase="discovery", + plugin_id=plugin_name, + message="插件名称重复", + details=f"冲突的插件目录:{existing_source} 与 {plugin.plugin_dir}", + hint="请修改其中一个插件的名称后重试", + ) + ) + continue + seen_name_sources[plugin_name] = plugin.plugin_dir + plugins.append(plugin) + + return PluginDiscoveryResult( + plugins=plugins, + skipped_plugins=skipped_plugins, + issues=issues, + ) + + +class PluginEnvironmentManager: + """运行时访问分组环境管理的门面层。 + + 运行时仍然保留历史上的 `prepare_environment(plugin)` 调用入口,但底层 + 实现已经变成两阶段模型: + + 1. `plan()` 负责解析跨插件分组和共享工件 + 2. `prepare_environment()` 负责把单个插件映射到它所属的分组环境 + """ + + def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None: + self.repo_root = repo_root.resolve() + self.uv_binary = uv_binary + self.cache_dir = self.repo_root / ".uv-cache" + self._planner = EnvironmentPlanner(self.repo_root, uv_binary=uv_binary) + self._group_manager = GroupEnvironmentManager( + self.repo_root, uv_binary=uv_binary + ) + self.uv_binary = self._planner.uv_binary + self._plan_result: EnvironmentPlanResult | None = None + + def plan(self, plugins: list[PluginSpec]) -> EnvironmentPlanResult: + """为当前插件集合生成共享环境规划。""" + plan_result = self._planner.plan(plugins) + self._plan_result = plan_result + return plan_result + + def prepare_group_environment(self, group: EnvironmentGroup) -> Path: + """返回指定分组的解释器路径。""" + if self._plan_result is None: + self._plan_result = EnvironmentPlanResult(groups=[group]) + return self._group_manager.prepare(group) + + def prepare_environment(self, plugin: PluginSpec) -> Path: + """返回该插件所属分组环境的解释器路径。 + + 如果调用方还没有先对整批插件做规划,这里会自动创建一个至少包含当 + 前插件的最小规划,以保证旧的"单插件直接调用"模式仍然可用。 + """ + if ( + self._plan_result is None + or plugin.name not in self._plan_result.plugin_to_group + ): + planned_plugins = ( + list(self._plan_result.plugins) if self._plan_result else [] + ) + if plugin.name not in {item.name for item in planned_plugins}: + planned_plugins.append(plugin) + self.plan(planned_plugins) + + assert self._plan_result is not None + group = self._plan_result.plugin_to_group.get(plugin.name) + if group is None: + reason = self._plan_result.skipped_plugins.get(plugin.name) + if reason is not None: + raise RuntimeError(reason) + raise RuntimeError(f"environment plan missing plugin: {plugin.name}") + + return self.prepare_group_environment(group) + + @staticmethod + def _fingerprint(plugin: PluginSpec) -> str: + requirements = _read_requirements_text(plugin.requirements_path) + payload = { + "python_version": plugin.python_version, + "requirements": requirements, + } + return json.dumps(payload, ensure_ascii=True, sort_keys=True) + + @staticmethod + def _load_state(state_path: Path) -> dict[str, Any]: + if not state_path.exists(): + return {} + try: + data = json.loads(state_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + logger.warning( + "Failed to parse plugin worker state {}: {}", state_path, exc + ) + return {} + except OSError as exc: + logger.warning("Failed to read plugin worker state {}: {}", state_path, exc) + return {} + return data if isinstance(data, dict) else {} + + @staticmethod + def _write_state(state_path: Path, plugin: PluginSpec, fingerprint: str) -> None: + state_path.write_text( + json.dumps( + { + "plugin": plugin.name, + "python_version": plugin.python_version, + "fingerprint": fingerprint, + }, + ensure_ascii=True, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + @staticmethod + def _matches_python_version(venv_dir: Path, version: str) -> bool: + pyvenv_cfg = venv_dir / "pyvenv.cfg" + if not pyvenv_cfg.exists(): + return False + try: + content = pyvenv_cfg.read_text(encoding="utf-8") + except OSError: + return False + match = re.search(r"version\s*=\s*(\d+\.\d+)\.\d+", content, re.IGNORECASE) + return match is not None and match.group(1) == version + + +def _copy_limiter_meta(meta: LimiterMeta | None) -> LimiterMeta | None: + if meta is None: + return None + return LimiterMeta( + kind=meta.kind, + limit=meta.limit, + window=meta.window, + scope=meta.scope, + behavior=meta.behavior, + message=meta.message, + ) + + +def _copy_conversation_meta(meta: ConversationMeta | None) -> ConversationMeta | None: + if meta is None: + return None + return ConversationMeta( + timeout=meta.timeout, + mode=meta.mode, + busy_message=meta.busy_message, + grace_period=meta.grace_period, + ) + + +def _validate_handler_kind( + plugin: PluginSpec, + *, + resolved_component: _ResolvedComponent, + attribute_name: str, + kind: str, +) -> HandlerKind: + if kind in _VALID_HANDLER_KINDS: + return cast(HandlerKind, kind) + raise ValueError( + f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} " + f"方法 {attribute_name!r} 的 handler kind {kind!r} 不合法;" + f"允许的值为 {', '.join(_VALID_HANDLER_KINDS)}。" + ) + + +def _load_component_instance( + plugin: PluginSpec, + resolved_component: _ResolvedComponent, +) -> Any: + component_cls = resolved_component.cls + if not _is_new_star_component(component_cls): + raise ValueError( + f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} " + f"解析到的类 {component_cls.__module__}.{component_cls.__qualname__} " + "不是 astrbot-sdk Star 组件。请继承 astrbot_sdk.Star。" + ) + try: + instance = component_cls() + except Exception as exc: + raise ValueError( + f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} " + f"实例化失败:{exc}" + ) from exc + logger.debug( + "Instantiated SDK plugin component {} for plugin {}", + resolved_component.class_path, + plugin.name, + ) + return instance + + +def _collect_component_agents( + plugin: PluginSpec, + component_cls: type[Any], + *, + seen_agents: set[str], +) -> list[LoadedAgent]: + agents: list[LoadedAgent] = [] + for runner_class, meta in _iter_agent_candidates(component_cls): + runner_key = f"{runner_class.__module__}.{runner_class.__qualname__}" + if runner_key in seen_agents: + continue + seen_agents.add(runner_key) + agents.append( + LoadedAgent( + spec=meta.spec.model_copy(deep=True), + runner_class=runner_class, + owner=None, + plugin_id=plugin.name, + ) + ) + return agents + + +def _build_loaded_handler( + plugin: PluginSpec, + *, + resolved_component: _ResolvedComponent, + instance: Any, + attribute_name: str, + bound: Any, + meta: Any, +) -> LoadedHandler: + handler_kind = _validate_handler_kind( + plugin, + resolved_component=resolved_component, + attribute_name=attribute_name, + kind=meta.kind, + ) + handler_id = ( + f"{plugin.name}:{instance.__class__.__module__}.{instance.__class__.__name__}." + f"{attribute_name}" + ) + if isinstance(meta.trigger, ScheduleTrigger): + _validate_schedule_signature(bound) + param_specs = _build_param_specs(bound) + return LoadedHandler( + descriptor=HandlerDescriptor( + id=handler_id, + trigger=meta.trigger, + kind=handler_kind, + contract=meta.contract, + description=meta.description, + priority=meta.priority, + permissions=meta.permissions.model_copy(deep=True), + filters=[item.model_copy(deep=True) for item in meta.filters], + param_specs=[item.model_copy(deep=True) for item in param_specs], + command_route=( + meta.command_route.model_copy(deep=True) + if meta.command_route is not None + else None + ), + ), + callable=bound, + owner=instance, + plugin_id=plugin.name, + local_filters=list(meta.local_filters), + limiter=_copy_limiter_meta(meta.limiter), + conversation=_copy_conversation_meta(meta.conversation), + ) + + +def _collect_component_members( + plugin: PluginSpec, + *, + resolved_component: _ResolvedComponent, + instance: Any, + seen_capability_sources: dict[str, str], +) -> tuple[list[LoadedHandler], list[LoadedCapability], list[LoadedLLMTool]]: + handlers: list[LoadedHandler] = [] + capabilities: list[LoadedCapability] = [] + llm_tools: list[LoadedLLMTool] = [] + + for name in _iter_discoverable_names(instance): + resolved = _resolve_handler_candidate(instance, name) + capability = _resolve_capability_candidate(instance, name) + llm_tool = _resolve_llm_tool_candidate(instance, name) + if resolved is None and capability is None and llm_tool is None: + continue + if capability is not None: + bound_capability, capability_meta = capability + capability_name = capability_meta.descriptor.name + _validate_loaded_capability_namespace( + plugin, + resolved_component=resolved_component, + attribute_name=name, + capability_name=capability_name, + ) + _register_loaded_capability_name( + seen_capability_sources, + capability_name=capability_name, + source_ref=f"{resolved_component.class_path}.{name}", + ) + capabilities.append( + LoadedCapability( + descriptor=capability_meta.descriptor.model_copy(deep=True), + callable=bound_capability, + owner=instance, + plugin_id=plugin.name, + ) + ) + if llm_tool is not None: + bound_tool, tool_meta = llm_tool + llm_tools.append( + LoadedLLMTool( + spec=tool_meta.spec.model_copy(deep=True), + callable=bound_tool, + owner=instance, + plugin_id=plugin.name, + ) + ) + if resolved is not None: + bound_handler, handler_meta = resolved + handlers.append( + _build_loaded_handler( + plugin, + resolved_component=resolved_component, + instance=instance, + attribute_name=name, + bound=bound_handler, + meta=handler_meta, + ) + ) + return handlers, capabilities, llm_tools + + +def load_plugin(plugin: PluginSpec) -> LoadedPlugin: + """加载插件,返回处理器和能力列表。 + + 仅支持 astrbot-sdk 新版 Star 组件(无参构造函数)。 + """ + with _PLUGIN_IMPORT_LOCK: + logger.debug("Loading SDK plugin {} from {}", plugin.name, plugin.plugin_dir) + _ensure_plugin_import_hook_installed() + namespace = _register_plugin_import_namespace(plugin) + _purge_plugin_bytecode(plugin.plugin_dir) + _purge_plugin_package(namespace.package_name) + _purge_plugin_modules(plugin.plugin_dir) + _ensure_plugin_package(namespace) + importlib.invalidate_caches() + + instances: list[Any] = [] + handlers: list[LoadedHandler] = [] + capabilities: list[LoadedCapability] = [] + llm_tools: list[LoadedLLMTool] = [] + agents: list[LoadedAgent] = [] + seen_agents: set[str] = set() + seen_capability_sources: dict[str, str] = {} + with caller_plugin_scope(plugin.name): + resolved_components = _plugin_component_classes(plugin) + + for resolved_component in resolved_components: + instance = _load_component_instance(plugin, resolved_component) + instances.append(instance) + agents.extend( + _collect_component_agents( + plugin, + resolved_component.cls, + seen_agents=seen_agents, + ) + ) + component_handlers, component_capabilities, component_tools = ( + _collect_component_members( + plugin, + resolved_component=resolved_component, + instance=instance, + seen_capability_sources=seen_capability_sources, + ) + ) + handlers.extend(component_handlers) + capabilities.extend(component_capabilities) + llm_tools.extend(component_tools) + + logger.debug( + "Loaded SDK plugin {}: {} components, {} handlers, {} capabilities, {} llm tools, {} agents", + plugin.name, + len(resolved_components), + len(handlers), + len(capabilities), + len(llm_tools), + len(agents), + ) + return LoadedPlugin( + plugin=plugin, + handlers=handlers, + capabilities=capabilities, + llm_tools=llm_tools, + agents=agents, + instances=instances, + ) + + +def _path_within_root(path: Path, root: Path) -> bool: + try: + path.resolve().relative_to(root.resolve()) + except ValueError: + return False + return True + + +def _plugin_defines_module_root(plugin_dir: Path, root_name: str) -> bool: + return (plugin_dir / f"{root_name}.py").exists() or ( + plugin_dir / root_name + ).exists() + + +def _register_plugin_import_namespace(plugin: PluginSpec) -> _PluginImportNamespace: + existing = _PLUGIN_IMPORT_NAMESPACES.get(plugin.name) + package_name = ( + existing.package_name + if existing is not None + else _plugin_package_name(plugin.name) + ) + namespace = _PluginImportNamespace( + plugin_id=plugin.name, + plugin_dir=plugin.plugin_dir.resolve(), + package_name=package_name, + ) + _PLUGIN_IMPORT_NAMESPACES[plugin.name] = namespace + return namespace + + +def _ensure_plugin_package(namespace: _PluginImportNamespace) -> types.ModuleType: + existing = sys.modules.get(namespace.package_name) + if isinstance(existing, types.ModuleType): + existing.__path__ = [str(namespace.plugin_dir)] + existing.__package__ = namespace.package_name + return existing + + module = types.ModuleType(namespace.package_name) + module.__file__ = str(namespace.plugin_dir) + module.__package__ = namespace.package_name + module.__path__ = [str(namespace.plugin_dir)] + module.__loader__ = None + spec = importlib.machinery.ModuleSpec( + namespace.package_name, + loader=None, + is_package=True, + ) + spec.submodule_search_locations = [str(namespace.plugin_dir)] + module.__spec__ = spec + sys.modules[namespace.package_name] = module + return module + + +def _module_belongs_to_plugin(module: Any, plugin_dir: Path) -> bool: + file_path = getattr(module, "__file__", None) + if isinstance(file_path, str) and _path_within_root(Path(file_path), plugin_dir): + return True + + package_paths = getattr(module, "__path__", None) + if package_paths is None: + return False + return any( + isinstance(candidate, str) and _path_within_root(Path(candidate), plugin_dir) + for candidate in package_paths + ) + + +def _purge_plugin_modules(plugin_dir: Path) -> None: + plugin_root = plugin_dir.resolve() + for module_name, module in list(sys.modules.items()): + if module is None: + continue + if _module_belongs_to_plugin(module, plugin_root): + sys.modules.pop(module_name, None) + + +def _purge_plugin_package(package_name: str) -> None: + for module_name in list(sys.modules): + if module_name == package_name or module_name.startswith(f"{package_name}."): + sys.modules.pop(module_name, None) + + +def _purge_plugin_bytecode(plugin_dir: Path) -> None: + plugin_root = plugin_dir.resolve() + for path in plugin_root.rglob("*"): + try: + if path.is_dir() and path.name == "__pycache__": + shutil.rmtree(path, ignore_errors=True) + continue + if path.is_file() and path.suffix in {".pyc", ".pyo"}: + path.unlink(missing_ok=True) + except OSError: + continue + + +def _import_plugin_string(path: str, plugin: PluginSpec) -> Any: + module_name, attr = path.split(":", 1) + namespace = _PLUGIN_IMPORT_NAMESPACES.get(plugin.name) + if namespace is None: + raise RuntimeError(f"plugin import namespace missing: {plugin.name}") + module = import_module(_plugin_module_name(namespace.package_name, module_name)) + return getattr(module, attr) + + +def _plugin_import_namespace_for_current_caller() -> _PluginImportNamespace | None: + plugin_id = current_caller_plugin_id() + if not plugin_id: + return None + return _PLUGIN_IMPORT_NAMESPACES.get(plugin_id) + + +def _rewrite_plugin_import_name( + namespace: _PluginImportNamespace, + name: str, +) -> str | None: + normalized = name.strip() + if not normalized: + return None + if normalized.startswith(_PLUGIN_PACKAGE_PREFIX): + return None + root_name = normalized.split(".", 1)[0] + if not _plugin_defines_module_root(namespace.plugin_dir, root_name): + return None + return _plugin_module_name(namespace.package_name, normalized) + + +def _plugin_import_alias_buckets() -> list[set[str]]: + buckets = getattr(_PLUGIN_IMPORT_ALIAS_STATE, "buckets", None) + if buckets is None: + buckets = [] + _PLUGIN_IMPORT_ALIAS_STATE.buckets = buckets + return buckets + + +def _push_plugin_import_alias_bucket() -> set[str]: + bucket: set[str] = set() + _plugin_import_alias_buckets().append(bucket) + return bucket + + +def _pop_plugin_import_alias_bucket(bucket: set[str]) -> set[str]: + buckets = _plugin_import_alias_buckets() + if buckets and buckets[-1] is bucket: + buckets.pop() + else: + with contextlib.suppress(ValueError): + buckets.remove(bucket) + return bucket + + +def _record_plugin_import_alias(alias_name: str) -> None: + normalized = alias_name.strip() + if not normalized or normalized.startswith(_PLUGIN_PACKAGE_PREFIX): + return + buckets = _plugin_import_alias_buckets() + if not buckets: + return + buckets[-1].add(normalized) + + +def _cleanup_plugin_import_aliases(alias_names: set[str]) -> None: + for alias_name in sorted( + alias_names, key=lambda item: item.count("."), reverse=True + ): + sys.modules.pop(alias_name, None) + + +def _plugin_scoped_import( + name: str, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, + fromlist: tuple[Any, ...] | list[Any] = (), + level: int = 0, +) -> Any: + with _PLUGIN_IMPORT_LOCK: + alias_bucket = _push_plugin_import_alias_bucket() + try: + return _ORIGINAL_BUILTIN_IMPORT(name, globals, locals, fromlist, level) + finally: + _cleanup_plugin_import_aliases( + _pop_plugin_import_alias_bucket(alias_bucket) + ) + + +def _ensure_plugin_import_meta_finder_installed() -> None: + global _PLUGIN_IMPORT_META_FINDER + if ( + _PLUGIN_IMPORT_META_FINDER is not None + and _PLUGIN_IMPORT_META_FINDER in sys.meta_path + ): + return + finder = _PluginScopedMetaPathFinder() + sys.meta_path.insert(0, finder) + _PLUGIN_IMPORT_META_FINDER = finder + + +def _ensure_plugin_import_hook_installed() -> None: + global _PLUGIN_IMPORT_HOOK_INSTALLED + _ensure_plugin_import_meta_finder_installed() + # 防御性检查:如果 hook 已在位,只补全标志位,不重复安装 + if builtins.__import__ is _plugin_scoped_import: + _PLUGIN_IMPORT_HOOK_INSTALLED = True + return + # 标志位声称已安装但实际 builtin 已被外部篡改(如测试框架 monkeypatch), + # 需要重置标志位以触发重新安装 + if ( + _PLUGIN_IMPORT_HOOK_INSTALLED + and builtins.__import__ is not _plugin_scoped_import + ): + _PLUGIN_IMPORT_HOOK_INSTALLED = False + if _PLUGIN_IMPORT_HOOK_INSTALLED: + return + builtins.__import__ = _plugin_scoped_import + _PLUGIN_IMPORT_HOOK_INSTALLED = True + + +def _restore_plugin_import_hook() -> None: + """还原 builtin __import__,用于插件卸载或测试 teardown 时清理全局状态。""" + global _PLUGIN_IMPORT_HOOK_INSTALLED, _PLUGIN_IMPORT_META_FINDER + if builtins.__import__ is _plugin_scoped_import: + builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT + if _PLUGIN_IMPORT_META_FINDER is not None: + with contextlib.suppress(ValueError): + sys.meta_path.remove(_PLUGIN_IMPORT_META_FINDER) + _PLUGIN_IMPORT_META_FINDER = None + _PLUGIN_IMPORT_HOOK_INSTALLED = False + + +def import_string(path: str, plugin_dir: Path | None = None) -> Any: + """通过字符串路径导入对象。""" + with _PLUGIN_IMPORT_LOCK: + module_name, attr = path.split(":", 1) + module = import_module(module_name) + return getattr(module, attr) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/peer.py b/astrbot-sdk/src/astrbot_sdk/runtime/peer.py new file mode 100644 index 0000000000..1ebbbd2830 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/peer.py @@ -0,0 +1,852 @@ +"""协议对等端模块。 + +定义 Peer 类,封装双向传输通道上的消息收发、初始化握手、能力调用、 +流式事件转发与取消处理。这里的 peer 指"通信对端/本端"这一网络协议概念, +而不是业务上的用户、群聊或会话对象。 + +核心职责: + - 消息序列化/反序列化 + - 初始化握手协议 + - 能力调用(同步/流式) + - 取消处理 + - 连接生命周期管理 +消息处理: + 入站: + ResultMessage -> 唤醒等待的 Future + EventMessage -> 投递到流式队列 + InitializeMessage -> 调用 initialize_handler + InvokeMessage -> 创建任务调用 invoke_handler + CancelMessage -> 取消对应的任务 + + 出站: + initialize() -> InitializeMessage + invoke() -> InvokeMessage(stream=False) + invoke_stream() -> InvokeMessage(stream=True) + cancel() -> CancelMessage + +使用示例: + # 作为客户端发起调用 + peer = Peer(transport=transport, peer_info=PeerInfo(...)) + await peer.start() + output = await peer.initialize(handlers) + result = await peer.invoke("llm.chat", {"prompt": "hello"}) + + # 作为服务端处理调用 + peer.set_invoke_handler(my_handler) + await peer.start() + +消息处理流程: + 入站消息: + ResultMessage -> 唤醒等待的 Future + EventMessage -> 投递到流式队列 + InitializeMessage -> 调用 _initialize_handler + InvokeMessage -> 创建任务调用 _invoke_handler + CancelMessage -> 取消对应的任务 + + 出站消息: + initialize() -> InitializeMessage + invoke() -> InvokeMessage(stream=False) + invoke_stream() -> InvokeMessage(stream=True) + cancel() -> CancelMessage + +取消机制: + - CancelToken 用于检查取消状态 + - 入站任务在收到 CancelMessage 时被取消 + - 早到取消:在任务执行前检查 cancel_token,避免竞态条件 + +`Peer` 把 `Transport` 和 s5r 协议消息模型接起来,负责: + +- 握手与远端元数据缓存 +- 请求 ID 关联 +- 非流式 / 流式调用分发 +- 取消传播 +- 连接异常时的统一收口 + +它本身不做业务路由,真正的执行逻辑交给 `CapabilityRouter` 或 +`HandlerDispatcher`。 +""" + +from __future__ import annotations + +import asyncio +import inspect +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from typing import Any + +from .._internal.invocation_context import ( + caller_plugin_scope, + current_caller_plugin_id, +) +from .._internal.sdk_logger import logger +from ..context import CancelToken +from ..errors import AstrBotError, ErrorCodes +from ..protocol.messages import ( + CancelMessage, + ErrorPayload, + EventMessage, + InitializeMessage, + InitializeOutput, + InvokeMessage, + PeerInfo, + ResultMessage, + parse_message, +) +from .capability_router import StreamExecution + +InitializeHandler = Callable[[InitializeMessage], Awaitable[InitializeOutput]] +InvokeHandler = Callable[ + [InvokeMessage, CancelToken], Awaitable[dict[str, Any] | StreamExecution] +] +CancelHandler = Callable[[str], Awaitable[None]] + +SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY = "supported_protocol_versions" +NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY = "negotiated_protocol_version" +# 入站消息字符数上限(8 MB)。超过此阈值的协议消息会被直接拒绝, +# 避免恶意或异常的巨型消息耗尽内存或阻塞解析 +MAX_INBOUND_MESSAGE_CHARS = 8 * 1024 * 1024 + + +def _dedupe_protocol_versions( + versions: Sequence[str] | None, *, preferred_version: str +) -> list[str]: + ordered_versions: list[str] = [preferred_version] + if versions is not None: + ordered_versions.extend(versions) + deduped: list[str] = [] + for version in ordered_versions: + if not isinstance(version, str) or not version: + continue + if version not in deduped: + deduped.append(version) + return deduped + + +def _parse_protocol_version(version: str) -> tuple[int, int] | None: + major, dot, minor = version.partition(".") + if not dot or not major.isdigit() or not minor.isdigit(): + return None + return int(major), int(minor) + + +def _select_negotiated_protocol_version( + requested_version: str, + remote_metadata: dict[str, Any], + local_supported_versions: Sequence[str], +) -> str | None: + """从双方支持的版本中选出最佳兼容版本。 + + 协商策略:优先精确匹配,否则在同主版本号范围内选双方都支持的最高版本。 + 排除比请求版本更高的候选,因为远端能提供高于我们请求的版本说明我们本地 + 尚未实现该版本协议,无法正确处理对应的协议消息。 + """ + if requested_version in local_supported_versions: + return requested_version + requested_key = _parse_protocol_version(requested_version) + if requested_key is None: + return None + remote_supported = remote_metadata.get(SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY) + if not isinstance(remote_supported, (list, tuple)): + return None + local_supported_set = set(local_supported_versions) + compatible_versions: list[tuple[tuple[int, int], str]] = [] + for version in remote_supported: + if not isinstance(version, str) or version not in local_supported_set: + continue + parsed_version = _parse_protocol_version(version) + if parsed_version is None: + continue + if parsed_version[0] != requested_key[0] or parsed_version > requested_key: + continue + compatible_versions.append((parsed_version, version)) + if not compatible_versions: + return None + compatible_versions.sort(reverse=True) + return compatible_versions[0][1] + + +class Peer: + """表示协议连接中的一个对等端。 + + `Peer` 封装一条双向传输通道上的消息收发、初始化握手、能力调用、 + 流式事件转发与取消处理。这里的 `peer` 指“通信对端/本端”这一网络 + 协议概念,而不是业务上的用户、群聊或会话对象。 + """ + + def __init__( + self, + *, + transport, + peer_info: PeerInfo, + protocol_version: str = "1.0", + supported_protocol_versions: Sequence[str] | None = None, + ) -> None: + """创建一个协议对等端实例。 + + Args: + transport: 底层传输实现,负责发送字符串消息并回调入站消息。 + peer_info: 当前端点对外声明的身份信息。 + protocol_version: 当前端点首选的协议版本,用于初始化握手。 + supported_protocol_versions: 当前端点可接受的协议版本列表。 + """ + self.transport = transport + self.peer_info = peer_info + self.protocol_version = protocol_version + self.supported_protocol_versions = _dedupe_protocol_versions( + supported_protocol_versions, + preferred_version=protocol_version, + ) + self.negotiated_protocol_version: str | None = None + self.remote_peer: PeerInfo | None = None + self.remote_handlers = [] + self.remote_provided_capabilities = [] + self.remote_capabilities = [] + self.remote_capability_map: dict[str, Any] = {} + self.remote_provided_capability_map: dict[str, Any] = {} + self.remote_metadata: dict[str, Any] = {} + + self._initialize_handler: InitializeHandler | None = None + self._invoke_handler: InvokeHandler | None = None + self._cancel_handler: CancelHandler | None = None + self._counter = 0 + self._closed = asyncio.Event() + self._unusable = False + self._stopping = False + self._pending_results: dict[str, asyncio.Future[ResultMessage]] = {} + self._pending_streams: dict[str, asyncio.Queue[Any]] = {} + self._inbound_tasks: dict[ + str, tuple[asyncio.Task[None], CancelToken, asyncio.Event] + ] = {} + self._remote_initialized = asyncio.Event() + self._remote_initialized_successfully = False + self._transport_watch_task: asyncio.Task[None] | None = None + # 记录当前正在执行 stop() 的 Task,用于防止 stop() 被并发重入 + self._stop_task: asyncio.Task[None] | None = None + + def set_initialize_handler(self, handler: InitializeHandler) -> None: + """注册处理远端 `initialize` 请求的握手处理器。""" + self._initialize_handler = handler + + def set_invoke_handler(self, handler: InvokeHandler) -> None: + """注册处理远端 `invoke` 请求的能力调用处理器。""" + self._invoke_handler = handler + + def set_cancel_handler(self, handler: CancelHandler) -> None: + """注册处理远端 `cancel` 请求的取消回调。""" + self._cancel_handler = handler + + async def start(self) -> None: + """启动传输层并将原始入站消息绑定到当前 `Peer`。""" + self._closed.clear() + self._unusable = False + self._stopping = False + self.negotiated_protocol_version = None + self._remote_initialized.clear() + self._remote_initialized_successfully = False + self.transport.set_message_handler(self._handle_raw_message) + await self.transport.start() + self._transport_watch_task = asyncio.create_task(self._watch_transport_closed()) + + async def stop(self) -> None: + """关闭 `Peer` 并清理所有挂起中的请求、流和入站任务。 + + 重入安全性:transport.stop() 关闭底层连接时会触发原始消息处理器的 + 异常路径,该路径调用 _fail_connection() -> _schedule_stop() -> stop(), + 形成间接递归。_stopping 标志和 _stop_task 引用共同防止重复清理资源。 + 使用 asyncio.shield 等待是因为:如果当前任务在等待另一个 stop() 完成 + 期间被取消,shield 保护内部 stop_task 不被连带取消,避免 Peer 停留在 + 半关闭状态。 + """ + if self._closed.is_set(): + return + current_task = asyncio.current_task() + if self._stopping: + # 防止并发重入:如果 stop() 已在其他协程中执行,则等待它完成而不是重复清理 + stop_task = self._stop_task + if stop_task is not None and stop_task is not current_task: + await asyncio.shield(stop_task) + return + self._stopping = True + # 记录当前 task,供后续重入检测和 _schedule_stop() 判定 + if current_task is not None and self._stop_task is None: + self._stop_task = current_task + try: + # 终止所有挂起的 RPC,避免调用方永久挂起 + for future in list(self._pending_results.values()): + if not future.done(): + future.set_exception(AstrBotError.internal_error("连接已关闭")) + self._pending_results.clear() + + for queue in list(self._pending_streams.values()): + await queue.put(AstrBotError.internal_error("连接已关闭")) + self._pending_streams.clear() + + # 取消所有入站任务 + for task, token, _started in list(self._inbound_tasks.values()): + token.cancel() + task.cancel() + self._inbound_tasks.clear() + + await self.transport.stop() + self._closed.set() + finally: + # 只在当前 task 就是 stop_task 时才清除引用,避免误清其他 task 的记录。 + # 场景:A 任务正在 stop() 中,B 任务也进入了 stop() 并等待 A 完成, + # 如果 B 在 finally 中清除了 _stop_task,A 还未执行完就会失去引用。 + if self._stop_task is current_task: + self._stop_task = None + + async def wait_closed(self) -> None: + """等待底层传输彻底关闭。""" + await self.transport.wait_closed() + + async def _watch_transport_closed(self) -> None: + """监视底层传输的意外关闭,并主动失败挂起调用。""" + try: + await self.transport.wait_closed() + if self._closed.is_set() or self._stopping: + return + await self._fail_connection( + AstrBotError( + code=ErrorCodes.NETWORK_ERROR, + message="连接已关闭", + hint="请检查对端进程或传输连接", + retryable=True, + ) + ) + finally: + current_task = asyncio.current_task() + if self._transport_watch_task is current_task: + self._transport_watch_task = None + + async def wait_until_remote_initialized(self, timeout: float | None = 30.0) -> None: + """等待远端完成初始化握手。 + + Args: + timeout: 等待秒数。传入 `None` 表示无限等待。 + """ + init_waiter = asyncio.create_task(self._remote_initialized.wait()) + closed_waiter = asyncio.create_task(self.wait_closed()) + try: + done, pending = await asyncio.wait( + {init_waiter, closed_waiter}, + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + if not done: + raise TimeoutError() + if init_waiter in done and self._remote_initialized_successfully: + return + raise AstrBotError.protocol_error("连接在初始化完成前关闭") + finally: + for task in (init_waiter, closed_waiter): + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def initialize( + self, + handlers, + *, + provided_capabilities=None, + metadata: dict[str, Any] | None = None, + ) -> InitializeOutput: + """向远端发送初始化请求并缓存远端声明的能力信息。 + + Args: + handlers: 当前端点声明可接收的处理器列表。 + metadata: 附带给远端的握手元数据。 + + Returns: + 远端返回的初始化结果。 + """ + self._ensure_usable() + request_id = self._next_id() + handshake_metadata = dict(metadata or {}) + handshake_metadata[SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY] = list( + self.supported_protocol_versions + ) + future: asyncio.Future[ResultMessage] = ( + asyncio.get_running_loop().create_future() + ) + self._pending_results[request_id] = future + await self._send( + InitializeMessage( + id=request_id, + protocol_version=self.protocol_version, + peer=self.peer_info, + handlers=list(handlers), + provided_capabilities=list(provided_capabilities or []), + metadata=handshake_metadata, + ) + ) + result = await future + if result.kind != "initialize_result": + raise AstrBotError.protocol_error("initialize 必须收到 initialize_result") + if not result.success: + self._unusable = True + await self.stop() + raise AstrBotError.from_payload( + result.error.model_dump() if result.error else {} + ) + output = InitializeOutput.model_validate(result.output) + negotiated_protocol_version = ( + output.protocol_version + or output.metadata.get(NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY) + or self.protocol_version + ) + if ( + not isinstance(negotiated_protocol_version, str) + or negotiated_protocol_version not in self.supported_protocol_versions + ): + self._unusable = True + await self.stop() + raise AstrBotError.protocol_version_mismatch( + f"对端返回了当前端点不支持的协商协议版本:{negotiated_protocol_version}" + ) + self.remote_peer = output.peer + self.remote_capabilities = output.capabilities + self.remote_capability_map = {item.name: item for item in output.capabilities} + self.remote_metadata = output.metadata + self.negotiated_protocol_version = negotiated_protocol_version + self._remote_initialized_successfully = True + self._remote_initialized.set() + return output + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: + """发起一次非流式能力调用并等待最终结果。 + + Args: + capability: 远端能力名。 + payload: 调用输入。 + stream: 必须为 `False`;流式场景应改用 `invoke_stream()`。 + request_id: 可选的请求 ID;未提供时自动生成。 + """ + self._ensure_usable() + if stream: + raise ValueError("stream=True 请使用 invoke_stream()") + request_id = request_id or self._next_id() + future: asyncio.Future[ResultMessage] = ( + asyncio.get_running_loop().create_future() + ) + self._pending_results[request_id] = future + await self._send( + InvokeMessage( + id=request_id, + capability=capability, + input=payload, + stream=False, + caller_plugin_id=current_caller_plugin_id(), + ) + ) + result = await future + if not result.success: + raise AstrBotError.from_payload( + result.error.model_dump() if result.error else {} + ) + return result.output + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + include_completed: bool = False, + ) -> AsyncIterator[EventMessage]: + """发起一次流式能力调用并返回事件迭代器。 + + 调用方会收到 `delta` 事件,`started` 会被内部吞掉, + 默认情况下 `completed` 用于结束迭代,`failed` 会转换为异常抛出。 + + Args: + capability: 远端能力名。 + payload: 调用输入。 + request_id: 可选的请求 ID;未提供时自动生成。 + include_completed: 是否把 `completed` 事件也返回给调用方。 + """ + self._ensure_usable() + request_id = request_id or self._next_id() + queue: asyncio.Queue[Any] = asyncio.Queue() + self._pending_streams[request_id] = queue + await self._send( + InvokeMessage( + id=request_id, + capability=capability, + input=payload, + stream=True, + caller_plugin_id=current_caller_plugin_id(), + ) + ) + + async def iterator() -> AsyncIterator[EventMessage]: + terminal_received = False + try: + while True: + item = await queue.get() + if isinstance(item, Exception): + raise item + if not isinstance(item, EventMessage): + raise AstrBotError.protocol_error("流式调用收到非法事件") + if item.phase == "started": + continue + if item.phase == "delta": + yield item + continue + if item.phase == "completed": + terminal_received = True + if include_completed: + yield item + break + if item.phase == "failed": + terminal_received = True + raise AstrBotError.from_payload( + item.error.model_dump() if item.error else {} + ) + finally: + self._pending_streams.pop(request_id, None) + if not terminal_received: + try: + await self.cancel( + request_id, + reason="consumer_closed_stream_early", + ) + except Exception as exc: + # 取消失败不应中断整个流处理流程,仅记录日志 + logger.debug( + "Failed to cancel stream after consumer closed early: " + "request_id={} error={}", + request_id, + exc, + ) + + return iterator() + + async def cancel(self, request_id: str, reason: str = "user_cancelled") -> None: + """向远端发送取消请求,尝试中止指定 ID 的在途调用。""" + await self._send(CancelMessage(id=request_id, reason=reason)) + + def _next_id(self) -> str: + """生成当前连接内递增的消息 ID。""" + self._counter += 1 + return f"msg_{self._counter:04d}" + + def _ensure_usable(self) -> None: + """确保连接仍处于可用状态,否则立即抛出协议错误。""" + if self._unusable: + raise AstrBotError.protocol_error("连接已进入不可用状态") + + async def _handle_raw_message(self, payload: str) -> None: + """解析原始消息并分发到对应的消息处理分支。""" + try: + # 入站消息大小检查:拒绝巨型消息,防止 OOM 和解析阻塞 + if len(payload) > MAX_INBOUND_MESSAGE_CHARS: + raise AstrBotError.protocol_error( + f"协议消息过大,已拒绝处理:" + f"当前 {len(payload) / 1024 / 1024:.1f} MB," + f"上限 {MAX_INBOUND_MESSAGE_CHARS / 1024 / 1024:.0f} MB" + ) + message = parse_message(payload) + if isinstance(message, ResultMessage): + await self._handle_result(message) + return + if isinstance(message, EventMessage): + await self._handle_event(message) + return + if isinstance(message, InitializeMessage): + await self._handle_initialize(message) + return + if isinstance(message, InvokeMessage): + token = CancelToken() + started = asyncio.Event() + task = asyncio.create_task(self._handle_invoke(message, token, started)) + self._inbound_tasks[message.id] = (task, token, started) + + def _on_invoke_done( + _task: asyncio.Task[None], request_id: str = message.id + ) -> None: + self._inbound_tasks.pop(request_id, None) + if _task.cancelled(): + return + exc = _task.exception() + if exc is None: + return + # 为什么整个连接都要失败?正常情况下 invoke handler 会把错误编码成 + # ResultMessage 发回给对端。如果异常仍然逃逸,说明要么回复发送失败 + # (transport 已断),要么 handler 实现有 bug。无论哪种情况,连接的 + # 消息交换契约已不可靠,继续使用可能导致对端无限等待或数据丢失。 + # 采用"单点故障 → 全连接失败"策略而非隔离单个 handler,是因为协议层 + # 无法保证后续消息的正确性。 + logger.error( + "Peer inbound invoke task crashed unexpectedly: " + "request_id={} error={!r}", + request_id, + exc, + ) + error = ( + exc + if isinstance(exc, AstrBotError) + else AstrBotError( + code=ErrorCodes.NETWORK_ERROR, + message="处理入站调用响应时连接已失效", + hint=str(exc), + retryable=True, + ) + ) + asyncio.create_task(self._fail_connection(error)) + + task.add_done_callback(_on_invoke_done) + return + if isinstance(message, CancelMessage): + await self._handle_cancel(message) + return + except Exception as exc: + if isinstance(exc, AstrBotError): + error = exc + else: + error = AstrBotError.protocol_error(f"无法解析协议消息: {exc}") + await self._fail_connection(error) + # 不再向上抛出异常,避免在 transport 读循环中引发未处理的异常导致整个连接崩溃 + logger.warning( + "Peer connection marked unusable after inbound message failure: {}", + error, + ) + return + + async def _handle_initialize(self, message: InitializeMessage) -> None: + """处理远端发起的初始化握手并返回握手结果。""" + self.remote_peer = message.peer + self.remote_handlers = message.handlers + self.remote_provided_capabilities = message.provided_capabilities + self.remote_provided_capability_map = { + item.name: item for item in message.provided_capabilities + } + self.remote_metadata = dict(message.metadata) + if self._initialize_handler is None: + await self._reject_initialize( + message, + AstrBotError.protocol_error("对端不接受 initialize"), + ) + return + + negotiated_protocol_version = _select_negotiated_protocol_version( + message.protocol_version, + self.remote_metadata, + self.supported_protocol_versions, + ) + if negotiated_protocol_version is None: + supported_versions = ", ".join(self.supported_protocol_versions) + await self._reject_initialize( + message, + AstrBotError.protocol_version_mismatch( + "服务端支持协议版本 " + f"{supported_versions},客户端请求版本 {message.protocol_version}" + ), + ) + return + + self.negotiated_protocol_version = negotiated_protocol_version + self.remote_metadata[NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY] = ( + negotiated_protocol_version + ) + output = await self._initialize_handler(message) + response_metadata = dict(output.metadata) + response_metadata[NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY] = ( + negotiated_protocol_version + ) + output = output.model_copy( + update={ + "protocol_version": negotiated_protocol_version, + "metadata": response_metadata, + } + ) + await self._send( + ResultMessage( + id=message.id, + kind="initialize_result", + success=True, + output=output.model_dump(), + ) + ) + self._remote_initialized_successfully = True + self._remote_initialized.set() + + async def _handle_invoke( + self, + message: InvokeMessage, + token: CancelToken, + started: asyncio.Event, + ) -> None: + """处理远端发起的能力调用,并按流式或非流式协议返回结果。""" + try: + started.set() + token.raise_if_cancelled() + if self._invoke_handler is None: + raise AstrBotError.capability_not_found(message.capability) + with caller_plugin_scope(message.caller_plugin_id): + execution = await self._invoke_handler(message, token) + if inspect.isawaitable(execution): + execution = await execution + if message.stream: + if not isinstance(execution, StreamExecution): + raise AstrBotError.protocol_error( + "stream=true 必须返回 StreamExecution" + ) + await self._send(EventMessage(id=message.id, phase="started")) + collect_chunks = execution.collect_chunks + chunks: list[dict[str, Any]] = [] + async for chunk in execution.iterator: + if collect_chunks: + chunks.append(chunk) + await self._send( + EventMessage(id=message.id, phase="delta", data=chunk) + ) + await self._send( + EventMessage( + id=message.id, + phase="completed", + output=execution.finalize(chunks), + ) + ) + return + if isinstance(execution, StreamExecution): + raise AstrBotError.protocol_error("stream=false 不能返回流式执行对象") + await self._send( + ResultMessage(id=message.id, success=True, output=execution) + ) + except asyncio.CancelledError: + await self._send_cancelled_termination(message) + except LookupError as exc: + error = AstrBotError.invalid_input(str(exc)) + await self._send_error_result(message, error) + except AstrBotError as exc: + await self._send_error_result(message, exc) + except Exception as exc: + await self._send_error_result( + message, AstrBotError.internal_error(str(exc)) + ) + + async def _handle_cancel(self, message: CancelMessage) -> None: + """处理远端取消请求并终止对应的入站任务。""" + inbound = self._inbound_tasks.get(message.id) + if inbound is None: + return + task, token, started = inbound + token.cancel() + if self._cancel_handler is not None: + await self._cancel_handler(message.id) + if started.is_set(): + task.cancel() + + async def _handle_result(self, message: ResultMessage) -> None: + """处理非流式结果消息并唤醒等待中的调用方。""" + future = self._pending_results.pop(message.id, None) + if future is None: + queue = self._pending_streams.get(message.id) + if queue is not None: + await queue.put( + AstrBotError.protocol_error("stream=true 调用不应收到 result") + ) + return + # 检查 future 是否已完成(可能被调用方取消) + if not future.done(): + future.set_result(message) + + async def _handle_event(self, message: EventMessage) -> None: + """处理流式事件消息并投递到对应请求的事件队列。""" + queue = self._pending_streams.get(message.id) + if queue is None: + future = self._pending_results.get(message.id) + if future is not None and not future.done(): + future.set_exception( + AstrBotError.protocol_error("stream=false 调用不应收到 event") + ) + return + await queue.put(message) + + async def _send_error_result( + self, message: InvokeMessage, error: AstrBotError + ) -> None: + """根据调用模式,将错误编码为 `result` 或失败事件发回远端。""" + if message.stream: + await self._send( + EventMessage( + id=message.id, + phase="failed", + error=ErrorPayload.model_validate(error.to_payload()), + ) + ) + return + await self._send( + ResultMessage( + id=message.id, + success=False, + error=ErrorPayload.model_validate(error.to_payload()), + ) + ) + + async def _reject_initialize( + self, message: InitializeMessage, error: AstrBotError + ) -> None: + """拒绝一次初始化握手,并把连接标记为不可继续使用。""" + await self._send( + ResultMessage( + id=message.id, + kind="initialize_result", + success=False, + error=ErrorPayload.model_validate(error.to_payload()), + ) + ) + self._unusable = True + self._remote_initialized.set() + await self.stop() + + async def _send_cancelled_termination(self, message: InvokeMessage) -> None: + """把本端取消执行转换为标准化的取消错误响应。""" + error = AstrBotError.cancelled() + await self._send_error_result(message, error) + + async def _fail_connection(self, error: AstrBotError) -> None: + """把连接标记为不可用,并让所有等待中的调用尽快失败。""" + if self._unusable: + return + self._unusable = True + self._remote_initialized.set() + + for future in list(self._pending_results.values()): + if not future.done(): + future.set_exception(error) + self._pending_results.clear() + + for queue in list(self._pending_streams.values()): + await queue.put(error) + self._pending_streams.clear() + + for task, token, _started in list(self._inbound_tasks.values()): + token.cancel() + task.cancel() + self._inbound_tasks.clear() + + self._schedule_stop() + + def _schedule_stop(self) -> None: + """安全地调度 stop(),避免与正在执行的 stop() 产生并发冲突。""" + if self._closed.is_set(): + return + # 已有 stop task 在跑则不重复创建,防止产生竞态条件 + if self._stop_task is not None and not self._stop_task.done(): + return + self._stop_task = asyncio.create_task(self.stop(), name="astrbot-sdk-peer-stop") + + async def _send(self, message) -> None: + """序列化协议消息并通过底层传输发送出去。""" + await self.transport.send(message.model_dump_json(exclude_none=True)) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py b/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py new file mode 100644 index 0000000000..6fdcf7227b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py @@ -0,0 +1,1066 @@ +"""Supervisor 端运行时:SupervisorRuntime 管理多个 Worker 进程,WorkerSession 封装与单个 Worker 的通信。 + +架构层次: + AstrBot Core (Python) + | + v + SupervisorRuntime (管理多插件) + | + +-- WorkerSession (插件 A) -- StdioTransport -- PluginWorkerRuntime (子进程) + | + +-- WorkerSession (插件 B, 插件 C) -- StdioTransport -- GroupWorkerRuntime (子进程) + | + +-- WorkerSession (插件 D) -- StdioTransport -- PluginWorkerRuntime (子进程) + +核心类: + SupervisorRuntime: 监管者运行时 + - 发现并加载所有插件 + - 为单个插件或兼容插件组启动 Worker 进程 + - 聚合所有 handler 并向 Core 注册 + - 路由 Core 的调用请求到对应 Worker + - 处理 Worker 进程崩溃和重连 + - handler ID 冲突检测和警告 + + WorkerSession: Worker 会话 + - 管理单个插件 Worker 进程 + - 通过 Peer 与 Worker 通信 + - 提供 invoke_handler 和 cancel 方法 + - 处理连接关闭回调 + - 自动清理已注册的 handlers + +信号处理: + - SIGTERM: 设置 stop_event,触发优雅关闭 + - SIGINT: 设置 stop_event,触发优雅关闭 +""" + +from __future__ import annotations + +import asyncio +import os +import signal +import sys +from collections.abc import Callable +from pathlib import Path +from typing import IO, Any, cast + +from .._internal.plugin_ids import ( + capability_belongs_to_plugin, + plugin_capability_prefix, +) +from .._internal.sdk_logger import logger +from ..errors import AstrBotError +from ..protocol.descriptors import CapabilityDescriptor +from ..protocol.messages import EventMessage, InitializeOutput, PeerInfo +from .capability_router import CapabilityRouter, StreamExecution +from .environment_groups import EnvironmentGroup +from .loader import ( + PluginDiscoveryIssue, + PluginEnvironmentManager, + PluginSpec, + discover_plugins, + load_plugin_config, +) +from .peer import Peer +from .transport import ( + StdioTransport, + WebSocketClientTransport, + build_websocket_client_ssl_context, +) +from .workers_manifest import RemoteWorkerSpec, load_remote_workers_manifest + +__all__ = [ + "SupervisorRuntime", + "WorkerSession", + "_install_signal_handlers", + "_prepare_stdio_transport", + "_sdk_source_dir", + "_wait_for_shutdown", +] + +# Worker 进程初始化握手超时:60 秒内必须完成 initialize 协议交换, +# 否则视为进程卡死或挂载过慢,直接报错让上层感知 +WORKER_INITIALIZE_TIMEOUT_SECONDS = 60.0 + + +def _install_signal_handlers(stop_event: asyncio.Event) -> None: + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, stop_event.set) + except NotImplementedError: + logger.debug("Signal handlers are not supported for {}", sig) + + +def _prepare_stdio_transport( + stdin: IO[str] | None, + stdout: IO[str] | None, +) -> tuple[IO[str], IO[str], IO[str] | None]: + if stdin is not None and stdout is not None: + return stdin, stdout, None + transport_stdin = stdin or sys.stdin + transport_stdout = stdout or sys.stdout + original_stdout = sys.stdout + sys.stdout = sys.stderr + return transport_stdin, transport_stdout, original_stdout + + +def _sdk_source_dir(repo_root: Path) -> Path: + candidate = repo_root.resolve() / "src" + if (candidate / "astrbot_sdk").exists(): + return candidate + return Path(__file__).resolve().parents[2] + + +async def _wait_for_shutdown(peer: Peer, stop_event: asyncio.Event) -> None: + stop_waiter = asyncio.create_task(stop_event.wait()) + transport_waiter = asyncio.create_task(peer.wait_closed()) + done, pending = await asyncio.wait( + {stop_waiter, transport_waiter}, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + for task in done: + if not task.cancelled(): + task.result() + + +def _plugin_name_from_handler_id(handler_id: str) -> str: + if ":" in handler_id: + return handler_id.split(":", 1)[0] + return handler_id + + +class WorkerSession: + def __init__( + self, + *, + plugin: PluginSpec | None = None, + group: EnvironmentGroup | None = None, + remote_worker: RemoteWorkerSpec | None = None, + repo_root: Path, + env_manager: PluginEnvironmentManager, + capability_router: CapabilityRouter, + on_closed: Callable[[], None] | None = None, + ) -> None: + target_count = sum(item is not None for item in (plugin, group, remote_worker)) + if target_count != 1: + raise ValueError( + "WorkerSession requires exactly one of plugin, group, or remote_worker" + ) + group_ref = group + self.remote_worker = remote_worker + self.is_remote = remote_worker is not None + if group_ref is not None: + primary_plugin = group_ref.plugins[0] + elif plugin is not None: + primary_plugin = plugin + else: + primary_plugin = None + self.group = group + self.plugins = ( + list(group_ref.plugins) + if group_ref is not None + else ([primary_plugin] if primary_plugin is not None else []) + ) + self.plugin = primary_plugin + self.worker_id = ( + remote_worker.id + if remote_worker is not None + else ( + group_ref.id + if group_ref is not None + else cast(PluginSpec, primary_plugin).name + ) + ) + self.repo_root = repo_root.resolve() + self.env_manager = env_manager + self.capability_router = capability_router + self.on_closed = on_closed + self.peer: Peer | None = None + self.handlers = [] + self.provided_capabilities: list[CapabilityDescriptor] = [] + self.loaded_plugins: list[str] = [] + self.skipped_plugins: dict[str, str] = {} + self.issues: list[PluginDiscoveryIssue] = [] + self.capability_sources: dict[str, str] = {} + self.llm_tools: list[dict[str, Any]] = [] + self.agents: list[dict[str, Any]] = [] + self.worker_registry: list[dict[str, Any]] = [] + self._connection_watch_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + transport = self._build_transport() + self.peer = Peer( + transport=transport, + peer_info=PeerInfo(name="astrbot-core", role="core", version="s5r"), + ) + self.peer.set_initialize_handler(self._handle_initialize) + self.peer.set_invoke_handler(self._handle_capability_invoke) + try: + await self.peer.start() + await self._wait_until_initialized() + self._sync_remote_state() + self._validate_initialized_state() + + except Exception: + await self.stop() + raise + + def _build_transport(self): + if self.remote_worker is not None: + ssl_context = build_websocket_client_ssl_context( + ca_file=self.remote_worker.tls.ca_file, + cert_file=self.remote_worker.tls.cert_file, + key_file=self.remote_worker.tls.key_file, + ) + return WebSocketClientTransport( + url=self.remote_worker.url, + ssl_context=ssl_context, + server_hostname=self.remote_worker.tls.server_hostname, + ) + + python_path, command, cwd = self._worker_command() + repo_src_dir = str(_sdk_source_dir(self.repo_root)) + env = os.environ.copy() + existing_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + f"{repo_src_dir}{os.pathsep}{existing_pythonpath}" + if existing_pythonpath + else repo_src_dir + ) + env.setdefault("PYTHONIOENCODING", "utf-8") + env.setdefault("PYTHONUTF8", "1") + return StdioTransport(command=command, cwd=cwd, env=env) + + async def _wait_until_initialized(self) -> None: + assert self.peer is not None + try: + await self.peer.wait_until_remote_initialized( + timeout=WORKER_INITIALIZE_TIMEOUT_SECONDS + ) + except TimeoutError as exc: + raise RuntimeError( + f"worker {self.worker_id} 初始化超时 " + f"({WORKER_INITIALIZE_TIMEOUT_SECONDS:.0f}s);" + "请检查 worker 日志中的 on_start / 装饰器初始化错误" + ) from exc + except AstrBotError as exc: + raise RuntimeError(f"worker {self.worker_id} 在初始化阶段退出") from exc + + def _sync_remote_state(self) -> None: + assert self.peer is not None + self.handlers = list(self.peer.remote_handlers) + self.provided_capabilities = list(self.peer.remote_provided_capabilities) + metadata = dict(self.peer.remote_metadata) + + remote_loaded_plugins = metadata.get("loaded_plugins") + if isinstance(remote_loaded_plugins, list): + self.loaded_plugins = [ + plugin_name + for plugin_name in remote_loaded_plugins + if isinstance(plugin_name, str) + ] + else: + self.loaded_plugins = [plugin.name for plugin in self.plugins] + + remote_skipped_plugins = metadata.get("skipped_plugins") + if isinstance(remote_skipped_plugins, dict): + self.skipped_plugins = { + str(plugin_name): str(reason) + for plugin_name, reason in remote_skipped_plugins.items() + } + + remote_capability_sources = metadata.get("capability_sources") + if isinstance(remote_capability_sources, dict): + self.capability_sources = { + str(capability_name): str(plugin_name) + for capability_name, plugin_name in remote_capability_sources.items() + } + + remote_issues = metadata.get("issues") + default_issue_owner = ( + self.plugin.name if self.plugin is not None else self.worker_id + ) + if isinstance(remote_issues, list): + self.issues = [ + PluginDiscoveryIssue( + severity=str(item.get("severity", "error")), # type: ignore[arg-type] + phase=str(item.get("phase", "load")), # type: ignore[arg-type] + plugin_id=str(item.get("plugin_id", default_issue_owner)), + message=str(item.get("message", "")), + details=str(item.get("details", "")), + hint=str(item.get("hint", "")), + ) + for item in remote_issues + if isinstance(item, dict) + ] + + remote_llm_tools = metadata.get("llm_tools") + if isinstance(remote_llm_tools, list): + self.llm_tools = [ + dict(item) for item in remote_llm_tools if isinstance(item, dict) + ] + + remote_agents = metadata.get("agents") + if isinstance(remote_agents, list): + self.agents = [ + dict(item) for item in remote_agents if isinstance(item, dict) + ] + + remote_worker_registry = metadata.get("worker_registry") + if isinstance(remote_worker_registry, list): + self.worker_registry = [ + dict(item) + for item in remote_worker_registry + if isinstance(item, dict) and str(item.get("name", "")).strip() + ] + + def _validate_initialized_state(self) -> None: + assert self.peer is not None + if self.remote_worker is not None and self.peer.remote_peer is not None: + if self.peer.remote_peer.name != self.worker_id: + raise RuntimeError( + "remote worker identity mismatch: " + f"expected {self.worker_id!r}, got {self.peer.remote_peer.name!r}" + ) + + plugin_ids = { + str(item.get("name", "")).strip() + for item in self.worker_registry + if isinstance(item, dict) + } + plugin_ids.discard("") + if not plugin_ids and self.plugins: + plugin_ids = {plugin.name for plugin in self.plugins} + if self.remote_worker is not None and not plugin_ids: + raise RuntimeError( + f"remote worker {self.worker_id} did not provide worker_registry" + ) + + for plugin_name in self.loaded_plugins: + if plugin_ids and plugin_name not in plugin_ids: + raise RuntimeError( + f"worker {self.worker_id} reported undeclared loaded plugin: " + f"{plugin_name}" + ) + for plugin_name in self.skipped_plugins: + if plugin_ids and plugin_name not in plugin_ids: + raise RuntimeError( + f"worker {self.worker_id} reported undeclared skipped plugin: " + f"{plugin_name}" + ) + for capability_name, plugin_name in self.capability_sources.items(): + if plugin_ids and plugin_name not in plugin_ids: + raise RuntimeError( + f"worker {self.worker_id} returned capability source outside " + f"worker_registry: {capability_name} -> {plugin_name}" + ) + for handler in self.handlers: + owner_plugin = _plugin_name_from_handler_id(handler.id) + if plugin_ids and owner_plugin not in plugin_ids: + raise RuntimeError( + f"worker {self.worker_id} returned handler outside worker_registry: " + f"{handler.id}" + ) + for item in self.llm_tools: + plugin_name = str(item.get("plugin_id", "")).strip() + if plugin_ids and plugin_name and plugin_name not in plugin_ids: + raise RuntimeError( + f"worker {self.worker_id} returned llm tool outside worker_registry: " + f"{plugin_name}" + ) + for item in self.agents: + plugin_name = str(item.get("plugin_id", "")).strip() + if plugin_ids and plugin_name and plugin_name not in plugin_ids: + raise RuntimeError( + f"worker {self.worker_id} returned agent outside worker_registry: " + f"{plugin_name}" + ) + + def _worker_command(self) -> tuple[Path, list[str], str]: + if self.group is not None: + prepare_group = getattr(self.env_manager, "prepare_group_environment", None) + if callable(prepare_group): + python_path = cast(Path, prepare_group(self.group)) + else: + python_path = self.env_manager.prepare_environment(self.plugins[0]) + return ( + python_path, + [ + str(python_path), + "-m", + "astrbot_sdk", + "worker", + "--group-metadata", + str(self.group.metadata_path), + ], + str(self.repo_root), + ) + + assert self.plugin is not None + plugin = self.plugin + python_path = self.env_manager.prepare_environment(plugin) + return ( + python_path, + [ + str(python_path), + "-m", + "astrbot_sdk", + "worker", + "--plugin-dir", + str(plugin.plugin_dir), + ], + str(plugin.plugin_dir), + ) + + def start_close_watch(self) -> None: + if ( + self.on_closed is None + or self.peer is None + or self._connection_watch_task is not None + ): + return + self._connection_watch_task = asyncio.create_task(self._watch_connection()) + + async def _watch_connection(self) -> None: + """监听 Worker 连接关闭,触发清理回调""" + try: + if self.peer is not None: + await self.peer.wait_closed() + if self.on_closed is not None: + try: + self.on_closed() + except Exception: + logger.exception( + "on_closed callback failed for worker {}", self.worker_id + ) + finally: + current_task = asyncio.current_task() + if self._connection_watch_task is current_task: + self._connection_watch_task = None + + async def stop(self) -> None: + if self.peer is not None: + await self.peer.stop() + + async def invoke_handler( + self, + handler_id: str, + event_payload: dict[str, Any], + *, + request_id: str, + args: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if self.peer is None: + raise RuntimeError("worker session is not running") + return await self.peer.invoke( + "handler.invoke", + { + "handler_id": handler_id, + "event": event_payload, + "args": dict(args or {}), + }, + request_id=request_id, + ) + + async def invoke_capability( + self, + capability_name: str, + payload: dict[str, Any], + *, + request_id: str, + ) -> dict[str, Any]: + if self.peer is None: + raise RuntimeError("worker session is not running") + return await self.peer.invoke( + capability_name, + payload, + request_id=request_id, + ) + + async def invoke_capability_stream( + self, + capability_name: str, + payload: dict[str, Any], + *, + request_id: str, + ): + if self.peer is None: + raise RuntimeError("worker session is not running") + event_stream = await self.peer.invoke_stream( + capability_name, + payload, + request_id=request_id, + include_completed=True, + ) + async for event in event_stream: + yield event + + async def cancel(self, request_id: str) -> None: + if self.peer is None: + return + await self.peer.cancel(request_id) + + async def _handle_initialize(self, _message) -> InitializeOutput: + return InitializeOutput( + peer=PeerInfo(name="astrbot-supervisor", role="core", version="s5r"), + capabilities=self.capability_router.all_descriptors(), + metadata={ + "worker_id": self.worker_id, + "plugins": [plugin.name for plugin in self.plugins], + }, + ) + + async def _handle_capability_invoke(self, message, cancel_token): + return await self.capability_router.execute( + message.capability, + message.input, + stream=message.stream, + cancel_token=cancel_token, + request_id=message.id, + ) + + def describe(self) -> dict[str, Any]: + return { + "worker_id": self.worker_id, + "plugins": [plugin.name for plugin in self.plugins], + "loaded_plugins": list(self.loaded_plugins), + "skipped_plugins": dict(self.skipped_plugins), + "issues": [issue.to_payload() for issue in self.issues], + } + + +class SupervisorRuntime: + def __init__( + self, + *, + transport, + plugins_dir: Path, + env_manager: PluginEnvironmentManager | None = None, + workers_manifest: Path | None = None, + ) -> None: + self.transport = transport + self.plugins_dir = plugins_dir.resolve() + self.repo_root = Path(__file__).resolve().parents[3] + self.env_manager = env_manager or PluginEnvironmentManager(self.repo_root) + self.workers_manifest = workers_manifest.resolve() if workers_manifest else None + self.capability_router = CapabilityRouter() + self.peer = Peer( + transport=self.transport, + peer_info=PeerInfo(name="astrbot-supervisor", role="plugin", version="s5r"), + ) + self.peer.set_invoke_handler(self._handle_upstream_invoke) + self.peer.set_cancel_handler(self._handle_upstream_cancel) + self.worker_sessions: dict[str, WorkerSession] = {} + self.handler_to_worker: dict[str, WorkerSession] = {} + self.capability_to_worker: dict[str, WorkerSession] = {} + self.plugin_to_worker_session: dict[str, WorkerSession] = {} + self._handler_sources: dict[str, str] = {} # handler_id -> plugin_name + self._capability_sources: dict[str, str] = {} # capability_name -> plugin_name + self.active_requests: dict[str, WorkerSession] = {} + self.loaded_plugins: list[str] = [] + self.skipped_plugins: dict[str, str] = {} + self.issues: list[PluginDiscoveryIssue] = [] + self._register_internal_capabilities() + + def _publish_plugin_registry_snapshot( + self, + plugins: list[PluginSpec], + *, + enabled_plugins: set[str], + ) -> None: + for plugin in plugins: + manifest = plugin.manifest_data + self.capability_router.upsert_plugin( + metadata={ + "name": plugin.name, + "display_name": str(manifest.get("display_name") or plugin.name), + "description": str( + manifest.get("desc") or manifest.get("description") or "" + ), + "repo": str(manifest.get("repo") or ""), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": plugin.name in enabled_plugins, + }, + config=load_plugin_config(plugin), + ) + + def _publish_discovered_plugin_registry(self, plugins: list[PluginSpec]) -> None: + """发布已发现插件的静态元数据。 + + 这一阶段发生在 worker 真正启动前。此时 supervisor 已经知道有哪些插件、 + 它们的 manifest/config 是什么,但尚未确认哪些插件实际完成加载,因此统一 + 以 `enabled=False` 暴露给 metadata 能力。 + """ + self._publish_plugin_registry_snapshot(plugins, enabled_plugins=set()) + + def _publish_loaded_plugin_registry(self, plugins: list[PluginSpec]) -> None: + """在 worker 启动完成后刷新插件启用状态。""" + self._publish_plugin_registry_snapshot( + plugins, + enabled_plugins=set(self.loaded_plugins), + ) + + def _publish_worker_registry(self, entries: list[dict[str, Any]]) -> None: + for item in entries: + plugin_name = str(item.get("name", "")).strip() + if not plugin_name: + continue + config = item.get("config") + metadata = dict(item) + metadata.pop("config", None) + self.capability_router.upsert_plugin( + metadata=metadata, + config=dict(config) if isinstance(config, dict) else {}, + ) + + def _publish_session_runtime_metadata(self, session: WorkerSession) -> None: + self._publish_worker_registry(session.worker_registry) + tools_by_plugin: dict[str, list[dict[str, Any]]] = {} + for item in session.llm_tools: + plugin_name = str(item.get("plugin_id", "")).strip() + if not plugin_name: + continue + tools_by_plugin.setdefault(plugin_name, []).append(dict(item)) + for plugin_name, items in tools_by_plugin.items(): + self.capability_router.set_plugin_llm_tools(plugin_name, items) + + agents_by_plugin: dict[str, list[dict[str, Any]]] = {} + for item in session.agents: + plugin_name = str(item.get("plugin_id", "")).strip() + if not plugin_name: + continue + agents_by_plugin.setdefault(plugin_name, []).append(dict(item)) + for plugin_name, items in agents_by_plugin.items(): + self.capability_router.set_plugin_agents(plugin_name, items) + + @staticmethod + def _session_plugin_ids(session: WorkerSession) -> set[str]: + plugin_ids = { + str(item.get("name", "")).strip() + for item in session.worker_registry + if isinstance(item, dict) + } + plugin_ids.discard("") + if plugin_ids: + return plugin_ids + return {plugin.name for plugin in session.plugins} + + def _validate_remote_session_plugins( + self, + session: WorkerSession, + *, + local_plugin_ids: set[str], + ) -> None: + if not session.is_remote: + return + conflicts = self._session_plugin_ids(session) & ( + local_plugin_ids | set(self.plugin_to_worker_session) + ) + if not conflicts: + return + names = ", ".join(sorted(conflicts)) + raise RuntimeError( + f"remote worker {session.worker_id} conflicts with existing plugins: {names}" + ) + + def _record_session_start_failure( + self, + session: WorkerSession, + exc: Exception, + ) -> None: + if session.plugins: + for plugin in session.plugins: + self.skipped_plugins[plugin.name] = str(exc) + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin.name, + message="插件 worker 启动失败", + details=str(exc), + ) + ) + return + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=session.worker_id, + message="远程 worker 连接失败", + details=str(exc), + ) + ) + + def _register_started_session(self, session: WorkerSession) -> None: + self.worker_sessions[session.worker_id] = session + self.skipped_plugins.update(session.skipped_plugins) + self.issues.extend(session.issues) + self._publish_session_runtime_metadata(session) + for plugin_name in session.loaded_plugins: + self.plugin_to_worker_session[plugin_name] = session + if plugin_name not in self.loaded_plugins: + self.loaded_plugins.append(plugin_name) + for handler in session.handlers: + self._register_handler( + handler, + session, + _plugin_name_from_handler_id(handler.id), + ) + for descriptor in session.provided_capabilities: + plugin_name = session.capability_sources.get(descriptor.name) + if plugin_name is None and len(session.loaded_plugins) == 1: + plugin_name = session.loaded_plugins[0] + if plugin_name is None: + plugin_name = session.worker_id + self._register_plugin_capability(descriptor, session, plugin_name) + session.start_close_watch() + + def _register_internal_capabilities(self) -> None: + self.capability_router.register( + CapabilityDescriptor( + name="handler.invoke", + description="框架内部:转发到插件 handler", + input_schema={ + "type": "object", + "properties": { + "handler_id": {"type": "string"}, + "event": {"type": "object"}, + }, + "required": ["handler_id", "event"], + }, + output_schema={ + "type": "object", + "properties": {}, + "required": [], + }, + cancelable=True, + ), + call_handler=self._route_handler_invoke, + exposed=False, + ) + + def _register_handler( + self, handler, session: WorkerSession, plugin_name: str + ) -> None: + """注册 handler,处理冲突时输出警告。 + + Args: + handler: Handler 描述符 + session: Worker 会话 + plugin_name: 插件名称 + """ + handler_id = handler.id + existing_plugin = self._handler_sources.get(handler_id) + + if existing_plugin is not None: + logger.warning( + f"Handler ID 冲突:'{handler_id}' 已被插件 '{existing_plugin}' 注册," + f"现在被插件 '{plugin_name}' 覆盖。" + ) + + self.handler_to_worker[handler_id] = session + self._handler_sources[handler_id] = plugin_name + + def _register_plugin_capability( + self, + descriptor: CapabilityDescriptor, + session: WorkerSession, + plugin_name: str, + ) -> None: + """注册插件 capability。""" + capability_name = descriptor.name + if not capability_belongs_to_plugin(capability_name, plugin_name): + expected_prefix = plugin_capability_prefix(plugin_name) + raise ValueError( + "插件导出的 capability 必须使用 plugin_id 作为公开命名空间前缀:" + f" plugin={plugin_name!r}, capability={capability_name!r}, " + f"expected_prefix={expected_prefix!r}" + ) + # Worker 侧 loader 已经做过命名空间校验;这里若还能撞名,说明协议数据 + # 与本地路由状态不一致,继续静默改名只会掩盖问题。 + if self.capability_router.contains(capability_name): + existing_plugin = self._capability_sources.get(capability_name, "") + raise RuntimeError( + "duplicate capability registration detected after worker load validation: " + f"{capability_name!r} already registered by {existing_plugin!r}, " + f"cannot register again for {plugin_name!r}" + ) + self._do_register_capability(descriptor, session, capability_name, plugin_name) + + def _do_register_capability( + self, + descriptor: CapabilityDescriptor, + session: WorkerSession, + capability_name: str, + plugin_name: str, + ) -> None: + """实际执行 capability 注册。""" + self.capability_router.register( + descriptor, + call_handler=self._make_plugin_capability_caller(session, capability_name), + stream_handler=( + self._make_plugin_capability_streamer(session, capability_name) + if descriptor.supports_stream + else None + ), + ) + self.capability_to_worker[capability_name] = session + self._capability_sources[capability_name] = plugin_name + + def _make_plugin_capability_caller( + self, + session: WorkerSession, + capability_name: str, + ): + async def call_handler( + request_id: str, + payload: dict[str, Any], + _cancel_token, + ) -> dict[str, Any]: + self.active_requests[request_id] = session + try: + return await session.invoke_capability( + capability_name, + payload, + request_id=request_id, + ) + finally: + self.active_requests.pop(request_id, None) + + return call_handler + + def _make_plugin_capability_streamer( + self, + session: WorkerSession, + capability_name: str, + ): + async def stream_handler( + request_id: str, + payload: dict[str, Any], + _cancel_token, + ): + completed_output: dict[str, Any] = {} + + async def iterator(): + self.active_requests[request_id] = session + try: + async for event in session.invoke_capability_stream( + capability_name, + payload, + request_id=request_id, + ): + if not isinstance(event, EventMessage): + raise AstrBotError.protocol_error( + "插件 worker 返回了非法的流式事件" + ) + if event.phase == "delta": + yield event.data or {} + continue + if event.phase == "completed": + completed_output.clear() + completed_output.update(event.output or {}) + finally: + self.active_requests.pop(request_id, None) + + return StreamExecution( + iterator=iterator(), + finalize=lambda chunks: completed_output or {"items": chunks}, + ) + + return stream_handler + + async def start(self) -> None: + discovery = discover_plugins(self.plugins_dir) + self.skipped_plugins = dict(discovery.skipped_plugins) + self.issues = list(discovery.issues) + local_plugin_ids = {plugin.name for plugin in discovery.plugins} + plan_result = self.env_manager.plan(discovery.plugins) + remote_workers = ( + load_remote_workers_manifest(self.workers_manifest) + if self.workers_manifest is not None + else [] + ) + self.skipped_plugins.update(plan_result.skipped_plugins) + self.issues.extend( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin_name, + message="插件环境规划失败", + details=str(reason), + ) + for plugin_name, reason in plan_result.skipped_plugins.items() + ) + # 先发布静态插件元数据,允许 supervisor 侧在 worker 启动阶段就读取配置/清单。 + self._publish_discovered_plugin_registry(discovery.plugins) + try: + planned_sessions: list[WorkerSession] = [] + if plan_result.groups: + for group in plan_result.groups: + planned_sessions.append( + WorkerSession( + group=group, + repo_root=self.repo_root, + env_manager=self.env_manager, + capability_router=self.capability_router, + on_closed=lambda worker_id=group.id: ( + self._handle_worker_closed(worker_id) + ), + ) + ) + else: + for plugin in plan_result.plugins: + planned_sessions.append( + WorkerSession( + plugin=plugin, + repo_root=self.repo_root, + env_manager=self.env_manager, + capability_router=self.capability_router, + on_closed=lambda worker_id=plugin.name: ( + self._handle_worker_closed(worker_id) + ), + ) + ) + for remote_worker in remote_workers: + planned_sessions.append( + WorkerSession( + remote_worker=remote_worker, + repo_root=self.repo_root, + env_manager=self.env_manager, + capability_router=self.capability_router, + on_closed=lambda worker_id=remote_worker.id: ( + self._handle_worker_closed(worker_id) + ), + ) + ) + + for session in planned_sessions: + try: + await session.start() + self._validate_remote_session_plugins( + session, + local_plugin_ids=local_plugin_ids, + ) + except Exception as exc: + self._record_session_start_failure(session, exc) + await session.stop() + continue + self._register_started_session(session) + + # worker 启动后再用实际加载结果刷新 enabled 状态,形成显式两阶段发布。 + self._publish_loaded_plugin_registry(discovery.plugins) + + aggregated_handlers = list(self.handler_to_worker.keys()) + logger.info( + "Loaded plugins: {}", ", ".join(sorted(self.loaded_plugins)) or "none" + ) + + await self.peer.start() + await self.peer.initialize( + [ + handler + for session in self.worker_sessions.values() + for handler in session.handlers + ], + provided_capabilities=self.capability_router.descriptors(), + metadata={ + "plugins": sorted(self.loaded_plugins), + "skipped_plugins": self.skipped_plugins, + "issues": [issue.to_payload() for issue in self.issues], + "aggregated_handler_ids": aggregated_handlers, + "workers": [ + session.describe() for session in self.worker_sessions.values() + ], + "worker_count": len(self.worker_sessions), + }, + ) + except Exception: + await self.stop() + raise + + def _handle_worker_closed(self, worker_id: str) -> None: + """Worker 连接关闭时的清理回调""" + session = self.worker_sessions.pop(worker_id, None) + if session is None: + return + # 从 handler_to_worker 中移除该插件注册的 handlers(仅当来源仍为此插件时) + for handler in session.handlers: + source_plugin = self._handler_sources.get(handler.id) + if source_plugin == _plugin_name_from_handler_id(handler.id) or ( + source_plugin == worker_id + ): + self.handler_to_worker.pop(handler.id, None) + self._handler_sources.pop(handler.id, None) + for descriptor in session.provided_capabilities: + source_plugin = self._capability_sources.get(descriptor.name) + capability_plugin = session.capability_sources.get(descriptor.name) + if source_plugin == capability_plugin or ( + capability_plugin is None + and ( + source_plugin == worker_id + or source_plugin in session.loaded_plugins + ) + ): + self.capability_to_worker.pop(descriptor.name, None) + self._capability_sources.pop(descriptor.name, None) + self.capability_router.unregister(descriptor.name) + session_loaded_plugins = getattr(session, "loaded_plugins", None) + if not isinstance(session_loaded_plugins, list): + session_loaded_plugins = [worker_id] + for plugin_name in session_loaded_plugins: + if plugin_name in self.loaded_plugins: + self.loaded_plugins.remove(plugin_name) + self.plugin_to_worker_session.pop(plugin_name, None) + self.capability_router.set_plugin_enabled(plugin_name, False) + self.capability_router.remove_http_apis_for_plugin(plugin_name) + stale_requests = [ + request_id + for request_id, active_session in self.active_requests.items() + if active_session is session + ] + for request_id in stale_requests: + self.active_requests.pop(request_id, None) + logger.warning("worker {} 连接已关闭,已清理相关 handlers", worker_id) + + async def stop(self) -> None: + for session in list(self.worker_sessions.values()): + await session.stop() + await self.peer.stop() + + async def _handle_upstream_invoke(self, message, cancel_token): + return await self.capability_router.execute( + message.capability, + message.input, + stream=message.stream, + cancel_token=cancel_token, + request_id=message.id, + ) + + async def _route_handler_invoke( + self, + request_id: str, + payload: dict[str, Any], + _cancel_token, + ) -> dict[str, Any]: + handler_id = str(payload.get("handler_id", "")) + session = self.handler_to_worker.get(handler_id) + if session is None: + raise AstrBotError.invalid_input(f"handler not found: {handler_id}") + self.active_requests[request_id] = session + try: + return await session.invoke_handler( + handler_id, + payload.get("event", {}), + request_id=request_id, + args=payload.get("args", {}), + ) + finally: + self.active_requests.pop(request_id, None) + + async def _handle_upstream_cancel(self, request_id: str) -> None: + session = self.active_requests.get(request_id) + if session is not None: + await session.cancel(request_id) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/transport.py b/astrbot-sdk/src/astrbot_sdk/runtime/transport.py new file mode 100644 index 0000000000..9f5f64c1b4 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/transport.py @@ -0,0 +1,557 @@ +"""传输层抽象模块。 + +定义 Transport 抽象基类及其实现,负责底层的消息传输。 +传输层只关心"发送字符串"和"接收字符串",不处理协议细节。 +传输实现: + Transport: 抽象基类,定义 start/stop/send/wait_closed 接口 + StdioTransport: 标准输入输出传输 + - 进程模式: 通过 command 参数启动子进程 + - 文件模式: 通过 stdin/stdout 参数指定文件描述符 + +传输类型: + Transport: 抽象基类,定义 start/stop/send 接口 + StdioTransport: 标准输入输出传输,支持进程模式和文件模式 + WebSocketServerTransport: WebSocket 服务端传输 + - 单连接限制,支持心跳配置 + - 通过 port 属性获取实际监听端口 + - 自动重连需要外部实现 + +使用示例: + # 子进程模式 + transport = StdioTransport( + command=["python", "-m", "my_plugin"], + cwd="/path/to/plugin", + ) + + # 标准输入输出模式 + transport = StdioTransport(stdin=sys.stdin, stdout=sys.stdout) + + # WebSocket 服务端 + transport = WebSocketServerTransport(host="0.0.0.0", port=8765) + + # WebSocket 客户端 + transport = WebSocketClientTransport(url="ws://localhost:8765") + + # 统一接口 + transport.set_message_handler(my_handler) + await transport.start() + await transport.send(json_string) + await transport.stop() + +`Transport` 只处理“字符串发出去 / 字符串收进来”这件事,不做协议解析,也不关心 +能力、handler 或迁移适配策略。当前实现包括: + +- `StdioTransport`: 子进程或文件对象上的按行文本传输 +- `WebSocketServerTransport`: 单连接 WebSocket 服务端 +- `WebSocketClientTransport`: WebSocket 客户端 + +自动重连、消息重放等策略不在这里实现,统一留给更上层编排。 +""" + +from __future__ import annotations + +import asyncio +import ssl +import sys +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path +from typing import IO, Any + +from .._internal.sdk_logger import logger + +MessageHandler = Callable[[str], Awaitable[None]] +STDIO_SUBPROCESS_STREAM_LIMIT = 8 * 1024 * 1024 + + +def build_websocket_server_ssl_context( + *, + ca_file: str | Path, + cert_file: str | Path, + key_file: str | Path, +) -> ssl.SSLContext: + """Build a mutual-TLS server SSL context for websocket workers.""" + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(cafile=str(ca_file)) + context.load_cert_chain(certfile=str(cert_file), keyfile=str(key_file)) + return context + + +def build_websocket_client_ssl_context( + *, + ca_file: str | Path, + cert_file: str | Path, + key_file: str | Path, +) -> ssl.SSLContext: + """Build a mutual-TLS client SSL context for websocket supervisor sessions.""" + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=str(ca_file)) + context.load_cert_chain(certfile=str(cert_file), keyfile=str(key_file)) + return context + + +def _get_aiohttp(): + import aiohttp + + return aiohttp + + +def _get_web(): + from aiohttp import web + + return web + + +def _frame_stdio_payload(payload: str) -> bytes: + body = payload.encode("utf-8") + return f"{len(body)}\n".encode("ascii") + body + + +def _parse_stdio_header(raw_header: bytes) -> int: + header = raw_header.decode("ascii").strip() + if not header: + raise ValueError("STDIO frame header is empty") + try: + size = int(header) + except ValueError as exc: + raise ValueError(f"Invalid STDIO frame header: {header!r}") from exc + # 拒绝负数 size,防止子进程写入畸形 header 导致 readexactly 行为异常 + if size < 0: + raise ValueError(f"STDIO frame size must be non-negative: {size}") + return size + + +# TODO 一个更好的解决方案? +def _is_windows_access_denied(error: BaseException) -> bool: + return ( + sys.platform == "win32" + and isinstance(error, PermissionError) + and getattr(error, "winerror", None) == 5 + ) + + +class Transport(ABC): + def __init__(self) -> None: + self._handler: MessageHandler | None = None + self._closed = asyncio.Event() + + def set_message_handler(self, handler: MessageHandler) -> None: + """注册收到原始字符串消息后的回调。""" + self._handler = handler + + @abstractmethod + async def start(self) -> None: + raise NotImplementedError + + @abstractmethod + async def stop(self) -> None: + raise NotImplementedError + + @abstractmethod + async def send(self, payload: str) -> None: + raise NotImplementedError + + async def wait_closed(self) -> None: + """等待传输层进入关闭状态。""" + await self._closed.wait() + + async def _dispatch(self, payload: str) -> None: + """把收到的原始载荷转交给上层处理器。""" + if self._handler is not None: + await self._handler(payload) + + async def _dispatch_safely(self, payload: str, *, source: str) -> None: + """安全地分发一帧消息:捕获所有非取消异常,避免单帧处理错误拖垮整个读循环。""" + try: + await self._dispatch(payload) + except asyncio.CancelledError: + # CancelledError 必须放行,否则无法优雅关闭 + raise + except Exception: + # 记录异常后继续读下一帧,而不是让读循环崩溃导致整个 transport 不可用 + logger.exception("Dropping inbound transport frame from {}", source) + + +class StdioTransport(Transport): + def __init__( + self, + *, + stdin: IO[str] | None = None, + stdout: IO[str] | None = None, + command: Sequence[str] | None = None, + cwd: str | None = None, + env: dict[str, str] | None = None, + ) -> None: + super().__init__() + self._stdin = stdin + self._stdout = stdout + self._command = list(command) if command is not None else None + self._cwd = cwd + self._env = env + self._process: asyncio.subprocess.Process | None = None + self._reader_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + self._closed.clear() + if self._command is not None: + self._process = await self._start_subprocess_with_retry() + self._reader_task = asyncio.create_task(self._read_process_loop()) + return + + self._stdin = self._stdin or sys.stdin + self._stdout = self._stdout or sys.stdout + self._reader_task = asyncio.create_task(self._read_file_loop()) + + async def _start_subprocess_with_retry(self) -> asyncio.subprocess.Process: + assert self._command is not None # 类型收窄:start() 已确保非空 + delays = [0.15, 0.35, 0.75] + last_error: BaseException | None = None + for attempt, delay in enumerate([0.0, *delays], start=1): + if delay: + await asyncio.sleep(delay) + try: + return await asyncio.create_subprocess_exec( + *self._command, + cwd=self._cwd, + env=self._env, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=sys.stderr, + limit=STDIO_SUBPROCESS_STREAM_LIMIT, + ) + except Exception as exc: + last_error = exc + if not _is_windows_access_denied(exc) or attempt == len(delays) + 1: + raise + logger.warning( + "Windows denied access while starting freshly prepared worker " + "interpreter, retrying attempt {}/{}: {}", + attempt, + len(delays) + 1, + exc, + ) + assert last_error is not None + raise last_error + + async def stop(self) -> None: + if self._reader_task is not None: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + if self._process is not None: + if self._process.returncode is None: + self._process.terminate() + try: + await asyncio.wait_for(self._process.wait(), timeout=5) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + self._process = None + self._closed.set() + + async def send(self, payload: str) -> None: + frame = _frame_stdio_payload(payload) + if self._process is not None: + if self._process.stdin is None: + raise RuntimeError("STDIO subprocess stdin 不可用") + self._process.stdin.write(frame) + await self._process.stdin.drain() + return + + if self._stdout is None: + raise RuntimeError("STDIO stdout 不可用") + + def _write() -> None: + assert self._stdout is not None + binary_stdout = getattr(self._stdout, "buffer", None) + if binary_stdout is None: + raise RuntimeError("STDIO stdout 必须提供可写入 bytes 的 buffer") + binary_stdout.write(frame) + binary_stdout.flush() + + await asyncio.to_thread(_write) + + async def _read_process_loop(self) -> None: + """从子进程 stdout 持续读取 STDIO 帧,单帧异常不中断整体读取。""" + assert self._process is not None + assert self._process.stdout is not None + try: + while True: + try: + raw_header = await self._process.stdout.readline() + if not raw_header: + break + payload_size = _parse_stdio_header(raw_header) + raw = await self._process.stdout.readexactly(payload_size) + # 使用 _dispatch_safely 而非 _dispatch,确保上层的单帧处理错误不会终结读循环 + await self._dispatch_safely( + raw.decode("utf-8"), + source="stdio-process", + ) + except asyncio.CancelledError: + raise + except asyncio.IncompleteReadError: + # 帧被截断说明子进程已经异常退出,读循环应终止 + logger.warning("STDIO subprocess frame truncated before completion") + break + except UnicodeDecodeError as exc: + # UTF-8 解码失败:跳过本帧继续,避免二进制脏数据导致整个连接断开 + logger.warning( + "Skipping STDIO subprocess frame with invalid UTF-8 payload: {}", + exc, + ) + continue + except ValueError as exc: + # header 解析失败后无法再可靠定位后续帧边界;继续读取只会让协议流长期失同步。 + logger.warning( + "Stopping STDIO subprocess read loop after malformed frame: {}", + exc, + ) + break + finally: + self._closed.set() + + async def _read_file_loop(self) -> None: + """从本地 stdin(file 模式)持续读取 STDIO 帧,单帧异常不中断整体读取。""" + assert self._stdin is not None + try: + while True: + try: + binary_stdin = getattr(self._stdin, "buffer", None) + if binary_stdin is None: + raise RuntimeError("STDIO stdin 必须提供可读取 bytes 的 buffer") + raw_header = await asyncio.to_thread(binary_stdin.readline) + if not raw_header: + break + payload_size = _parse_stdio_header(raw_header) + raw = await asyncio.to_thread(binary_stdin.read, payload_size) + if len(raw) != payload_size: + raise EOFError("STDIO frame truncated before payload completed") + await self._dispatch_safely( + raw.decode("utf-8"), + source="stdio-file", + ) + except asyncio.CancelledError: + raise + except EOFError as exc: + # 流被截断意味着上游已关闭,读循环应终止 + logger.warning("{}", exc) + break + except UnicodeDecodeError as exc: + # UTF-8 解码失败:跳过本帧继续,保留连接可用 + logger.warning( + "Skipping STDIO file frame with invalid UTF-8 payload: {}", + exc, + ) + continue + except ValueError as exc: + # 文件模式同样无法从坏 header 中恢复到下一帧边界;直接终止读取更安全。 + logger.warning( + "Stopping STDIO file read loop after malformed frame: {}", exc + ) + break + finally: + self._closed.set() + + +class WebSocketServerTransport(Transport): + def __init__( + self, + *, + host: str = "127.0.0.1", + port: int = 8765, + path: str = "/", + heartbeat: float = 30.0, + ssl_context: ssl.SSLContext | None = None, + ) -> None: + super().__init__() + self._host = host + self._port = port + self._actual_port: int | None = None + self._path = path + self._heartbeat = heartbeat + self._ssl_context = ssl_context + self._app: Any | None = None + self._runner: Any | None = None + self._site: Any | None = None + self._ws: Any | None = None + self._write_lock = asyncio.Lock() + self._connected = asyncio.Event() + + async def start(self) -> None: + web = _get_web() + self._closed.clear() + self._connected.clear() + self._app = web.Application() + self._app.router.add_get(self._path, self._handle_socket) + self._runner = web.AppRunner(self._app) + await self._runner.setup() + self._site = web.TCPSite( + self._runner, + self._host, + self._port, + ssl_context=self._ssl_context, + ) + await self._site.start() + if self._site._server and getattr(self._site._server, "sockets", None): + socket = self._site._server.sockets[0] + self._actual_port = socket.getsockname()[1] + + async def stop(self) -> None: + self._connected.clear() + if self._ws is not None and not self._ws.closed: + await self._ws.close() + if self._site is not None: + await self._site.stop() + self._site = None + if self._runner is not None: + await self._runner.cleanup() + self._runner = None + self._closed.set() + + async def send(self, payload: str) -> None: + if self._ws is None or self._ws.closed: + await asyncio.wait_for(self._connected.wait(), timeout=30.0) + if self._ws is None or self._ws.closed: + raise RuntimeError("WebSocket 尚未连接") + async with self._write_lock: + await self._ws.send_str(payload) + + async def _handle_socket(self, request) -> Any: + web = _get_web() + aiohttp = _get_aiohttp() + if self._ws is not None and not self._ws.closed: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.close(code=1008, message=b"only one websocket connection allowed") + return ws + + ws = web.WebSocketResponse( + heartbeat=self._heartbeat if self._heartbeat > 0 else None + ) + await ws.prepare(request) + self._ws = ws + self._connected.set() + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + # 文本帧直接分发,无需编解码 + await self._dispatch_safely( + msg.data, source="websocket-server-text" + ) + elif msg.type == aiohttp.WSMsgType.BINARY: + # 二进制帧需要先尝试 UTF-8 解码;解码失败只跳过本帧,不断开连接 + try: + payload = msg.data.decode("utf-8") + except UnicodeDecodeError as exc: + logger.warning( + "Skipping websocket server binary frame with invalid UTF-8 payload: {}", + exc, + ) + continue + await self._dispatch_safely( + payload, + source="websocket-server-binary", + ) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error("websocket server error: {}", ws.exception()) + break + finally: + self._connected.clear() + self._closed.set() + self._ws = None + return ws + + @property + def port(self) -> int: + return self._actual_port or self._port + + @property + def url(self) -> str: + scheme = "wss" if self._ssl_context is not None else "ws" + return f"{scheme}://{self._host}:{self.port}{self._path}" + + +class WebSocketClientTransport(Transport): + def __init__( + self, + *, + url: str, + heartbeat: float = 30.0, + ssl_context: ssl.SSLContext | None = None, + server_hostname: str | None = None, + ) -> None: + super().__init__() + self._url = url + self._heartbeat = heartbeat + self._ssl_context = ssl_context + self._server_hostname = server_hostname + self._session: Any | None = None + self._ws: Any | None = None + self._reader_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + aiohttp = _get_aiohttp() + self._closed.clear() + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect( + self._url, + heartbeat=self._heartbeat if self._heartbeat > 0 else None, + ssl_context=self._ssl_context, + server_hostname=self._server_hostname, + ) + self._reader_task = asyncio.create_task(self._read_loop()) + + async def stop(self) -> None: + if self._reader_task is not None: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + if self._ws is not None and not self._ws.closed: + await self._ws.close() + if self._session is not None: + await self._session.close() + self._ws = None + self._session = None + self._closed.set() + + async def send(self, payload: str) -> None: + if self._ws is None or self._ws.closed: + raise RuntimeError("WebSocket client 尚未连接") + await self._ws.send_str(payload) + + async def _read_loop(self) -> None: + assert self._ws is not None + aiohttp = _get_aiohttp() + try: + async for msg in self._ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await self._dispatch_safely( + msg.data, source="websocket-client-text" + ) + elif msg.type == aiohttp.WSMsgType.BINARY: + # 与 server 端一致:二进制帧解码失败仅跳过本帧,保持连接存活 + try: + payload = msg.data.decode("utf-8") + except UnicodeDecodeError as exc: + logger.warning( + "Skipping websocket client binary frame with invalid UTF-8 payload: {}", + exc, + ) + continue + await self._dispatch_safely( + payload, + source="websocket-client-binary", + ) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error("websocket client error: {}", self._ws.exception()) + break + finally: + self._closed.set() diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/worker.py b/astrbot-sdk/src/astrbot_sdk/runtime/worker.py new file mode 100644 index 0000000000..6d04b6cd89 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/worker.py @@ -0,0 +1,536 @@ +"""Worker 端运行时:PluginWorkerRuntime 运行单个插件,GroupWorkerRuntime 在同一进程中运行多个插件。 + +核心类: + GroupWorkerRuntime: 组 Worker 运行时 + - 在同一进程中加载并运行多个插件 + - 聚合所有插件的 handlers 和 capabilities + - 统一处理 invoke 和 cancel 请求 + - 管理每个插件的生命周期回调 + + PluginWorkerRuntime: 单插件 Worker 运行时 + - 加载单个插件 + - 通过 Peer 与 Supervisor 通信 + - 分发 handler 调用 + - 处理生命周期回调 (on_start, on_stop) + +启动流程: + Worker 启动: + 1. load_plugin_spec() 加载插件规范 + 2. load_plugin() 加载插件组件 + 3. 创建 Peer 并设置处理器 + 4. 向 Supervisor 发送 initialize + 5. 等待 Supervisor 的 initialize_result + 6. 执行 on_start 生命周期回调 +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .._internal.decorator_lifecycle import run_lifecycle_with_decorators +from .._internal.invocation_context import caller_plugin_scope +from .._internal.sdk_logger import logger +from ..context import Context as RuntimeContext +from ..errors import AstrBotError +from ..protocol.messages import PeerInfo +from .handler_dispatcher import CapabilityDispatcher, HandlerDispatcher +from .loader import ( + LoadedPlugin, + PluginDiscoveryIssue, + PluginSpec, + load_plugin, + load_plugin_config, + load_plugin_spec, +) +from .peer import Peer + +__all__ = [ + "GroupPluginRuntimeState", + "GroupWorkerRuntime", + "PluginWorkerRuntime", + "_load_plugin_specs", + "_load_group_plugin_specs", +] + +GLOBAL_MCP_RISK_ATTR = "__astrbot_acknowledge_global_mcp_risk__" + + +@dataclass(slots=True) +class GroupPluginRuntimeState: + plugin: PluginSpec + loaded_plugin: LoadedPlugin + lifecycle_context: RuntimeContext + + +def _plugin_acknowledges_global_mcp_risk(instances: list[Any]) -> bool: + return any( + bool(getattr(instance.__class__, GLOBAL_MCP_RISK_ATTR, False)) + for instance in instances + ) + + +def _metadata_plugin_instances(loaded_plugin: Any) -> list[Any]: + """Return plugin instances for metadata-only inspection. + + Metadata serialization is also exercised by lightweight tests that stub + ``loaded_plugin`` with only the fields relevant to the payload. Missing + ``instances`` means the plugin cannot acknowledge the global MCP risk, but + it should not break issue/metadata reporting. + """ + instances = getattr(loaded_plugin, "instances", []) + if isinstance(instances, list): + return instances + if isinstance(instances, tuple): + return list(instances) + return [] + + +def _load_group_plugin_specs(group_metadata_path: Path) -> tuple[str, list[PluginSpec]]: + try: + payload = json.loads(group_metadata_path.read_text(encoding="utf-8")) + except Exception as exc: + raise RuntimeError( + f"failed to read worker group metadata: {group_metadata_path}" + ) from exc + + if not isinstance(payload, dict): + raise RuntimeError(f"invalid worker group metadata: {group_metadata_path}") + + entries = payload.get("plugin_entries") + if not isinstance(entries, list) or not entries: + raise RuntimeError( + f"worker group metadata missing plugin_entries: {group_metadata_path}" + ) + + plugins: list[PluginSpec] = [] + for entry in entries: + if not isinstance(entry, dict): + raise RuntimeError( + f"worker group metadata contains invalid plugin entry: {group_metadata_path}" + ) + plugin_dir = entry.get("plugin_dir") + if not isinstance(plugin_dir, str) or not plugin_dir: + raise RuntimeError( + f"worker group metadata contains invalid plugin_dir: {group_metadata_path}" + ) + plugins.append(load_plugin_spec(Path(plugin_dir))) + + group_id = payload.get("group_id") + if not isinstance(group_id, str) or not group_id: + group_id = group_metadata_path.stem + return group_id, plugins + + +def _load_plugin_specs(plugin_dirs: list[Path]) -> list[PluginSpec]: + if not plugin_dirs: + raise RuntimeError("worker requires at least one plugin directory") + return [load_plugin_spec(plugin_dir) for plugin_dir in plugin_dirs] + + +def _build_worker_registry_entry( + plugin: PluginSpec, + *, + enabled: bool, +) -> dict[str, Any]: + manifest = plugin.manifest_data + return { + "name": plugin.name, + "display_name": str(manifest.get("display_name") or plugin.name), + "description": str(manifest.get("desc") or manifest.get("description") or ""), + "repo": str(manifest.get("repo") or ""), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": enabled, + "config": load_plugin_config(plugin), + } + + +async def run_plugin_lifecycle( + instances: list[Any], + method_name: str, + context: RuntimeContext, +) -> None: + """运行插件生命周期方法。""" + for instance in instances: + method = getattr(instance, method_name, None) + with caller_plugin_scope(context.plugin_id): + await run_lifecycle_with_decorators( + instance=instance, + hook=method if callable(method) else None, + method_name=method_name, + context=context, + ) + + +class GroupWorkerRuntime: + def __init__( + self, + *, + transport, + group_metadata_path: Path | None = None, + plugin_dirs: list[Path] | None = None, + worker_id: str | None = None, + ) -> None: + if group_metadata_path is None and not plugin_dirs: + raise ValueError("group_metadata_path or plugin_dirs is required") + if group_metadata_path is not None and plugin_dirs: + raise ValueError( + "group_metadata_path and plugin_dirs are mutually exclusive" + ) + self.group_metadata_path = ( + group_metadata_path.resolve() if group_metadata_path is not None else None + ) + if self.group_metadata_path is not None: + default_worker_id, plugins = _load_group_plugin_specs( + self.group_metadata_path + ) + else: + assert plugin_dirs is not None + plugins = _load_plugin_specs([path.resolve() for path in plugin_dirs]) + default_worker_id = plugins[0].name + self.plugins = plugins + self.worker_id = str(worker_id or default_worker_id) + self.transport = transport + self.peer = Peer( + transport=self.transport, + peer_info=PeerInfo(name=self.worker_id, role="plugin", version="s5r"), + ) + self.skipped_plugins: dict[str, str] = {} + self.issues: list[PluginDiscoveryIssue] = [] + self._plugin_states: list[GroupPluginRuntimeState] = [] + self._active_plugin_states: list[GroupPluginRuntimeState] = [] + self._load_plugins() + self._refresh_dispatchers() + self.peer.set_invoke_handler(self._handle_invoke) + self.peer.set_cancel_handler(self._handle_cancel) + + def _load_plugins(self) -> None: + for plugin in self.plugins: + try: + loaded_plugin = load_plugin(plugin) + except Exception as exc: + self.skipped_plugins[plugin.name] = str(exc) + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin.name, + message="插件加载失败", + details=str(exc), + ) + ) + logger.exception( + "worker {} 中插件 {} 加载失败,启动时将跳过", + self.worker_id, + plugin.name, + ) + continue + + lifecycle_context = RuntimeContext(peer=self.peer, plugin_id=plugin.name) + self._plugin_states.append( + GroupPluginRuntimeState( + plugin=plugin, + loaded_plugin=loaded_plugin, + lifecycle_context=lifecycle_context, + ) + ) + self._active_plugin_states = list(self._plugin_states) + + def _refresh_dispatchers(self) -> None: + handlers = [ + handler + for state in self._active_plugin_states + for handler in state.loaded_plugin.handlers + ] + capabilities = [ + capability + for state in self._active_plugin_states + for capability in state.loaded_plugin.capabilities + ] + self.dispatcher = HandlerDispatcher( + plugin_id=self.worker_id, + peer=self.peer, + handlers=handlers, + ) + self.capability_dispatcher = CapabilityDispatcher( + plugin_id=self.worker_id, + peer=self.peer, + capabilities=capabilities, + llm_tools=[ + tool + for state in self._active_plugin_states + for tool in state.loaded_plugin.llm_tools + ], + ) + + async def start(self) -> None: + await self.peer.start() + started_states: list[GroupPluginRuntimeState] = [] + try: + active_states: list[GroupPluginRuntimeState] = [] + for state in self._plugin_states: + try: + await self._run_lifecycle(state, "on_start") + except Exception as exc: + self.skipped_plugins[state.plugin.name] = str(exc) + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="lifecycle", + plugin_id=state.plugin.name, + message="插件 on_start 失败", + details=str(exc), + ) + ) + logger.exception( + "worker {} 中插件 {} on_start 失败,启动时将跳过", + self.worker_id, + state.plugin.name, + ) + continue + active_states.append(state) + started_states.append(state) + + self._active_plugin_states = active_states + self._refresh_dispatchers() + if not self._active_plugin_states: + raise RuntimeError(f"worker {self.worker_id} has no active plugins") + + await self.peer.initialize( + [ + handler.descriptor + for state in self._active_plugin_states + for handler in state.loaded_plugin.handlers + ], + provided_capabilities=[ + capability.descriptor + for state in self._active_plugin_states + for capability in state.loaded_plugin.capabilities + ], + metadata=self._initialize_metadata(), + ) + except Exception: + for state in reversed(started_states): + try: + await self._run_lifecycle(state, "on_stop") + except Exception: + logger.exception( + "worker {} 在启动失败清理插件 {} on_stop 时发生异常", + self.worker_id, + state.plugin.name, + ) + await self.peer.stop() + raise + + async def stop(self) -> None: + first_error: Exception | None = None + try: + for state in reversed(self._active_plugin_states): + try: + await self._run_lifecycle(state, "on_stop") + except Exception as exc: + if first_error is None: + first_error = exc + logger.exception( + "worker {} 停止插件 {} 时发生异常", + self.worker_id, + state.plugin.name, + ) + finally: + await self.peer.stop() + if first_error is not None: + raise first_error + + async def _handle_invoke(self, message, cancel_token): + if message.capability == "handler.invoke": + return await self.dispatcher.invoke(message, cancel_token) + try: + return await self.capability_dispatcher.invoke(message, cancel_token) + except LookupError as exc: + raise AstrBotError.capability_not_found(message.capability) from exc + + async def _handle_cancel(self, request_id: str) -> None: + await self.dispatcher.cancel(request_id) + await self.capability_dispatcher.cancel(request_id) + + def _initialize_metadata(self) -> dict[str, Any]: + return { + "worker_id": self.worker_id, + "plugins": [plugin.name for plugin in self.plugins], + "loaded_plugins": [ + state.plugin.name for state in self._active_plugin_states + ], + "skipped_plugins": dict(self.skipped_plugins), + "worker_registry": [ + _build_worker_registry_entry( + plugin, + enabled=plugin.name + in {state.plugin.name for state in self._active_plugin_states}, + ) + for plugin in self.plugins + ], + "capability_sources": { + capability.descriptor.name: state.plugin.name + for state in self._active_plugin_states + for capability in state.loaded_plugin.capabilities + }, + "issues": [issue.to_payload() for issue in self.issues], + "llm_tools": [ + { + **tool.spec.to_payload(), + "plugin_id": state.plugin.name, + } + for state in self._active_plugin_states + for tool in state.loaded_plugin.llm_tools + ], + "agents": [ + { + **agent.spec.to_payload(), + "plugin_id": state.plugin.name, + } + for state in self._active_plugin_states + for agent in state.loaded_plugin.agents + ], + "acknowledge_global_mcp_risk": any( + _plugin_acknowledges_global_mcp_risk( + _metadata_plugin_instances(state.loaded_plugin) + ) + for state in self._active_plugin_states + ), + } + + async def _run_lifecycle( + self, + state: GroupPluginRuntimeState, + method_name: str, + ) -> None: + await run_plugin_lifecycle( + state.loaded_plugin.instances, method_name, state.lifecycle_context + ) + + +class PluginWorkerRuntime: + def __init__( + self, + *, + plugin_dir: Path, + transport, + worker_id: str | None = None, + ) -> None: + self.plugin = load_plugin_spec(plugin_dir) + self.worker_id = str(worker_id or self.plugin.name) + self.transport = transport + self.loaded_plugin = load_plugin(self.plugin) + self.peer = Peer( + transport=self.transport, + peer_info=PeerInfo(name=self.worker_id, role="plugin", version="s5r"), + ) + self.dispatcher = HandlerDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + handlers=self.loaded_plugin.handlers, + ) + self.capability_dispatcher = CapabilityDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + capabilities=self.loaded_plugin.capabilities, + llm_tools=self.loaded_plugin.llm_tools, + ) + self._lifecycle_context = RuntimeContext( + peer=self.peer, plugin_id=self.plugin.name + ) + self.issues: list[PluginDiscoveryIssue] = [] + self.peer.set_invoke_handler(self._handle_invoke) + self.peer.set_cancel_handler(self._handle_cancel) + + async def start(self) -> None: + await self.peer.start() + lifecycle_started = False + try: + await self._run_lifecycle("on_start") + lifecycle_started = True + await self.peer.initialize( + [item.descriptor for item in self.loaded_plugin.handlers], + provided_capabilities=[ + item.descriptor for item in self.loaded_plugin.capabilities + ], + metadata={ + "worker_id": self.worker_id, + "plugins": [self.plugin.name], + "loaded_plugins": [self.plugin.name], + "skipped_plugins": {}, + "worker_registry": [ + _build_worker_registry_entry(self.plugin, enabled=True) + ], + "issues": [issue.to_payload() for issue in self.issues], + "capability_sources": { + item.descriptor.name: self.plugin.name + for item in self.loaded_plugin.capabilities + }, + "llm_tools": [ + { + **item.spec.to_payload(), + "plugin_id": self.plugin.name, + } + for item in self.loaded_plugin.llm_tools + ], + "agents": [ + { + **item.spec.to_payload(), + "plugin_id": self.plugin.name, + } + for item in self.loaded_plugin.agents + ], + "acknowledge_global_mcp_risk": _plugin_acknowledges_global_mcp_risk( + _metadata_plugin_instances(self.loaded_plugin) + ), + }, + ) + except Exception: + if lifecycle_started: + logger.exception( + "插件 {} 在向 supervisor 上报 initialize 时失败", + self.plugin.name, + ) + else: + logger.exception( + "插件 {} 在 on_start / 装饰器初始化阶段失败;" + "supervisor 可能随后只看到初始化超时,请优先检查这条异常", + self.plugin.name, + ) + if lifecycle_started: + try: + await self._run_lifecycle("on_stop") + except Exception: + logger.exception( + "插件 {} 在启动失败清理 on_stop 时发生异常", + self.plugin.name, + ) + await self.peer.stop() + raise + + async def stop(self) -> None: + try: + await self._run_lifecycle("on_stop") + finally: + await self.peer.stop() + + async def _handle_invoke(self, message, cancel_token): + if message.capability == "handler.invoke": + return await self.dispatcher.invoke(message, cancel_token) + try: + return await self.capability_dispatcher.invoke(message, cancel_token) + except LookupError as exc: + raise AstrBotError.capability_not_found(message.capability) from exc + + async def _handle_cancel(self, request_id: str) -> None: + await self.dispatcher.cancel(request_id) + await self.capability_dispatcher.cancel(request_id) + + async def _run_lifecycle(self, method_name: str) -> None: + await run_plugin_lifecycle( + self.loaded_plugin.instances, method_name, self._lifecycle_context + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/workers_manifest.py b/astrbot-sdk/src/astrbot_sdk/runtime/workers_manifest.py new file mode 100644 index 0000000000..724ffa247b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/workers_manifest.py @@ -0,0 +1,120 @@ +"""Supervisor-side manifest for remote websocket workers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import urlparse + +import yaml + + +@dataclass(slots=True) +class RemoteWorkerTLSConfig: + ca_file: Path + cert_file: Path + key_file: Path + server_hostname: str | None = None + + +@dataclass(slots=True) +class RemoteWorkerSpec: + id: str + url: str + tls: RemoteWorkerTLSConfig + + +def load_remote_workers_manifest(manifest_path: Path) -> list[RemoteWorkerSpec]: + resolved_path = manifest_path.resolve() + payload = yaml.safe_load(resolved_path.read_text(encoding="utf-8")) or {} + if not isinstance(payload, dict): + raise ValueError("workers manifest must be a mapping") + + entries = payload.get("workers") + if not isinstance(entries, list): + raise ValueError("workers manifest must define a 'workers' list") + + workers: list[RemoteWorkerSpec] = [] + seen_ids: set[str] = set() + for index, entry in enumerate(entries): + if not isinstance(entry, dict): + raise ValueError(f"workers[{index}] must be an object") + _reject_unsupported_worker_keys(entry, index=index) + worker_id = str(entry.get("id", "")).strip() + if not worker_id: + raise ValueError(f"workers[{index}].id must be a non-empty string") + if worker_id in seen_ids: + raise ValueError(f"duplicate worker id in workers manifest: {worker_id}") + seen_ids.add(worker_id) + + raw_url = str(entry.get("url", "")).strip() + parsed = urlparse(raw_url) + if parsed.scheme != "wss": + raise ValueError( + f"workers[{index}].url must use wss:// for mutual TLS: {raw_url!r}" + ) + if not parsed.netloc: + raise ValueError(f"workers[{index}].url must include a host: {raw_url!r}") + + tls_payload = entry.get("tls") + if not isinstance(tls_payload, dict): + raise ValueError(f"workers[{index}].tls must be an object") + tls = _load_tls_config( + tls_payload, + manifest_dir=resolved_path.parent, + prefix=f"workers[{index}].tls", + ) + workers.append(RemoteWorkerSpec(id=worker_id, url=raw_url, tls=tls)) + + return workers + + +def _reject_unsupported_worker_keys(entry: dict[str, object], *, index: int) -> None: + unsupported = {"group_id", "plugins"} & set(entry) + if unsupported: + names = ", ".join(sorted(unsupported)) + raise ValueError( + f"workers[{index}] must not declare {names}; websocket host config only " + "accepts worker connection settings" + ) + + +def _load_tls_config( + payload: dict[str, object], + *, + manifest_dir: Path, + prefix: str, +) -> RemoteWorkerTLSConfig: + ca_file = _resolve_required_path( + payload.get("ca_file"), manifest_dir, f"{prefix}.ca_file" + ) + cert_file = _resolve_required_path( + payload.get("cert_file"), + manifest_dir, + f"{prefix}.cert_file", + ) + key_file = _resolve_required_path( + payload.get("key_file"), manifest_dir, f"{prefix}.key_file" + ) + server_hostname_raw = payload.get("server_hostname") + server_hostname = ( + str(server_hostname_raw).strip() if server_hostname_raw is not None else None + ) + if server_hostname == "": + server_hostname = None + return RemoteWorkerTLSConfig( + ca_file=ca_file, + cert_file=cert_file, + key_file=key_file, + server_hostname=server_hostname, + ) + + +def _resolve_required_path(value: object, base_dir: Path, field_name: str) -> Path: + text = str(value or "").strip() + if not text: + raise ValueError(f"{field_name} must be a non-empty path") + path = Path(text) + if not path.is_absolute(): + path = (base_dir / path).resolve() + return path diff --git a/astrbot-sdk/src/astrbot_sdk/schedule.py b/astrbot-sdk/src/astrbot_sdk/schedule.py new file mode 100644 index 0000000000..5daccdd78a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/schedule.py @@ -0,0 +1,93 @@ +"""Schedule-specific SDK types. + +本模块定义定时任务相关的 SDK 类型,主要为 ScheduleContext 提供数据结构。 + +ScheduleContext 包含: +- schedule_id: 调度任务唯一标识 +- job_id: core cron_jobs 表中的任务 ID +- plugin_id: 所属插件 ID +- handler_id: 对应 handler 的标识 +- name: 调度任务名称 +- description: 调度任务说明 +- job_type: core cron job 类型(basic / active_agent) +- trigger_kind: 触发类型(cron / interval / once) +- cron: cron 表达式(仅 cron 类型) +- interval_seconds: 间隔秒数(仅 interval 类型) +- timezone: IANA 时区名称(仅声明了时区时存在) +- scheduled_at: 计划执行时间(仅 once 类型) + +使用方式: +通过 @on_schedule 装饰器注册的 handler 可通过参数注入获取 ScheduleContext。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(slots=True) +class ScheduleContext: + schedule_id: str + plugin_id: str + handler_id: str + trigger_kind: str + job_id: str | None = None + name: str | None = None + description: str | None = None + job_type: str | None = None + cron: str | None = None + interval_seconds: int | None = None + timezone: str | None = None + scheduled_at: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> ScheduleContext: + schedule = payload.get("schedule") + if not isinstance(schedule, dict): + raise ValueError("schedule payload is required") + return cls( + schedule_id=str(schedule.get("schedule_id", "")), + job_id=( + str(schedule["job_id"]) + if isinstance(schedule.get("job_id"), str) + else None + ), + plugin_id=str(schedule.get("plugin_id", "")), + handler_id=str(schedule.get("handler_id", "")), + name=( + str(schedule["name"]) if isinstance(schedule.get("name"), str) else None + ), + description=( + str(schedule["description"]) + if isinstance(schedule.get("description"), str) + else None + ), + job_type=( + str(schedule["job_type"]) + if isinstance(schedule.get("job_type"), str) + else None + ), + trigger_kind=str(schedule.get("trigger_kind", "")), + cron=( + str(schedule["cron"]) if isinstance(schedule.get("cron"), str) else None + ), + interval_seconds=( + int(schedule["interval_seconds"]) + if isinstance(schedule.get("interval_seconds"), int) + else None + ), + timezone=( + str(schedule["timezone"]) + if isinstance(schedule.get("timezone"), str) + else None + ), + scheduled_at=( + str(schedule["scheduled_at"]) + if isinstance(schedule.get("scheduled_at"), str) + else None + ), + ) + + +__all__ = ["ScheduleContext"] diff --git a/astrbot-sdk/src/astrbot_sdk/session_waiter.py b/astrbot-sdk/src/astrbot_sdk/session_waiter.py new file mode 100644 index 0000000000..4b7b92972d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/session_waiter.py @@ -0,0 +1,665 @@ +"""Session-based conversational flow management. + +本模块实现会话等待器 (session_waiter),用于构建多轮对话流程。 + +核心组件: +- SessionController: 控制会话生命周期,支持超时管理、会话保持、历史记录 +- SessionWaiterManager: 管理活跃的会话等待器,处理事件分发和注册/注销 +- @session_waiter 装饰器: 将普通 handler 转换为会话式 handler + +使用场景: +当需要在用户首次触发后继续监听后续消息(如分步表单、问答游戏), +可使用 @session_waiter 装饰器自动管理会话状态和超时。 + +注意事项: +在当前桥接设计中,不应在普通 SDK handler 内直接 await session_waiter, +这会导致首次 dispatch 保持打开直到下一条消息到达。 +推荐写法是 `await ctx.register_task(waiter(...), "...")`,让 waiter 在后台任务中 +承接后续消息;直接 await 仅适用于你明确需要保持当前 dispatch 挂起的场景。 +""" + +from __future__ import annotations + +import asyncio +import time +import weakref +from collections.abc import Awaitable, Callable, Coroutine +from contextvars import ContextVar +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Concatenate, ParamSpec, Protocol, TypeVar, cast, overload + +from ._internal.invocation_context import current_caller_plugin_id +from ._internal.sdk_logger import logger +from .events import MessageEvent + +_OwnerT = TypeVar("_OwnerT") +_P = ParamSpec("_P") +_ResultT = TypeVar("_ResultT") +_WaiterKey = tuple[str, str] + +_HANDLER_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet() +_REGISTERED_BACKGROUND_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet() +_WARNED_DIRECT_WAIT_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet() +_ACTIVE_WAITER_KEY: ContextVar[_WaiterKey | None] = ContextVar( + "astrbot_sdk_active_waiter_key", + default=None, +) + + +class _TaskReentrantLock: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._owner: asyncio.Task[Any] | None = None + self._depth = 0 + + async def acquire(self) -> None: + current_task = asyncio.current_task() + if current_task is None: + raise RuntimeError("session waiter lock requires an active asyncio task") + if self._owner is current_task: + self._depth += 1 + return + await self._lock.acquire() + self._owner = current_task + self._depth = 1 + + def release(self) -> None: + current_task = asyncio.current_task() + if current_task is None or self._owner is not current_task: + raise RuntimeError("session waiter lock released by a non-owner task") + self._depth -= 1 + if self._depth > 0: + return + self._owner = None + self._lock.release() + + async def __aenter__(self) -> _TaskReentrantLock: + await self.acquire() + return self + + async def __aexit__(self, *_exc_info: object) -> None: + self.release() + + +def _mark_session_waiter_handler_task(task: asyncio.Task[Any]) -> None: + _HANDLER_TASKS.add(task) + + +def _unmark_session_waiter_handler_task(task: asyncio.Task[Any]) -> None: + _HANDLER_TASKS.discard(task) + + +def _mark_session_waiter_background_task(task: asyncio.Task[Any]) -> None: + _REGISTERED_BACKGROUND_TASKS.add(task) + + +def _unmark_session_waiter_background_task(task: asyncio.Task[Any]) -> None: + _REGISTERED_BACKGROUND_TASKS.discard(task) + + +class _SessionWaiterDecorator(Protocol): + @overload + def __call__( + self, + func: Callable[ + Concatenate[SessionController, MessageEvent, _P], + Awaitable[_ResultT], + ], + /, + ) -> Callable[Concatenate[MessageEvent, _P], Coroutine[Any, Any, _ResultT]]: ... + + @overload + def __call__( + self, + func: Callable[ + Concatenate[_OwnerT, SessionController, MessageEvent, _P], + Awaitable[_ResultT], + ], + /, + ) -> Callable[ + Concatenate[_OwnerT, MessageEvent, _P], + Coroutine[Any, Any, _ResultT], + ]: ... + + +@dataclass(slots=True) +class SessionController: + future: asyncio.Future[Any] = field(default_factory=asyncio.Future) + current_event: asyncio.Event | None = None + ts: float | None = None + timeout: float | None = None + history_chains: list[list[dict[str, Any]]] = field(default_factory=list) + + def stop(self, error: Exception | None = None) -> None: + if self.future.done(): + return + if error is not None: + self.future.set_exception(error) + else: + self.future.set_result(None) + + def keep(self, timeout: float = 0, reset_timeout: bool = False) -> None: + new_ts = time.time() + if reset_timeout: + if timeout <= 0: + self.stop() + return + else: + if self.timeout is None or self.ts is None: + raise RuntimeError( + "session waiter keep(reset_timeout=False) requires an active timeout" + ) + left_timeout = self.timeout - (new_ts - self.ts) + timeout = left_timeout + timeout + if timeout <= 0: + self.stop() + return + + if self.current_event and not self.current_event.is_set(): + self.current_event.set() + + current_event = asyncio.Event() + self.current_event = current_event + self.ts = new_ts + self.timeout = timeout + asyncio.create_task(self._holding(current_event, timeout)) + + async def _holding(self, event: asyncio.Event, timeout: float) -> None: + try: + await asyncio.wait_for(event.wait(), timeout) + except asyncio.TimeoutError as exc: + self.stop(exc) + except asyncio.CancelledError: + return + + def get_history_chains(self) -> list[list[dict[str, Any]]]: + return list(self.history_chains) + + +@dataclass(slots=True) +class _WaiterEntry: + session_key: str + plugin_id: str + handler: Callable[[SessionController, MessageEvent], Awaitable[Any]] + controller: SessionController + record_history_chains: bool + unregister_enabled: bool = True + + +class SessionWaiterManager: + def __init__(self, *, plugin_id: str, peer) -> None: + self._plugin_id = plugin_id + self._peer = peer + self._entries: dict[str, dict[str, _WaiterEntry]] = {} + self._locks: dict[_WaiterKey, _TaskReentrantLock] = {} + + @staticmethod + def _make_key(*, plugin_id: str, session_key: str) -> _WaiterKey: + return (plugin_id, session_key) + + async def register( + self, + *, + event: MessageEvent, + handler: Callable[[SessionController, MessageEvent], Awaitable[Any]], + timeout: int, + record_history_chains: bool, + ) -> Any: + if event._context is None: + raise RuntimeError("session_waiter requires runtime context") + self._warn_if_direct_wait_in_handler(event) + session_key = event.unified_msg_origin + plugin_id = self._resolve_plugin_id(event) + entry = _WaiterEntry( + session_key=session_key, + plugin_id=plugin_id, + handler=handler, + controller=SessionController(), + record_history_chains=record_history_chains, + ) + previous = self._entries.setdefault(session_key, {}).get(plugin_id) + restorable_previous: _WaiterEntry | None = None + self._entries[session_key][plugin_id] = entry + self._lock_for(session_key, plugin_id) + if previous is not None: + previous.unregister_enabled = False + if _ACTIVE_WAITER_KEY.get() == self._make_key( + plugin_id=plugin_id, + session_key=session_key, + ): + restorable_previous = previous + else: + self._finish_entry( + previous, + RuntimeError("session waiter replaced by a newer waiter"), + ) + logger.warning( + "Session waiter replaced: plugin_id={} session_key={}", + plugin_id, + session_key, + ) + try: + await self._invoke_system_waiter( + "system.session_waiter.register", + session_key=session_key, + plugin_id=plugin_id, + ) + entry.controller.keep(timeout, reset_timeout=True) + except Exception: + entry.unregister_enabled = False + await self._remove_entry(entry) + if restorable_previous is not None: + self._entries.setdefault(session_key, {})[plugin_id] = ( + restorable_previous + ) + restorable_previous.unregister_enabled = True + self._lock_for(session_key, plugin_id) + raise + try: + return await entry.controller.future + finally: + if entry.unregister_enabled: + await self.unregister(session_key, plugin_id=plugin_id) + + def _warn_if_direct_wait_in_handler(self, event: MessageEvent) -> None: + current_task = asyncio.current_task() + if current_task is None: + return + if current_task not in _HANDLER_TASKS: + return + if current_task in _REGISTERED_BACKGROUND_TASKS: + return + if current_task in _WARNED_DIRECT_WAIT_TASKS: + return + _WARNED_DIRECT_WAIT_TASKS.add(current_task) + logger.warning( + "Direct await on session_waiter blocks the current handler dispatch; " + 'prefer `await ctx.register_task(waiter(...), "...")`: ' + "plugin_id={} session_key={}", + event._context.plugin_id, + event.unified_msg_origin, + ) + + async def wait_for_event( + self, + *, + event: MessageEvent, + timeout: int, + record_history_chains: bool = False, + ) -> MessageEvent: + future: asyncio.Future[MessageEvent] = ( + asyncio.get_running_loop().create_future() + ) + + async def _handler( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + if not future.done(): + future.set_result(waiter_event) + controller.stop() + + await self.register( + event=event, + handler=_handler, + timeout=timeout, + record_history_chains=record_history_chains, + ) + return future.result() + + async def unregister( + self, + session_key: str, + *, + plugin_id: str | None = None, + ) -> None: + target_plugin_id = self._resolve_unregister_plugin_id( + session_key, + plugin_id=plugin_id, + ) + if target_plugin_id is None: + return + lock_key = (session_key, target_plugin_id) + lock = self._lock_for(session_key, target_plugin_id) + removed = False + async with lock: + session_entries = self._entries.get(session_key) + if session_entries is None: + return + removed = session_entries.pop(target_plugin_id, None) is not None + if not session_entries: + self._entries.pop(session_key, None) + if self._locks.get(lock_key) is lock: + self._locks.pop(lock_key, None) + if not removed: + return + try: + await self._invoke_system_waiter( + "system.session_waiter.unregister", + session_key=session_key, + plugin_id=target_plugin_id, + ) + except Exception: + logger.debug( + "Failed to unregister session waiter: plugin_id={} session_key={}", + target_plugin_id, + session_key, + ) + + async def fail( + self, + session_key: str, + error: Exception, + *, + plugin_id: str | None = None, + ) -> bool: + resolved_plugin_id = plugin_id + if resolved_plugin_id is None: + caller_plugin_id = current_caller_plugin_id() + if caller_plugin_id: + resolved_plugin_id = caller_plugin_id + entry = self._select_entry( + session_key, + plugin_id=resolved_plugin_id, + allow_ambiguous=False, + missing_result=None, + ) + if entry is None: + return False + lock = self._lock_for(session_key, entry.plugin_id) + async with lock: + current = self._get_entry(session_key, entry.plugin_id) + if current is None or current.controller.future.done(): + return False + self._finish_entry(current, error) + return True + + def has_active_waiter(self, event: MessageEvent) -> bool: + session_key = event.unified_msg_origin + event_plugin_id = self._event_plugin_id(event) + if event_plugin_id is not None: + entry = self._get_entry(session_key, event_plugin_id) + return entry is not None and not entry.controller.future.done() + return bool(self.get_waiter_plugin_ids(session_key)) + + def has_waiter(self, event: MessageEvent) -> bool: + return self.has_active_waiter(event) + + def get_waiter_plugin_ids(self, session_key: str) -> list[str]: + return sorted( + plugin_id + for plugin_id, entry in self._entries.get(session_key, {}).items() + if not entry.controller.future.done() + ) + + async def dispatch( + self, + event: MessageEvent, + *, + plugin_id: str | None = None, + ) -> dict[str, Any]: + if event._context is None: + raise RuntimeError("session_waiter dispatch requires runtime context") + session_key = event.unified_msg_origin + entry = self._select_entry( + session_key, + plugin_id=plugin_id, + allow_ambiguous=False, + missing_result=None, + ambiguous_error=LookupError( + f"session waiter dispatch for session '{session_key}' requires explicit plugin identity" + ), + ) + if entry is None: + return {"sent_message": False, "stop": False, "call_llm": False} + lock = self._lock_for(session_key, entry.plugin_id) + async with lock: + current = self._get_entry(session_key, entry.plugin_id) + if current is None or current.controller.future.done(): + return {"sent_message": False, "stop": False, "call_llm": False} + waiter_event = self._build_waiter_event(current, event) + if current.record_history_chains: + chain = [] + raw_chain = ( + waiter_event.raw.get("chain") + if isinstance(waiter_event.raw, dict) + else None + ) + if isinstance(raw_chain, list): + chain = [dict(item) for item in raw_chain if isinstance(item, dict)] + current.controller.history_chains.append(chain) + active_key_token = _ACTIVE_WAITER_KEY.set( + self._make_key( + plugin_id=current.plugin_id, + session_key=current.session_key, + ) + ) + try: + # Keep follow-up handler execution serialized per waiter while still + # allowing nested waiter cleanup in the same task to re-enter safely. + await current.handler(current.controller, waiter_event) + finally: + _ACTIVE_WAITER_KEY.reset(active_key_token) + return { + "sent_message": False, + "stop": waiter_event.is_stopped(), + "call_llm": False, + } + + def _resolve_plugin_id(self, event: MessageEvent) -> str: + caller_plugin_id = current_caller_plugin_id() + if caller_plugin_id: + return caller_plugin_id + context = event._context + if context is not None and context.plugin_id.strip(): + return context.plugin_id + return self._plugin_id + + @staticmethod + def _event_plugin_id(event: MessageEvent) -> str | None: + context = event._context + if context is None: + return None + plugin_id = context.plugin_id.strip() + return plugin_id or None + + def _resolve_unregister_plugin_id( + self, + session_key: str, + *, + plugin_id: str | None, + ) -> str | None: + if plugin_id is not None: + normalized = str(plugin_id).strip() + return normalized or None + session_entries = self._entries.get(session_key, {}) + if len(session_entries) != 1: + return None + return next(iter(session_entries)) + + def _select_entry( + self, + session_key: str, + *, + plugin_id: str | None, + allow_ambiguous: bool, + missing_result: _WaiterEntry | None, + ambiguous_error: Exception | None = None, + ) -> _WaiterEntry | None: + if plugin_id is not None: + return self._get_entry(session_key, plugin_id) + active_entries = [ + entry + for entry in self._entries.get(session_key, {}).values() + if not entry.controller.future.done() + ] + if not active_entries: + return missing_result + if len(active_entries) > 1 and not allow_ambiguous: + if ambiguous_error is not None: + raise ambiguous_error + return missing_result + return active_entries[0] + + def _get_entry(self, session_key: str, plugin_id: str) -> _WaiterEntry | None: + return self._entries.get(session_key, {}).get(plugin_id) + + def _lock_for(self, session_key: str, plugin_id: str) -> _TaskReentrantLock: + return self._locks.setdefault((session_key, plugin_id), _TaskReentrantLock()) + + async def _remove_entry(self, entry: _WaiterEntry) -> None: + lock_key = (entry.session_key, entry.plugin_id) + lock = self._lock_for(entry.session_key, entry.plugin_id) + async with lock: + session_entries = self._entries.get(entry.session_key) + if session_entries is None: + return + current = session_entries.get(entry.plugin_id) + if current is not entry: + return + session_entries.pop(entry.plugin_id, None) + if not session_entries: + self._entries.pop(entry.session_key, None) + if self._locks.get(lock_key) is lock: + self._locks.pop(lock_key, None) + + @staticmethod + def _finish_entry(entry: _WaiterEntry, error: Exception | None = None) -> None: + entry.controller.stop(error) + if ( + entry.controller.current_event is not None + and not entry.controller.current_event.is_set() + ): + entry.controller.current_event.set() + + async def _invoke_system_waiter( + self, + capability: str, + *, + session_key: str, + plugin_id: str, + ) -> None: + from ._internal.invocation_context import caller_plugin_scope + + with caller_plugin_scope(plugin_id): + await self._peer.invoke( + capability, + {"session_key": session_key}, + ) + + def _build_waiter_event( + self, + entry: _WaiterEntry, + event: MessageEvent, + ) -> MessageEvent: + from .context import Context + + source_payload = self._source_payload_from_event(event) + cancel_token = ( + event._context.cancel_token if event._context is not None else None + ) + waiter_context = Context( + peer=self._peer, + plugin_id=entry.plugin_id, + request_id=( + event._context.request_id if event._context is not None else None + ), + cancel_token=cancel_token, + source_event_payload=source_payload, + ) + # Rebuild the event so the waiter always sees the registering plugin identity + # and the exact source payload that triggered the follow-up dispatch. + return MessageEvent.from_payload( + source_payload, + context=waiter_context, + ) + + @staticmethod + def _source_payload_from_event(event: MessageEvent) -> dict[str, Any]: + raw_payload = event.raw if isinstance(event.raw, dict) else None + if raw_payload is not None and { + "text", + "session_id", + "platform", + }.issubset(raw_payload): + return dict(raw_payload) + return event.to_payload() + + +def session_waiter( + timeout: int = 30, + *, + record_history_chains: bool = False, +) -> _SessionWaiterDecorator: + def decorator( + func: Callable[..., Awaitable[Any]], + ) -> Callable[..., Coroutine[Any, Any, Any]]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + owner = None + event: MessageEvent | None = None + trailing_args: tuple[Any, ...] = () + if args and isinstance(args[0], MessageEvent): + event = args[0] + trailing_args = args[1:] + elif len(args) >= 2 and isinstance(args[1], MessageEvent): + owner = args[0] + event = args[1] + trailing_args = args[2:] + if event is None: + raise RuntimeError("session_waiter requires a MessageEvent argument") + if event._context is None: + raise RuntimeError("session_waiter requires runtime context") + manager = getattr(event._context.peer, "_session_waiter_manager", None) + if manager is None: + raise RuntimeError("session_waiter manager is unavailable") + + if owner is None: + free_func = cast(Callable[..., Awaitable[Any]], func) + + async def bound_handler( + controller: SessionController, + waiter_event: MessageEvent, + ) -> Any: + return await free_func( + controller, + waiter_event, + *trailing_args, + **kwargs, + ) + else: + method_func = cast(Callable[..., Awaitable[Any]], func) + + async def bound_handler( + controller: SessionController, + waiter_event: MessageEvent, + ) -> Any: + return await method_func( + owner, + controller, + waiter_event, + *trailing_args, + **kwargs, + ) + + return await manager.register( + event=event, + handler=bound_handler, + timeout=timeout, + record_history_chains=record_history_chains, + ) + + return wrapper + + return cast(_SessionWaiterDecorator, decorator) + + +__all__ = [ + "_OwnerT", + "_P", + "_ResultT", + "SessionController", + "SessionWaiterManager", + "session_waiter", +] diff --git a/astrbot-sdk/src/astrbot_sdk/star.py b/astrbot-sdk/src/astrbot_sdk/star.py new file mode 100644 index 0000000000..d05d159d42 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/star.py @@ -0,0 +1,131 @@ +"""astrbot-sdk 原生插件基类。""" + +from __future__ import annotations + +import json +import traceback +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Any, cast + +from ._internal.sdk_logger import logger +from .errors import AstrBotError +from .plugin_kv import PluginKVStoreMixin + +if TYPE_CHECKING: + from .context import Context + + +class Star(PluginKVStoreMixin): + """astrbot-sdk 原生插件基类。""" + + __handlers__: tuple[str, ...] = () + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + from .decorators import get_handler_meta + + handlers: dict[str, None] = {} + for base in reversed(cls.__mro__): + for name, attr in getattr(base, "__dict__", {}).items(): + func = getattr(attr, "__func__", attr) + meta = get_handler_meta(func) + if meta is not None and meta.trigger is not None: + handlers[name] = None + cls.__handlers__ = tuple(handlers.keys()) + + @property + def context(self) -> Context | None: + return self._context_var().get() + + def _require_runtime_context(self) -> Context: + ctx = self.context + if ctx is None: + raise RuntimeError( + "Star runtime context is only available during lifecycle, " + "handler, and registered LLM tool execution" + ) + return ctx + + def _context_var(self) -> ContextVar[Context | None]: + existing_context_var = getattr(self, "__astrbot_context_var__", None) + if isinstance(existing_context_var, ContextVar): + return cast("ContextVar[Context | None]", existing_context_var) + created_context_var: ContextVar[Context | None] = ContextVar( + f"astrbot_sdk_star_context_{id(self)}", + default=None, + ) + setattr(self, "__astrbot_context_var__", created_context_var) + return created_context_var + + def _bind_runtime_context(self, ctx: Context | None) -> Token[Context | None]: + return self._context_var().set(ctx) + + def _reset_runtime_context(self, token: Token[Context | None]) -> None: + self._context_var().reset(token) + + async def on_start(self, ctx: Any | None = None) -> None: + await self.initialize() + + async def on_stop(self, ctx: Any | None = None) -> None: + await self.terminate() + + async def initialize(self) -> None: + return None + + async def terminate(self) -> None: + return None + + async def text_to_image( + self, + text: str, + *, + return_url: bool = True, + ) -> str: + return await self._require_runtime_context().text_to_image( + text, + return_url=return_url, + ) + + async def html_render( + self, + tmpl: str, + data: dict[str, Any], + *, + return_url: bool = True, + options: dict[str, Any] | None = None, + ) -> str: + return await self._require_runtime_context().html_render( + tmpl, + data, + return_url=return_url, + options=options, + ) + + @staticmethod + async def default_on_error(error: Exception, event, ctx) -> None: + del ctx + if isinstance(error, AstrBotError): + lines: list[str] = [] + if error.retryable: + lines.append("请求失败,请稍后重试") + elif error.hint: + lines.append(error.hint) + else: + lines.append(error.message) + if error.docs_url: + lines.append(f"文档:{error.docs_url}") + if error.details: + lines.append( + f"详情:{json.dumps(error.details, ensure_ascii=False, sort_keys=True)}" + ) + await event.reply("\n".join(lines)) + else: + await event.reply("出了点问题,请联系插件作者") + logger.error("handler 执行失败\n{}", traceback.format_exc()) + + async def on_error(self, error: Exception, event, ctx) -> None: + await Star.default_on_error(error, event, ctx) + + @classmethod + def __astrbot_is_new_star__(cls) -> bool: + return True diff --git a/astrbot-sdk/src/astrbot_sdk/star_tools.py b/astrbot-sdk/src/astrbot_sdk/star_tools.py new file mode 100644 index 0000000000..fe7aa451c0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/star_tools.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any + +from ._internal.star_runtime import current_star_context +from .context import Context +from .message.components import BaseMessageComponent +from .message.result import MessageChain +from .message.session import MessageSession + +if TYPE_CHECKING: + from .clients.skills import SkillRegistration + from .llm.tools import LLMToolManager + + +class _StarToolsContextDescriptor: + def __get__(self, _instance: object, _owner: type[object]) -> Context | None: + return current_star_context() + + +class StarTools: + """Star 工具类,提供类方法访问运行时上下文能力。 + + 所有方法都通过当前上下文动态路由到对应的能力接口。 + 只在 lifecycle、handler 和已注册的 LLM tool 执行期间可用。 + """ + + _context = _StarToolsContextDescriptor() + + @classmethod + def _get_context(cls) -> Context | None: + """获取当前 Star 运行时上下文。""" + return cls._context + + @classmethod + def _require_context(cls) -> Context: + """获取当前运行时上下文,如果不存在则抛出 RuntimeError。""" + ctx = current_star_context() + if ctx is None: + raise RuntimeError( + "StarTools context is only available during lifecycle, " + "handler, and registered LLM tool execution" + ) + return ctx + + @classmethod + def get_llm_tool_manager(cls) -> LLMToolManager: + return cls._require_context().get_llm_tool_manager() + + @classmethod + async def activate_llm_tool(cls, name: str) -> bool: + return await cls._require_context().activate_llm_tool(name) + + @classmethod + async def deactivate_llm_tool(cls, name: str) -> bool: + return await cls._require_context().deactivate_llm_tool(name) + + @classmethod + async def send_message( + cls, + session: str | MessageSession, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + ) -> dict[str, Any]: + return await cls._require_context().send_message(session, content) + + @classmethod + async def send_message_by_id( + cls, + type: str, + id: str, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + *, + platform: str, + ) -> dict[str, Any]: + return await cls._require_context().send_message_by_id( + type, + id, + content, + platform=platform, + ) + + @classmethod + async def register_llm_tool( + cls, + name: str, + parameters_schema: dict[str, Any], + desc: str, + func_obj: Callable[..., Awaitable[Any]] | Callable[..., Any], + *, + active: bool = True, + ) -> list[str]: + return await cls._require_context().register_llm_tool( + name, + parameters_schema, + desc, + func_obj, + active=active, + ) + + @classmethod + async def unregister_llm_tool(cls, name: str) -> bool: + return await cls._require_context().unregister_llm_tool(name) + + @classmethod + async def register_skill( + cls, + *, + name: str, + path: str, + description: str = "", + ) -> SkillRegistration: + return await cls._require_context().skills.register( + name=name, + path=path, + description=description, + ) + + @classmethod + async def unregister_skill(cls, name: str) -> bool: + return await cls._require_context().skills.unregister(name) diff --git a/astrbot-sdk/src/astrbot_sdk/templates/project_notes/AGENTS.md b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/AGENTS.md new file mode 100644 index 0000000000..33bb5548f5 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/AGENTS.md @@ -0,0 +1,12 @@ +# AGENTS.md + +## AstrBot Plugin Notes + +- Prefer raising `AstrBotError` from `astrbot_sdk.errors` for expected failures. +- Reuse stable `ErrorCodes` and factory helpers instead of inventing ad-hoc `{"error": ...}` payloads. +- Validate the generated plugin with `astrbot-sdk validate --plugin-dir .` before packaging or sharing it. +- Run `python -m pytest tests/test_plugin.py -v` after changing plugin behavior so the sample harness contract stays honest. +- `astrbot-sdk build --plugin-dir .` should create the release zip without development-only files such as `AGENTS.md`, `CLAUDE.md`, `.claude/`, `.agents/`, or `.opencode/`. +- Exported capabilities should use `.`, and HTTP routes should use `/{plugin_id}` or `/{plugin_id}/...` so the plugin stays collision-safe inside `GroupWorkerRuntime`. + +- 除非有充分理由,插件的直接依赖应声明已验证的最低兼容版本。若已知存在不兼容的大版本或问题版本,应同时补充上界或排除约束 diff --git a/astrbot-sdk/src/astrbot_sdk/templates/project_notes/CLAUDE.md b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/CLAUDE.md new file mode 100644 index 0000000000..6df0e003b9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/CLAUDE.md @@ -0,0 +1,12 @@ +# CLAUDE.md + +## AstrBot Plugin Notes + +- Prefer raising `AstrBotError` from `astrbot_sdk.errors` for expected failures. +- Reuse stable `ErrorCodes` and factory helpers instead of inventing ad-hoc `{"error": ...}` payloads. +- Validate the generated plugin with `astrbot-sdk validate --plugin-dir .` before packaging or sharing it. +- Run `python -m pytest tests/test_plugin.py -v` after changing plugin behavior so the sample harness contract stays honest. +- `astrbot-sdk build --plugin-dir .` should create the release zip without development-only files such as `AGENTS.md`, `CLAUDE.md`, `.claude/`, `.agents/`, or `.opencode/`. +- Exported capabilities should use `.`, and HTTP routes should use `/{plugin_id}` or `/{plugin_id}/...` so the plugin stays collision-safe inside `GroupWorkerRuntime`. + +- 除非有充分理由,插件的直接依赖应声明已验证的最低兼容版本。若已知存在不兼容的大版本或问题版本,应同时补充上界或排除约束 diff --git a/astrbot-sdk/src/astrbot_sdk/testing.py b/astrbot-sdk/src/astrbot_sdk/testing.py new file mode 100644 index 0000000000..de0c9627be --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/testing.py @@ -0,0 +1,859 @@ +"""本地开发与插件测试辅助。 + +`astrbot_sdk.testing` 是面向插件作者的稳定开发入口: + +- `PluginHarness` 负责复用现有 loader / dispatcher 执行链 +- `MockCapabilityRouter` 提供进程内 mock core 能力 +- `MockPeer` 让 `Context` 客户端继续走真实的 capability 调用路径 +- `StdoutPlatformSink` / `RecordedSend` 提供可观测的发送记录 + +这个模块刻意不暴露 runtime 内部编排数据结构,只封装本地开发/测试真正 +需要的最小稳定面。 +""" + +from __future__ import annotations + +import asyncio +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from ._internal.decorator_lifecycle import run_lifecycle_with_decorators +from ._internal.testing_support import ( + InMemoryDB, + InMemoryMemory, + MockCapabilityRouter, + MockContext, + MockLLMClient, + MockMessageEvent, + MockPeer, + MockPlatformClient, + RecordedSend, + StdoutPlatformSink, +) +from ._message_types import normalize_message_type +from .context import CancelToken +from .context import Context as RuntimeContext +from .errors import AstrBotError +from .events import MessageEvent +from .protocol.descriptors import ( + CommandTrigger, + CompositeFilterSpec, + EventTrigger, + LocalFilterRefSpec, + MessageTrigger, + MessageTypeFilterSpec, + PlatformFilterSpec, + ScheduleTrigger, +) +from .protocol.messages import InvokeMessage +from .runtime._command_matching import ( + build_command_args, + build_regex_args, + command_root_name, + match_command_name, +) +from .runtime._streaming import StreamExecution +from .runtime.handler_dispatcher import CapabilityDispatcher, HandlerDispatcher +from .runtime.loader import ( + LoadedHandler, + LoadedPlugin, + PluginSpec, + load_plugin, + load_plugin_config, + load_plugin_spec, + validate_plugin_spec, +) +from .star import Star + + +class _PluginLoadError(RuntimeError): + """本地 harness 初始化阶段的已知插件加载失败。""" + + +class _PluginExecutionError(RuntimeError): + """本地 harness 执行插件代码时的已知插件异常。""" + + +def _plugin_metadata_from_spec( + plugin: PluginSpec, + *, + enabled: bool, +) -> dict[str, Any]: + manifest = plugin.manifest_data + support_platforms = manifest.get("support_platforms") + return { + "name": plugin.name, + "display_name": str(manifest.get("display_name") or plugin.name), + "description": str(manifest.get("desc") or manifest.get("description") or ""), + "repo": str(manifest.get("repo") or ""), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": enabled, + "reserved": bool(manifest.get("reserved", False)), + "support_platforms": [ + str(item) for item in support_platforms if isinstance(item, str) + ] + if isinstance(support_platforms, list) + else [], + "astrbot_version": ( + str(manifest.get("astrbot_version")) + if manifest.get("astrbot_version") is not None + else None + ), + } + + +def _handler_metadata_from_loaded( + plugin_id: str, loaded: LoadedHandler +) -> dict[str, Any]: + event_types: list[str] = [] + trigger = loaded.descriptor.trigger + if isinstance(trigger, EventTrigger): + event_types.append(trigger.type) + return { + "plugin_name": plugin_id, + "handler_full_name": loaded.descriptor.id, + "trigger_type": trigger.type + if isinstance(trigger, EventTrigger) + else str(getattr(trigger, "kind", trigger.type)), + "event_types": event_types, + "enabled": True, + "group_path": list( + loaded.descriptor.command_route.group_path + if loaded.descriptor.command_route is not None + else [] + ), + "require_admin": loaded.descriptor.permissions.require_admin, + "required_role": loaded.descriptor.permissions.required_role, + } + + +@dataclass(slots=True) +class LocalRuntimeConfig: + """本地 harness 的稳定配置对象。""" + + plugin_dir: Path + session_id: str = "local-session" + user_id: str = "local-user" + platform: str = "test" + group_id: str | None = None + event_type: str = "message" + + +@dataclass(slots=True) +class MockClock: + now: float = 0.0 + + def time(self) -> float: + return self.now + + def advance(self, seconds: float) -> float: + self.now += float(seconds) + return self.now + + +@dataclass(slots=True) +class SDKTestEnvironment: + root: Path + + @property + def plugins_dir(self) -> Path: + path = self.root / "plugins" + path.mkdir(parents=True, exist_ok=True) + return path + + def plugin_dir(self, name: str) -> Path: + path = self.plugins_dir / name + path.mkdir(parents=True, exist_ok=True) + return path + + +class PluginHarness: + """本地插件消息泵。 + + 这里复用真实的 loader / dispatcher 执行链,只负责: + - 在同一个事件循环里装配单插件运行时 + - 维持本地 mock core 与发送记录 + - 把后续消息持续送入同一个 dispatcher + """ + + def __init__( + self, + config: LocalRuntimeConfig, + *, + platform_sink: StdoutPlatformSink | None = None, + ) -> None: + self.config = config + self.platform_sink = platform_sink or StdoutPlatformSink() + self.router = MockCapabilityRouter(platform_sink=self.platform_sink) + self.peer = MockPeer(self.router) + self.plugin: PluginSpec | None = None + self.loaded_plugin: LoadedPlugin | None = None + self.dispatcher: HandlerDispatcher | None = None + self.capability_dispatcher: CapabilityDispatcher | None = None + self.lifecycle_context: RuntimeContext | None = None + self._request_counter = 0 + self._started = False + + @classmethod + def from_plugin_dir( + cls, + plugin_dir: str | Path, + *, + session_id: str = "local-session", + user_id: str = "local-user", + platform: str = "test", + group_id: str | None = None, + event_type: str = "message", + platform_sink: StdoutPlatformSink | None = None, + ) -> PluginHarness: + return cls( + LocalRuntimeConfig( + plugin_dir=Path(plugin_dir), + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + ), + platform_sink=platform_sink, + ) + + async def __aenter__(self) -> PluginHarness: + await self.start() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.stop() + + @property + def sent_messages(self) -> list[RecordedSend]: + return list(self.platform_sink.records) + + def clear_sent_messages(self) -> None: + self.platform_sink.clear() + + async def start(self) -> None: + if self._started: + return + try: + self.plugin = load_plugin_spec(self.config.plugin_dir) + validate_plugin_spec(self.plugin) + self.loaded_plugin = load_plugin(self.plugin) + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginLoadError(str(exc)) from exc + self.dispatcher = HandlerDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + handlers=self.loaded_plugin.handlers, + ) + self.capability_dispatcher = CapabilityDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + capabilities=self.loaded_plugin.capabilities, + llm_tools=self.loaded_plugin.llm_tools, + ) + self.lifecycle_context = RuntimeContext( + peer=self.peer, + plugin_id=self.plugin.name, + ) + plugin_metadata = _plugin_metadata_from_spec(self.plugin, enabled=True) + plugin_metadata["acknowledge_global_mcp_risk"] = any( + bool( + getattr( + instance.__class__, + "__astrbot_acknowledge_global_mcp_risk__", + False, + ) + ) + for instance in self.loaded_plugin.instances + ) + self.router.upsert_plugin( + metadata=plugin_metadata, + config=load_plugin_config(self.plugin), + ) + self.router.set_plugin_handlers( + self.plugin.name, + [ + _handler_metadata_from_loaded(self.plugin.name, handler) + for handler in self.loaded_plugin.handlers + ], + ) + self.router.set_plugin_llm_tools( + self.plugin.name, + [tool.spec.to_payload() for tool in self.loaded_plugin.llm_tools], + ) + self.router.set_plugin_agents( + self.plugin.name, + [agent.spec.to_payload() for agent in self.loaded_plugin.agents], + ) + try: + await self._run_lifecycle("on_start") + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + self._started = True + + async def stop(self) -> None: + if ( + not self._started + or self.loaded_plugin is None + or self.lifecycle_context is None + ): + return + try: + await self._run_lifecycle("on_stop") + finally: + if self.plugin is not None: + self.router.set_plugin_enabled(self.plugin.name, False) + self.router.set_plugin_handlers(self.plugin.name, []) + self.router.remove_dynamic_command_routes_for_plugin(self.plugin.name) + self.router.remove_http_apis_for_plugin(self.plugin.name) + self._started = False + + async def dispatch_text( + self, + text: str, + *, + session_id: str | None = None, + user_id: str | None = None, + platform: str | None = None, + group_id: str | None = None, + event_type: str | None = None, + request_id: str | None = None, + ) -> list[RecordedSend]: + payload = self.build_event_payload( + text=text, + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + request_id=request_id, + ) + return await self.dispatch_event(payload, request_id=request_id) + + async def dispatch_event( + self, + event_payload: dict[str, Any], + *, + request_id: str | None = None, + ) -> list[RecordedSend]: + await self.start() + assert self.loaded_plugin is not None + assert self.dispatcher is not None + + start_index = len(self.platform_sink.records) + if self._has_waiter_for_event(event_payload): + await self._invoke_session_waiter( + event_payload, + request_id=request_id, + ) + await self._wait_for_followup_side_effects( + start_index=start_index, + event_payload=event_payload, + ) + return self.platform_sink.records[start_index:] + + matches = self._match_handlers(event_payload) + help_text = self._build_group_root_help(event_payload) + if help_text is not None and not any( + isinstance(loaded.descriptor.trigger, CommandTrigger) + for loaded, _args in matches + ): + assert self.lifecycle_context is not None + await self.lifecycle_context.platform.send( + str(event_payload.get("session_id", "")), + help_text, + ) + return self.platform_sink.records[start_index:] + if not matches: + raise AstrBotError.invalid_input("未找到匹配的 handler") + for loaded, args in matches: + result = await self._invoke_handler( + loaded, + event_payload, + args=args, + request_id=request_id, + ) + # Mirror the runtime dispatcher contract: once a handler explicitly + # stops the event, later matches in the same dispatch should not run. + if bool(result.get("stop", False)): + break + return self.platform_sink.records[start_index:] + + async def invoke_capability( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + stream: bool = False, + ) -> dict[str, Any] | StreamExecution: + await self.start() + assert self.capability_dispatcher is not None + message = InvokeMessage( + id=request_id or self._next_request_id("cap"), + capability=capability, + input=dict(payload), + stream=stream, + ) + try: + return await self.capability_dispatcher.invoke(message, CancelToken()) + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + + def build_event_payload( + self, + *, + text: str, + session_id: str | None = None, + user_id: str | None = None, + platform: str | None = None, + group_id: str | None = None, + event_type: str | None = None, + request_id: str | None = None, + ) -> dict[str, Any]: + session_value = session_id or self.config.session_id + group_value = group_id if group_id is not None else self.config.group_id + event_type_value = event_type or self.config.event_type + payload = { + "type": event_type_value, + "event_type": event_type_value, + "text": text, + "session_id": session_value, + "user_id": user_id or self.config.user_id, + "platform": platform or self.config.platform, + "platform_id": platform or self.config.platform, + "group_id": group_value, + "self_id": f"{platform or self.config.platform}-bot", + "sender_name": str(user_id or self.config.user_id or ""), + "is_admin": False, + "raw": { + "trace_id": request_id or self._next_request_id("trace"), + "event_type": event_type_value, + }, + } + if group_value: + payload["message_type"] = "group" + elif payload["user_id"]: + payload["message_type"] = "private" + else: + payload["message_type"] = "other" + return payload + + async def _invoke_handler( + self, + loaded: LoadedHandler, + event_payload: dict[str, Any], + *, + args: dict[str, Any], + request_id: str | None = None, + ) -> dict[str, Any]: + assert self.dispatcher is not None + message = InvokeMessage( + id=request_id or self._next_request_id("msg"), + capability="handler.invoke", + input={ + "handler_id": loaded.descriptor.id, + "event": dict(event_payload), + "args": dict(args), + }, + ) + try: + result = await self.dispatcher.invoke(message, CancelToken()) + return dict(result) + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + + async def _invoke_session_waiter( + self, + event_payload: dict[str, Any], + *, + request_id: str | None = None, + ) -> dict[str, Any]: + assert self.dispatcher is not None + message = InvokeMessage( + id=request_id or self._next_request_id("msg"), + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": dict(event_payload), + "args": {}, + }, + ) + try: + result = await self.dispatcher.invoke(message, CancelToken()) + return dict(result) + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + + async def _wait_for_followup_side_effects( + self, + *, + start_index: int, + event_payload: dict[str, Any], + ) -> None: + settled_rounds = 0 + for _ in range(20): + if len(self.platform_sink.records) > start_index: + return + await asyncio.sleep(0) + if self._has_waiter_for_event(event_payload): + settled_rounds = 0 + continue + settled_rounds += 1 + if settled_rounds >= 3: + return + + async def _run_lifecycle(self, method_name: str) -> None: + assert self.loaded_plugin is not None + assert self.lifecycle_context is not None + + for instance in self.loaded_plugin.instances: + hook = self._resolve_lifecycle_hook(instance, method_name) + await run_lifecycle_with_decorators( + instance=instance, + hook=hook, + method_name=method_name, + context=self.lifecycle_context, + ) + + def _match_handlers( + self, + event_payload: dict[str, Any], + ) -> list[tuple[LoadedHandler, dict[str, Any]]]: + assert self.loaded_plugin is not None + ranked: list[tuple[int, int, LoadedHandler, dict[str, Any]]] = [] + for index, loaded in enumerate(self.loaded_plugin.handlers): + args = self._match_handler(loaded, event_payload) + if args is None: + continue + ranked.append((loaded.descriptor.priority, index, loaded, args)) + for dynamic in self._match_dynamic_handlers(event_payload): + ranked.append(dynamic) + ranked.sort(key=lambda item: (-item[0], item[1])) + return [(loaded, args) for _priority, _index, loaded, args in ranked] + + def _match_dynamic_handlers( + self, + event_payload: dict[str, Any], + ) -> list[tuple[int, int, LoadedHandler, dict[str, Any]]]: + assert self.loaded_plugin is not None + assert self.plugin is not None + ranked: list[tuple[int, int, LoadedHandler, dict[str, Any]]] = [] + routes = self.router.list_dynamic_command_routes(self.plugin.name) + handler_map = { + loaded.descriptor.id: loaded for loaded in self.loaded_plugin.handlers + } + base_order = len(self.loaded_plugin.handlers) + for index, route in enumerate(routes): + if not isinstance(route, dict): + continue + handler_full_name = str(route.get("handler_full_name", "")).strip() + loaded = handler_map.get(handler_full_name) + if loaded is None: + continue + args = self._match_dynamic_route(loaded, route, event_payload) + if args is None: + continue + priority = route.get("priority", loaded.descriptor.priority) + if not isinstance(priority, int) or isinstance(priority, bool): + priority = loaded.descriptor.priority + ranked.append((priority, base_order + index, loaded, args)) + return ranked + + def _match_dynamic_route( + self, + loaded: LoadedHandler, + route: dict[str, Any], + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_filters(loaded, event_payload): + return None + command_name = str(route.get("command_name", "")).strip() + if not command_name: + return None + text = str(event_payload.get("text", "")) + if bool(route.get("use_regex", False)): + match = re.search(command_name, text) + if match is None: + return None + return build_regex_args(loaded.descriptor.param_specs, match) + remainder = match_command_name(text, command_name) + if remainder is None: + return None + return build_command_args(loaded.descriptor.param_specs, remainder) + + def _match_handler( + self, + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_permissions(loaded, event_payload): + return None + trigger = loaded.descriptor.trigger + if isinstance(trigger, CommandTrigger): + return self._match_command_trigger(loaded, trigger, event_payload) + if isinstance(trigger, MessageTrigger): + return self._match_message_trigger(loaded, trigger, event_payload) + if isinstance(trigger, EventTrigger): + current_type = str( + event_payload.get("event_type") + or event_payload.get("type") + or "message" + ) + if current_type != trigger.event_type: + return None + return {} + if isinstance(trigger, ScheduleTrigger): + if ( + str(event_payload.get("event_type") or event_payload.get("type")) + == "schedule" + ): + schedule_payload = event_payload.get("schedule") + if isinstance(schedule_payload, dict): + target_handler_id = str( + schedule_payload.get("handler_id", "") + ).strip() + if target_handler_id and target_handler_id != loaded.descriptor.id: + return None + return {} + return None + return None + + def _match_command_trigger( + self, + loaded: LoadedHandler, + trigger: CommandTrigger, + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_filters(loaded, event_payload): + return None + text = str(event_payload.get("text", "")).strip() + for command_name in [trigger.command, *trigger.aliases]: + if not command_name: + continue + match = match_command_name(text, command_name) + if match is None: + continue + return build_command_args(loaded.descriptor.param_specs, match) + return None + + def _build_group_root_help(self, event_payload: dict[str, Any]) -> str | None: + assert self.loaded_plugin is not None + root_name = command_root_name(str(event_payload.get("text", ""))) + if not root_name: + return None + entries: list[tuple[str, str | None]] = [] + seen_commands: set[str] = set() + for loaded in self.loaded_plugin.handlers: + descriptor = loaded.descriptor + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + continue + if not self._passes_filters(loaded, event_payload): + continue + route = descriptor.command_route + root_candidates: list[str] = [] + if route is not None and route.group_path: + group_root = str(route.group_path[0]).strip() + if group_root: + root_candidates.append(group_root) + for name in [trigger.command, *trigger.aliases]: + normalized = str(name).strip() + if " " not in normalized: + continue + command_root = normalized.split()[0].strip() + if command_root: + root_candidates.append(command_root) + if root_name not in dict.fromkeys(root_candidates): + continue + display_command = ( + str(route.display_command).strip() + if route is not None and str(route.display_command).strip() + else str(trigger.command).strip() + ) + if not display_command or display_command in seen_commands: + continue + seen_commands.add(display_command) + description = ( + str(descriptor.description or "").strip() + or str(trigger.description or "").strip() + or None + ) + entries.append((display_command, description)) + if not entries: + return None + lines = [f"{root_name}命令:"] + for command_name, description in entries: + line = f"- /{command_name}" + if description: + line += f": {description}" + lines.append(line) + return "\n".join(lines) + + def _match_message_trigger( + self, + loaded: LoadedHandler, + trigger: MessageTrigger, + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_filters(loaded, event_payload): + return None + text = str(event_payload.get("text", "")) + if trigger.regex: + match = re.search(trigger.regex, text) + if match is None: + return None + return build_regex_args(loaded.descriptor.param_specs, match) + if trigger.keywords and not any( + keyword in text for keyword in trigger.keywords + ): + return None + return {} + + @staticmethod + def _passes_permissions( + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> bool: + permissions = loaded.descriptor.permissions + required_role = permissions.required_role + if required_role is None and permissions.require_admin: + required_role = "admin" + if required_role == "admin": + return bool(event_payload.get("is_admin", False)) + return True + + def _passes_filters( + self, + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> bool: + for filter_spec in loaded.descriptor.filters: + if isinstance(filter_spec, PlatformFilterSpec): + if str(event_payload.get("platform", "")) not in filter_spec.platforms: + return False + elif isinstance(filter_spec, MessageTypeFilterSpec): + if ( + self._message_type_name(event_payload) + not in filter_spec.message_types + ): + return False + elif isinstance(filter_spec, CompositeFilterSpec): + if not self._passes_composite_filter(filter_spec, event_payload): + return False + elif isinstance(filter_spec, LocalFilterRefSpec): + continue + return True + + def _passes_composite_filter( + self, + filter_spec: CompositeFilterSpec, + event_payload: dict[str, Any], + ) -> bool: + results: list[bool] = [] + for child in filter_spec.children: + if isinstance(child, PlatformFilterSpec): + results.append( + str(event_payload.get("platform", "")) in child.platforms + ) + elif isinstance(child, MessageTypeFilterSpec): + results.append( + self._message_type_name(event_payload) in child.message_types + ) + elif isinstance(child, LocalFilterRefSpec): + results.append(True) + elif isinstance(child, CompositeFilterSpec): + results.append(self._passes_composite_filter(child, event_payload)) + if filter_spec.kind == "and": + return all(results) + return any(results) + + def _has_waiter_for_event(self, event_payload: dict[str, Any]) -> bool: + assert self.dispatcher is not None + probe_event = MessageEvent.from_payload( + event_payload, + context=self.lifecycle_context, + ) + public_probe = getattr(self.dispatcher, "has_active_waiter", None) + if callable(public_probe): + return bool(public_probe(probe_event)) + session_waiters = getattr(self.dispatcher, "_session_waiters", None) + if session_waiters is None: + return False + if hasattr(session_waiters, "has_waiter"): + return session_waiters.has_waiter(probe_event) + if isinstance(session_waiters, dict): + return any( + manager.has_waiter(probe_event) + for manager in session_waiters.values() + if hasattr(manager, "has_waiter") + ) + return False + + @staticmethod + def _message_type_name(event_payload: dict[str, Any]) -> str: + return normalize_message_type( + event_payload.get("message_type", ""), + group_id=str(event_payload.get("group_id", "")).strip() or None, + user_id=str(event_payload.get("user_id", "")).strip() or None, + empty_default="other", + ) + + @staticmethod + def _resolve_lifecycle_hook(instance: Any, method_name: str): + hook = getattr(instance, method_name, None) + marker = getattr(instance.__class__, "__astrbot_is_new_star__", None) + is_new_star = True + if callable(marker): + is_new_star = bool(marker()) + + if hook is not None and callable(hook): + bound_func = getattr(hook, "__func__", hook) + star_default = getattr(Star, method_name, None) + if star_default is None or bound_func is not star_default: + return hook + + if not is_new_star: + alias = {"on_start": "initialize", "on_stop": "terminate"}.get(method_name) + if alias is not None: + legacy_hook = getattr(instance, alias, None) + if legacy_hook is not None and callable(legacy_hook): + return legacy_hook + + if hook is not None and callable(hook): + return hook + return None + + def _next_request_id(self, prefix: str) -> str: + self._request_counter += 1 + return f"{prefix}_{self._request_counter:04d}" + + +__all__ = [ + "InMemoryDB", + "InMemoryMemory", + "LocalRuntimeConfig", + "MockClock", + "MockCapabilityRouter", + "MockContext", + "MockLLMClient", + "MockMessageEvent", + "MockPeer", + "MockPlatformClient", + "SDKTestEnvironment", + "PluginHarness", + "RecordedSend", + "StdoutPlatformSink", +] diff --git a/astrbot-sdk/src/astrbot_sdk/types.py b/astrbot-sdk/src/astrbot_sdk/types.py new file mode 100644 index 0000000000..c2bc911ec7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/types.py @@ -0,0 +1,22 @@ +"""SDK parameter helper types. + +本模块提供 SDK 参数类型助手,用于增强命令参数解析能力。 + +GreedyStr: +用于标记"贪婪字符串"参数,在命令解析时将剩余所有文本作为一个整体参数。 +例如:/echo hello world this is a test +如果最后一个参数类型为 GreedyStr,将获取 "hello world this is a test" 而非仅 "hello" + +使用方式: +在 handler 签名中将最后一个参数标注为 GreedyStr 类型, +_loader_support 会识别此类型并调整参数解析逻辑。 +""" + +from __future__ import annotations + + +class GreedyStr(str): + """Consume the remaining command text as one argument.""" + + +__all__ = ["GreedyStr"] diff --git a/astrbot/__init__.py b/astrbot/__init__.py index 73d64f303f..f7604c5b15 100644 --- a/astrbot/__init__.py +++ b/astrbot/__init__.py @@ -1,3 +1,16 @@ -from .core.log import LogManager +from __future__ import annotations -logger = LogManager.GetLogger(log_name="astrbot") +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .core import logger as logger + +__all__ = ["logger"] + + +def __getattr__(name: str) -> Any: + if name == "logger": + from .core import logger + + return logger + raise AttributeError(name) diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 51690ede27..a11435a84b 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,47 +1,185 @@ +from __future__ import annotations + import os +from importlib import import_module +from typing import TYPE_CHECKING, Any -from astrbot.core.config import AstrBotConfig -from astrbot.core.config.default import DB_PATH -from astrbot.core.db.sqlite import SQLiteDatabase -from astrbot.core.file_token_service import FileTokenService -from astrbot.core.utils.pip_installer import ( - DependencyConflictError as DependencyConflictError, -) -from astrbot.core.utils.pip_installer import ( - PipInstaller, -) -from astrbot.core.utils.requirements_utils import ( - RequirementsPrecheckFailed as RequirementsPrecheckFailed, -) -from astrbot.core.utils.requirements_utils import ( - find_missing_requirements as find_missing_requirements, -) -from astrbot.core.utils.requirements_utils import ( - find_missing_requirements_or_raise as find_missing_requirements_or_raise, -) -from astrbot.core.utils.shared_preferences import SharedPreferences -from astrbot.core.utils.t2i.renderer import HtmlRenderer - -from .log import LogBroker, LogManager # noqa from .utils.astrbot_path import get_astrbot_data_path -# 初始化数据存储文件夹 +if TYPE_CHECKING: + from .config import AstrBotConfig + from .db.sqlite import SQLiteDatabase + from .file_token_service import FileTokenService + from .log import LogBroker, LogManager + from .utils.pip_installer import DependencyConflictError, PipInstaller + from .utils.requirements_utils import ( + RequirementsPrecheckFailed, + find_missing_requirements, + find_missing_requirements_or_raise, + ) +else: + AstrBotConfig: Any + SQLiteDatabase: Any + FileTokenService: Any + LogBroker: Any + LogManager: Any + DependencyConflictError: Any + PipInstaller: Any + RequirementsPrecheckFailed: Any + find_missing_requirements: Any + find_missing_requirements_or_raise: Any + astrbot_config: Any + db_helper: Any + file_token_service: Any + html_renderer: Any + logger: Any + pip_installer: Any + sp: Any + os.makedirs(get_astrbot_data_path(), exist_ok=True) DEMO_MODE = os.getenv("DEMO_MODE", "False").strip().lower() in ("true", "1", "t") -astrbot_config = AstrBotConfig() -t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") -html_renderer = HtmlRenderer(t2i_base_url) -logger = LogManager.GetLogger(log_name="astrbot") -LogManager.configure_logger(logger, astrbot_config) -LogManager.configure_trace_logger(astrbot_config) -db_helper = SQLiteDatabase(DB_PATH) -# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 -sp = SharedPreferences(db_helper=db_helper) -# 文件令牌服务 -file_token_service = FileTokenService() -pip_installer = PipInstaller( - astrbot_config.get("pip_install_arg", ""), - astrbot_config.get("pypi_index_url", None), -) +__all__ = [ + "AstrBotConfig", + "DEMO_MODE", + "DependencyConflictError", + "FileTokenService", + "LogBroker", + "LogManager", + "PipInstaller", + "RequirementsPrecheckFailed", + "SQLiteDatabase", + "astrbot_config", + "db_helper", + "file_token_service", + "find_missing_requirements", + "find_missing_requirements_or_raise", + "html_renderer", + "logger", + "pip_installer", + "sp", +] + +_SINGLETON_CACHE: dict[str, Any] = {} + + +def _get_astrbot_config(): + config_module = import_module(".config", __name__) + cached = _SINGLETON_CACHE.get("astrbot_config") + if cached is None: + cached = config_module.AstrBotConfig() + _SINGLETON_CACHE["astrbot_config"] = cached + return cached + + +def _get_log_manager(): + return import_module(".log", __name__).LogManager + + +def _get_logger(): + cached = _SINGLETON_CACHE.get("logger") + if cached is None: + logger_obj = _get_log_manager().GetLogger(log_name="astrbot") + config = _get_astrbot_config() + log_manager = _get_log_manager() + log_manager.configure_logger(logger_obj, config) + log_manager.configure_trace_logger(config) + _SINGLETON_CACHE["logger"] = logger_obj + cached = logger_obj + return cached + + +def _get_db_helper(): + cached = _SINGLETON_CACHE.get("db_helper") + if cached is None: + sqlite_module = import_module(".db.sqlite", __name__) + default_module = import_module(".config.default", __name__) + cached = sqlite_module.SQLiteDatabase(default_module.DB_PATH) + _SINGLETON_CACHE["db_helper"] = cached + return cached + + +def _get_shared_preferences(): + cached = _SINGLETON_CACHE.get("sp") + if cached is None: + shared_preferences_module = import_module(".utils.shared_preferences", __name__) + cached = shared_preferences_module.SharedPreferences(db_helper=_get_db_helper()) + _SINGLETON_CACHE["sp"] = cached + return cached + + +def _get_file_token_service(): + cached = _SINGLETON_CACHE.get("file_token_service") + if cached is None: + service_module = import_module(".file_token_service", __name__) + cached = service_module.FileTokenService() + _SINGLETON_CACHE["file_token_service"] = cached + return cached + + +def _get_html_renderer(): + cached = _SINGLETON_CACHE.get("html_renderer") + if cached is None: + renderer_module = import_module(".utils.t2i.renderer", __name__) + config = _get_astrbot_config() + endpoint = config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") + cached = renderer_module.HtmlRenderer(endpoint) + _SINGLETON_CACHE["html_renderer"] = cached + return cached + + +def _get_pip_installer(): + cached = _SINGLETON_CACHE.get("pip_installer") + if cached is None: + installer_module = import_module(".utils.pip_installer", __name__) + config = _get_astrbot_config() + cached = installer_module.PipInstaller( + config.get("pip_install_arg", ""), + config.get("pypi_index_url", None), + ) + _SINGLETON_CACHE["pip_installer"] = cached + return cached + + +def __getattr__(name: str) -> Any: + if name == "AstrBotConfig": + return import_module(".config", __name__).AstrBotConfig + if name in {"LogBroker", "LogManager"}: + module = import_module(".log", __name__) + return getattr(module, name) + if name == "DependencyConflictError": + return import_module(".utils.pip_installer", __name__).DependencyConflictError + if name == "FileTokenService": + return import_module(".file_token_service", __name__).FileTokenService + if name == "PipInstaller": + return import_module(".utils.pip_installer", __name__).PipInstaller + if name == "RequirementsPrecheckFailed": + return import_module( + ".utils.requirements_utils", __name__ + ).RequirementsPrecheckFailed + if name == "SQLiteDatabase": + return import_module(".db.sqlite", __name__).SQLiteDatabase + if name == "find_missing_requirements": + return import_module( + ".utils.requirements_utils", __name__ + ).find_missing_requirements + if name == "find_missing_requirements_or_raise": + return import_module( + ".utils.requirements_utils", __name__ + ).find_missing_requirements_or_raise + if name == "astrbot_config": + return _get_astrbot_config() + if name == "logger": + return _get_logger() + if name == "db_helper": + return _get_db_helper() + if name == "sp": + return _get_shared_preferences() + if name == "file_token_service": + return _get_file_token_service() + if name == "html_renderer": + return _get_html_renderer() + if name == "pip_installer": + return _get_pip_installer() + raise AttributeError(name) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index af969a3fac..aceb2261ba 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -137,6 +137,7 @@ def __init__(self) -> None: self.tools: list[mcp.Tool] = [] self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() + self.process_pid: int | None = None # Store connection config for reconnection self._mcp_server_config: dict | None = None @@ -144,6 +145,24 @@ def __init__(self) -> None: self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging + @staticmethod + def _extract_stdio_process_pid(streams_context: object) -> int | None: + """Best-effort extraction for stdio subprocess PID used by lease cleanup. + + TODO(refactor): replace this async-generator frame introspection with a + stable MCP library hook once the upstream transport exposes process PID. + """ + generator = getattr(streams_context, "gen", None) + frame = getattr(generator, "ag_frame", None) + if frame is None: + return None + process = frame.f_locals.get("process") + pid = getattr(process, "pid", None) + try: + return int(pid) if pid is not None else None + except (TypeError, ValueError): + return None + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server @@ -159,6 +178,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: # Store config for reconnection self._mcp_server_config = mcp_server_config self._server_name = name + self.process_pid = None cfg = _prepare_config(mcp_server_config.copy()) @@ -261,6 +281,7 @@ def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None: ), # type: ignore ), ) + self.process_pid = self._extract_stdio_process_pid(self._streams_context) # Create a new client session self.session = await self.exit_stack.enter_async_context( @@ -390,6 +411,7 @@ async def cleanup(self) -> None: # Set running_event first to unblock any waiting tasks self.running_event.set() + self.process_pid = None class MCPTool(FunctionTool, Generic[TContext]): diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index 09bf32deb4..89a6edd73e 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -11,7 +11,42 @@ from astrbot.core.star.star_handler import EventType +def _sdk_safe_payload(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, list): + return [_sdk_safe_payload(item) for item in value] + if isinstance(value, dict): + return {str(key): _sdk_safe_payload(item) for key, item in value.items()} + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + try: + dumped = model_dump() + except Exception: + return str(value) + return _sdk_safe_payload(dumped) + return str(value) + + class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): + async def on_agent_begin( + self, + run_context: ContextWrapper[AstrAgentContext], + ) -> None: + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "agent_begin", + run_context.context.event, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK agent_begin dispatch failed: %s", exc) + async def on_agent_done(self, run_context, llm_response) -> None: # 执行事件钩子 if llm_response and llm_response.reasoning_content: @@ -25,6 +60,45 @@ async def on_agent_done(self, run_context, llm_response) -> None: EventType.OnLLMResponseEvent, llm_response, ) + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_response", + run_context.context.event, + { + "completion_text": ( + llm_response.completion_text if llm_response else "" + ), + }, + llm_response=llm_response, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK llm_response dispatch failed: %s", exc) + try: + await sdk_plugin_bridge.dispatch_message_event( + "agent_done", + run_context.context.event, + { + "completion_text": ( + llm_response.completion_text if llm_response else "" + ), + "tool_call_names": ( + list(llm_response.tools_call_name) + if llm_response and llm_response.tools_call_name + else [] + ), + }, + llm_response=llm_response, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK agent_done dispatch failed: %s", exc) async def on_tool_start( self, @@ -38,6 +112,23 @@ async def on_tool_start( tool, tool_args, ) + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_tool_start", + run_context.context.event, + { + "tool_name": tool.name, + "tool_args": _sdk_safe_payload(tool_args), + }, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK llm_tool_start dispatch failed: %s", exc) async def on_tool_end( self, @@ -54,6 +145,24 @@ async def on_tool_end( tool_args, tool_result, ) + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_tool_end", + run_context.context.event, + { + "tool_name": tool.name, + "tool_args": _sdk_safe_payload(tool_args), + "tool_result": _sdk_safe_payload(tool_result), + }, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK llm_tool_end dispatch failed: %s", exc) # special handle web_search_tavily platform_name = run_context.context.event.get_platform_name() diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index eca24699ae..c4ec095a4c 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -87,6 +87,38 @@ def _build_tool_result_status_message( return status_msg +async def _apply_sdk_streaming_delta_filters( + sdk_plugin_bridge, + astr_event, + chain: MessageChain, +) -> MessageChain: + if sdk_plugin_bridge is None: + return chain + try: + stream_result = MessageEventResult(chain=list(chain.chain)) + stream_result.type = chain.type + stream_result.use_t2i_ = chain.use_t2i_ + await sdk_plugin_bridge.dispatch_message_event( + "streaming_delta", + astr_event, + { + "message_outline": chain.get_plain_text(with_other_comps_mark=True), + "result_content_type": "streaming_delta", + }, + event_result=stream_result, + ) + return MessageChain( + chain=list(stream_result.chain or []), + use_t2i_=stream_result.use_t2i_, + type=stream_result.type or chain.type, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK streaming_delta dispatch failed: %s", exc) + return chain + + async def run_agent( agent_runner: AgentRunner, max_step: int = 30, @@ -97,6 +129,9 @@ async def run_agent( ) -> AsyncGenerator[MessageChain | None, None]: step_idx = 0 astr_event = agent_runner.run_context.context.event + sdk_plugin_bridge = getattr( + agent_runner.run_context.context.context, "sdk_plugin_bridge", None + ) tool_name_by_call_id: dict[str, str] = {} while step_idx < max_step + 1: step_idx += 1 @@ -215,7 +250,13 @@ async def run_agent( if chain.type == "reasoning" and not show_reasoning: # display the reasoning content only when configured continue - yield resp.data["chain"] # MessageChain + chain = await _apply_sdk_streaming_delta_filters( + sdk_plugin_bridge, + astr_event, + chain, + ) + if chain is not None: + yield chain if not stop_watcher.done(): stop_watcher.cancel() try: diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 1fb4b03368..a3154a38af 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -586,6 +586,24 @@ async def _execute_local( if awaitable is None: raise ValueError("Tool must have a valid handler or override 'run' method.") + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "calling_func_tool", + event, + { + "tool_name": tool.name, + "tool_args": json.loads( + json.dumps(tool_args, ensure_ascii=False, default=str) + ), + }, + ) + except Exception as exc: + logger.warning("SDK calling_func_tool dispatch failed: %s", exc) + wrapper = call_local_llm_tool( context=run_context, handler=awaitable, diff --git a/astrbot/core/backup/constants.py b/astrbot/core/backup/constants.py index be206b3074..10962abe46 100644 --- a/astrbot/core/backup/constants.py +++ b/astrbot/core/backup/constants.py @@ -26,6 +26,7 @@ get_astrbot_config_path, get_astrbot_plugin_data_path, get_astrbot_plugin_path, + get_astrbot_sdk_plugins_path, get_astrbot_t2i_templates_path, get_astrbot_temp_path, get_astrbot_webchat_path, @@ -67,6 +68,7 @@ def get_backup_directories() -> dict[str, str]: """ return { "plugins": get_astrbot_plugin_path(), # 插件本体 + "sdk_plugins": get_astrbot_sdk_plugins_path(), # SDK 插件本体 "plugin_data": get_astrbot_plugin_data_path(), # 插件数据 "config": get_astrbot_config_path(), # 配置目录 "t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板 diff --git a/astrbot/core/command_compatibility.py b/astrbot/core/command_compatibility.py new file mode 100644 index 0000000000..46edcc6248 --- /dev/null +++ b/astrbot/core/command_compatibility.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import re +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Literal + +from astrbot_sdk.protocol.descriptors import CommandTrigger, HandlerDescriptor + +from astrbot.core.star.filter.command import CommandFilter +from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import ( + EventType, + StarHandlerMetadata, + star_handlers_registry, +) + + +@dataclass(slots=True) +class CommandRegistration: + runtime_kind: Literal["legacy", "sdk"] + plugin_name: str + plugin_display_name: str | None + handler_full_name: str + command_name: str + + +@dataclass(slots=True) +class CrossSystemCommandConflict: + command_name: str + legacy: CommandRegistration + sdk: CommandRegistration + + def to_dashboard_payload(self) -> dict[str, Any]: + return { + "conflict_key": self.command_name, + "handlers": [ + { + "handler_full_name": self.legacy.handler_full_name, + "plugin": self.legacy.plugin_name, + "plugin_display_name": self.legacy.plugin_display_name, + "current_name": self.legacy.command_name, + "runtime_kind": self.legacy.runtime_kind, + }, + { + "handler_full_name": self.sdk.handler_full_name, + "plugin": self.sdk.plugin_name, + "plugin_display_name": self.sdk.plugin_display_name, + "current_name": self.sdk.command_name, + "runtime_kind": self.sdk.runtime_kind, + }, + ], + } + + +def normalize_command_name(value: str) -> str: + return re.sub(r"\s+", " ", str(value).strip()) + + +def command_matches_text(command_name: str, text: str) -> bool: + normalized_command = normalize_command_name(command_name) + normalized_text = normalize_command_name(text) + if not normalized_command or not normalized_text: + return False + return normalized_text == normalized_command or normalized_text.startswith( + f"{normalized_command} " + ) + + +def commands_overlap(left: str, right: str) -> bool: + normalized_left = normalize_command_name(left) + normalized_right = normalize_command_name(right) + if not normalized_left or not normalized_right: + return False + return ( + normalized_left == normalized_right + or normalized_left.startswith(f"{normalized_right} ") + or normalized_right.startswith(f"{normalized_left} ") + ) + + +def _command_prefixes(command_name: str) -> tuple[str, ...]: + normalized = normalize_command_name(command_name) + if not normalized: + return () + prefixes: list[str] = [] + parts: list[str] = [] + for token in normalized.split(" "): + parts.append(token) + prefixes.append(" ".join(parts)) + return tuple(prefixes) + + +def collect_legacy_command_registrations( + handlers: Iterable[StarHandlerMetadata] | None = None, +) -> list[CommandRegistration]: + source_handlers = ( + handlers + if handlers is not None + else star_handlers_registry.get_handlers_by_event_type( + EventType.AdapterMessageEvent, + only_activated=True, + ) + ) + registrations: list[CommandRegistration] = [] + for handler in source_handlers: + filter_ref = _locate_legacy_command_filter(handler) + if filter_ref is None: + continue + plugin_meta = star_map.get(handler.handler_module_path) + plugin_name = ( + plugin_meta.name if plugin_meta is not None else handler.handler_module_path + ) + plugin_display_name = ( + plugin_meta.display_name if plugin_meta is not None else None + ) + seen_names: set[str] = set() + for command_name in filter_ref.get_complete_command_names(): + normalized = normalize_command_name(command_name) + if not normalized or normalized in seen_names: + continue + seen_names.add(normalized) + registrations.append( + CommandRegistration( + runtime_kind="legacy", + plugin_name=plugin_name, + plugin_display_name=plugin_display_name, + handler_full_name=handler.handler_full_name, + command_name=normalized, + ) + ) + return registrations + + +def match_legacy_command_registrations( + handlers: Iterable[StarHandlerMetadata], + text: str, +) -> list[CommandRegistration]: + return [ + registration + for registration in collect_legacy_command_registrations(handlers) + if command_matches_text(registration.command_name, text) + ] + + +def collect_sdk_command_registrations( + *, + plugin_name: str, + plugin_display_name: str | None, + handler_full_name: str, + descriptor: HandlerDescriptor, +) -> list[CommandRegistration]: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return [] + registrations: list[CommandRegistration] = [] + seen_names: set[str] = set() + for command_name in [trigger.command, *trigger.aliases]: + normalized = normalize_command_name(command_name) + if not normalized or normalized in seen_names: + continue + seen_names.add(normalized) + registrations.append( + CommandRegistration( + runtime_kind="sdk", + plugin_name=plugin_name, + plugin_display_name=plugin_display_name, + handler_full_name=handler_full_name, + command_name=normalized, + ) + ) + return registrations + + +def match_sdk_command_registrations( + registrations: Iterable[CommandRegistration], + text: str, +) -> list[CommandRegistration]: + return [ + registration + for registration in registrations + if command_matches_text(registration.command_name, text) + ] + + +def build_cross_system_conflicts( + legacy_registrations: Iterable[CommandRegistration], + sdk_registrations: Iterable[CommandRegistration], +) -> list[CrossSystemCommandConflict]: + conflicts: list[CrossSystemCommandConflict] = [] + seen_pairs: set[tuple[str, str, str]] = set() + legacy_by_exact: dict[str, list[CommandRegistration]] = {} + legacy_by_prefix: dict[str, list[CommandRegistration]] = {} + for legacy_registration in legacy_registrations: + normalized_command = normalize_command_name(legacy_registration.command_name) + if not normalized_command: + continue + legacy_by_exact.setdefault(normalized_command, []).append(legacy_registration) + for prefix in _command_prefixes(normalized_command): + legacy_by_prefix.setdefault(prefix, []).append(legacy_registration) + + for sdk_registration in sdk_registrations: + normalized_sdk_command = normalize_command_name(sdk_registration.command_name) + if not normalized_sdk_command: + continue + candidate_legacy: list[CommandRegistration] = [] + seen_legacy_commands: set[tuple[str, str]] = set() + for prefix in _command_prefixes(normalized_sdk_command): + for legacy_registration in legacy_by_exact.get(prefix, []): + legacy_key = ( + legacy_registration.handler_full_name, + legacy_registration.command_name, + ) + if legacy_key in seen_legacy_commands: + continue + seen_legacy_commands.add(legacy_key) + candidate_legacy.append(legacy_registration) + for legacy_registration in legacy_by_prefix.get(normalized_sdk_command, []): + legacy_key = ( + legacy_registration.handler_full_name, + legacy_registration.command_name, + ) + if legacy_key in seen_legacy_commands: + continue + seen_legacy_commands.add(legacy_key) + candidate_legacy.append(legacy_registration) + + for legacy_registration in candidate_legacy: + pair_key = ( + _build_conflict_key( + legacy_registration.command_name, + sdk_registration.command_name, + ), + legacy_registration.handler_full_name, + sdk_registration.handler_full_name, + ) + if pair_key in seen_pairs: + continue + seen_pairs.add(pair_key) + conflicts.append( + CrossSystemCommandConflict( + command_name=_build_conflict_key( + legacy_registration.command_name, + sdk_registration.command_name, + ), + legacy=legacy_registration, + sdk=sdk_registration, + ) + ) + return conflicts + + +def _locate_legacy_command_filter( + handler: StarHandlerMetadata, +) -> CommandFilter | CommandGroupFilter | None: + for filter_ref in handler.event_filters: + if isinstance(filter_ref, CommandFilter | CommandGroupFilter): + return filter_ref + return None + + +def _build_conflict_key(legacy_command: str, sdk_command: str) -> str: + if legacy_command == sdk_command: + return legacy_command + return f"{legacy_command} <> {sdk_command}" diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 715f938679..579d80a97c 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -20,17 +20,6 @@ _MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json" -def _list_local_skill_dirs(skills_root: Path) -> list[Path]: - skills: list[Path] = [] - for entry in sorted(skills_root.iterdir()): - if not entry.is_dir(): - continue - skill_md = entry / "SKILL.md" - if skill_md.exists(): - skills.append(entry) - return skills - - def _discover_bay_credentials(endpoint: str) -> str: """Try to auto-discover Bay API key from credentials.json. @@ -383,20 +372,25 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: splitting into `apply` and `scan` phases. """ skills_root = Path(get_astrbot_skills_path()) - if not skills_root.is_dir(): - return - local_skill_dirs = _list_local_skill_dirs(skills_root) + skill_manager: SkillManager | None = None + local_skill_sources = [] + if skills_root.exists(): + skill_manager = SkillManager(skills_root=str(skills_root)) + local_skill_sources = skill_manager.list_local_skill_sources() temp_dir = Path(get_astrbot_temp_path()) temp_dir.mkdir(parents=True, exist_ok=True) zip_base = temp_dir / "skills_bundle" zip_path = zip_base.with_suffix(".zip") + bundle_dir = temp_dir / f"skills_bundle_{uuid.uuid4().hex}" try: - if local_skill_dirs: + if local_skill_sources: + assert skill_manager is not None if zip_path.exists(): zip_path.unlink() - shutil.make_archive(str(zip_base), "zip", str(skills_root)) + skill_manager.materialize_local_skill_bundle(bundle_dir) + shutil.make_archive(str(zip_base), "zip", root_dir=str(bundle_dir)) remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip" logger.info("Uploading skills bundle to sandbox...") await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}") @@ -420,6 +414,8 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: len(managed), ) finally: + if bundle_dir.exists(): + shutil.rmtree(bundle_dir, ignore_errors=True) if zip_path.exists(): try: zip_path.unlink() diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 77c298cac8..6a38311c67 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -2,6 +2,7 @@ import json import logging import os +from pathlib import Path from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -46,6 +47,7 @@ def __init__( if not self.check_exist(): """不存在时载入默认配置""" + Path(config_path).parent.mkdir(parents=True, exist_ok=True) with open(config_path, "w", encoding="utf-8-sig") as f: json.dump(default_config, f, indent=4, ensure_ascii=False) object.__setattr__(self, "first_deploy", True) # 标记第一次部署 @@ -158,6 +160,8 @@ def save_config(self, replace_config: dict | None = None) -> None: """ if replace_config: self.update(replace_config) + # Alternate config files may be created under data/config on first write. + Path(self.config_path).parent.mkdir(parents=True, exist_ok=True) with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 2c282867f9..76cb0c303b 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -263,6 +263,8 @@ async def update_conversation( title: str | None = None, persona_id: str | None = None, token_usage: int | None = None, + *, + clear_persona: bool = False, ) -> None: """更新会话的对话. @@ -273,6 +275,8 @@ async def update_conversation( token_usage (int | None): token 使用量。None 表示不更新 """ + # TODO(compat): Keep clear_persona keyword-only until external plugins + # have fully migrated away from positional update_conversation calls. if not conversation_id: # 如果没有提供 conversation_id,则获取当前的 conversation_id = await self.get_curr_conversation_id(unified_msg_origin) @@ -281,6 +285,7 @@ async def update_conversation( cid=conversation_id, title=title, persona_id=persona_id, + clear_persona=clear_persona, content=history, token_usage=token_usage, ) @@ -329,6 +334,19 @@ async def update_conversation_persona_id( persona_id=persona_id, ) + async def unset_conversation_persona( + self, + unified_msg_origin: str, + conversation_id: str | None = None, + ) -> None: + """Clear the conversation-specific persona override and fall back to default.""" + + await self.update_conversation( + unified_msg_origin=unified_msg_origin, + conversation_id=conversation_id, + clear_persona=True, + ) + async def add_message_pair( self, cid: str, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe6b1c351d..fc6a95e29e 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -16,8 +16,7 @@ import traceback from asyncio import Queue -from astrbot.api import logger, sp -from astrbot.core import LogBroker, LogManager +from astrbot.core import LogBroker, LogManager, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager @@ -29,6 +28,7 @@ from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.provider.manager import ProviderManager +from astrbot.core.sdk_bridge import SdkPluginBridge from astrbot.core.star.context import Context from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.star.star_manager import PluginManager @@ -200,6 +200,11 @@ async def initialize(self) -> None: # 扫描、注册插件、实例化插件类 await self.plugin_manager.reload() + self.sdk_plugin_bridge = SdkPluginBridge(self.star_context) + self.star_context.sdk_plugin_bridge = self.sdk_plugin_bridge + self.platform_manager.sdk_plugin_bridge = self.sdk_plugin_bridge + await self.sdk_plugin_bridge.start() + # 根据配置实例化各个 Provider await self.provider_manager.initialize() @@ -309,6 +314,12 @@ async def start(self) -> None: except BaseException: logger.error(traceback.format_exc()) + if getattr(self, "sdk_plugin_bridge", None) is not None: + try: + await self.sdk_plugin_bridge.dispatch_system_event("astrbot_loaded") + except Exception as exc: + logger.warning(f"SDK astrbot_loaded event dispatch failed: {exc}") + # 同时运行curr_tasks中的所有任务 await asyncio.gather(*self.curr_tasks, return_exceptions=True) @@ -324,6 +335,9 @@ async def stop(self) -> None: if self.cron_manager: await self.cron_manager.shutdown() + if getattr(self, "sdk_plugin_bridge", None) is not None: + await self.sdk_plugin_bridge.stop() + for plugin in self.plugin_manager.context.get_all_stars(): try: await self.plugin_manager._terminate_plugin(plugin) @@ -349,6 +363,8 @@ async def stop(self) -> None: async def restart(self) -> None: """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" + if getattr(self, "sdk_plugin_bridge", None) is not None: + await self.sdk_plugin_bridge.stop() await self.provider_manager.terminate() await self.platform_manager.terminate() await self.kb_manager.terminate() diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index ff7facd247..24c8ab3872 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -8,6 +8,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger +from apscheduler.triggers.interval import IntervalTrigger from astrbot import logger from astrbot.core.agent.tool import ToolSet @@ -65,7 +66,8 @@ async def add_basic_job( self, *, name: str, - cron_expression: str, + cron_expression: str | None = None, + interval_seconds: int | None = None, handler: Callable[..., Any | Awaitable[Any]], description: str | None = None, timezone: str | None = None, @@ -73,12 +75,19 @@ async def add_basic_job( enabled: bool = True, persistent: bool = False, ) -> CronJob: + if (cron_expression is None) == (interval_seconds is None): + raise ValueError( + "cron_expression and interval_seconds must have exactly one value" + ) + payload_data = dict(payload or {}) + if interval_seconds is not None: + payload_data["interval_seconds"] = interval_seconds job = await self.db.create_cron_job( name=name, job_type="basic", cron_expression=cron_expression, timezone=timezone, - payload=payload or {}, + payload=payload_data, description=description, enabled=enabled, persistent=persistent, @@ -167,7 +176,21 @@ def _schedule_job(self, job: CronJob) -> None: run_at = run_at.replace(tzinfo=tzinfo) trigger = DateTrigger(run_date=run_at, timezone=tzinfo) else: - trigger = CronTrigger.from_crontab(job.cron_expression, timezone=tzinfo) + interval_seconds = None + if isinstance(job.payload, dict): + payload_interval = job.payload.get("interval_seconds") + if isinstance(payload_interval, int): + interval_seconds = payload_interval + if interval_seconds is not None: + trigger = IntervalTrigger( + seconds=interval_seconds, + timezone=tzinfo, + ) + else: + trigger = CronTrigger.from_crontab( + job.cron_expression, + timezone=tzinfo, + ) self.scheduler.add_job( self._run_job, id=job.job_id, diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index a18c127ebf..380ec31d5a 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -166,6 +166,8 @@ async def update_conversation( persona_id: str | None = None, content: list[dict] | None = None, token_usage: int | None = None, + *, + clear_persona: bool = False, ) -> None: """Update a conversation's history.""" ... @@ -213,6 +215,172 @@ async def get_platform_message_history( """Get platform message history for a specific user.""" ... + async def _collect_legacy_platform_message_history( + self, + platform_id: str, + user_id: str, + *, + page_size: int = 200, + ) -> list[PlatformMessageHistory]: + """Best-effort compatibility fallback for legacy database backends.""" + # TODO(compat): Remove this pagination shim after third-party database + # backends implement the SDK-native platform message history methods. + rows: list[PlatformMessageHistory] = [] + page = 1 + while True: + batch = list( + await self.get_platform_message_history( + platform_id=platform_id, + user_id=user_id, + page=page, + page_size=page_size, + ) + ) + if not batch: + break + rows.extend(batch) + if len(batch) < page_size: + break + page += 1 + return rows + + async def list_sdk_platform_message_history( + self, + platform_id: str, + user_id: str, + cursor_id: int | None = None, + limit: int = 50, + include_total: bool = False, + ) -> tuple[list[PlatformMessageHistory], int | None]: + """List SDK message history records ordered by descending id. + + Legacy third-party backends may still implement only the older paged + history API. Fall back to that API so they keep working without having + to implement the new SDK-specific helpers immediately. + """ + + rows = await self._collect_legacy_platform_message_history( + platform_id=platform_id, + user_id=user_id, + page_size=max(int(limit), 50), + ) + rows.sort(key=lambda item: int(item.id or 0), reverse=True) + if cursor_id is not None: + rows = [item for item in rows if int(item.id or 0) < int(cursor_id)] + total = len(rows) if include_total else None + return rows[: max(int(limit), 1)], total + + async def delete_platform_message_before( + self, + platform_id: str, + user_id: str, + before: datetime.datetime, + ) -> int: + """Delete platform message history records strictly older than ``before``.""" + + # TODO(compat): Add a real legacy fallback only if we introduce a safe + # record-level delete path for custom database backends. + raise NotImplementedError( + "This database backend does not implement delete_platform_message_before(). " + "Upgrade the backend to support SDK message history pruning.", + ) + + async def delete_platform_message_after( + self, + platform_id: str, + user_id: str, + after: datetime.datetime, + ) -> int: + """Delete platform message history records strictly newer than ``after``.""" + + rows = await self._collect_legacy_platform_message_history( + platform_id=platform_id, + user_id=user_id, + ) + deleted_count = sum( + 1 + for item in rows + if item.created_at is not None and item.created_at > after + ) + if deleted_count == 0: + return 0 + + now = ( + datetime.datetime.now(after.tzinfo) + if after.tzinfo is not None + else datetime.datetime.now() + ) + delta_seconds = max(0.0, (now - after).total_seconds()) + offset_sec = int(delta_seconds) + if delta_seconds > offset_sec: + offset_sec += 1 + await self.delete_platform_message_offset( + platform_id=platform_id, + user_id=user_id, + offset_sec=offset_sec, + ) + return deleted_count + + async def delete_all_platform_message_history( + self, + platform_id: str, + user_id: str, + ) -> int: + """Delete all platform message history records for a specific user.""" + + rows = await self._collect_legacy_platform_message_history( + platform_id=platform_id, + user_id=user_id, + ) + if not rows: + return 0 + + oldest_created_at = min( + (item.created_at for item in rows if item.created_at is not None), + default=None, + ) + if oldest_created_at is None: + offset_sec = 60 * 60 * 24 * 365 * 100 + else: + now = ( + datetime.datetime.now(oldest_created_at.tzinfo) + if oldest_created_at.tzinfo is not None + else datetime.datetime.now() + ) + delta_seconds = max(0.0, (now - oldest_created_at).total_seconds()) + offset_sec = int(delta_seconds) + if delta_seconds > offset_sec: + offset_sec += 1 + + await self.delete_platform_message_offset( + platform_id=platform_id, + user_id=user_id, + offset_sec=max(offset_sec, 1), + ) + return len(rows) + + async def find_platform_message_history_by_idempotency_key( + self, + platform_id: str, + user_id: str, + idempotency_key: str, + ) -> PlatformMessageHistory | None: + """Find one message history record by the SDK idempotency key.""" + + rows = await self._collect_legacy_platform_message_history( + platform_id=platform_id, + user_id=user_id, + ) + matched = [] + for item in rows: + content = item.content if isinstance(item.content, dict) else {} + if str(content.get("idempotency_key", "")) == str(idempotency_key): + matched.append(item) + if not matched: + return None + matched.sort(key=lambda item: int(item.id or 0), reverse=True) + return matched[0] + @abc.abstractmethod async def get_platform_message_history_by_id( self, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index c8e50909d5..c55a05b1db 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -294,7 +294,14 @@ async def create_conversation( return new_conversation async def update_conversation( - self, cid, title=None, persona_id=None, content=None, token_usage=None + self, + cid, + title=None, + persona_id=None, + content=None, + token_usage=None, + *, + clear_persona: bool = False, ): async with self.get_db() as session: session: AsyncSession @@ -305,7 +312,9 @@ async def update_conversation( values = {} if title is not None: values["title"] = title - if persona_id is not None: + if clear_persona: + values["persona_id"] = None + elif persona_id is not None: values["persona_id"] = persona_id if content is not None: values["content"] = content @@ -510,6 +519,121 @@ async def get_platform_message_history( result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() + async def list_sdk_platform_message_history( + self, + platform_id, + user_id, + cursor_id=None, + limit=50, + include_total=False, + ): + """List SDK message history records ordered by descending id.""" + async with self.get_db() as session: + session: AsyncSession + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + .order_by(desc(PlatformMessageHistory.id)) + ) + if cursor_id is not None: + query = query.where(PlatformMessageHistory.id < cursor_id) + result = await session.execute(query.limit(limit)) + total: int | None = None + if include_total: + total_query = ( + select(func.count()) + .select_from(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + ) + total_result = await session.execute(total_query) + total = int(total_result.scalar() or 0) + return list(result.scalars().all()), total + + async def delete_platform_message_before( + self, + platform_id, + user_id, + before, + ) -> int: + """Delete platform message history records strictly older than the boundary.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + col(PlatformMessageHistory.created_at) < before, + ), + ) + return int(result.rowcount or 0) + + async def delete_platform_message_after( + self, + platform_id, + user_id, + after, + ) -> int: + """Delete platform message history records strictly newer than the boundary.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + col(PlatformMessageHistory.created_at) > after, + ), + ) + return int(result.rowcount or 0) + + async def delete_all_platform_message_history( + self, + platform_id, + user_id, + ) -> int: + """Delete all platform message history records for a specific user.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + ), + ) + return int(result.rowcount or 0) + + async def find_platform_message_history_by_idempotency_key( + self, + platform_id, + user_id, + idempotency_key, + ) -> PlatformMessageHistory | None: + """Find a SDK message history record by its idempotency key.""" + async with self.get_db() as session: + session: AsyncSession + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + func.json_extract( + PlatformMessageHistory.content, "$.idempotency_key" + ) + == str(idempotency_key), + ) + .order_by(desc(PlatformMessageHistory.id)) + ) + result = await session.execute(query.limit(1)) + return result.scalar_one_or_none() + async def get_platform_message_history_by_id( self, message_id: int ) -> PlatformMessageHistory | None: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index f26409e56e..43a7987980 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import traceback from pathlib import Path +from typing import TYPE_CHECKING from astrbot.core import logger from astrbot.core.provider.manager import ProviderManager @@ -10,9 +13,9 @@ from .kb_db_sqlite import KBSQLiteDatabase from .kb_helper import KBHelper from .models import KBDocument, KnowledgeBase -from .retrieval.manager import RetrievalManager, RetrievalResult -from .retrieval.rank_fusion import RankFusion -from .retrieval.sparse_retriever import SparseRetriever + +if TYPE_CHECKING: + from .retrieval.manager import RetrievalManager, RetrievalResult FILES_PATH = get_astrbot_knowledge_base_path() DB_PATH = Path(FILES_PATH) / "kb.db" @@ -37,6 +40,10 @@ def __init__( async def initialize(self) -> None: """初始化知识库模块""" try: + from .retrieval.manager import RetrievalManager + from .retrieval.rank_fusion import RankFusion + from .retrieval.sparse_retriever import SparseRetriever + logger.info("正在初始化知识库模块...") # 初始化数据库 diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 3dd0719b11..81bd091674 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -56,7 +56,12 @@ def _is_plugin_path(pathname: str | None) -> bool: if not pathname: return False norm_path = os.path.normpath(pathname) - return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path) + markers = ( + os.path.normpath("data/plugins"), + os.path.normpath("data/sdk_plugins"), + os.path.normpath("astrbot/builtin_stars"), + ) + return any(marker in norm_path for marker in markers) def _get_short_level_name(level_name: str) -> str: diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 0965fe7f7f..29a54047e2 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -30,6 +30,36 @@ class MessageChain: type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" + def __iter__(self): + return iter(self.chain) + + def __len__(self) -> int: + return len(self.chain) + + def __getitem__(self, index): + return self.chain[index] + + def __setitem__(self, index, value) -> None: + self.chain[index] = value + + def __bool__(self) -> bool: + return bool(self.chain) + + def append(self, component: BaseMessageComponent) -> None: + self.chain.append(component) + + def extend(self, components) -> None: + self.chain.extend(components) + + def insert(self, index: int, component: BaseMessageComponent) -> None: + self.chain.insert(index, component) + + def pop(self, index: int = -1): + return self.chain.pop(index) + + def clear(self) -> None: + self.chain.clear() + def message(self, message: str): """添加一条文本消息到消息链 `chain` 中。 diff --git a/astrbot/core/message/message_types.py b/astrbot/core/message/message_types.py new file mode 100644 index 0000000000..e8c7b32cfb --- /dev/null +++ b/astrbot/core/message/message_types.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +_GROUP_MESSAGE_TYPES = {"group", "groupmessage", "group_message"} +_PRIVATE_MESSAGE_TYPES = { + "private", + "privatemessage", + "private_message", + "friend", + "friendmessage", + "friend_message", +} +_OTHER_MESSAGE_TYPES = {"other", "othermessage", "other_message"} + + +def sdk_message_type( + value: Any, + *, + group_id: str | None = None, + user_id: str | None = None, + empty_default: str = "", +) -> str: + """Collapse core-visible message types to SDK canonical values.""" + + normalized = str(getattr(value, "value", value) or "").strip().lower() + if normalized in _GROUP_MESSAGE_TYPES: + return "group" + if normalized in _PRIVATE_MESSAGE_TYPES: + return "private" + if normalized in _OTHER_MESSAGE_TYPES: + return "other" + if group_id: + return "group" + if user_id: + return "private" + if not normalized: + return empty_default + return "other" diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index c7441d09f4..5d9a2bdfca 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -185,6 +185,20 @@ async def process( except Exception: logger.warning("send_typing failed", exc_info=True) await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "waiting_llm_request", + event, + ) + except Exception as exc: + logger.warning( + "SDK waiting_llm_request dispatch failed: %s", + exc, + ) async with session_lock_manager.acquire_lock(event.unified_msg_origin): logger.debug("acquired session lock for llm request") @@ -230,6 +244,19 @@ async def process( if reset_coro: reset_coro.close() return + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_request", + event, + { + "prompt": req.prompt, + "provider_id": provider.meta().id, + }, + provider_request=req, + ) + except Exception as exc: + logger.warning("SDK llm_request dispatch failed: %s", exc) # apply reset if reset_coro: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 070ad7bdee..a44c71612e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -4,19 +4,12 @@ from typing import TYPE_CHECKING from astrbot.core import astrbot_config, logger -from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner -from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( - DashscopeAgentRunner, -) from astrbot.core.agent.runners.deerflow.constants import ( DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, DEERFLOW_PROVIDER_TYPE, ) -from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import ( - DeerFlowAgentRunner, -) -from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS +from astrbot.core.astr_agent_run_util import _apply_sdk_streaming_delta_filters from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( MessageChain, @@ -217,16 +210,25 @@ async def _handle_streaming_response( async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]: mark_stream_consumed() + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) try: async for chain, is_error in run_third_party_agent( runner, stream_to_general=False, custom_error_message=custom_error_message, ): + chain = await _apply_sdk_streaming_delta_filters( + sdk_plugin_bridge, + event, + chain, + ) aggregator.add_chunk(chain, is_error) if is_error: event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True) - yield chain + if chain is not None: + yield chain finally: # Streaming runner cleanup must happen after consumer # finishes iterating to avoid tearing down active streams. @@ -327,14 +329,46 @@ async def process( # call event hook if await call_event_hook(event, EventType.OnLLMRequestEvent, req): return + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_request", + event, + { + "prompt": req.prompt, + "provider_id": self.prov_id, + }, + provider_request=req, + ) + except Exception as exc: + logger.warning("SDK llm_request dispatch failed: %s", exc) if self.runner_type == "dify": + from astrbot.core.agent.runners.dify.dify_agent_runner import ( + DifyAgentRunner, + ) + runner = DifyAgentRunner[AstrAgentContext]() elif self.runner_type == "coze": + from astrbot.core.agent.runners.coze.coze_agent_runner import ( + CozeAgentRunner, + ) + runner = CozeAgentRunner[AstrAgentContext]() elif self.runner_type == "dashscope": + from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( + DashscopeAgentRunner, + ) + runner = DashscopeAgentRunner[AstrAgentContext]() elif self.runner_type == DEERFLOW_PROVIDER_TYPE: + from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import ( + DeerFlowAgentRunner, + ) + runner = DeerFlowAgentRunner[AstrAgentContext]() else: raise ValueError( diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 9422d6317a..a353832b0b 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -60,6 +60,23 @@ async def process( e, traceback_text, ) + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "plugin_error", + event, + { + "plugin_name": md.name, + "handler_name": handler.handler_name, + "error": str(e), + "traceback": traceback_text, + }, + ) + except Exception as exc: + logger.warning("SDK plugin_error dispatch failed: %s", exc) if not event.is_stopped() and event.is_at_or_wake_command: ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}" diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 076f7f12ac..684d291db6 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -1,5 +1,6 @@ from collections.abc import AsyncGenerator +from astrbot.core import logger from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider.entities import ProviderRequest from astrbot.core.star.star_handler import StarHandlerMetadata @@ -16,6 +17,9 @@ async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager + self.sdk_plugin_bridge = getattr( + ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) # initialize agent sub stage self.agent_sub_stage = AgentRequestSubStage() @@ -33,6 +37,67 @@ async def process( activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers", ) + if ( + activated_handlers + and self.sdk_plugin_bridge is not None + and not event.is_stopped() + and ( + not hasattr(self.sdk_plugin_bridge, "has_active_sdk_command_handlers") + or self.sdk_plugin_bridge.has_active_sdk_command_handlers() + ) + and hasattr(self.sdk_plugin_bridge, "detect_legacy_command_conflict") + ): + # 新旧插件命令冲突时,SDK 插件优先:循环移除所有冲突的旧插件 handler + removed_handler_names: set[str] = set() + max_iterations = len(activated_handlers) + iteration_count = 0 + while activated_handlers: + iteration_count += 1 + if iteration_count > max_iterations: + logger.warning( + "Legacy command conflict filtering exceeded the handler count guard, aborting the conflict loop: remaining_handlers=%s", + len(activated_handlers), + ) + break + conflict = self.sdk_plugin_bridge.detect_legacy_command_conflict( + event, + activated_handlers, + ) + if conflict is None: + break + logger.warning( + "新旧插件命令冲突,SDK 插件优先: command=%s legacy_handler=%s sdk_handler=%s", + conflict.command_name, + conflict.legacy.handler_full_name, + conflict.sdk.handler_full_name, + ) + target_handler_name = conflict.legacy.handler_full_name + filtered_handlers: list[StarHandlerMetadata] = [] + removed_current_conflict = False + for handler in activated_handlers: + handler_full_name = getattr(handler, "handler_full_name", None) + if handler_full_name == target_handler_name: + removed_current_conflict = True + removed_handler_names.add(target_handler_name) + continue + filtered_handlers.append(handler) + if not removed_current_conflict: + logger.warning( + "Legacy command conflict matched an unknown handler, keeping legacy handler list unchanged: legacy_handler=%s sdk_handler=%s", + conflict.legacy.handler_full_name, + conflict.sdk.handler_full_name, + ) + break + activated_handlers = filtered_handlers + if removed_handler_names: + # 同步更新 event extras,确保下游 sub stage 看到过滤后的列表 + event.set_extra("activated_handlers", activated_handlers) + # 清理已移除 handler 的解析参数 + handlers_parsed_params = event.get_extra("handlers_parsed_params") + if isinstance(handlers_parsed_params, dict): + for name in removed_handler_names: + handlers_parsed_params.pop(name, None) + # 有插件 Handler 被激活 if activated_handlers: async for resp in self.star_request_sub_stage.process(event): @@ -49,18 +114,40 @@ async def process( else: yield + if self.sdk_plugin_bridge is not None and not event.is_stopped(): + sdk_result = await self.sdk_plugin_bridge.dispatch_message(event) + if sdk_result.sent_message or sdk_result.stopped: + yield + # 调用 LLM 相关请求 if not self.ctx.astrbot_config["provider_settings"].get("enable", True): return - if ( - not event._has_send_oper - and event.is_at_or_wake_command - and not event.call_llm - ): + # LLM 调用意愿的三级回退:SDK bridge > 新版 event API > 旧版 event 字段 + should_call_llm = ( + self.sdk_plugin_bridge.get_effective_should_call_llm(event) + if self.sdk_plugin_bridge is not None + and hasattr(self.sdk_plugin_bridge, "get_effective_should_call_llm") + else ( + event.should_call_default_llm() + if hasattr(event, "should_call_default_llm") + else not event.call_llm + ) + ) + effective_result = ( + self.sdk_plugin_bridge.get_effective_result(event) + if self.sdk_plugin_bridge is not None + and hasattr(self.sdk_plugin_bridge, "get_effective_result") + else event.get_result() + ) + # 发送操作状态的两级回退:新版 has_send_operation() > 旧版 _has_send_oper + has_send_operation = ( + event.has_send_operation() + if hasattr(event, "has_send_operation") + else event._has_send_oper + ) + if not has_send_operation and event.is_at_or_wake_command and should_call_llm: # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 - if ( - event.get_result() and not event.is_stopped() - ) or not event.get_result(): + if (effective_result and not event.is_stopped()) or not effective_result: async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 6a884a5181..b8805824f7 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -7,6 +7,7 @@ from astrbot.core import logger from astrbot.core.message.components import BaseMessageComponent, ComponentType from astrbot.core.message.message_event_result import MessageChain, ResultContentType +from astrbot.core.message.message_types import sdk_message_type from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import EventType from astrbot.core.utils.path_util import path_Mapping @@ -53,6 +54,9 @@ class RespondStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config + self.sdk_plugin_bridge = getattr( + ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) self.platform_settings: dict = self.config.get("platform_settings", {}) self.reply_with_mention = ctx.astrbot_config["platform_settings"][ @@ -86,7 +90,12 @@ async def initialize(self, ctx: PipelineContext) -> None: self.interval = [float(t) for t in interval_str_ls] except BaseException as e: logger.error(f"解析分段回复的间隔时间失败。{e}") - logger.info(f"分段回复间隔时间:{self.interval}") + logger.info(f"分段回复间隔时间:{self.interval}") + + def _get_effective_result(self, event: AstrMessageEvent): + if self.sdk_plugin_bridge is not None: + return self.sdk_plugin_bridge.get_effective_result(event) + return event.get_result() async def _word_cnt(self, text: str) -> int: """分段回复 统计字数""" @@ -128,12 +137,36 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bo # 如果所有组件都为空 return True + @staticmethod + def _message_outline_for_sdk_event( + chain: MessageChain | list[BaseMessageComponent] | None, + ) -> str: + if isinstance(chain, MessageChain): + return chain.get_plain_text(with_other_comps_mark=True) + if isinstance(chain, list): + return MessageChain(chain).get_plain_text(with_other_comps_mark=True) + return "" + + @staticmethod + def _message_payloads_for_sdk_event( + chain: MessageChain | list[BaseMessageComponent] | None, + ) -> list[dict]: + from astrbot_sdk.message.components import component_to_payload_sync + + if isinstance(chain, MessageChain): + components = chain.chain + elif isinstance(chain, list): + components = chain + else: + components = [] + return [component_to_payload_sync(component) for component in components] + def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: """检查是否需要分段回复""" if not self.enable_seg: return False - if (result := event.get_result()) is None: + if (result := self._get_effective_result(event)) is None: return False if self.only_llm_result and not result.is_model_result(): return False @@ -167,21 +200,72 @@ def _extract_comp( return extracted + def _bind_plugin_log(self): + bind = getattr(logger, "bind", None) + if callable(bind): + return bind(plugin_tag="[Plug]") + return logger + + async def _dispatch_after_message_sent( + self, + event: AstrMessageEvent, + result, + ) -> bool: + if await call_event_hook(event, EventType.OnAfterMessageSentEvent): + return True + + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_message_event( + "after_message_sent", + event, + { + "session_id": event.unified_msg_origin, + "platform": event.get_platform_name(), + "platform_id": event.get_platform_id(), + "message_type": sdk_message_type(event.get_message_type()), + "sender_name": event.get_sender_name(), + "self_id": event.get_self_id(), + "message_outline": self._message_outline_for_sdk_event( + result.chain + ), + "sent_message_outline": self._message_outline_for_sdk_event( + result.chain + ), + "sent_messages": self._message_payloads_for_sdk_event( + result.chain + ), + }, + ) + except Exception as exc: + logger.warning(f"SDK after_message_sent dispatch failed: {exc}") + return False + async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - result = event.get_result() + result = self._get_effective_result(event) if result is None: return if event.get_extra("_streaming_finished", False): # prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again return if result.result_content_type == ResultContentType.STREAMING_FINISH: + logger.info( + "Streaming finish reached, dispatching after_message_sent hooks." + ) event.set_extra("_streaming_finished", True) + await self._dispatch_after_message_sent(event, result) + event.clear_result() return - logger.info( + log = ( + self._bind_plugin_log() + if event.get_extra("_sdk_origin_plugin_id") + else logger + ) + log.info( f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}", ) @@ -290,7 +374,7 @@ async def process( exc_info=True, ) - if await call_event_hook(event, EventType.OnAfterMessageSentEvent): + if await self._dispatch_after_message_sent(event, result): return event.clear_result() diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 4ee7461305..7ff2bbaa9d 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,8 +5,8 @@ from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger -from astrbot.core.message.components import At, Image, Json, Node, Plain, Record, Reply -from astrbot.core.message.message_event_result import ResultContentType +from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply +from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType @@ -20,8 +20,19 @@ @register_stage class ResultDecorateStage(Stage): + @staticmethod + def _message_outline_for_sdk_event(chain: MessageChain | list | None) -> str: + if isinstance(chain, MessageChain): + return chain.get_plain_text(with_other_comps_mark=True) + if isinstance(chain, list): + return MessageChain(chain).get_plain_text(with_other_comps_mark=True) + return "" + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx + self.sdk_plugin_bridge = getattr( + ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ "reply_with_mention" @@ -101,6 +112,11 @@ async def initialize(self, ctx: PipelineContext) -> None: provider_cfg = ctx.astrbot_config.get("provider_settings", {}) self.show_reasoning = provider_cfg.get("display_reasoning_text", False) + def _get_effective_result(self, event: AstrMessageEvent): + if self.sdk_plugin_bridge is not None: + return self.sdk_plugin_bridge.get_effective_result(event) + return event.get_result() + def _split_text_by_words(self, text: str) -> list[str]: """使用分段词列表分段文本""" if not self.split_words_pattern: @@ -127,7 +143,7 @@ async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - result = event.get_result() + result = self._get_effective_result(event) if result is None or not result.chain: return @@ -184,13 +200,37 @@ async def process( ) return + result = self._get_effective_result(event) + if result is None or not result.chain: + return + + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_message_event( + "decorating_result", + event, + { + "message_outline": self._message_outline_for_sdk_event( + result.chain + ), + "result_content_type": ( + result.result_content_type.name.lower() + if result.result_content_type is not None + else "" + ), + }, + event_result=result, + ) + except Exception as exc: + logger.warning(f"SDK decorating_result dispatch failed: {exc}") + # 流式输出不执行下面的逻辑 if is_stream: logger.info("流式输出已启用,跳过结果装饰阶段") return # 需要再获取一次。插件可能直接对 chain 进行了替换。 - result = event.get_result() + result = self._get_effective_result(event) if result is None: return @@ -275,21 +315,8 @@ async def process( and event.get_extra("_llm_reasoning_content") ): # inject reasoning content to chain - reasoning_content = str(event.get_extra("_llm_reasoning_content")) - if event.get_platform_name() == "lark": - result.chain.insert( - 0, - Json( - data={ - "type": "lark_collapsible_panel_reasoning", - "title": "💭 Thinking", - "expanded": False, - "content": reasoning_content, - }, - ), - ) - else: - result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) + reasoning_content = event.get_extra("_llm_reasoning_content") + result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) if should_tts and tts_provider: new_chain = [] diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 243d03378c..e78db8660d 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -92,5 +92,14 @@ async def execute(self, event: AstrMessageEvent) -> None: logger.debug("pipeline 执行完毕。") finally: - event.cleanup_temporary_local_files() - active_event_registry.unregister(event) + try: + event.cleanup_temporary_local_files() + finally: + try: + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + sdk_plugin_bridge.close_request_overlay_for_event(event) + finally: + active_event_registry.unregister(event) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 0ecd47fedc..8e1293f85e 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import asyncio import hashlib @@ -6,11 +8,9 @@ import uuid from collections.abc import AsyncGenerator from time import time -from typing import Any +from typing import TYPE_CHECKING, Any from astrbot import logger -from astrbot.core.agent.tool import ToolSet -from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( At, AtAll, @@ -23,7 +23,6 @@ ) from astrbot.core.message.message_event_result import MessageChain, MessageEventResult from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric from astrbot.core.utils.trace import TraceSpan @@ -31,6 +30,11 @@ from .message_session import MessageSesion, MessageSession # noqa from .platform_metadata import PlatformMetadata +if TYPE_CHECKING: + from astrbot.core.agent.tool import ToolSet + from astrbot.core.db.po import Conversation + from astrbot.core.provider.entities import ProviderRequest + class AstrMessageEvent(abc.ABC): def __init__( @@ -86,9 +90,9 @@ def __init__( """事件级 TraceSpan(别名: span)""" self._has_send_oper = False - """在此次事件中是否有过至少一次发送消息的操作""" + """底层标记:事件是否已触发至少一次平台发送。新代码应通过 mark_send_operation() / has_send_operation() 操作。""" self.call_llm = False - """是否在此消息事件中禁止默认的 LLM 请求""" + """语义反转的遗留字段:True 表示阻止内置默认 LLM 阶段。新代码应使用 set_default_llm_blocked() / should_call_default_llm()。""" self._temporary_local_files: list[str] = [] """Temporary local files created during this event and safe to delete when it finishes.""" @@ -137,7 +141,10 @@ def get_message_str(self) -> str: """获取消息字符串。""" return self.message_str - def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str: + def _outline_chain( + self, + chain: MessageChain | list[BaseMessageComponent] | None, + ) -> str: if not chain: return "" @@ -261,6 +268,10 @@ def is_admin(self) -> bool: """是否是管理员。""" return self.role == "admin" + def has_admin_permission(self) -> bool: + """语义更明确的别名:is_admin() 容易被误解为"判断身份",has_admin_permission 强调权限语义。""" + return self.is_admin() + async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str: """将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。""" while True: @@ -285,7 +296,7 @@ async def send_streaming( asyncio.create_task( Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name), ) - self._has_send_oper = True + self.mark_send_operation() async def send_typing(self) -> None: """发送输入中状态。 @@ -305,6 +316,15 @@ async def _pre_send(self) -> None: async def _post_send(self) -> None: """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" + def _active_sdk_result_binding(self): + binding = getattr(self, "_sdk_result_binding", None) + if binding is None: + return None + is_active = getattr(binding, "is_active", None) + if callable(is_active) and not is_active(): + return None + return binding + def set_result(self, result: MessageEventResult | str) -> None: """设置消息事件的结果。 @@ -332,10 +352,18 @@ async def check_count(self, event: AstrMessageEvent): # 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表 if isinstance(result, MessageEventResult) and result.chain is None: result.chain = [] + binding = self._active_sdk_result_binding() + if binding is not None: + binding.set_result(result) + return self._result = result def stop_event(self) -> None: """终止事件传播。""" + binding = self._active_sdk_result_binding() + if binding is not None: + binding.stop_event() + return if self._result is None: self.set_result(MessageEventResult().stop_event()) else: @@ -343,6 +371,10 @@ def stop_event(self) -> None: def continue_event(self) -> None: """继续事件传播。""" + binding = self._active_sdk_result_binding() + if binding is not None: + binding.continue_event() + return if self._result is None: self.set_result(MessageEventResult().continue_event()) else: @@ -350,23 +382,65 @@ def continue_event(self) -> None: def is_stopped(self) -> bool: """是否终止事件传播。""" + binding = self._active_sdk_result_binding() + if binding is not None and binding.has_result_state(): + return binding.is_stopped() if self._result is None: return False # 默认是继续传播 return self._result.is_stopped() def should_call_llm(self, call_llm: bool) -> None: - """是否在此消息事件中禁止默认的 LLM 请求。 + """向后兼容的包装器:历史调用者传 True 意为“阻止 LLM”,名字语义反转。 - 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 + 新代码应直接使用 set_default_llm_blocked() 或 should_call_default_llm()。 """ - self.call_llm = call_llm + self.set_default_llm_blocked(call_llm) + + def disable_default_llm(self, disabled: bool = True) -> None: + """向后兼容别名:disabled=True 阻止内置默认 LLM 阶段。""" + self.set_default_llm_blocked(disabled) + + def set_default_llm_blocked(self, blocked: bool = True) -> None: + """底层写入方法:blocked=True 阻止本事件的内置 LLM 阶段。""" + self.call_llm = bool(blocked) + + def set_default_llm_allowed(self, allowed: bool = True) -> None: + """allowed=True 表示允许内置 LLM 阶段(等价于 blocked=False)。""" + self.set_default_llm_blocked(not allowed) + + def should_call_default_llm(self) -> bool: + """返回内置默认 LLM 管道是否仍被允许。call_llm 语义反转:True=阻止。""" + return not bool(self.call_llm) + + def mark_send_operation(self) -> None: + """标记本事件已至少发送过一条平台消息。""" + self.set_send_operation_state(True) + + def set_send_operation_state(self, has_sent: bool) -> None: + """底层写入方法:更新事件的发送操作状态。""" + self._has_send_oper = bool(has_sent) + + def has_send_operation(self) -> bool: + """返回本事件是否已发送过至少一条平台消息。""" + return bool(self._has_send_oper) + + def get_send_operation_state(self) -> bool: + """向后兼容的读取方法,供 bridge 代码读取原始发送标记。""" + return self.has_send_operation() def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" + binding = self._active_sdk_result_binding() + if binding is not None and binding.has_result_state(): + return binding.get_result() return self._result def clear_result(self) -> None: """清除消息事件的结果。""" + binding = self._active_sdk_result_binding() + if binding is not None: + binding.clear_result() + return self._result = None """消息链相关""" @@ -446,6 +520,8 @@ def request_llm( if len(contexts) > 0 and conversation: conversation = None + from astrbot.core.provider.entities import ProviderRequest + return ProviderRequest( prompt=prompt, session_id=session_id, @@ -476,7 +552,7 @@ async def send(self, message: MessageChain) -> None: sid=sid, ), ) - self._has_send_oper = True + self.mark_send_operation() async def react(self, emoji: str) -> None: """对消息添加表情回应。 diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 15c04166dc..1a26ebd58d 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -2,6 +2,7 @@ import traceback from asyncio import Queue from dataclasses import dataclass +from typing import TYPE_CHECKING from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig @@ -12,6 +13,9 @@ from .register import platform_cls_map from .sources.webchat.webchat_adapter import WebChatAdapter +if TYPE_CHECKING: + from astrbot.core.sdk_bridge.plugin_bridge import SdkPluginBridge + @dataclass class PlatformTasks: @@ -34,6 +38,7 @@ def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: 这个配置中的 unique_session 需要特殊处理, 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue + self.sdk_plugin_bridge: SdkPluginBridge | None = None def _is_valid_platform_id(self, platform_id: str | None) -> bool: if not platform_id: @@ -202,6 +207,7 @@ async def load_platform(self, platform_config: dict) -> None: return cls_type = platform_cls_map[platform_config["type"]] inst: Platform = cls_type(platform_config, self.settings, self.event_queue) + setattr(inst, "sdk_plugin_bridge", self.sdk_plugin_bridge) self._inst_map[platform_config["id"]] = { "inst": inst, "client_id": inst.client_self_id, @@ -222,6 +228,17 @@ async def load_platform(self, platform_config: dict) -> None: await handler.handler() except Exception: logger.error(traceback.format_exc()) + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_system_event( + "platform_loaded", + { + "platform": inst.meta().name, + "platform_id": inst.meta().id, + }, + ) + except Exception as exc: + logger.warning(f"SDK platform_loaded event dispatch failed: {exc}") async def _task_wrapper( self, task: asyncio.Task, platform: Platform | None = None @@ -300,6 +317,48 @@ async def terminate(self) -> None: def get_insts(self): return self.platform_insts + async def refresh_native_commands( + self, *, platforms: set[str] | None = None + ) -> None: + """Refresh native command menus for running platform adapters. + + Native command registration is platform-specific. Today Telegram owns its + own command sync path, so plugin hot reloads need an explicit follow-up + refresh to make newly loaded SDK commands visible without waiting for the + periodic registration job or a full restart. + """ + requested_platforms = ( + {item.strip().lower() for item in platforms if item and item.strip()} + if platforms + else None + ) + for inst in list(self.platform_insts): + platform_name = "" + try: + platform_name = str(inst.meta().name).strip().lower() + except Exception: + logger.debug("Failed to read platform metadata during command refresh.") + continue + + if ( + requested_platforms is not None + and platform_name not in requested_platforms + ): + continue + + register_commands = getattr(inst, "register_commands", None) + if not callable(register_commands): + continue + + try: + await register_commands() + except Exception as exc: + logger.warning( + "刷新 %s 平台原生命令失败: %s", + platform_name or "unknown", + exc, + ) + def get_all_stats(self) -> dict: """获取所有平台的统计信息 diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index 62ec5070ab..0a1384d3c1 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -8,6 +8,20 @@ """维护了平台适配器名称和适配器类的映射""" +def _is_same_adapter_identity(existing_cls: type, new_cls: type) -> bool: + """Return whether two adapter classes represent the same logical adapter. + + Re-imports and hot reloads can create a new class object for the same + module/class name. Those cases should refresh the registry entry instead of + being treated as a real naming conflict. + """ + + return ( + existing_cls.__module__ == new_cls.__module__ + and existing_cls.__qualname__ == new_cls.__qualname__ + ) + + def register_platform_adapter( adapter_name: str, desc: str, @@ -26,11 +40,6 @@ def register_platform_adapter( """ def decorator(cls): - if adapter_name in platform_cls_map: - raise ValueError( - f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。", - ) - # 添加必备选项 if default_config_tmpl: if "type" not in default_config_tmpl: @@ -55,6 +64,28 @@ def decorator(cls): i18n_resources=i18n_resources, config_metadata=config_metadata, ) + + existing_cls = platform_cls_map.get(adapter_name) + if existing_cls is not None: + # SDK/adapter tests and hot reload paths can import the same adapter + # module more than once in one process. Refresh that registration in + # place so we keep conflict detection for genuinely different classes. + if not _is_same_adapter_identity(existing_cls, cls): + raise ValueError( + f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。", + ) + + for index, registered_pm in enumerate(platform_registry): + if registered_pm.name == adapter_name: + platform_registry[index] = pm + break + else: + platform_registry.append(pm) + + platform_cls_map[adapter_name] = cls + logger.debug(f"平台适配器 {adapter_name} 重复注册,已刷新既有注册信息") + return cls + platform_registry.append(pm) platform_cls_map[adapter_name] = cls logger.debug(f"平台适配器 {adapter_name} 已注册") diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 7657962a11..50215ca44f 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -48,6 +48,7 @@ def __init__( self.settings = platform_settings self.client_self_id: str | None = None self.registered_handlers = [] + self.sdk_plugin_bridge = None # 指令注册相关 self.enable_command_register = self.config.get("discord_command_register", True) self.guild_id = self.config.get("discord_guild_id_for_debug", None) @@ -366,42 +367,25 @@ async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] - - for handler_md in star_handlers_registry: - if not star_map[handler_md.handler_module_path].activated: - continue - if not handler_md.enabled: - continue - for event_filter in handler_md.event_filters: - cmd_info = self._extract_command_info(event_filter, handler_md) - if not cmd_info: - continue - - cmd_name, description, cmd_filter_instance = cmd_info - - # 创建动态回调 - callback = self._create_dynamic_callback(cmd_name) - - # 创建一个通用的参数选项来接收所有文本输入 - options = [ - discord.Option( - name="params", - description="指令的所有参数", - type=discord.SlashCommandOptionType.string, - required=False, - ), - ] - - # 创建SlashCommand - slash_command = discord.SlashCommand( - name=cmd_name, - description=description, - func=callback, - options=options, - guild_ids=[self.guild_id] if self.guild_id else None, - ) - self.client.add_application_command(slash_command) - registered_commands.append(cmd_name) + for cmd_name, description in self.collect_commands(): + callback = self._create_dynamic_callback(cmd_name) + options = [ + discord.Option( + name="params", + description="指令的所有参数", + type=discord.SlashCommandOptionType.string, + required=False, + ), + ] + slash_command = discord.SlashCommand( + name=cmd_name, + description=description, + func=callback, + options=options, + guild_ids=[self.guild_id] if self.guild_id else None, + ) + self.client.add_application_command(slash_command) + registered_commands.append(cmd_name) if registered_commands: logger.info( @@ -415,6 +399,53 @@ async def _collect_and_register_commands(self) -> None: await self.client.sync_commands() logger.info("[Discord] 指令同步完成。") + def collect_commands(self) -> list[tuple[str, str]]: + """收集 legacy 与 SDK 的顶层原生命令。""" + command_dict: dict[str, str] = {} + + for handler_md in star_handlers_registry: + if not star_map[handler_md.handler_module_path].activated: + continue + if not handler_md.enabled: + continue + for event_filter in handler_md.event_filters: + cmd_info = self._extract_command_info(event_filter, handler_md) + if not cmd_info: + continue + cmd_name, description, _cmd_filter_instance = cmd_info + if cmd_name in command_dict: + logger.warning( + f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " + f"'{command_dict[cmd_name]}'" + ) + command_dict.setdefault(cmd_name, description) + + sdk_bridge = getattr(self, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + for item in sdk_bridge.list_native_command_candidates("discord"): + cmd_name = str(item.get("name", "")).strip() + if not cmd_name: + continue + if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): + logger.debug(f"[Discord] 跳过不符合规范的 SDK 指令: {cmd_name}") + continue + description = str(item.get("description") or "").strip() + if not description: + if item.get("is_group"): + description = f"Command group: {cmd_name}" + else: + description = f"Command: {cmd_name}" + if len(description) > 100: + description = f"{description[:97]}..." + if cmd_name in command_dict: + logger.warning( + f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " + f"'{command_dict[cmd_name]}'" + ) + command_dict.setdefault(cmd_name, description) + + return sorted(command_dict.items(), key=lambda item: item[0].lower()) + def _create_dynamic_callback(self, cmd_name: str): """为每个指令动态创建一个异步回调函数""" @@ -481,7 +512,6 @@ def _extract_command_info( ) -> tuple[str, str, CommandFilter | None] | None: """从事件过滤器中提取指令信息""" cmd_name = None - # is_group = False cmd_filter_instance = None if isinstance(event_filter, CommandFilter): @@ -501,7 +531,6 @@ def _extract_command_info( if not cmd_name: return None - # Discord 斜杠指令名称规范 if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}") return None diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 6c22a1aa5f..40f058d307 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -3,7 +3,8 @@ import re import sys import uuid -from typing import cast +from collections.abc import Sequence +from typing import Protocol, cast from apscheduler.schedulers.asyncio import AsyncIOScheduler from telegram import BotCommand, Update @@ -40,6 +41,14 @@ from typing_extensions import override +class _CaptionEntityLike(Protocol): + # Telegram stubs expose caption_entities as tuples, so this helper only + # relies on the fields we actually read instead of a concrete container type. + type: str + offset: int + length: int + + @register_platform_adapter("telegram", "telegram 适配器") class TelegramPlatformAdapter(Platform): def __init__( @@ -51,6 +60,7 @@ def __init__( super().__init__(platform_config, event_queue) self.settings = platform_settings self.client_self_id = uuid.uuid4().hex[:8] + self.sdk_plugin_bridge = None base_url = self.config.get( "telegram_api_base_url", @@ -248,6 +258,31 @@ def collect_commands(self) -> list[BotCommand]: ) command_dict.setdefault(cmd_name, description) + sdk_bridge = getattr(self, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + for item in sdk_bridge.list_native_command_candidates("telegram"): + cmd_name = str(item.get("name", "")).strip() + if not cmd_name or cmd_name in skip_commands: + continue + if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32: + continue + + description = str(item.get("description") or "").strip() + if not description: + if item.get("is_group"): + description = f"Command group: {cmd_name}" + else: + description = f"Command: {cmd_name}" + if len(description) > 30: + description = description[:30] + "..." + + if cmd_name in command_dict: + logger.warning( + f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " + f"'{command_dict[cmd_name]}'" + ) + command_dict.setdefault(cmd_name, description) + commands_a = sorted(command_dict.keys()) return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a] @@ -335,18 +370,6 @@ async def convert_message( logger.warning("Received an update without a message.") return None - def _apply_caption() -> None: - if update.message.caption: - message.message_str = update.message.caption - message.message.append(Comp.Plain(message.message_str)) - if update.message.caption and update.message.caption_entities: - for entity in update.message.caption_entities: - if entity.type == "mention": - name = update.message.caption[ - entity.offset + 1 : entity.offset + entity.length - ] - message.message.append(Comp.At(qq=name, name=name)) - message = AstrBotMessage() message.session_id = str(update.message.chat.id) @@ -466,7 +489,11 @@ def _apply_caption() -> None: photo = update.message.photo[-1] # get the largest photo file = await photo.get_file() message.message.append(Comp.Image(file=file.file_path, url=file.file_path)) - _apply_caption() + self._append_caption_components( + message, + update.message.caption, + update.message.caption_entities, + ) elif update.message.sticker: # 将sticker当作图片处理 @@ -489,7 +516,11 @@ def _apply_caption() -> None: message.message.append( Comp.File(file=file_path, name=file_name, url=file_path) ) - _apply_caption() + self._append_caption_components( + message, + update.message.caption, + update.message.caption_entities, + ) elif update.message.video: file = await update.message.video.get_file() @@ -501,10 +532,40 @@ def _apply_caption() -> None: ) else: message.message.append(Comp.Video(file=file_path, path=file.file_path)) - _apply_caption() + self._append_caption_components( + message, + update.message.caption, + update.message.caption_entities, + ) return message + @staticmethod + def _append_caption_components( + message: AstrBotMessage, + caption: str | None, + caption_entities: Sequence[_CaptionEntityLike] | None, + ) -> None: + """Keep media captions aligned with photo/document/video conversions.""" + + if not caption: + return + + # Telegram attaches captions to multiple media types; keeping the shared + # conversion here prevents photo/document/video from drifting again. + message.message_str = caption + message.message.append(Comp.Plain(message.message_str)) + + if not caption_entities: + return + + for entity in caption_entities: + if entity.type == "mention": + name = message.message_str[ + entity.offset + 1 : entity.offset + entity.length + ] + message.message.append(Comp.At(qq=name, name=name)) + async def handle_media_group_message( self, update: Update, context: ContextTypes.DEFAULT_TYPE ): diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 26b434573f..e0edd3933c 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -110,12 +110,15 @@ async def send_by_session( return for request_id in target_request_ids: + # Proactive sends are already complete messages. Do not replay them as + # streaming chunks tied to the active request, otherwise the frontend + # keeps the current request in a loading state until that request ends. await WebChatMessageEvent._send( request_id, message_chain, session.session_id, - streaming=True, - emit_complete=True, + streaming=False, + emit_complete=False, ) # If only passive subscription queues exist for this conversation, diff --git a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py index 2f87b88b90..6034b5e371 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py @@ -1,10 +1,22 @@ """企业微信智能机器人平台适配器包""" -from .wecomai_adapter import WecomAIBotAdapter -from .wecomai_api import WecomAIBotAPIClient -from .wecomai_event import WecomAIBotMessageEvent -from .wecomai_server import WecomAIBotServer -from .wecomai_utils import WecomAIBotConstants +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .wecomai_adapter import WecomAIBotAdapter + from .wecomai_api import WecomAIBotAPIClient + from .wecomai_event import WecomAIBotMessageEvent + from .wecomai_server import WecomAIBotServer + from .wecomai_utils import WecomAIBotConstants +else: + WecomAIBotAdapter: Any + WecomAIBotAPIClient: Any + WecomAIBotMessageEvent: Any + WecomAIBotServer: Any + WecomAIBotConstants: Any __all__ = [ "WecomAIBotAPIClient", @@ -13,3 +25,17 @@ "WecomAIBotMessageEvent", "WecomAIBotServer", ] + + +def __getattr__(name: str) -> Any: + if name == "WecomAIBotAdapter": + return import_module(".wecomai_adapter", __name__).WecomAIBotAdapter + if name == "WecomAIBotAPIClient": + return import_module(".wecomai_api", __name__).WecomAIBotAPIClient + if name == "WecomAIBotMessageEvent": + return import_module(".wecomai_event", __name__).WecomAIBotMessageEvent + if name == "WecomAIBotServer": + return import_module(".wecomai_server", __name__).WecomAIBotServer + if name == "WecomAIBotConstants": + return import_module(".wecomai_utils", __name__).WecomAIBotConstants + raise AttributeError(name) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index f27d4671e5..86931c2c43 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -1,15 +1,19 @@ """企业微信智能机器人事件处理模块,处理消息事件的发送和接收""" +from __future__ import annotations + import asyncio from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import At, Image, Plain -from .wecomai_api import WecomAIBotAPIClient -from .wecomai_queue_mgr import WecomAIQueueMgr -from .wecomai_webhook import WecomAIBotWebhookClient +if TYPE_CHECKING: + from .wecomai_api import WecomAIBotAPIClient + from .wecomai_queue_mgr import WecomAIQueueMgr + from .wecomai_webhook import WecomAIBotWebhookClient class WecomAIBotMessageEvent(AstrMessageEvent): diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index ad8bb44f6d..c674cd8195 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -1,8 +1,232 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from astrbot_sdk.message.components import component_to_payload_sync + from astrbot.core.db import BaseDatabase from astrbot.core.db.po import PlatformMessageHistory +from astrbot.core.message.components import ( + At, + AtAll, + BaseMessageComponent, + File, + Forward, + Image, + Plain, + Poke, + Record, + Reply, + Unknown, + Video, +) +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.platform.message_type import MessageType + + +@dataclass(frozen=True, slots=True) +class MessageHistorySender: + sender_id: str | None = None + sender_name: str | None = None + + +@dataclass(slots=True) +class MessageHistoryRecord: + id: int + session: MessageSession + sender: MessageHistorySender + parts: list[BaseMessageComponent] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + created_at: datetime | None = None + updated_at: datetime | None = None + idempotency_key: str | None = None + + +@dataclass(frozen=True, slots=True) +class MessageHistoryPage: + records: list[MessageHistoryRecord] + next_cursor: str | None + total: int | None + + +def _message_type_key(value: MessageType | str) -> str: + if isinstance(value, MessageType): + if value == MessageType.GROUP_MESSAGE: + return "group" + if value == MessageType.FRIEND_MESSAGE: + return "private" + return "other" + normalized = str(value).strip().lower() + if normalized in {"group", "groupmessage", "group_message"}: + return "group" + if normalized in { + "private", + "friend", + "friendmessage", + "privatemessage", + "friend_message", + "private_message", + }: + return "private" + if normalized in {"other", "othermessage", "other_message"}: + return "other" + raise ValueError(f"Unsupported message type: {value}") + + +def _message_type_enum(value: str) -> MessageType: + normalized = _message_type_key(value) + if normalized == "group": + return MessageType.GROUP_MESSAGE + if normalized == "private": + return MessageType.FRIEND_MESSAGE + return MessageType.OTHER_MESSAGE + + +def _session_storage_key(session: MessageSession) -> str: + # TODO(refactor): persist message_type as a first-class column once the + # legacy message history model can be migrated without impacting old plugins. + return f"{_message_type_key(session.message_type)}:{session.session_id}" + + +def _optional_int_cursor(cursor: str | None) -> int | None: + if cursor is None: + return None + text = str(cursor).strip() + if not text: + return None + return int(text) + + +def _payload_to_component(payload: Any) -> BaseMessageComponent: + if not isinstance(payload, dict): + return Unknown(text=str(payload)) + + raw_type = str(payload.get("type", "unknown") or "unknown").lower() + data = payload.get("data") + if not isinstance(data, dict): + data = {} + + if raw_type in {"text", "plain"}: + return Plain(str(data.get("text", "")), convert=False) + if raw_type == "image": + image_data = dict(data) + image_file = str(image_data.pop("file", "") or image_data.get("url") or "") + return Image(image_file, **image_data) + if raw_type == "at": + qq_value = data.get("qq") + if str(qq_value).lower() == "all": + return AtAll() + return At(qq=str(qq_value or ""), name=str(data.get("name", ""))) + if raw_type == "reply": + reply_data = dict(data) + chain_payload = reply_data.get("chain") + reply_data["chain"] = ( + [_payload_to_component(item) for item in chain_payload] + if isinstance(chain_payload, list) + else [] + ) + return Reply(**reply_data) + if raw_type == "record": + record_data = dict(data) + record_file = str(record_data.pop("file", "") or record_data.get("url") or "") + return Record(record_file, **record_data) + if raw_type == "video": + video_data = dict(data) + video_file = str(video_data.pop("file", "") or "") + return Video(video_file, **video_data) + if raw_type == "file": + file_value = str(data.get("file") or data.get("file_") or data.get("url") or "") + return File( + str(data.get("name", "") or "file"), + file="" if file_value.startswith(("http://", "https://")) else file_value, + url=file_value if file_value.startswith(("http://", "https://")) else "", + ) + if raw_type == "poke": + return Poke( + poke_type=data.get("type"), + id=data.get("id"), + qq=data.get("qq"), + ) + if raw_type == "forward": + return Forward(id=str(data.get("id", ""))) + return Unknown(text=str(payload)) + + +def _legacy_content_to_payloads( + content: dict[str, Any], +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + message_parts = content.get("message") + if not isinstance(message_parts, list): + return [], {} + payloads: list[dict[str, Any]] = [] + for part in message_parts: + if not isinstance(part, dict): + continue + part_type = str(part.get("type", "")).strip().lower() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + payloads.append({"type": "text", "data": {"text": text}}) + continue + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + continue + payloads.append( + { + "type": "reply", + "data": { + "id": str(message_id), + "message_str": str(part.get("selected_text", "")), + "chain": [], + }, + } + ) + continue + if part_type not in {"image", "record", "file", "video"}: + continue + payload_data: dict[str, Any] = {} + attachment_id = part.get("attachment_id") + if attachment_id is not None: + payload_data["attachment_id"] = str(attachment_id) + filename = part.get("filename") + if filename is not None: + payload_data["filename"] = str(filename) + if part_type == "file": + payload_data["name"] = str(filename) + path_value = part.get("path") + if path_value not in (None, ""): + payload_data["path"] = str(path_value) + payload_data["file"] = str(path_value) + payloads.append({"type": part_type, "data": payload_data}) + metadata = {key: value for key, value in content.items() if key != "message"} + return payloads, metadata + + +def _content_to_parts_and_metadata( + content: Any, +) -> tuple[list[dict[str, Any]], dict[str, Any], str | None]: + if not isinstance(content, dict): + return [], {}, None + if isinstance(content.get("parts"), list): + metadata = content.get("metadata") + idempotency_key = content.get("idempotency_key") + return ( + [dict(item) for item in content["parts"] if isinstance(item, dict)], + dict(metadata) if isinstance(metadata, dict) else {}, + str(idempotency_key) if idempotency_key is not None else None, + ) + payloads, metadata = _legacy_content_to_payloads(content) + return payloads, metadata, None class PlatformMessageHistoryManager: + MessageHistorySender = MessageHistorySender + MessageHistoryRecord = MessageHistoryRecord + MessageHistoryPage = MessageHistoryPage + def __init__(self, db_helper: BaseDatabase) -> None: self.db = db_helper @@ -10,7 +234,7 @@ async def insert( self, platform_id: str, user_id: str, - content: dict, # TODO: parse from message chain + content: dict, sender_id: str | None = None, sender_name: str | None = None, ) -> PlatformMessageHistory: @@ -49,3 +273,146 @@ async def delete( user_id=user_id, offset_sec=offset_sec, ) + + async def append( + self, + session: MessageSession, + *, + parts: list[BaseMessageComponent], + sender: MessageHistorySender, + metadata: dict[str, Any] | None = None, + idempotency_key: str | None = None, + ) -> MessageHistoryRecord: + storage_user_id = _session_storage_key(session) + if idempotency_key: + # TODO(refactor): move idempotency_key into a dedicated indexed column + # after the legacy history table is migrated for the new SDK path. + existing = await self.db.find_platform_message_history_by_idempotency_key( + platform_id=session.platform_id, + user_id=storage_user_id, + idempotency_key=idempotency_key, + ) + if existing is not None: + return self._record_from_model(existing) + + content = { + "parts": [component_to_payload_sync(part) for part in parts], + "metadata": dict(metadata or {}), + } + if idempotency_key is not None: + content["idempotency_key"] = str(idempotency_key) + + record = await self.db.insert_platform_message_history( + platform_id=session.platform_id, + user_id=storage_user_id, + content=content, + sender_id=sender.sender_id, + sender_name=sender.sender_name, + ) + return self._record_from_model(record) + + async def list( + self, + session: MessageSession, + *, + cursor: str | None = None, + limit: int = 50, + ) -> MessageHistoryPage: + normalized_limit = max(1, int(limit)) + rows, total = await self.db.list_sdk_platform_message_history( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + cursor_id=_optional_int_cursor(cursor), + limit=normalized_limit + 1, + include_total=True, + ) + has_more = len(rows) > normalized_limit + page_rows = rows[:normalized_limit] + records = [self._record_from_model(row) for row in page_rows] + next_cursor = str(page_rows[-1].id) if has_more and page_rows else None + return MessageHistoryPage(records=records, next_cursor=next_cursor, total=total) + + async def get_by_id( + self, + session: MessageSession, + record_id: int, + ) -> MessageHistoryRecord | None: + record = await self.db.get_platform_message_history_by_id(int(record_id)) + if record is None: + return None + if record.platform_id != session.platform_id: + return None + if record.user_id != _session_storage_key(session): + return None + return self._record_from_model(record) + + async def delete_before( + self, + session: MessageSession, + *, + before: datetime, + ) -> int: + return await self.db.delete_platform_message_before( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + before=before, + ) + + async def delete_after( + self, + session: MessageSession, + *, + after: datetime, + ) -> int: + return await self.db.delete_platform_message_after( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + after=after, + ) + + async def delete_all(self, session: MessageSession) -> int: + return await self.db.delete_all_platform_message_history( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + ) + + def _record_from_model( + self, record: PlatformMessageHistory + ) -> MessageHistoryRecord: + parts_payload, metadata, idempotency_key = _content_to_parts_and_metadata( + record.content + ) + return MessageHistoryRecord( + id=int(record.id or 0), + session=self._session_from_storage_record(record), + sender=MessageHistorySender( + sender_id=str(record.sender_id) + if record.sender_id is not None + else None, + sender_name=( + str(record.sender_name) if record.sender_name is not None else None + ), + ), + parts=[_payload_to_component(item) for item in parts_payload], + metadata=metadata, + created_at=record.created_at, + updated_at=record.updated_at, + idempotency_key=idempotency_key, + ) + + def _session_from_storage_record( + self, record: PlatformMessageHistory + ) -> MessageSession: + raw_user_id = str(record.user_id or "") + message_type = "private" + session_id = raw_user_id + if ":" in raw_user_id: + maybe_message_type, maybe_session_id = raw_user_id.split(":", 1) + if maybe_message_type in {"group", "private", "other"} and maybe_session_id: + message_type = maybe_message_type + session_id = maybe_session_id + return MessageSession( + platform_name=str(record.platform_id), + message_type=_message_type_enum(message_type), + session_id=session_id, + ) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7a3e1543a7..c1815d2e0d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -96,6 +96,13 @@ def register_provider_change_hook( if hook not in self._provider_change_hooks: self._provider_change_hooks.append(hook) + def unregister_provider_change_hook( + self, + hook: Callable[[str, ProviderType, str | None], None], + ) -> None: + if hook in self._provider_change_hooks: + self._provider_change_hooks.remove(hook) + def _notify_provider_changed( self, provider_id: str, diff --git a/astrbot/core/sdk_bridge/__init__.py b/astrbot/core/sdk_bridge/__init__.py new file mode 100644 index 0000000000..9ebd9232dd --- /dev/null +++ b/astrbot/core/sdk_bridge/__init__.py @@ -0,0 +1,31 @@ +"""SDK bridge package public exports.""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .capability_bridge import CoreCapabilityBridge + from .plugin_bridge import SdkPluginBridge + from .trigger_converter import TriggerConverter +else: + CoreCapabilityBridge: Any + SdkPluginBridge: Any + TriggerConverter: Any + +__all__ = [ + "CoreCapabilityBridge", + "SdkPluginBridge", + "TriggerConverter", +] + + +def __getattr__(name: str) -> Any: + if name == "CoreCapabilityBridge": + return import_module(".capability_bridge", __name__).CoreCapabilityBridge + if name == "SdkPluginBridge": + return import_module(".plugin_bridge", __name__).SdkPluginBridge + if name == "TriggerConverter": + return import_module(".trigger_converter", __name__).TriggerConverter + raise AttributeError(name) diff --git a/astrbot/core/sdk_bridge/bridge_base.py b/astrbot/core/sdk_bridge/bridge_base.py new file mode 100644 index 0000000000..771525a510 --- /dev/null +++ b/astrbot/core/sdk_bridge/bridge_base.py @@ -0,0 +1,619 @@ +from __future__ import annotations + +import asyncio +import contextlib +import json +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, cast + +from astrbot_sdk._internal.invocation_context import current_caller_plugin_id +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import CapabilityRouter + +from astrbot.core.file_token_service import FileTokenService +from astrbot.core.message.components import ComponentTypes, Image, Plain +from astrbot.core.message.message_event_result import MessageChain + +if TYPE_CHECKING: + from astrbot.core.star.context import Context as StarContext + + +def _get_runtime_sp(): + from astrbot.core import sp + + return sp + + +def _get_runtime_html_renderer(): + from astrbot.core import html_renderer + + return html_renderer + + +def _get_runtime_astrbot_config(): + from astrbot.core import astrbot_config + + return astrbot_config + + +def _get_runtime_file_token_service() -> FileTokenService: + from astrbot.core import file_token_service + + return cast(FileTokenService, file_token_service) + + +def _get_runtime_tool_types(): + from astrbot.core.agent.tool import FunctionTool, ToolSet + + return FunctionTool, ToolSet + + +def _get_runtime_provider_types(): + from astrbot.core.provider.provider import ( + EmbeddingProvider, + RerankProvider, + STTProvider, + TTSProvider, + ) + + return STTProvider, TTSProvider, EmbeddingProvider, RerankProvider + + +@dataclass(slots=True) +class _EventStreamState: + request_context: Any + queue: asyncio.Queue[MessageChain | None] + task: asyncio.Task[None] + + +def _build_message_chain_from_payload( + chain_payload: list[dict[str, Any]], +) -> MessageChain: + components = [] + for item in chain_payload: + if not isinstance(item, dict): + continue + comp_type = str(item.get("type", "")).lower() + data = item.get("data", {}) + if comp_type in {"text", "plain"} and isinstance(data, dict): + components.append(Plain(str(data.get("text", "")), convert=False)) + continue + if comp_type == "image" and isinstance(data, dict): + file_value = str(data.get("file") or data.get("url") or "") + if file_value.startswith(("http://", "https://")): + components.append(Image.fromURL(file_value)) + elif file_value: + file_path = ( + file_value[8:] if file_value.startswith("file:///") else file_value + ) + components.append(Image.fromFileSystem(file_path)) + continue + component_cls = ComponentTypes.get(comp_type) + if component_cls is None: + components.append( + Plain(json.dumps(item, ensure_ascii=False), convert=False) + ) + continue + try: + if isinstance(data, dict): + components.append(component_cls(**data)) + else: + components.append(Plain(str(item), convert=False)) + except Exception: + components.append( + Plain(json.dumps(item, ensure_ascii=False), convert=False) + ) + return MessageChain(components) + + +class CapabilityBridgeBase(CapabilityRouter): + MEMORY_SCOPE = "sdk_memory" + + _star_context: StarContext + _plugin_bridge: Any + + @staticmethod + def _to_iso_datetime(value: Any) -> str | None: + if value is None: + return None + isoformat = getattr(value, "isoformat", None) + if callable(isoformat): + return str(isoformat()) + if isinstance(value, (int, float)) and value > 0: + return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat() + return None + + @staticmethod + def _optional_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _normalize_history_items(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [dict(item) for item in value if isinstance(item, dict)] + if isinstance(value, str): + with contextlib.suppress(json.JSONDecodeError, TypeError, ValueError): + decoded = json.loads(value) + if isinstance(decoded, list): + return [dict(item) for item in decoded if isinstance(item, dict)] + return [] + + @staticmethod + def _normalize_persona_dialogs(value: Any) -> list[str]: + if isinstance(value, list): + return [str(item) for item in value if isinstance(item, str)] + if isinstance(value, str): + with contextlib.suppress(json.JSONDecodeError, TypeError, ValueError): + decoded = json.loads(value) + if isinstance(decoded, list): + return [str(item) for item in decoded if isinstance(item, str)] + return [] + + @staticmethod + def _normalize_session_scoped_config( + raw_config: Any, + session_id: str, + ) -> dict[str, Any]: + if not isinstance(raw_config, dict): + return {} + nested = raw_config.get(session_id) + if isinstance(nested, dict): + return dict(nested) + # Session plugin config is stored as {session_id: {...}}, but session + # service config already lives directly under the per-session storage key. + # Accept both shapes so the bridge stays compatible with existing data. + return dict(raw_config) + + def _serialize_persona(self, persona: Any) -> dict[str, Any] | None: + if persona is None: + return None + return { + "persona_id": str(getattr(persona, "persona_id", "") or ""), + "system_prompt": str(getattr(persona, "system_prompt", "") or ""), + "begin_dialogs": self._normalize_persona_dialogs( + getattr(persona, "begin_dialogs", None) + ), + "tools": ( + [str(item) for item in getattr(persona, "tools", [])] + if isinstance(getattr(persona, "tools", None), list) + else None + ), + "skills": ( + [str(item) for item in getattr(persona, "skills", [])] + if isinstance(getattr(persona, "skills", None), list) + else None + ), + "custom_error_message": ( + str(getattr(persona, "custom_error_message", "")) + if getattr(persona, "custom_error_message", None) is not None + else None + ), + "folder_id": ( + str(getattr(persona, "folder_id", "")) + if getattr(persona, "folder_id", None) is not None + else None + ), + "sort_order": int(getattr(persona, "sort_order", 0) or 0), + "created_at": self._to_iso_datetime(getattr(persona, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(persona, "updated_at", None)), + } + + def _serialize_conversation(self, conversation: Any) -> dict[str, Any] | None: + if conversation is None: + return None + return { + "conversation_id": str(getattr(conversation, "cid", "") or ""), + "session": str(getattr(conversation, "user_id", "") or ""), + "platform_id": str(getattr(conversation, "platform_id", "") or ""), + "history": self._normalize_history_items( + getattr(conversation, "history", None) + ), + "title": ( + str(getattr(conversation, "title", "")) + if getattr(conversation, "title", None) is not None + else None + ), + "persona_id": ( + str(getattr(conversation, "persona_id", "")) + if getattr(conversation, "persona_id", None) is not None + else None + ), + "created_at": self._to_iso_datetime( + getattr(conversation, "created_at", None) + ), + "updated_at": self._to_iso_datetime( + getattr(conversation, "updated_at", None) + ), + "token_usage": ( + int(getattr(conversation, "token_usage")) + if getattr(conversation, "token_usage", None) is not None + else None + ), + } + + def _serialize_kb(self, kb_helper_or_record: Any) -> dict[str, Any] | None: + kb = getattr(kb_helper_or_record, "kb", kb_helper_or_record) + if kb is None: + return None + return { + "kb_id": str(getattr(kb, "kb_id", "") or ""), + "kb_name": str(getattr(kb, "kb_name", "") or ""), + "description": ( + str(getattr(kb, "description", "")) + if getattr(kb, "description", None) is not None + else None + ), + "emoji": ( + str(getattr(kb, "emoji", "")) + if getattr(kb, "emoji", None) is not None + else None + ), + "embedding_provider_id": str( + getattr(kb, "embedding_provider_id", "") or "" + ), + "rerank_provider_id": ( + str(getattr(kb, "rerank_provider_id", "")) + if getattr(kb, "rerank_provider_id", None) is not None + else None + ), + "chunk_size": ( + int(getattr(kb, "chunk_size")) + if getattr(kb, "chunk_size", None) is not None + else None + ), + "chunk_overlap": ( + int(getattr(kb, "chunk_overlap")) + if getattr(kb, "chunk_overlap", None) is not None + else None + ), + "top_k_dense": ( + int(getattr(kb, "top_k_dense")) + if getattr(kb, "top_k_dense", None) is not None + else None + ), + "top_k_sparse": ( + int(getattr(kb, "top_k_sparse")) + if getattr(kb, "top_k_sparse", None) is not None + else None + ), + "top_m_final": ( + int(getattr(kb, "top_m_final")) + if getattr(kb, "top_m_final", None) is not None + else None + ), + "doc_count": int(getattr(kb, "doc_count", 0) or 0), + "chunk_count": int(getattr(kb, "chunk_count", 0) or 0), + "created_at": self._to_iso_datetime(getattr(kb, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(kb, "updated_at", None)), + } + + def _serialize_kb_document(self, document: Any) -> dict[str, Any] | None: + if document is None: + return None + return { + "doc_id": str(getattr(document, "doc_id", "") or ""), + "kb_id": str(getattr(document, "kb_id", "") or ""), + "doc_name": str(getattr(document, "doc_name", "") or ""), + "file_type": str(getattr(document, "file_type", "") or ""), + "file_size": int(getattr(document, "file_size", 0) or 0), + "file_path": str(getattr(document, "file_path", "") or ""), + "chunk_count": int(getattr(document, "chunk_count", 0) or 0), + "media_count": int(getattr(document, "media_count", 0) or 0), + "created_at": self._to_iso_datetime(getattr(document, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(document, "updated_at", None)), + } + + @staticmethod + def _serialize_member(member: Any) -> dict[str, Any] | None: + if member is None: + return None + user_id = getattr(member, "user_id", None) + if user_id is None and isinstance(member, dict): + user_id = member.get("user_id") + if user_id is None: + return None + nickname = getattr(member, "nickname", None) + if nickname is None and isinstance(member, dict): + nickname = member.get("nickname") + role = getattr(member, "role", None) + if role is None and isinstance(member, dict): + role = member.get("role") + return { + "user_id": str(user_id), + "nickname": str(nickname or ""), + "role": str(role or ""), + } + + @classmethod + def _serialize_group(cls, group: Any) -> dict[str, Any] | None: + if group is None: + return None + members_payload = [] + raw_members = getattr(group, "members", None) + if raw_members is None: + raw_members = getattr(group, "member_list", None) + if raw_members is None and isinstance(group, dict): + raw_members = group.get("members") or group.get("member_list") + if isinstance(raw_members, list): + for member in raw_members: + serialized_member = cls._serialize_member(member) + if serialized_member is not None: + members_payload.append(serialized_member) + group_id = getattr(group, "group_id", None) + if group_id is None and isinstance(group, dict): + group_id = group.get("group_id") + group_name = getattr(group, "group_name", None) + if group_name is None and isinstance(group, dict): + group_name = group.get("group_name") + group_avatar = getattr(group, "group_avatar", None) + if group_avatar is None and isinstance(group, dict): + group_avatar = group.get("group_avatar") + group_owner = getattr(group, "group_owner", None) + if group_owner is None and isinstance(group, dict): + group_owner = group.get("group_owner") + group_admins = getattr(group, "group_admins", None) + if group_admins is None and isinstance(group, dict): + group_admins = group.get("group_admins") + return { + "group_id": str(group_id or ""), + "group_name": str(group_name or ""), + "group_avatar": str(group_avatar or ""), + "group_owner": str(group_owner or ""), + "group_admins": ( + [str(item) for item in group_admins] + if isinstance(group_admins, list) + else [] + ), + "members": members_payload, + } + + @staticmethod + def _serialize_platform_error(error: Any) -> dict[str, Any] | None: + if error is None: + return None + message = getattr(error, "message", None) + timestamp = getattr(error, "timestamp", None) + traceback_value = getattr(error, "traceback", None) + if isinstance(error, dict): + message = error.get("message", message) + timestamp = error.get("timestamp", timestamp) + traceback_value = error.get("traceback", traceback_value) + if not message: + return None + return { + "message": str(message), + "timestamp": CapabilityBridgeBase._to_iso_datetime(timestamp) + or str(timestamp or ""), + "traceback": ( + str(traceback_value) if traceback_value is not None else None + ), + } + + @classmethod + def _serialize_platform_snapshot(cls, platform: Any) -> dict[str, Any] | None: + if platform is None: + return None + meta = None + try: + meta = platform.meta() + except Exception: + meta = None + platform_id = str( + getattr(meta, "id", None) or getattr(platform, "config", {}).get("id", "") + ).strip() + platform_type = str(getattr(meta, "name", "") or "").strip() + if not platform_id or not platform_type: + return None + status = getattr(platform, "status", None) + errors = getattr(platform, "errors", []) + status_value = getattr(status, "value", status) + return { + "id": platform_id, + "name": str(getattr(meta, "adapter_display_name", None) or platform_type), + "type": platform_type, + "status": str(status_value or "pending"), + "errors": [ + payload + for payload in ( + cls._serialize_platform_error(item) + for item in (errors if isinstance(errors, list) else []) + ) + if payload is not None + ], + "last_error": cls._serialize_platform_error( + getattr(platform, "last_error", None) + ), + "unified_webhook": bool( + platform.unified_webhook() + if hasattr(platform, "unified_webhook") + else False + ), + } + + @classmethod + def _serialize_platform_stats(cls, stats: Any) -> dict[str, Any] | None: + if not isinstance(stats, dict): + return None + payload = dict(stats) + payload["last_error"] = cls._serialize_platform_error(stats.get("last_error")) + meta = stats.get("meta") + payload["meta"] = dict(meta) if isinstance(meta, dict) else {} + return payload + + def _get_platform_inst_by_id(self, platform_id: str) -> Any | None: + platform_manager = getattr(self._star_context, "platform_manager", None) + if platform_manager is None or not hasattr(platform_manager, "get_insts"): + return None + normalized_platform_id = str(platform_id).strip() + if not normalized_platform_id: + return None + for platform in list(platform_manager.get_insts()): + meta = None + try: + meta = platform.meta() + except Exception: + continue + if str(getattr(meta, "id", "")).strip() == normalized_platform_id: + return platform + return None + + def _resolve_plugin_id(self, request_id: str) -> str: + plugin_id = current_caller_plugin_id() + if plugin_id: + return plugin_id + return self._plugin_bridge.resolve_request_plugin_id(request_id) + + def _reserved_plugin_names(self) -> set[str]: + reserved: set[str] = set() + get_all_stars = getattr(self._star_context, "get_all_stars", None) + if not callable(get_all_stars): + return reserved + stars = get_all_stars() + if not isinstance(stars, Iterable): + return reserved + for star in stars: + name = getattr(star, "name", None) + if name and bool(getattr(star, "reserved", False)): + reserved.add(str(name)) + return reserved + + def _require_reserved_plugin( + self, + request_id: str, + capability_name: str, + ) -> str: + plugin_id = self._resolve_plugin_id(request_id) + if plugin_id in {"system", "__system__"}: + return plugin_id + if plugin_id in self._reserved_plugin_names(): + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} is restricted to reserved/system plugins" + ) + + def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + checker = getattr(self._plugin_bridge, "plugin_supports_platform", None) + if not callable(checker): + return True + return bool(checker(plugin_id, platform_name)) + + def _platform_name_from_id(self, platform_id: str) -> str: + platform = self._get_platform_inst_by_id(platform_id) + if platform is None: + return "" + meta = getattr(platform, "meta", None) + if not callable(meta): + return "" + try: + payload = meta() + except Exception: + return "" + return str(getattr(payload, "name", "") or "").strip().lower() + + def _session_platform_name(self, session: str) -> str: + platform_id = str(session).split(":", maxsplit=1)[0].strip() + if not platform_id: + return "" + return self._platform_name_from_id(platform_id) + + def _require_platform_support_for_session( + self, + request_id: str, + session: str, + capability_name: str, + ) -> str: + plugin_id = self._resolve_plugin_id(request_id) + platform_name = self._session_platform_name(session) + if not platform_name or self._plugin_supports_platform( + plugin_id, platform_name + ): + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} does not support platform '{platform_name}' for plugin '{plugin_id}'" + ) + + def _resolve_dispatch_target( + self, + request_id: str, + payload: dict[str, Any], + ) -> tuple[str, str]: + target_payload = payload.get("target") + dispatch_token = "" + if isinstance(target_payload, dict): + raw_payload = target_payload.get("raw") + if isinstance(raw_payload, dict): + dispatch_token = str(raw_payload.get("dispatch_token", "")) + if not dispatch_token: + nested_raw_payload = raw_payload.get("raw") + if isinstance(nested_raw_payload, dict): + dispatch_token = str( + nested_raw_payload.get("dispatch_token", "") + ) + if not dispatch_token: + request_context = self._plugin_bridge.resolve_request_session(request_id) + if request_context is None: + raise AstrBotError.invalid_input( + "Missing dispatch token for platform send" + ) + dispatch_token = request_context.dispatch_token + session = str(payload.get("session", "")) + return session, dispatch_token + + def _resolve_event_request_context( + self, + request_id: str, + payload: dict[str, Any], + ): + def _has_event(request_context: Any | None) -> bool: + if request_context is None: + return False + has_event = getattr(request_context, "has_event", None) + if has_event is not None: + return bool(has_event) + return hasattr(request_context, "event") + + target_payload = payload.get("target") + dispatch_token = "" + if isinstance(target_payload, dict): + raw_payload = target_payload.get("raw") + if isinstance(raw_payload, dict): + dispatch_token = str(raw_payload.get("dispatch_token", "")) + if not dispatch_token: + nested_raw = raw_payload.get("raw") + if isinstance(nested_raw, dict): + dispatch_token = str(nested_raw.get("dispatch_token", "")) + if dispatch_token: + request_context = self._plugin_bridge.get_request_context_by_token( + dispatch_token + ) + return request_context if _has_event(request_context) else None + request_context = self._plugin_bridge.resolve_request_session(request_id) + return request_context if _has_event(request_context) else None + + def _resolve_current_group_request_context( + self, + request_id: str, + payload: dict[str, Any], + ): + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None: + return None + payload_session = str(payload.get("session", "")).strip() + if payload_session and payload_session != str( + request_context.event.unified_msg_origin + ): + raise AstrBotError.invalid_input( + "platform.get_group/get_members only support the current event session" + ) + return request_context + + @staticmethod + def _build_core_message_chain(chain_payload: list[dict[str, Any]]) -> MessageChain: + return _build_message_chain_from_payload(chain_payload) diff --git a/astrbot/core/sdk_bridge/capabilities/__init__.py b/astrbot/core/sdk_bridge/capabilities/__init__.py new file mode 100644 index 0000000000..4ba44e5e9c --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/__init__.py @@ -0,0 +1,29 @@ +from .basic import BasicCapabilityMixin +from .conversation import ConversationCapabilityMixin +from .kb import KnowledgeBaseCapabilityMixin +from .llm import LLMCapabilityMixin +from .mcp import MCPCapabilityMixin +from .message_history import MessageHistoryCapabilityMixin +from .permission import PermissionCapabilityMixin +from .persona import PersonaCapabilityMixin +from .platform import PlatformCapabilityMixin +from .provider import ProviderCapabilityMixin +from .session import SessionCapabilityMixin +from .skill import SkillCapabilityMixin +from .system import SystemCapabilityMixin + +__all__ = [ + "BasicCapabilityMixin", + "ConversationCapabilityMixin", + "KnowledgeBaseCapabilityMixin", + "LLMCapabilityMixin", + "MCPCapabilityMixin", + "MessageHistoryCapabilityMixin", + "PermissionCapabilityMixin", + "PersonaCapabilityMixin", + "PlatformCapabilityMixin", + "ProviderCapabilityMixin", + "SessionCapabilityMixin", + "SkillCapabilityMixin", + "SystemCapabilityMixin", +] diff --git a/astrbot/core/sdk_bridge/capabilities/_host.py b/astrbot/core/sdk_bridge/capabilities/_host.py new file mode 100644 index 0000000000..c3bda8de05 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/_host.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + + class CapabilityMixinHost: + MEMORY_SCOPE: str + _event_streams: dict[str, Any] + _plugin_bridge: Any + _star_context: Any + _memory_backends_by_plugin: dict[str, Any] + _memory_index_by_plugin: dict[str, dict[str, dict[str, Any]]] + _memory_dirty_keys_by_plugin: dict[str, set[str]] + _memory_expires_at_by_plugin: dict[str, dict[str, Any]] + + def register( + self, + descriptor: Any, + *, + call_handler: Any = None, + stream_handler: Any = None, + finalize: Any = None, + exposed: bool = True, + ) -> None: ... + + def _builtin_descriptor( + self, + name: str, + description: str, + *, + supports_stream: bool = False, + cancelable: bool = False, + ) -> Any: ... + + def _resolve_plugin_id(self, request_id: str) -> str: ... + + def _resolve_dispatch_target( + self, + request_id: str, + payload: dict[str, Any], + ) -> tuple[str, str]: ... + + def _resolve_event_request_context( + self, + request_id: str, + payload: dict[str, Any], + ) -> Any: ... + + def _resolve_current_group_request_context( + self, + request_id: str, + payload: dict[str, Any], + ) -> Any: ... + + def _build_core_message_chain( + self, chain_payload: list[dict[str, Any]] + ) -> Any: ... + + def _serialize_group(self, group: Any) -> dict[str, Any] | None: ... + + def _require_reserved_plugin( + self, + request_id: str, + capability_name: str, + ) -> str: ... + + def _plugin_supports_platform( + self, + plugin_id: str, + platform_name: str, + ) -> bool: ... + + def _platform_name_from_id(self, platform_id: str) -> str: ... + + def _session_platform_name(self, session: str) -> str: ... + + def _require_platform_support_for_session( + self, + request_id: str, + session: str, + capability_name: str, + ) -> str: ... + + def _get_platform_inst_by_id(self, platform_id: str) -> Any | None: ... + + def _serialize_platform_snapshot( + self, platform: Any + ) -> dict[str, Any] | None: ... + + def _serialize_platform_stats(self, stats: Any) -> dict[str, Any] | None: ... + + def _normalize_session_scoped_config( + self, + raw_config: Any, + session_id: str, + ) -> dict[str, Any]: ... + + def _get_typed_provider( + self, + payload: dict[str, Any], + capability_name: str, + provider_label: str, + expected_type: type[Any], + ) -> Any: ... + + def _provider_embedding_get_embedding( + self, + request_id: str, + payload: dict[str, Any], + token: Any, + ) -> Awaitable[dict[str, Any]]: ... + + def _provider_embedding_get_embeddings( + self, + request_id: str, + payload: dict[str, Any], + token: Any, + ) -> Awaitable[dict[str, Any]]: ... + + def _reserved_plugin_names(self) -> set[str]: ... + + def _serialize_persona(self, persona: Any) -> dict[str, Any] | None: ... + + def _normalize_persona_dialogs(self, value: Any) -> list[str]: ... + + def _serialize_conversation( + self, conversation: Any + ) -> dict[str, Any] | None: ... + + def _normalize_history_items(self, value: Any) -> list[dict[str, Any]]: ... + + def _optional_int(self, value: Any) -> int | None: ... + + def _serialize_kb(self, kb_helper_or_record: Any) -> dict[str, Any] | None: ... + + def _serialize_kb_document(self, document: Any) -> dict[str, Any] | None: ... + +else: + + class CapabilityMixinHost: + # Keep the runtime host empty so it cannot shadow CapabilityRouter methods in + # CoreCapabilityBridge's MRO. The typed method declarations above are only for + # static analysis. + pass diff --git a/astrbot/core/sdk_bridge/capabilities/basic.py b/astrbot/core/sdk_bridge/capabilities/basic.py new file mode 100644 index 0000000000..8a4bc765d1 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/basic.py @@ -0,0 +1,698 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from astrbot_sdk._memory_backend import PluginMemoryBackend +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import StreamExecution + +from astrbot.core.utils.astrbot_path import get_astrbot_plugin_data_path + +from ..bridge_base import _get_runtime_provider_types, _get_runtime_sp +from ._host import CapabilityMixinHost + + +class BasicCapabilityMixin(CapabilityMixinHost): + def _memory_backend_for_plugin(self, plugin_id: str) -> PluginMemoryBackend: + backend = self._memory_backends_by_plugin.get(plugin_id) + if backend is None: + backend = PluginMemoryBackend( + Path(get_astrbot_plugin_data_path()) / plugin_id + ) + self._memory_backends_by_plugin[plugin_id] = backend + return backend + + def _resolve_memory_embedding_provider_id( + self, + payload: dict[str, Any], + *, + required: bool, + ) -> str | None: + provider_id = str(payload.get("provider_id", "")).strip() + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + if provider_id: + provider = self._star_context.get_provider_by_id(provider_id) + if provider is None or not isinstance(provider, embedding_provider_cls): + raise AstrBotError.invalid_input( + f"memory.search unknown embedding provider: {provider_id}" + ) + return provider_id + providers = self._star_context.get_all_embedding_providers() + if providers: + provider = providers[0] + provider_id = str(getattr(provider.meta(), "id", "") or "").strip() + if provider_id: + return provider_id + if required: + raise AstrBotError.invalid_input( + "memory.search requires an embedding provider", + ) + return None + + def _register_db_capabilities(self) -> None: + self.register( + self._builtin_descriptor("db.get", "Read plugin kv"), + call_handler=self._db_get, + ) + self.register( + self._builtin_descriptor("db.set", "Write plugin kv"), + call_handler=self._db_set, + ) + self.register( + self._builtin_descriptor("db.delete", "Delete plugin kv"), + call_handler=self._db_delete, + ) + self.register( + self._builtin_descriptor("db.list", "List plugin kv"), + call_handler=self._db_list, + ) + self.register( + self._builtin_descriptor("db.get_many", "Read plugin kv in batch"), + call_handler=self._db_get_many, + ) + self.register( + self._builtin_descriptor("db.set_many", "Write plugin kv in batch"), + call_handler=self._db_set_many, + ) + self.register( + self._builtin_descriptor( + "db.watch", + "Watch plugin kv", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._db_watch, + ) + + async def _db_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "value": await _get_runtime_sp().get_async( + "plugin", + plugin_id, + str(payload.get("key", "")), + None, + ) + } + + async def _db_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + await _get_runtime_sp().put_async( + "plugin", + plugin_id, + str(payload.get("key", "")), + payload.get("value"), + ) + return {} + + async def _db_delete( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + await _get_runtime_sp().remove_async( + "plugin", + plugin_id, + str(payload.get("key", "")), + ) + return {} + + async def _db_list( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + prefix = payload.get("prefix") + prefix_value = str(prefix) if isinstance(prefix, str) else None + items = await _get_runtime_sp().range_get_async("plugin", plugin_id, None) + keys = sorted( + item.key + for item in items + if prefix_value is None or item.key.startswith(prefix_value) + ) + return {"keys": keys} + + async def _db_get_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys_payload = payload.get("keys") + if not isinstance(keys_payload, list): + raise AstrBotError.invalid_input("db.get_many requires a keys array") + items = [] + for key in keys_payload: + key_text = str(key) + items.append( + { + "key": key_text, + "value": await _get_runtime_sp().get_async( + "plugin", + plugin_id, + key_text, + None, + ), + } + ) + return {"items": items} + + async def _db_set_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + items_payload = payload.get("items") + if not isinstance(items_payload, list): + raise AstrBotError.invalid_input("db.set_many requires an items array") + for item in items_payload: + if not isinstance(item, dict): + raise AstrBotError.invalid_input("db.set_many items must be objects") + await _get_runtime_sp().put_async( + "plugin", + plugin_id, + str(item.get("key", "")), + item.get("value"), + ) + return {} + + async def _db_watch( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> StreamExecution: + raise AstrBotError.invalid_input( + "db.watch is unsupported in AstrBot SDK MVP", + hint="Use db.get/list polling in MVP", + ) + + def _register_memory_capabilities(self) -> None: + self.register( + self._builtin_descriptor("memory.search", "Search plugin memory"), + call_handler=self._memory_search, + ) + self.register( + self._builtin_descriptor("memory.save", "Save plugin memory"), + call_handler=self._memory_save, + ) + self.register( + self._builtin_descriptor("memory.get", "Get plugin memory"), + call_handler=self._memory_get, + ) + self.register( + self._builtin_descriptor("memory.list_keys", "List plugin memory keys"), + call_handler=self._memory_list_keys, + ) + self.register( + self._builtin_descriptor("memory.exists", "Check plugin memory key"), + call_handler=self._memory_exists, + ) + self.register( + self._builtin_descriptor("memory.delete", "Delete plugin memory"), + call_handler=self._memory_delete, + ) + self.register( + self._builtin_descriptor( + "memory.clear_namespace", + "Delete plugin memory in a namespace", + ), + call_handler=self._memory_clear_namespace, + ) + self.register( + self._builtin_descriptor( + "memory.save_with_ttl", + "Save plugin memory with ttl metadata", + ), + call_handler=self._memory_save_with_ttl, + ) + self.register( + self._builtin_descriptor("memory.get_many", "Get plugin memories"), + call_handler=self._memory_get_many, + ) + self.register( + self._builtin_descriptor("memory.delete_many", "Delete plugin memories"), + call_handler=self._memory_delete_many, + ) + self.register( + self._builtin_descriptor("memory.count", "Count plugin memories"), + call_handler=self._memory_count, + ) + self.register( + self._builtin_descriptor("memory.stats", "Get plugin memory stats"), + call_handler=self._memory_stats, + ) + + async def _memory_search( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + query = str(payload.get("query", "")) + mode = str(payload.get("mode", "auto")).strip().lower() or "auto" + limit = self._optional_int(payload.get("limit")) + raw_min_score = payload.get("min_score") + min_score = float(raw_min_score) if raw_min_score is not None else None + namespace = str(payload.get("namespace")) if payload.get("namespace") else None + include_descendants = bool(payload.get("include_descendants", True)) + provider_id = self._resolve_memory_embedding_provider_id( + payload, + required=mode in {"vector", "hybrid"}, + ) + effective_mode = mode + if effective_mode == "auto": + effective_mode = "hybrid" if provider_id is not None else "keyword" + backend = self._memory_backend_for_plugin(plugin_id) + items = await backend.search( + query, + namespace=namespace, + include_descendants=include_descendants, + mode=effective_mode, + limit=limit, + min_score=min_score, + provider_id=provider_id, + embed_one=( + ( + lambda text: self._memory_embedding_for_text( + request_id, + provider_id, + text, + _token, + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + embed_many=( + ( + lambda texts: self._memory_embeddings_for_texts( + request_id, + provider_id, + texts, + _token, + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + ) + return {"items": items} + + async def _memory_embedding_for_text( + self, + request_id: str, + provider_id: str, + text: str, + token, + ) -> list[float]: + output = await self._provider_embedding_get_embedding( + request_id, + {"provider_id": provider_id, "text": text}, + token, + ) + embedding = output.get("embedding") + if not isinstance(embedding, list): + return [] + return [float(item) for item in embedding] + + async def _memory_embeddings_for_texts( + self, + request_id: str, + provider_id: str, + texts: list[str], + token, + ) -> list[list[float]]: + output = await self._provider_embedding_get_embeddings( + request_id, + {"provider_id": provider_id, "texts": texts}, + token, + ) + embeddings = output.get("embeddings") + if not isinstance(embeddings, list): + return [] + return [ + [float(value) for value in item] + for item in embeddings + if isinstance(item, list) + ] + + async def _memory_save( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + value = payload.get("value") + if not isinstance(value, dict): + raise AstrBotError.invalid_input("memory.save requires an object value") + await self._memory_backend_for_plugin(plugin_id).save( + str(payload.get("key", "")), + value, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + value = await self._memory_backend_for_plugin(plugin_id).get( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"value": value} + + async def _memory_list_keys( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys = await self._memory_backend_for_plugin(plugin_id).list_keys( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"keys": keys} + + async def _memory_exists( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + exists = await self._memory_backend_for_plugin(plugin_id).exists( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"exists": exists} + + async def _memory_delete( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + await self._memory_backend_for_plugin(plugin_id).delete( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_clear_namespace( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + deleted_count = await self._memory_backend_for_plugin( + plugin_id + ).clear_namespace( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"deleted_count": deleted_count} + + async def _memory_save_with_ttl( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + value = payload.get("value") + if not isinstance(value, dict): + raise AstrBotError.invalid_input( + "memory.save_with_ttl requires an object value" + ) + ttl_seconds = int(payload.get("ttl_seconds", 0)) + await self._memory_backend_for_plugin(plugin_id).save_with_ttl( + str(payload.get("key", "")), + value, + ttl_seconds, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys_payload = payload.get("keys") + if not isinstance(keys_payload, list): + raise AstrBotError.invalid_input("memory.get_many requires a keys array") + items = await self._memory_backend_for_plugin(plugin_id).get_many( + [str(key) for key in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"items": items} + + async def _memory_delete_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys_payload = payload.get("keys") + if not isinstance(keys_payload, list): + raise AstrBotError.invalid_input("memory.delete_many requires a keys array") + deleted_count = await self._memory_backend_for_plugin(plugin_id).delete_many( + [str(key) for key in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"deleted_count": deleted_count} + + async def _memory_count( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + count = await self._memory_backend_for_plugin(plugin_id).count( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"count": count} + + async def _memory_stats( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + stats = await self._memory_backend_for_plugin(plugin_id).stats( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", True)), + ) + stats["plugin_id"] = plugin_id + return stats + + def _register_http_capabilities(self) -> None: + self.register( + self._builtin_descriptor("http.register_api", "Register http route"), + call_handler=self._http_register_api, + ) + self.register( + self._builtin_descriptor("http.unregister_api", "Unregister http route"), + call_handler=self._http_unregister_api, + ) + self.register( + self._builtin_descriptor("http.list_apis", "List http routes"), + call_handler=self._http_list_apis, + ) + + async def _http_register_api( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + methods = payload.get("methods") + if not isinstance(methods, list) or not all( + isinstance(item, str) for item in methods + ): + raise AstrBotError.invalid_input( + "http.register_api requires a string methods array" + ) + self._plugin_bridge.register_http_api( + plugin_id=plugin_id, + route=str(payload.get("route", "")), + methods=methods, + handler_capability=str(payload.get("handler_capability", "")), + description=str(payload.get("description", "")), + ) + return {} + + async def _http_unregister_api( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + methods = payload.get("methods") + if not isinstance(methods, list) or not all( + isinstance(item, str) for item in methods + ): + raise AstrBotError.invalid_input( + "http.unregister_api requires a string methods array" + ) + self._plugin_bridge.unregister_http_api( + plugin_id=plugin_id, + route=str(payload.get("route", "")), + methods=methods, + ) + return {} + + async def _http_list_apis( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return {"apis": self._plugin_bridge.list_http_apis(plugin_id)} + + def _register_metadata_capabilities(self) -> None: + self.register( + self._builtin_descriptor("metadata.get_plugin", "Get plugin metadata"), + call_handler=self._metadata_get_plugin, + ) + self.register( + self._builtin_descriptor("metadata.list_plugins", "List plugins metadata"), + call_handler=self._metadata_list_plugins, + ) + self.register( + self._builtin_descriptor( + "metadata.get_plugin_config", + "Get current plugin config", + ), + call_handler=self._metadata_get_plugin_config, + ) + self.register( + self._builtin_descriptor( + "metadata.save_plugin_config", + "Save current plugin config", + ), + call_handler=self._metadata_save_plugin_config, + ) + + async def _metadata_get_plugin( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin = self._plugin_bridge.get_plugin_metadata(str(payload.get("name", ""))) + return {"plugin": plugin} + + async def _metadata_list_plugins( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return {"plugins": self._plugin_bridge.list_plugin_metadata()} + + async def _metadata_get_plugin_config( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + requested = str(payload.get("name", "")) + if requested != plugin_id: + return {"config": None} + return {"config": self._plugin_bridge.get_plugin_config(plugin_id)} + + async def _metadata_save_plugin_config( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input( + "metadata.save_plugin_config requires config object" + ) + return {"config": self._plugin_bridge.save_plugin_config(plugin_id, config)} diff --git a/astrbot/core/sdk_bridge/capabilities/conversation.py b/astrbot/core/sdk_bridge/capabilities/conversation.py new file mode 100644 index 0000000000..90ba6a15fa --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/conversation.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from astrbot_sdk.errors import AstrBotError + +from ._host import CapabilityMixinHost + + +class ConversationCapabilityMixin(CapabilityMixinHost): + def _register_conversation_capabilities(self) -> None: + self.register( + self._builtin_descriptor("conversation.new", "Create conversation"), + call_handler=self._conversation_new, + ) + self.register( + self._builtin_descriptor("conversation.switch", "Switch conversation"), + call_handler=self._conversation_switch, + ) + self.register( + self._builtin_descriptor("conversation.delete", "Delete conversation"), + call_handler=self._conversation_delete, + ) + self.register( + self._builtin_descriptor("conversation.get", "Get conversation"), + call_handler=self._conversation_get, + ) + self.register( + self._builtin_descriptor( + "conversation.get_current", + "Get current conversation", + ), + call_handler=self._conversation_get_current, + ) + self.register( + self._builtin_descriptor("conversation.list", "List conversations"), + call_handler=self._conversation_list, + ) + self.register( + self._builtin_descriptor("conversation.update", "Update conversation"), + call_handler=self._conversation_update, + ) + self.register( + self._builtin_descriptor( + "conversation.unset_persona", + "Unset conversation persona override", + ), + call_handler=self._conversation_unset_persona, + ) + + async def _conversation_new( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = str(payload.get("session", "")).strip() + if not session: + raise AstrBotError.invalid_input("conversation.new requires session") + raw_conversation = payload.get("conversation") + if raw_conversation is None: + raw_conversation = {} + if not isinstance(raw_conversation, dict): + raise AstrBotError.invalid_input( + "conversation.new requires conversation object" + ) + conversation_id = ( + await self._star_context.conversation_manager.new_conversation( + unified_msg_origin=session, + platform_id=( + str(raw_conversation.get("platform_id")) + if raw_conversation.get("platform_id") is not None + else None + ), + content=self._normalize_history_items(raw_conversation.get("history")), + title=( + str(raw_conversation.get("title")) + if raw_conversation.get("title") is not None + else None + ), + persona_id=( + str(raw_conversation.get("persona_id")) + if raw_conversation.get("persona_id") is not None + else None + ), + ) + ) + return {"conversation_id": conversation_id} + + async def _conversation_switch( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = str(payload.get("session", "")).strip() + conversation_id = str(payload.get("conversation_id", "")).strip() + if not session: + raise AstrBotError.invalid_input("conversation.switch requires session") + if not conversation_id: + raise AstrBotError.invalid_input( + "conversation.switch requires conversation_id" + ) + await self._star_context.conversation_manager.switch_conversation( + unified_msg_origin=session, + conversation_id=conversation_id, + ) + return {} + + async def _conversation_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + await self._star_context.conversation_manager.delete_conversation( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=( + str(payload.get("conversation_id")) + if payload.get("conversation_id") is not None + else None + ), + ) + return {} + + async def _conversation_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + conversation = await self._star_context.conversation_manager.get_conversation( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=str(payload.get("conversation_id", "")), + create_if_not_exists=bool(payload.get("create_if_not_exists", False)), + ) + return {"conversation": self._serialize_conversation(conversation)} + + async def _conversation_get_current( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = str(payload.get("session", "")) + conversation_id = ( + await self._star_context.conversation_manager.get_curr_conversation_id( + session + ) + ) + if not conversation_id and bool(payload.get("create_if_not_exists", False)): + conversation_id = ( + await self._star_context.conversation_manager.new_conversation(session) + ) + if not conversation_id: + return {"conversation": None} + conversation = await self._star_context.conversation_manager.get_conversation( + unified_msg_origin=session, + conversation_id=conversation_id, + create_if_not_exists=bool(payload.get("create_if_not_exists", False)), + ) + return {"conversation": self._serialize_conversation(conversation)} + + async def _conversation_list( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = payload.get("session") + platform_id = payload.get("platform_id") + conversations = await self._star_context.conversation_manager.get_conversations( + unified_msg_origin=( + str(session) if session is not None and str(session).strip() else None + ), + platform_id=( + str(platform_id) + if platform_id is not None and str(platform_id).strip() + else None + ), + ) + return { + "conversations": [ + item + for item in ( + self._serialize_conversation(conversation) + for conversation in conversations + ) + if item is not None + ] + } + + async def _conversation_update( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_conversation = payload.get("conversation") + if raw_conversation is None: + raw_conversation = {} + if not isinstance(raw_conversation, dict): + raise AstrBotError.invalid_input( + "conversation.update requires conversation object" + ) + await self._star_context.conversation_manager.update_conversation( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=( + str(payload.get("conversation_id")) + if payload.get("conversation_id") is not None + else None + ), + history=( + self._normalize_history_items(raw_conversation.get("history")) + if "history" in raw_conversation + else None + ), + title=( + str(raw_conversation.get("title")) + if raw_conversation.get("title") is not None + else None + ), + persona_id=( + str(raw_conversation.get("persona_id")) + if raw_conversation.get("persona_id") is not None + else None + ), + token_usage=self._optional_int(raw_conversation.get("token_usage")), + ) + return {} + + async def _conversation_unset_persona( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + await self._star_context.conversation_manager.unset_conversation_persona( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=( + str(payload.get("conversation_id")) + if payload.get("conversation_id") is not None + else None + ), + ) + return {} diff --git a/astrbot/core/sdk_bridge/capabilities/kb.py b/astrbot/core/sdk_bridge/capabilities/kb.py new file mode 100644 index 0000000000..fe252d414f --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/kb.py @@ -0,0 +1,456 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core.sdk_bridge.bridge_base import _get_runtime_file_token_service + +from ._host import CapabilityMixinHost + + +class KnowledgeBaseCapabilityMixin(CapabilityMixinHost): + def _register_kb_capabilities(self) -> None: + self.register( + self._builtin_descriptor("kb.list", "List knowledge bases"), + call_handler=self._kb_list, + ) + self.register( + self._builtin_descriptor("kb.get", "Get knowledge base"), + call_handler=self._kb_get, + ) + self.register( + self._builtin_descriptor("kb.create", "Create knowledge base"), + call_handler=self._kb_create, + ) + self.register( + self._builtin_descriptor("kb.update", "Update knowledge base"), + call_handler=self._kb_update, + ) + self.register( + self._builtin_descriptor("kb.delete", "Delete knowledge base"), + call_handler=self._kb_delete, + ) + self.register( + self._builtin_descriptor("kb.retrieve", "Retrieve from knowledge bases"), + call_handler=self._kb_retrieve, + ) + self.register( + self._builtin_descriptor( + "kb.document.upload", "Upload knowledge base document" + ), + call_handler=self._kb_document_upload, + ) + self.register( + self._builtin_descriptor( + "kb.document.list", "List knowledge base documents" + ), + call_handler=self._kb_document_list, + ) + self.register( + self._builtin_descriptor("kb.document.get", "Get knowledge base document"), + call_handler=self._kb_document_get, + ) + self.register( + self._builtin_descriptor( + "kb.document.delete", + "Delete knowledge base document", + ), + call_handler=self._kb_document_delete, + ) + self.register( + self._builtin_descriptor( + "kb.document.refresh", + "Refresh knowledge base document", + ), + call_handler=self._kb_document_refresh, + ) + + async def _get_kb_helper(self, kb_id: str): + return await self._star_context.kb_manager.get_kb(kb_id) + + async def _require_kb_helper(self, kb_id: str): + kb_id_text = str(kb_id).strip() + if not kb_id_text: + raise AstrBotError.invalid_input("kb capability requires kb_id") + kb_helper = await self._get_kb_helper(kb_id_text) + if kb_helper is None: + raise AstrBotError.invalid_input(f"Unknown knowledge base: {kb_id_text}") + return kb_helper + + @staticmethod + def _normalize_kb_names(payload: dict[str, Any]) -> list[str]: + raw_names = payload.get("kb_names") + if not isinstance(raw_names, list): + return [] + return [str(item).strip() for item in raw_names if str(item).strip()] + + @staticmethod + def _normalize_kb_ids(payload: dict[str, Any]) -> list[str]: + raw_ids = payload.get("kb_ids") + if not isinstance(raw_ids, list): + return [] + return [str(item).strip() for item in raw_ids if str(item).strip()] + + async def _resolve_retrieve_kb_names( + self, + payload: dict[str, Any], + ) -> list[str]: + kb_names = self._normalize_kb_names(payload) + if kb_names: + return kb_names + resolved_names: list[str] = [] + for kb_id in self._normalize_kb_ids(payload): + kb_helper = await self._get_kb_helper(kb_id) + if kb_helper is not None and getattr(kb_helper, "kb", None) is not None: + kb_name = str(getattr(kb_helper.kb, "kb_name", "")).strip() + if kb_name: + resolved_names.append(kb_name) + return resolved_names + + async def _kb_list( + self, + _request_id: str, + _payload: dict[str, object], + _token, + ) -> dict[str, object]: + kbs = await self._star_context.kb_manager.list_kbs() + return { + "kbs": [ + payload + for payload in (self._serialize_kb(kb) for kb in kbs) + if payload is not None + ] + } + + async def _kb_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._get_kb_helper(str(payload.get("kb_id", ""))) + return {"kb": self._serialize_kb(kb_helper)} + + async def _kb_create( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.create requires kb object") + try: + kb_helper = await self._star_context.kb_manager.create_kb( + kb_name=str(raw_kb.get("kb_name", "")), + description=( + str(raw_kb.get("description")) + if raw_kb.get("description") is not None + else None + ), + emoji=( + str(raw_kb.get("emoji")) + if raw_kb.get("emoji") is not None + else None + ), + embedding_provider_id=( + str(raw_kb.get("embedding_provider_id")) + if raw_kb.get("embedding_provider_id") is not None + else None + ), + rerank_provider_id=( + str(raw_kb.get("rerank_provider_id")) + if raw_kb.get("rerank_provider_id") is not None + else None + ), + chunk_size=self._optional_int(raw_kb.get("chunk_size")), + chunk_overlap=self._optional_int(raw_kb.get("chunk_overlap")), + top_k_dense=self._optional_int(raw_kb.get("top_k_dense")), + top_k_sparse=self._optional_int(raw_kb.get("top_k_sparse")), + top_m_final=self._optional_int(raw_kb.get("top_m_final")), + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"kb": self._serialize_kb(kb_helper)} + + async def _kb_update( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_id = str(payload.get("kb_id", "")).strip() + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.update requires kb object") + kb_helper = await self._get_kb_helper(kb_id) + if kb_helper is None: + return {"kb": None} + current_kb = getattr(kb_helper, "kb", None) + kb_name = raw_kb.get("kb_name") + try: + updated_helper = await self._star_context.kb_manager.update_kb( + kb_id=kb_id, + kb_name=( + str(kb_name) + if kb_name is not None + else str(getattr(current_kb, "kb_name", "")) + ), + description=( + str(raw_kb.get("description")) + if raw_kb.get("description") is not None + else None + ) + if "description" in raw_kb + else None, + emoji=( + str(raw_kb.get("emoji")) + if raw_kb.get("emoji") is not None + else None + ) + if "emoji" in raw_kb + else None, + embedding_provider_id=( + str(raw_kb.get("embedding_provider_id")) + if raw_kb.get("embedding_provider_id") is not None + else None + ) + if "embedding_provider_id" in raw_kb + else None, + rerank_provider_id=( + str(raw_kb.get("rerank_provider_id")) + if raw_kb.get("rerank_provider_id") is not None + else None + ) + if "rerank_provider_id" in raw_kb + else None, + chunk_size=( + self._optional_int(raw_kb.get("chunk_size")) + if "chunk_size" in raw_kb + else None + ), + chunk_overlap=( + self._optional_int(raw_kb.get("chunk_overlap")) + if "chunk_overlap" in raw_kb + else None + ), + top_k_dense=( + self._optional_int(raw_kb.get("top_k_dense")) + if "top_k_dense" in raw_kb + else None + ), + top_k_sparse=( + self._optional_int(raw_kb.get("top_k_sparse")) + if "top_k_sparse" in raw_kb + else None + ), + top_m_final=( + self._optional_int(raw_kb.get("top_m_final")) + if "top_m_final" in raw_kb + else None + ), + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"kb": self._serialize_kb(updated_helper)} + + async def _kb_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + deleted = await self._star_context.kb_manager.delete_kb( + str(payload.get("kb_id", "")) + ) + return {"deleted": bool(deleted)} + + async def _kb_retrieve( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + query = str(payload.get("query", "")).strip() + if not query: + raise AstrBotError.invalid_input("kb.retrieve requires query") + kb_names = await self._resolve_retrieve_kb_names(payload) + if not kb_names: + raise AstrBotError.invalid_input("kb.retrieve requires kb_ids or kb_names") + result = await self._star_context.kb_manager.retrieve( + query=query, + kb_names=kb_names, + top_k_fusion=self._optional_int(payload.get("top_k_fusion")) or 20, + top_m_final=self._optional_int(payload.get("top_m_final")) or 5, + ) + if result is None: + return {"result": None} + return {"result": dict(result)} + + async def _kb_document_upload( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_id = str(payload.get("kb_id", "")).strip() + kb_helper = await self._require_kb_helper(kb_id) + raw_document = payload.get("document") + if not isinstance(raw_document, dict): + raise AstrBotError.invalid_input( + "kb.document.upload requires document object" + ) + + text_value = raw_document.get("text") + if isinstance(text_value, str) and text_value.strip(): + file_name = str(raw_document.get("file_name", "")).strip() or "document.txt" + file_type = ( + str(raw_document.get("file_type", "")).strip() + or Path(file_name).suffix.lstrip(".") + or "txt" + ) + document = await kb_helper.upload_document( + file_name=file_name, + file_content=None, + file_type=file_type, + chunk_size=self._optional_int(raw_document.get("chunk_size")) or 512, + chunk_overlap=( + self._optional_int(raw_document.get("chunk_overlap")) or 50 + ), + batch_size=self._optional_int(raw_document.get("batch_size")) or 32, + tasks_limit=self._optional_int(raw_document.get("tasks_limit")) or 3, + max_retries=self._optional_int(raw_document.get("max_retries")) or 3, + pre_chunked_text=[text_value], + ) + return {"document": self._serialize_kb_document(document)} + + url_value = raw_document.get("url") + if isinstance(url_value, str) and url_value.strip(): + try: + document = await self._star_context.kb_manager.upload_from_url( + kb_id=kb_id, + url=url_value.strip(), + chunk_size=self._optional_int(raw_document.get("chunk_size")) + or 512, + chunk_overlap=( + self._optional_int(raw_document.get("chunk_overlap")) or 50 + ), + batch_size=self._optional_int(raw_document.get("batch_size")) or 32, + tasks_limit=self._optional_int(raw_document.get("tasks_limit")) + or 3, + max_retries=self._optional_int(raw_document.get("max_retries")) + or 3, + enable_cleaning=bool(raw_document.get("enable_cleaning", False)), + cleaning_provider_id=( + str(raw_document.get("cleaning_provider_id")) + if raw_document.get("cleaning_provider_id") is not None + else None + ), + ) + except (OSError, ValueError) as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"document": self._serialize_kb_document(document)} + + file_token = str(raw_document.get("file_token", "")).strip() + if not file_token: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_token, url, or text" + ) + try: + file_path = await _get_runtime_file_token_service().handle_file(file_token) + except KeyError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + path = Path(file_path) + if not path.exists(): + raise AstrBotError.invalid_input(f"File does not exist: {file_path}") + file_name = str(raw_document.get("file_name", "")).strip() or path.name + file_type = str( + raw_document.get("file_type", "") + ).strip() or path.suffix.lstrip(".") + if not file_type: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_type when the file has no suffix" + ) + file_content = await asyncio.to_thread(path.read_bytes) + try: + document = await kb_helper.upload_document( + file_name=file_name, + file_content=file_content, + file_type=file_type, + chunk_size=self._optional_int(raw_document.get("chunk_size")) or 512, + chunk_overlap=( + self._optional_int(raw_document.get("chunk_overlap")) or 50 + ), + batch_size=self._optional_int(raw_document.get("batch_size")) or 32, + tasks_limit=self._optional_int(raw_document.get("tasks_limit")) or 3, + max_retries=self._optional_int(raw_document.get("max_retries")) or 3, + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"document": self._serialize_kb_document(document)} + + async def _kb_document_list( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + documents = await kb_helper.list_documents( + offset=self._optional_int(payload.get("offset")) or 0, + limit=self._optional_int(payload.get("limit")) or 100, + ) + return { + "documents": [ + item + for item in ( + self._serialize_kb_document(document) for document in documents + ) + if item is not None + ] + } + + async def _kb_document_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + document = await kb_helper.get_document(str(payload.get("doc_id", ""))) + return {"document": self._serialize_kb_document(document)} + + async def _kb_document_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + doc_id = str(payload.get("doc_id", "")).strip() + existing_document = await kb_helper.get_document(doc_id) + if existing_document is None: + return {"deleted": False} + await kb_helper.delete_document(doc_id) + return {"deleted": True} + + async def _kb_document_refresh( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + doc_id = str(payload.get("doc_id", "")).strip() + document = await kb_helper.get_document(doc_id) + if document is None: + return {"document": None} + try: + await kb_helper.refresh_document(doc_id) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + refreshed_document = await kb_helper.get_document(doc_id) + return {"document": self._serialize_kb_document(refreshed_document)} diff --git a/astrbot/core/sdk_bridge/capabilities/llm.py b/astrbot/core/sdk_bridge/capabilities/llm.py new file mode 100644 index 0000000000..c5bd47fb87 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/llm.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, Protocol, TypeGuard + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import StreamExecution + +from astrbot import logger + +from ..bridge_base import _get_runtime_tool_types +from ._host import CapabilityMixinHost + +if TYPE_CHECKING: + from astrbot.core.agent.tool import ToolSet + from astrbot.core.provider.entities import LLMResponse + + +class _ChatProvider(Protocol): + async def text_chat(self, **kwargs: Any) -> LLMResponse: ... + + async def text_chat_stream(self, **kwargs: Any) -> AsyncIterator[LLMResponse]: ... + + +class _ProviderMetaLike(Protocol): + id: str + model: str | None + + +class LLMCapabilityMixin(CapabilityMixinHost): + def _register_llm_capabilities(self) -> None: + self.register( + self._builtin_descriptor("llm.chat", "Send chat request"), + call_handler=self._llm_chat, + ) + self.register( + self._builtin_descriptor( + "llm.chat_raw", + "Send chat request and return raw response", + ), + call_handler=self._llm_chat_raw, + ) + self.register( + self._builtin_descriptor( + "llm.stream_chat", + "Stream chat response", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._llm_stream_chat, + ) + + async def _llm_chat( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + response = await self._call_llm(payload, request_id=request_id) + return {"text": response.completion_text} + + async def _llm_chat_raw( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + response = await self._call_llm(payload, request_id=request_id) + usage = None + if response.usage is not None: + usage = { + "input_tokens": response.usage.input, + "output_tokens": response.usage.output, + "total_tokens": response.usage.total, + } + return { + "text": response.completion_text, + "usage": usage, + "finish_reason": "tool_calls" if response.tools_call_ids else "stop", + "tool_calls": response.to_openai_tool_calls(), + "role": response.role, + "reasoning_content": response.reasoning_content or None, + "reasoning_signature": response.reasoning_signature, + } + + async def _llm_stream_chat( + self, + request_id: str, + payload: dict[str, Any], + token, + ) -> StreamExecution: + provider, request_kwargs = self._resolve_llm_request( + payload, + request_id=request_id, + ) + started_at = time.perf_counter() + provider_label = self._describe_provider(provider) + + async def fallback_iterator() -> AsyncIterator[dict[str, Any]]: + logger.warning( + f"SDK llm.stream_chat fell back to non-streaming provider.text_chat for {provider_label}" + ) + response = await provider.text_chat(**request_kwargs) + logger.info( + f"SDK llm.stream_chat fallback first output for {provider_label} after {time.perf_counter() - started_at:.3f}s" + ) + for char in response.completion_text: + token.raise_if_cancelled() + await asyncio.sleep(0) + yield {"text": char} + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + stream = provider.text_chat_stream(**request_kwargs) + yielded_text = False + first_text_logged = False + async for response in stream: + token.raise_if_cancelled() + text = response.completion_text + if response.is_chunk: + if text: + if not first_text_logged: + first_text_logged = True + logger.info( + f"SDK llm.stream_chat first streamed chunk for {provider_label} after {time.perf_counter() - started_at:.3f}s" + ) + yielded_text = True + yield {"text": text} + continue + if text: + if not first_text_logged: + first_text_logged = True + logger.info( + f"SDK llm.stream_chat first final chunk for {provider_label} after {time.perf_counter() - started_at:.3f}s" + ) + if yielded_text: + yield {"_final_text": text} + else: + yielded_text = True + yield {"text": text, "_final_text": text} + else: + yield {"_final_text": text} + except NotImplementedError: + async for item in fallback_iterator(): + yield item + + def finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]: + final_text = None + for item in reversed(chunks): + if "_final_text" in item: + final_text = str(item.get("_final_text", "")) + break + if final_text is None: + final_text = "".join(str(item.get("text", "")) for item in chunks) + return {"text": final_text} + + return StreamExecution( + iterator=iterator(), + finalize=finalize, + ) + + async def _call_llm( + self, + payload: dict[str, Any], + *, + request_id: str, + ) -> LLMResponse: + provider, request_kwargs = self._resolve_llm_request( + payload, + request_id=request_id, + ) + return await provider.text_chat(**request_kwargs) + + def _resolve_llm_request( + self, + payload: dict[str, Any], + *, + request_id: str, + ) -> tuple[_ChatProvider, dict[str, Any]]: + request_context = self._plugin_bridge.resolve_request_session(request_id) + provider_id = payload.get("provider_id") + if provider_id: + provider = self._star_context.get_provider_by_id(str(provider_id)) + else: + request_context_has_event = False + if request_context is not None: + has_event = getattr(request_context, "has_event", None) + request_context_has_event = ( + bool(has_event) + if has_event is not None + else hasattr(request_context, "event") + ) + provider = self._star_context.get_using_provider( + request_context.event.unified_msg_origin + if request_context is not None and request_context_has_event + else None, + ) + if provider is None: + raise AstrBotError.internal_error( + "No active chat provider is available", + hint="Please configure a chat provider in AstrBot first", + ) + if not self._is_chat_provider(provider): + raise AstrBotError.invalid_input( + f"Provider '{provider_id}' is not a chat provider", + hint="Please choose a configured chat provider for llm.chat requests", + ) + return provider, self._normalize_llm_payload(payload) + + @staticmethod + def _describe_provider(provider: _ChatProvider) -> str: + provider_meta_getter = getattr(provider, "meta", None) + if not callable(provider_meta_getter): + return provider.__class__.__name__ + provider_meta = provider_meta_getter() + if not LLMCapabilityMixin._is_provider_meta(provider_meta): + return provider.__class__.__name__ + return f"{provider_meta.id}/{provider_meta.model}" + + @staticmethod + def _is_chat_provider(provider: object) -> TypeGuard[_ChatProvider]: + return callable(getattr(provider, "text_chat", None)) and callable( + getattr(provider, "text_chat_stream", None) + ) + + @staticmethod + def _is_provider_meta(value: object) -> TypeGuard[_ProviderMetaLike]: + return hasattr(value, "id") and hasattr(value, "model") + + @staticmethod + def _normalize_llm_payload(payload: dict[str, Any]) -> dict[str, Any]: + contexts_payload = payload.get("contexts") + if contexts_payload is None: + contexts_payload = payload.get("history") + contexts = ( + [dict(item) for item in contexts_payload] + if isinstance(contexts_payload, list) + else None + ) + image_urls = payload.get("image_urls") + tool_calls_result = payload.get("tool_calls_result") + tools_payload = payload.get("tools") + request_kwargs: dict[str, Any] = { + "prompt": str(payload.get("prompt", "")), + "image_urls": ( + [str(item) for item in image_urls] + if isinstance(image_urls, list) + else None + ), + "func_tool": ( + LLMCapabilityMixin._build_toolset(tools_payload) + if isinstance(tools_payload, list) + else None + ), + "contexts": contexts, + "tool_calls_result": ( + [dict(item) for item in tool_calls_result] + if isinstance(tool_calls_result, list) + else None + ), + "system_prompt": str(payload.get("system", "")), + "model": (str(payload["model"]) if payload.get("model") else None), + "temperature": payload.get("temperature"), + } + return request_kwargs + + @staticmethod + def _build_toolset(tools_payload: list[Any]) -> ToolSet: + function_tool_cls, tool_set_cls = _get_runtime_tool_types() + tool_set = tool_set_cls() + for item in tools_payload: + if not isinstance(item, dict): + raise AstrBotError.invalid_input("llm tools items must be objects") + if str(item.get("type", "function")) != "function": + raise AstrBotError.invalid_input( + "Only function tools are supported in AstrBot SDK MVP" + ) + function_payload = item.get("function") + if not isinstance(function_payload, dict): + raise AstrBotError.invalid_input( + "llm tools items must contain a function object" + ) + name = str(function_payload.get("name", "")).strip() + if not name: + raise AstrBotError.invalid_input( + "llm function tool name must not be empty" + ) + description = str(function_payload.get("description", "") or "") + parameters = function_payload.get("parameters") + if not isinstance(parameters, dict): + parameters = {"type": "object", "properties": {}} + tool_set.add_tool( + function_tool_cls( + name=name, + description=description, + parameters=parameters, + handler=None, + ) + ) + return tool_set diff --git a/astrbot/core/sdk_bridge/capabilities/mcp.py b/astrbot/core/sdk_bridge/capabilities/mcp.py new file mode 100644 index 0000000000..ff58c83b5f --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/mcp.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core import logger + +from ._host import CapabilityMixinHost + + +class MCPCapabilityMixin(CapabilityMixinHost): + @staticmethod + def _mcp_timeout(payload: dict[str, Any], capability_name: str) -> float: + raw_timeout = payload.get("timeout", 30.0) + try: + timeout = float(raw_timeout) + except (TypeError, ValueError) as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires numeric timeout" + ) from exc + if timeout <= 0: + raise AstrBotError.invalid_input(f"{capability_name} requires timeout > 0") + return timeout + + @staticmethod + def _mcp_name(payload: dict[str, Any], capability_name: str) -> str: + name = str(payload.get("name", "")).strip() + if not name: + raise AstrBotError.invalid_input(f"{capability_name} requires name") + return name + + @staticmethod + def _mcp_config(payload: dict[str, Any], capability_name: str) -> dict[str, Any]: + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input( + f"{capability_name} requires config object" + ) + return dict(config) + + def _func_tool_manager(self): + return self._star_context.get_llm_tool_manager() + + @staticmethod + def _global_mcp_record_from_state( + *, + name: str, + config: dict[str, Any], + runtime: Any | None, + ) -> dict[str, Any]: + client = getattr(runtime, "client", None) if runtime is not None else None + return { + "name": name, + "scope": "global", + "active": bool(config.get("active", True)), + "running": runtime is not None, + "config": dict(config), + "tools": [ + str(tool.name) + for tool in getattr(client, "tools", []) + if getattr(tool, "name", None) + ] + if client is not None + else [], + "errlogs": list(getattr(client, "server_errlogs", [])) + if client is not None + else [], + "last_error": None, + } + + def _get_global_mcp_record(self, name: str) -> dict[str, Any] | None: + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if not isinstance(servers, dict): + return None + config = servers.get(name) + if not isinstance(config, dict): + return None + runtime = func_tool_manager.mcp_server_runtime_view.get(name) + return self._global_mcp_record_from_state( + name=name, + config=dict(config), + runtime=runtime, + ) + + def _list_global_mcp_records(self) -> list[dict[str, Any]]: + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if not isinstance(servers, dict): + return [] + return [ + self._global_mcp_record_from_state( + name=str(name), + config=dict(config), + runtime=func_tool_manager.mcp_server_runtime_view.get(str(name)), + ) + for name, config in sorted(servers.items(), key=lambda item: str(item[0])) + if str(name).strip() and isinstance(config, dict) + ] + + def _require_global_mcp_ack(self, request_id: str, capability_name: str) -> str: + plugin_id = self._resolve_plugin_id(request_id) + if self._plugin_bridge.acknowledges_global_mcp_risk(plugin_id): + return plugin_id + raise PermissionError( + f"{capability_name} requires @acknowledge_global_mcp_risk" + ) + + @staticmethod + def _audit_global_mcp_mutation( + *, + plugin_id: str, + action: str, + server_name: str, + request_id: str, + ) -> None: + audit_entry = { + "plugin_id": plugin_id, + "action": action, + "server_name": server_name, + "request_id": request_id, + } + logger.info("SDK global MCP mutation: {}", audit_entry) + + async def _mcp_local_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.get") + return {"server": self._plugin_bridge.get_local_mcp_server(plugin_id, name)} + + async def _mcp_local_list( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return {"servers": self._plugin_bridge.list_local_mcp_servers(plugin_id)} + + async def _mcp_local_enable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.enable") + timeout = self._mcp_timeout(payload, "mcp.local.enable") + return { + "server": await self._plugin_bridge.enable_local_mcp_server( + plugin_id, + name, + timeout=timeout, + ) + } + + async def _mcp_local_disable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.disable") + return { + "server": await self._plugin_bridge.disable_local_mcp_server( + plugin_id, + name, + ) + } + + async def _mcp_local_wait_until_ready( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.wait_until_ready") + timeout = self._mcp_timeout(payload, "mcp.local.wait_until_ready") + return { + "server": await self._plugin_bridge.wait_for_local_mcp_server( + plugin_id, + name, + timeout=timeout, + ) + } + + async def _mcp_session_open( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.session.open") + config = self._mcp_config(payload, "mcp.session.open") + timeout = self._mcp_timeout(payload, "mcp.session.open") + session_id, tools = await self._plugin_bridge.open_temporary_mcp_session( + plugin_id, + name=name, + config=config, + timeout=timeout, + ) + return {"session_id": session_id, "tools": tools} + + async def _mcp_session_list_tools( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + session_id = str(payload.get("session_id", "")).strip() + return { + "tools": self._plugin_bridge.get_temporary_mcp_session_tools( + plugin_id, + session_id, + ) + } + + async def _mcp_session_call_tool( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + session_id = str(payload.get("session_id", "")).strip() + tool_name = str(payload.get("tool_name", "")).strip() + if not tool_name: + raise AstrBotError.invalid_input("mcp.session.call_tool requires tool_name") + args = payload.get("args") + if not isinstance(args, dict): + raise AstrBotError.invalid_input( + "mcp.session.call_tool requires args object" + ) + result = await self._plugin_bridge.call_temporary_mcp_tool( + plugin_id, + session_id=session_id, + tool_name=tool_name, + arguments=dict(args), + ) + return {"result": result} + + async def _mcp_session_close( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + session_id = str(payload.get("session_id", "")).strip() + await self._plugin_bridge.close_temporary_mcp_session(plugin_id, session_id) + return {} + + async def _mcp_global_register( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.register") + name = self._mcp_name(payload, "mcp.global.register") + config = self._mcp_config(payload, "mcp.global.register") + timeout = self._mcp_timeout(payload, "mcp.global.register") + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.setdefault("mcpServers", {}) + if not isinstance(servers, dict): + raise AstrBotError.invalid_input("Invalid global MCP config shape") + if name in servers: + raise AstrBotError.invalid_input( + f"Global MCP server already exists: {name}" + ) + normalized_config = dict(config) + normalized_config.setdefault("active", True) + servers[name] = normalized_config + func_tool_manager.save_mcp_config(config_payload) + if bool(normalized_config.get("active", True)): + await func_tool_manager.enable_mcp_server( + name, normalized_config, timeout=timeout + ) + record = self._get_global_mcp_record(name) + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="register", + server_name=name, + request_id=request_id, + ) + return {"server": record} + + async def _mcp_global_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_global_mcp_ack(request_id, "mcp.global.get") + name = self._mcp_name(payload, "mcp.global.get") + return {"server": self._get_global_mcp_record(name)} + + async def _mcp_global_list( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_global_mcp_ack(request_id, "mcp.global.list") + return {"servers": self._list_global_mcp_records()} + + async def _mcp_global_enable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.enable") + name = self._mcp_name(payload, "mcp.global.enable") + timeout = self._mcp_timeout(payload, "mcp.global.enable") + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if ( + not isinstance(servers, dict) + or name not in servers + or not isinstance(servers[name], dict) + ): + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + servers[name]["active"] = True + func_tool_manager.save_mcp_config(config_payload) + await func_tool_manager.enable_mcp_server( + name, dict(servers[name]), timeout=timeout + ) + record = self._get_global_mcp_record(name) + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="enable", + server_name=name, + request_id=request_id, + ) + return {"server": record} + + async def _mcp_global_disable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.disable") + name = self._mcp_name(payload, "mcp.global.disable") + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if ( + not isinstance(servers, dict) + or name not in servers + or not isinstance(servers[name], dict) + ): + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + servers[name]["active"] = False + func_tool_manager.save_mcp_config(config_payload) + await func_tool_manager.disable_mcp_server(name) + record = self._get_global_mcp_record(name) + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="disable", + server_name=name, + request_id=request_id, + ) + return {"server": record} + + async def _mcp_global_unregister( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.unregister") + name = self._mcp_name(payload, "mcp.global.unregister") + func_tool_manager = self._func_tool_manager() + existing_record = self._get_global_mcp_record(name) + if existing_record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if not isinstance(servers, dict): + raise AstrBotError.invalid_input("Invalid global MCP config shape") + servers.pop(name, None) + func_tool_manager.save_mcp_config(config_payload) + await func_tool_manager.disable_mcp_server(name) + existing_record["running"] = False + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="unregister", + server_name=name, + request_id=request_id, + ) + return {"server": existing_record} + + async def _internal_mcp_local_execute( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = str(payload.get("plugin_id", "")).strip() + server_name = str(payload.get("server_name", "")).strip() + tool_name = str(payload.get("tool_name", "")).strip() + tool_args = payload.get("tool_args") + if not plugin_id or not server_name or not tool_name: + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires plugin_id, server_name, and tool_name" + ) + if not isinstance(tool_args, dict): + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires tool_args object" + ) + return await self._plugin_bridge.execute_local_mcp_tool( + plugin_id, + server_name=server_name, + tool_name=tool_name, + tool_args=dict(tool_args), + ) + + def _register_mcp_capabilities(self) -> None: + self.register( + self._builtin_descriptor("mcp.local.get", "Get local MCP server"), + call_handler=self._mcp_local_get, + ) + self.register( + self._builtin_descriptor("mcp.local.list", "List local MCP servers"), + call_handler=self._mcp_local_list, + ) + self.register( + self._builtin_descriptor("mcp.local.enable", "Enable local MCP server"), + call_handler=self._mcp_local_enable, + ) + self.register( + self._builtin_descriptor("mcp.local.disable", "Disable local MCP server"), + call_handler=self._mcp_local_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.local.wait_until_ready", + "Wait until local MCP server is ready", + ), + call_handler=self._mcp_local_wait_until_ready, + ) + self.register( + self._builtin_descriptor("mcp.session.open", "Open temporary MCP session"), + call_handler=self._mcp_session_open, + ) + self.register( + self._builtin_descriptor( + "mcp.session.list_tools", + "List temporary MCP session tools", + ), + call_handler=self._mcp_session_list_tools, + ) + self.register( + self._builtin_descriptor( + "mcp.session.call_tool", + "Call tool on temporary MCP session", + ), + call_handler=self._mcp_session_call_tool, + ) + self.register( + self._builtin_descriptor( + "mcp.session.close", "Close temporary MCP session" + ), + call_handler=self._mcp_session_close, + ) + self.register( + self._builtin_descriptor( + "mcp.global.register", "Register global MCP server" + ), + call_handler=self._mcp_global_register, + ) + self.register( + self._builtin_descriptor("mcp.global.get", "Get global MCP server"), + call_handler=self._mcp_global_get, + ) + self.register( + self._builtin_descriptor("mcp.global.list", "List global MCP servers"), + call_handler=self._mcp_global_list, + ) + self.register( + self._builtin_descriptor("mcp.global.enable", "Enable global MCP server"), + call_handler=self._mcp_global_enable, + ) + self.register( + self._builtin_descriptor("mcp.global.disable", "Disable global MCP server"), + call_handler=self._mcp_global_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.global.unregister", + "Unregister global MCP server", + ), + call_handler=self._mcp_global_unregister, + ) + self.register( + self._builtin_descriptor( + "internal.mcp.local.execute", + "Execute local MCP tool", + ), + call_handler=self._internal_mcp_local_execute, + exposed=False, + ) diff --git a/astrbot/core/sdk_bridge/capabilities/message_history.py b/astrbot/core/sdk_bridge/capabilities/message_history.py new file mode 100644 index 0000000000..ebcdb74378 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/message_history.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.message.components import component_to_payload_sync + +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform_message_history_mgr import MessageHistorySender + +from ._host import CapabilityMixinHost + + +def _core_message_type_from_sdk(value: str) -> MessageType: + normalized = str(value).strip().lower() + if normalized == "group": + return MessageType.GROUP_MESSAGE + if normalized == "private": + return MessageType.FRIEND_MESSAGE + if normalized == "other": + return MessageType.OTHER_MESSAGE + raise AstrBotError.invalid_input( + f"Unsupported message history message_type: {value}" + ) + + +def _sdk_message_type_from_core(value: MessageType | str) -> str: + if isinstance(value, MessageType): + if value == MessageType.GROUP_MESSAGE: + return "group" + if value == MessageType.FRIEND_MESSAGE: + return "private" + return "other" + return str(value).strip().lower() + + +class MessageHistoryCapabilityMixin(CapabilityMixinHost): + @staticmethod + def _typed_message_history_session(payload: Any) -> MessageSession: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + "message_history capabilities require a session object" + ) + platform_id = str(payload.get("platform_id", "")).strip() + message_type = str(payload.get("message_type", "")).strip() + session_id = str(payload.get("session_id", "")).strip() + if not platform_id or not message_type or not session_id: + raise AstrBotError.invalid_input( + "message_history session requires platform_id, message_type, and session_id" + ) + return MessageSession( + platform_name=platform_id, + message_type=_core_message_type_from_sdk(message_type), + session_id=session_id, + ) + + @staticmethod + def _serialize_session(session: MessageSession) -> dict[str, str]: + return { + "platform_id": str(session.platform_id), + "message_type": _sdk_message_type_from_core(session.message_type), + "session_id": str(session.session_id), + } + + def _serialize_message_history_record(self, record: Any) -> dict[str, Any] | None: + if record is None: + return None + session = getattr(record, "session", None) + sender = getattr(record, "sender", None) + parts = getattr(record, "parts", None) + return { + "id": int(getattr(record, "id", 0) or 0), + "session": ( + self._serialize_session(session) + if isinstance(session, MessageSession) + else {} + ), + "sender": { + "sender_id": ( + str(getattr(sender, "sender_id", "")) + if getattr(sender, "sender_id", None) is not None + else None + ), + "sender_name": ( + str(getattr(sender, "sender_name", "")) + if getattr(sender, "sender_name", None) is not None + else None + ), + }, + "parts": ( + [component_to_payload_sync(part) for part in parts] + if isinstance(parts, list) + else [] + ), + "metadata": ( + dict(getattr(record, "metadata", {})) + if isinstance(getattr(record, "metadata", None), dict) + else {} + ), + "created_at": self._to_iso_datetime(getattr(record, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(record, "updated_at", None)), + "idempotency_key": ( + str(getattr(record, "idempotency_key", "")) + if getattr(record, "idempotency_key", None) is not None + else None + ), + } + + @staticmethod + def _parse_boundary(raw_value: Any, field_name: str) -> datetime: + text = str(raw_value or "").strip() + if not text: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires {field_name}" + ) + try: + return datetime.fromisoformat(text) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires an ISO datetime string" + ) from exc + + async def _message_history_list( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + raw_limit = self._optional_int(payload.get("limit")) + limit = 50 if raw_limit is None else raw_limit + if limit < 1: + raise AstrBotError.invalid_input("message_history.list requires limit >= 1") + page = await self._star_context.message_history_manager.list( + session, + cursor=( + str(payload.get("cursor")) + if payload.get("cursor") is not None + else None + ), + limit=limit, + ) + return { + "page": { + "records": [ + item + for item in ( + self._serialize_message_history_record(record) + for record in page.records + ) + if item is not None + ], + "next_cursor": page.next_cursor, + "total": page.total, + } + } + + async def _message_history_get_by_id( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + record_id = self._optional_int(payload.get("record_id")) + if record_id is None or record_id < 1: + raise AstrBotError.invalid_input( + "message_history.get_by_id requires record_id >= 1" + ) + record = await self._star_context.message_history_manager.get_by_id( + session, + record_id, + ) + return {"record": self._serialize_message_history_record(record)} + + async def _message_history_append( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + sender_payload = payload.get("sender") + if not isinstance(sender_payload, dict): + raise AstrBotError.invalid_input( + "message_history.append requires sender object" + ) + parts_payload = payload.get("parts") + if not isinstance(parts_payload, list) or any( + not isinstance(item, dict) for item in parts_payload + ): + raise AstrBotError.invalid_input( + "message_history.append requires parts array" + ) + metadata = payload.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + raise AstrBotError.invalid_input( + "message_history.append requires metadata object when provided" + ) + record = await self._star_context.message_history_manager.append( + session, + parts=self._build_core_message_chain(parts_payload).chain, + sender=MessageHistorySender( + sender_id=( + str(sender_payload.get("sender_id")) + if sender_payload.get("sender_id") is not None + else None + ), + sender_name=( + str(sender_payload.get("sender_name")) + if sender_payload.get("sender_name") is not None + else None + ), + ), + metadata=dict(metadata or {}), + idempotency_key=( + str(payload.get("idempotency_key")) + if payload.get("idempotency_key") is not None + else None + ), + ) + return {"record": self._serialize_message_history_record(record)} + + async def _message_history_delete_before( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + deleted_count = await self._star_context.message_history_manager.delete_before( + session, + before=self._parse_boundary(payload.get("before"), "delete_before"), + ) + return {"deleted_count": int(deleted_count)} + + async def _message_history_delete_after( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + deleted_count = await self._star_context.message_history_manager.delete_after( + session, + after=self._parse_boundary(payload.get("after"), "delete_after"), + ) + return {"deleted_count": int(deleted_count)} + + async def _message_history_delete_all( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + deleted_count = await self._star_context.message_history_manager.delete_all( + session + ) + return {"deleted_count": int(deleted_count)} + + def _register_message_history_capabilities(self) -> None: + self.register( + self._builtin_descriptor("message_history.list", "List message history"), + call_handler=self._message_history_list, + ) + self.register( + self._builtin_descriptor( + "message_history.get_by_id", + "Get message history by id", + ), + call_handler=self._message_history_get_by_id, + ) + self.register( + self._builtin_descriptor( + "message_history.append", "Append message history" + ), + call_handler=self._message_history_append, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_before", + "Delete message history before timestamp", + ), + call_handler=self._message_history_delete_before, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_after", + "Delete message history after timestamp", + ), + call_handler=self._message_history_delete_after, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_all", + "Delete all message history in session", + ), + call_handler=self._message_history_delete_all, + ) diff --git a/astrbot/core/sdk_bridge/capabilities/permission.py b/astrbot/core/sdk_bridge/capabilities/permission.py new file mode 100644 index 0000000000..e7f153080c --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/permission.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from ._host import CapabilityMixinHost + + +class PermissionCapabilityMixin(CapabilityMixinHost): + def _register_permission_capabilities(self) -> None: + self.register( + self._builtin_descriptor("permission.check", "Check user permission role"), + call_handler=self._permission_check, + ) + self.register( + self._builtin_descriptor("permission.get_admins", "List admin ids"), + call_handler=self._permission_get_admins, + ) + self.register( + self._builtin_descriptor( + "permission.manager.add_admin", + "Add admin id", + ), + call_handler=self._permission_manager_add_admin, + ) + self.register( + self._builtin_descriptor( + "permission.manager.remove_admin", + "Remove admin id", + ), + call_handler=self._permission_manager_remove_admin, + ) + + @staticmethod + def _normalize_admin_ids(values: Any) -> list[str]: + if not isinstance(values, list): + return [] + normalized: list[str] = [] + for item in values: + user_id = str(item).strip() + if user_id: + normalized.append(user_id) + return normalized + + def _permission_config(self) -> Any: + get_config = getattr(self._star_context, "get_config", None) + if callable(get_config): + return get_config() + config = getattr(self._star_context, "_config", None) + if config is not None: + return config + raise AstrBotError.invalid_input("permission capabilities require core config") + + def _admin_ids_snapshot(self, config: Any) -> list[str]: + admins = self._normalize_admin_ids( + config.get("admins_id", []) if hasattr(config, "get") else [] + ) + config["admins_id"] = list(admins) + return admins + + @staticmethod + def _save_config(config: Any) -> None: + save_config = getattr(config, "save_config", None) + if callable(save_config): + save_config() + + @staticmethod + def _required_user_id(payload: dict[str, Any], capability_name: str) -> str: + user_id = str(payload.get("user_id", "")).strip() + if not user_id: + raise AstrBotError.invalid_input(f"{capability_name} requires user_id") + return user_id + + def _require_admin_event_context( + self, + request_id: str, + payload: dict[str, Any], + capability_name: str, + ) -> None: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or bool( + getattr(request_context, "cancelled", False) + ): + if bool(payload.get("_caller_is_admin", False)): + return + raise AstrBotError.invalid_input( + f"{capability_name} requires an active event context" + ) + event = getattr(request_context, "event", None) + if event is None or not callable(getattr(event, "is_admin", None)): + raise AstrBotError.invalid_input( + f"{capability_name} requires an active event context" + ) + # Prefer the authenticated event context whenever one is available. + # The payload hint is only a fallback for proactive calls that were + # created from an admin-triggered flow but no longer have a live event. + if not bool(event.is_admin()): + raise AstrBotError.invalid_input( + f"{capability_name} requires admin privileges" + ) + + async def _permission_check( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + user_id = self._required_user_id(payload, "permission.check") + config = self._permission_config() + admins = self._admin_ids_snapshot(config) + is_admin = user_id in admins + return { + "is_admin": is_admin, + "role": "admin" if is_admin else "member", + } + + async def _permission_get_admins( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + config = self._permission_config() + return {"admins": self._admin_ids_snapshot(config)} + + async def _permission_manager_add_admin( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "permission.manager.add_admin") + self._require_admin_event_context( + request_id, + payload, + "permission.manager.add_admin", + ) + user_id = self._required_user_id(payload, "permission.manager.add_admin") + config = self._permission_config() + admins = self._admin_ids_snapshot(config) + if user_id in admins: + return {"changed": False} + admins.append(user_id) + config["admins_id"] = admins + self._save_config(config) + return {"changed": True} + + async def _permission_manager_remove_admin( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "permission.manager.remove_admin") + self._require_admin_event_context( + request_id, + payload, + "permission.manager.remove_admin", + ) + user_id = self._required_user_id(payload, "permission.manager.remove_admin") + config = self._permission_config() + admins = self._admin_ids_snapshot(config) + if user_id not in admins: + return {"changed": False} + admins.remove(user_id) + config["admins_id"] = admins + self._save_config(config) + return {"changed": True} diff --git a/astrbot/core/sdk_bridge/capabilities/persona.py b/astrbot/core/sdk_bridge/capabilities/persona.py new file mode 100644 index 0000000000..94db89cabb --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/persona.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from astrbot_sdk.errors import AstrBotError + +from ._host import CapabilityMixinHost + + +class PersonaCapabilityMixin(CapabilityMixinHost): + def _register_persona_capabilities(self) -> None: + self.register( + self._builtin_descriptor("persona.get", "Get persona"), + call_handler=self._persona_get, + ) + self.register( + self._builtin_descriptor("persona.list", "List personas"), + call_handler=self._persona_list, + ) + self.register( + self._builtin_descriptor("persona.create", "Create persona"), + call_handler=self._persona_create, + ) + self.register( + self._builtin_descriptor("persona.update", "Update persona"), + call_handler=self._persona_update, + ) + self.register( + self._builtin_descriptor("persona.delete", "Delete persona"), + call_handler=self._persona_delete, + ) + + async def _persona_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + persona_id = str(payload.get("persona_id", "")).strip() + try: + persona = await self._star_context.persona_manager.get_persona(persona_id) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"persona": self._serialize_persona(persona)} + + async def _persona_list( + self, + _request_id: str, + _payload: dict[str, object], + _token, + ) -> dict[str, object]: + personas = await self._star_context.persona_manager.get_all_personas() + return { + "personas": [ + payload + for payload in ( + self._serialize_persona(persona) for persona in personas + ) + if payload is not None + ] + } + + async def _persona_create( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.create requires persona object") + try: + persona = await self._star_context.persona_manager.create_persona( + persona_id=str(raw_persona.get("persona_id", "")), + system_prompt=str(raw_persona.get("system_prompt", "")), + begin_dialogs=self._normalize_persona_dialogs( + raw_persona.get("begin_dialogs") + ), + tools=( + [str(item) for item in raw_persona.get("tools", [])] + if isinstance(raw_persona.get("tools"), list) + else None + ), + skills=( + [str(item) for item in raw_persona.get("skills", [])] + if isinstance(raw_persona.get("skills"), list) + else None + ), + custom_error_message=( + str(raw_persona.get("custom_error_message")) + if raw_persona.get("custom_error_message") is not None + else None + ), + folder_id=( + str(raw_persona.get("folder_id")) + if raw_persona.get("folder_id") is not None + else None + ), + sort_order=int(raw_persona.get("sort_order", 0)), + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"persona": self._serialize_persona(persona)} + + async def _persona_update( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.update requires persona object") + persona = await self._star_context.persona_manager.update_persona( + persona_id=str(payload.get("persona_id", "")), + system_prompt=raw_persona.get("system_prompt"), + begin_dialogs=( + self._normalize_persona_dialogs(raw_persona.get("begin_dialogs")) + if "begin_dialogs" in raw_persona + else None + ), + tools=( + [str(item) for item in raw_persona.get("tools", [])] + if isinstance(raw_persona.get("tools"), list) + else raw_persona.get("tools") + ), + skills=( + [str(item) for item in raw_persona.get("skills", [])] + if isinstance(raw_persona.get("skills"), list) + else raw_persona.get("skills") + ), + custom_error_message=raw_persona.get("custom_error_message"), + ) + return {"persona": self._serialize_persona(persona)} + + async def _persona_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + persona_id = str(payload.get("persona_id", "")).strip() + try: + await self._star_context.persona_manager.delete_persona(persona_id) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {} diff --git a/astrbot/core/sdk_bridge/capabilities/platform.py b/astrbot/core/sdk_bridge/capabilities/platform.py new file mode 100644 index 0000000000..68668ababc --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/platform.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core.message.components import Image, Plain +from astrbot.core.message.message_event_result import MessageChain + +from ._host import CapabilityMixinHost + + +class PlatformCapabilityMixin(CapabilityMixinHost): + def _register_platform_capabilities(self) -> None: + self.register( + self._builtin_descriptor("platform.send", "Send plain text"), + call_handler=self._platform_send, + ) + self.register( + self._builtin_descriptor("platform.send_image", "Send image"), + call_handler=self._platform_send_image, + ) + self.register( + self._builtin_descriptor("platform.send_chain", "Send message chain"), + call_handler=self._platform_send_chain, + ) + self.register( + self._builtin_descriptor( + "platform.send_by_session", + "Send message chain to a specific session", + ), + call_handler=self._platform_send_by_session, + ) + self.register( + self._builtin_descriptor("platform.get_group", "Get current group data"), + call_handler=self._platform_get_group, + ) + self.register( + self._builtin_descriptor("platform.get_members", "Get group members"), + call_handler=self._platform_get_members, + ) + self.register( + self._builtin_descriptor( + "platform.list_instances", + "List available platform instances", + ), + call_handler=self._platform_list_instances, + ) + + def _register_platform_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "platform.manager.get_by_id", + "Get platform management snapshot by id", + ), + call_handler=self._platform_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "platform.manager.clear_errors", + "Clear platform error records", + ), + call_handler=self._platform_manager_clear_errors, + ) + self.register( + self._builtin_descriptor( + "platform.manager.get_stats", + "Get platform stats by id", + ), + call_handler=self._platform_manager_get_stats, + ) + + async def _platform_send( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session, dispatch_token = self._resolve_dispatch_target(request_id, payload) + self._require_platform_support_for_session( + request_id, + session, + "platform.send", + ) + self._plugin_bridge.before_platform_send(dispatch_token) + await self._star_context.send_message( + session, + MessageChain([Plain(str(payload.get("text", "")), convert=False)]), + ) + return {"message_id": self._plugin_bridge.mark_platform_send(dispatch_token)} + + async def _platform_send_image( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session, dispatch_token = self._resolve_dispatch_target(request_id, payload) + self._require_platform_support_for_session( + request_id, + session, + "platform.send_image", + ) + self._plugin_bridge.before_platform_send(dispatch_token) + image_url = str(payload.get("image_url", "")) + component = ( + Image.fromURL(image_url) + if image_url.startswith(("http://", "https://")) + else Image.fromFileSystem(image_url) + ) + await self._star_context.send_message(session, MessageChain([component])) + return {"message_id": self._plugin_bridge.mark_platform_send(dispatch_token)} + + async def _platform_send_chain( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session, dispatch_token = self._resolve_dispatch_target(request_id, payload) + self._require_platform_support_for_session( + request_id, + session, + "platform.send_chain", + ) + self._plugin_bridge.before_platform_send(dispatch_token) + chain_payload = payload.get("chain") + if not isinstance(chain_payload, list): + raise AstrBotError.invalid_input( + "platform.send_chain requires a chain array" + ) + await self._star_context.send_message( + session, + self._build_core_message_chain(chain_payload), + ) + return {"message_id": self._plugin_bridge.mark_platform_send(dispatch_token)} + + async def _platform_send_by_session( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + chain_payload = payload.get("chain") + if not isinstance(chain_payload, list): + raise AstrBotError.invalid_input( + "platform.send_by_session requires a chain array" + ) + session = str(payload.get("session", "")) + if not session: + raise AstrBotError.invalid_input( + "platform.send_by_session requires a session" + ) + self._require_platform_support_for_session( + request_id, + session, + "platform.send_by_session", + ) + request_context = self._resolve_event_request_context(request_id, payload) + dispatch_token = None + if request_context is not None and not request_context.cancelled: + dispatch_token = request_context.dispatch_token + self._plugin_bridge.before_platform_send(dispatch_token) + await self._star_context.send_message( + session, + self._build_core_message_chain(chain_payload), + ) + if dispatch_token is not None: + return { + "message_id": self._plugin_bridge.mark_platform_send(dispatch_token) + } + return {"message_id": f"sdk_proactive_{uuid.uuid4().hex}"} + + async def _platform_get_group( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_current_group_request_context( + request_id, payload + ) + if request_context is None: + return {"group": None} + group = await request_context.event.get_group() + return {"group": self._serialize_group(group)} + + async def _platform_get_members( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_current_group_request_context( + request_id, payload + ) + if request_context is None: + return {"members": []} + group = await request_context.event.get_group() + serialized_group = self._serialize_group(group) + if serialized_group is None: + return {"members": []} + members = serialized_group.get("members") + return {"members": list(members) if isinstance(members, list) else []} + + async def _platform_list_instances( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + platform_manager = getattr(self._star_context, "platform_manager", None) + if platform_manager is None or not hasattr(platform_manager, "get_insts"): + return {"platforms": []} + platforms_payload: list[dict[str, Any]] = [] + for platform in list(platform_manager.get_insts()): + meta = None + try: + meta = platform.meta() + except Exception: + continue + platform_id = str(getattr(meta, "id", "")).strip() + platform_type = str(getattr(meta, "name", "")).strip() + if not platform_id or not platform_type: + continue + if not self._plugin_supports_platform(plugin_id, platform_type): + continue + status = getattr(platform, "status", None) + status_value = getattr(status, "value", status) + display_name = str( + getattr(meta, "adapter_display_name", None) or platform_type + ) + platforms_payload.append( + { + "id": platform_id, + "name": display_name, + "type": platform_type, + "status": str(status_value or "unknown"), + } + ) + return {"platforms": platforms_payload} + + async def _platform_manager_get_by_id( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "platform.manager.get_by_id", + ) + platform = self._get_platform_inst_by_id(str(payload.get("platform_id", ""))) + return {"platform": self._serialize_platform_snapshot(platform)} + + async def _platform_manager_clear_errors( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "platform.manager.clear_errors", + ) + platform = self._get_platform_inst_by_id(str(payload.get("platform_id", ""))) + if platform is None: + raise AstrBotError.invalid_input("Unknown platform_id") + clear_errors = getattr(platform, "clear_errors", None) + if callable(clear_errors): + clear_errors() + return {} + + async def _platform_manager_get_stats( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "platform.manager.get_stats", + ) + platform = self._get_platform_inst_by_id(str(payload.get("platform_id", ""))) + if platform is None: + return {"stats": None} + get_stats = getattr(platform, "get_stats", None) + if not callable(get_stats): + return {"stats": None} + return {"stats": self._serialize_platform_stats(get_stats())} diff --git a/astrbot/core/sdk_bridge/capabilities/provider.py b/astrbot/core/sdk_bridge/capabilities/provider.py new file mode 100644 index 0000000000..b0edf8f5a7 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/provider.py @@ -0,0 +1,1372 @@ +from __future__ import annotations + +import asyncio +import base64 +import contextlib +import json +import uuid +from collections.abc import AsyncIterator +from typing import Any, cast + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.llm.entities import LLMToolSpec, ProviderMeta, ToolCallsResult +from astrbot_sdk.llm.entities import ProviderType as SDKProviderType +from astrbot_sdk.runtime.capability_router import StreamExecution + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from ..bridge_base import _get_runtime_provider_types, _get_runtime_tool_types +from ._host import CapabilityMixinHost + + +class ProviderCapabilityMixin(CapabilityMixinHost): + def _register_provider_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.get_using", "Get active provider"), + call_handler=self._provider_get_using, + ) + self.register( + self._builtin_descriptor("provider.get_by_id", "Get provider by id"), + call_handler=self._provider_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.get_current_chat_provider_id", + "Get active chat provider id", + ), + call_handler=self._provider_get_current_chat_provider_id, + ) + self.register( + self._builtin_descriptor("provider.list_all", "List chat providers"), + call_handler=self._provider_list_all, + ) + self.register( + self._builtin_descriptor("provider.list_all_tts", "List tts providers"), + call_handler=self._provider_list_all_tts, + ) + self.register( + self._builtin_descriptor("provider.list_all_stt", "List stt providers"), + call_handler=self._provider_list_all_stt, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_embedding", + "List embedding providers", + ), + call_handler=self._provider_list_all_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_rerank", + "List rerank providers", + ), + call_handler=self._provider_list_all_rerank, + ) + self.register( + self._builtin_descriptor( + "provider.get_using_tts", + "Get active tts provider", + ), + call_handler=self._provider_get_using_tts, + ) + self.register( + self._builtin_descriptor( + "provider.get_using_stt", + "Get active stt provider", + ), + call_handler=self._provider_get_using_stt, + ) + self.register( + self._builtin_descriptor( + "provider.stt.get_text", + "Transcribe audio with STT provider", + ), + call_handler=self._provider_stt_get_text, + ) + self.register( + self._builtin_descriptor( + "provider.tts.get_audio", + "Synthesize audio with TTS provider", + ), + call_handler=self._provider_tts_get_audio, + ) + self.register( + self._builtin_descriptor( + "provider.tts.support_stream", + "Check whether TTS provider supports native streaming", + ), + call_handler=self._provider_tts_support_stream, + ) + self.register( + self._builtin_descriptor( + "provider.tts.get_audio_stream", + "Stream audio with TTS provider", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_tts_get_audio_stream, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embedding", + "Get embedding vector", + ), + call_handler=self._provider_embedding_get_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embeddings", + "Get embedding vectors in batch", + ), + call_handler=self._provider_embedding_get_embeddings, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_dim", + "Get embedding dimension", + ), + call_handler=self._provider_embedding_get_dim, + ) + self.register( + self._builtin_descriptor( + "provider.rerank.rerank", + "Rerank documents", + ), + call_handler=self._provider_rerank_rerank, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.get", + "Get registered and active sdk llm tools", + ), + call_handler=self._llm_tool_manager_get, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.activate", + "Activate sdk llm tool", + ), + call_handler=self._llm_tool_manager_activate, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.deactivate", + "Deactivate sdk llm tool", + ), + call_handler=self._llm_tool_manager_deactivate, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.add", + "Register sdk llm tool metadata", + ), + call_handler=self._llm_tool_manager_add, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.remove", + "Unregister sdk llm tool metadata", + ), + call_handler=self._llm_tool_manager_remove, + ) + self.register( + self._builtin_descriptor("agent.tool_loop.run", "Run sdk tool loop agent"), + call_handler=self._agent_tool_loop_run, + ) + self.register( + self._builtin_descriptor("agent.registry.list", "List sdk agents"), + call_handler=self._agent_registry_list, + ) + self.register( + self._builtin_descriptor("agent.registry.get", "Get sdk agent"), + call_handler=self._agent_registry_get, + ) + + def _register_provider_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.manager.set", "Set active provider"), + call_handler=self._provider_manager_set, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_by_id", + "Get managed provider record by id", + ), + call_handler=self._provider_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_merged_provider_config", + "Get merged managed provider config by id", + ), + call_handler=self._provider_manager_get_merged_provider_config, + ) + self.register( + self._builtin_descriptor( + "provider.manager.load", + "Load a provider instance without persisting config", + ), + call_handler=self._provider_manager_load, + ) + self.register( + self._builtin_descriptor( + "provider.manager.terminate", + "Terminate a loaded provider instance", + ), + call_handler=self._provider_manager_terminate, + ) + self.register( + self._builtin_descriptor( + "provider.manager.create", + "Create and load a provider config", + ), + call_handler=self._provider_manager_create, + ) + self.register( + self._builtin_descriptor( + "provider.manager.update", + "Update and reload a provider config", + ), + call_handler=self._provider_manager_update, + ) + self.register( + self._builtin_descriptor( + "provider.manager.delete", + "Delete a provider config", + ), + call_handler=self._provider_manager_delete, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_insts", + "List loaded chat provider instances", + ), + call_handler=self._provider_manager_get_insts, + ) + self.register( + self._builtin_descriptor( + "provider.manager.watch_changes", + "Stream provider change events", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_manager_watch_changes, + ) + + @staticmethod + def _provider_to_payload(provider: Any | None) -> dict[str, Any] | None: + if provider is None: + return None + meta = provider.meta() + return ProviderCapabilityMixin._provider_meta_to_payload(meta) + + @staticmethod + def _normalize_sdk_provider_type(value: Any) -> SDKProviderType: + if isinstance(value, SDKProviderType): + return value + raw_provider_type = getattr(value, "provider_type", value) + provider_type_value = ( + str(raw_provider_type.value) + if hasattr(raw_provider_type, "value") + else str(raw_provider_type) + ) + try: + return SDKProviderType(provider_type_value) + except ValueError: + return SDKProviderType.CHAT_COMPLETION + + @classmethod + def _provider_meta_to_payload(cls, meta: Any) -> dict[str, Any]: + provider_type = cls._normalize_sdk_provider_type(meta) + return ProviderMeta( + id=str(getattr(meta, "id", "")), + model=( + str(getattr(meta, "model", "")) + if getattr(meta, "model", None) is not None + else None + ), + type=str(getattr(meta, "type", "")), + provider_type=provider_type, + ).to_payload() + + @classmethod + def _managed_provider_from_config( + cls, + provider_config: dict[str, Any] | None, + *, + loaded: bool, + ) -> dict[str, Any] | None: + if not isinstance(provider_config, dict): + return None + provider_id = str(provider_config.get("id", "")).strip() + provider_type_text = str(provider_config.get("type", "")).strip() + if not provider_id or not provider_type_text: + return None + provider_type = cls._normalize_sdk_provider_type( + provider_config.get("provider_type", SDKProviderType.CHAT_COMPLETION.value) + ) + return { + "id": provider_id, + "model": ( + str(provider_config.get("model")) + if provider_config.get("model") is not None + else None + ), + "type": provider_type_text, + "provider_type": provider_type.value, + "loaded": bool(loaded), + "enabled": bool(provider_config.get("enable", True)), + "provider_source_id": ( + str(provider_config.get("provider_source_id")) + if provider_config.get("provider_source_id") is not None + else None + ), + } + + @classmethod + def _managed_provider_to_payload( + cls, provider: Any | None + ) -> dict[str, Any] | None: + if provider is None: + return None + meta_payload = cls._provider_to_payload(provider) + if meta_payload is None: + return None + provider_config = getattr(provider, "provider_config", None) + return { + **meta_payload, + "loaded": True, + "enabled": bool( + provider_config.get("enable", True) + if isinstance(provider_config, dict) + else True + ), + "provider_source_id": ( + str(provider_config.get("provider_source_id")) + if isinstance(provider_config, dict) + and provider_config.get("provider_source_id") is not None + else None + ), + } + + def _find_provider_config_by_id(self, provider_id: str) -> dict[str, Any] | None: + provider_manager = getattr(self._star_context, "provider_manager", None) + providers_config = getattr(provider_manager, "providers_config", None) + if not isinstance(providers_config, list): + return None + for item in providers_config: + if not isinstance(item, dict): + continue + if str(item.get("id", "")).strip() == provider_id: + return dict(item) + return None + + def _managed_provider_payload_by_id( + self, + provider_id: str, + *, + fallback_config: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + normalized_provider_id = str(provider_id).strip() + if not normalized_provider_id: + return None + provider = self._star_context.get_provider_by_id(normalized_provider_id) + payload = self._managed_provider_to_payload(provider) + if payload is not None: + return payload + provider_config = self._find_provider_config_by_id(normalized_provider_id) + if provider_config is None: + provider_config = ( + dict(fallback_config) if isinstance(fallback_config, dict) else None + ) + return self._managed_provider_from_config(provider_config, loaded=False) + + def _resolve_current_chat_provider_id( + self, + request_context: Any | None, + ) -> str | None: + if request_context is None: + return None + provider = self._star_context.get_using_provider( + request_context.event.unified_msg_origin + ) + if provider is None: + return None + meta = provider.meta() + return str(getattr(meta, "id", "") or "") + + async def _provider_get_using( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_provider(payload.get("umo")) + return {"provider": self._provider_to_payload(provider)} + + async def _provider_get_current_chat_provider_id( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_provider(payload.get("umo")) + if provider is None: + return {"provider_id": None} + return {"provider_id": str(provider.meta().id)} + + async def _provider_get_by_id( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._get_provider_by_id(payload, "provider.get_by_id") + return {"provider": self._provider_to_payload(provider)} + + def _provider_list_payload(self, providers: list[Any]) -> dict[str, Any]: + return { + "providers": [ + payload + for payload in ( + self._provider_to_payload(provider) for provider in providers + ) + if payload is not None + ] + } + + async def _provider_list_all( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload(self._star_context.get_all_providers()) + + async def _provider_list_all_tts( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload(self._star_context.get_all_tts_providers()) + + async def _provider_list_all_stt( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload(self._star_context.get_all_stt_providers()) + + async def _provider_list_all_embedding( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload( + self._star_context.get_all_embedding_providers() + ) + + async def _provider_list_all_rerank( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload( + self._star_context.get_all_rerank_providers() + ) + + async def _provider_get_using_tts( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_tts_provider(payload.get("umo")) + return {"provider": self._provider_to_payload(provider)} + + async def _provider_get_using_stt( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_stt_provider(payload.get("umo")) + return {"provider": self._provider_to_payload(provider)} + + @staticmethod + def _tts_stream_texts_from_payload(payload: dict[str, Any]) -> list[str]: + text = payload.get("text") + if isinstance(text, str): + return [text] + text_chunks = payload.get("text_chunks") + if isinstance(text_chunks, list): + chunks = [str(item) for item in text_chunks] + if chunks: + return chunks + raise AstrBotError.invalid_input( + "provider.tts.get_audio_stream requires text or text_chunks" + ) + + def _get_provider_by_id( + self, + payload: dict[str, Any], + capability_name: str, + ) -> Any: + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + f"{capability_name} requires provider_id", + ) + provider = self._star_context.get_provider_by_id(provider_id) + if provider is None: + raise AstrBotError.invalid_input( + f"{capability_name} unknown provider_id: {provider_id}", + ) + return provider + + def _get_typed_provider( + self, + payload: dict[str, Any], + capability_name: str, + provider_label: str, + expected_type: type[Any], + ) -> Any: + provider = self._get_provider_by_id(payload, capability_name) + if not isinstance(provider, expected_type): + raise AstrBotError.invalid_input( + f"{capability_name} requires a {provider_label} provider", + ) + return provider + + async def _provider_stt_get_text( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + stt_provider_cls, _, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.stt.get_text", + "speech_to_text", + stt_provider_cls, + ) + return {"text": await provider.get_text(str(payload.get("audio_url", "")))} + + async def _provider_tts_get_audio( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, tts_provider_cls, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.tts.get_audio", + "text_to_speech", + tts_provider_cls, + ) + return {"audio_path": await provider.get_audio(str(payload.get("text", "")))} + + async def _provider_tts_support_stream( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, tts_provider_cls, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.tts.support_stream", + "text_to_speech", + tts_provider_cls, + ) + return {"supported": bool(provider.support_stream())} + + async def _provider_tts_get_audio_stream( + self, + _request_id: str, + payload: dict[str, Any], + token, + ) -> StreamExecution: + _, tts_provider_cls, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.tts.get_audio_stream", + "text_to_speech", + tts_provider_cls, + ) + texts = self._tts_stream_texts_from_payload(payload) + text_queue: asyncio.Queue[str | None] = asyncio.Queue() + audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue() + for text in texts: + await text_queue.put(text) + await text_queue.put(None) + state: dict[str, BaseException] = {} + + async def producer() -> None: + try: + await provider.get_audio_stream(text_queue, audio_queue) + except Exception as exc: # pragma: no cover - provider-specific failures + state["error"] = exc + finally: + await audio_queue.put(None) + + task = asyncio.create_task(producer()) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + token.raise_if_cancelled() + item = await audio_queue.get() + if item is None: + break + chunk_text: str | None = None + chunk_audio: bytes | bytearray + if isinstance(item, tuple): + chunk_text = str(item[0]) + chunk_audio = item[1] + else: + chunk_audio = item + yield { + "audio_base64": base64.b64encode(bytes(chunk_audio)).decode( + "ascii" + ), + "text": chunk_text, + } + error = state.get("error") + if error is not None: + raise error + finally: + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + else: + with contextlib.suppress(Exception): + await task + + def finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]: + return chunks[-1] if chunks else {"audio_base64": "", "text": None} + + return StreamExecution(iterator=iterator(), finalize=finalize) + + async def _provider_embedding_get_embedding( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.embedding.get_embedding", + "embedding", + embedding_provider_cls, + ) + return {"embedding": await provider.get_embedding(str(payload.get("text", "")))} + + async def _provider_embedding_get_embeddings( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.embedding.get_embeddings", + "embedding", + embedding_provider_cls, + ) + texts = payload.get("texts") + if not isinstance(texts, list): + raise AstrBotError.invalid_input( + "provider.embedding.get_embeddings requires texts", + ) + return { + "embeddings": await provider.get_embeddings([str(item) for item in texts]) + } + + async def _provider_embedding_get_dim( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.embedding.get_dim", + "embedding", + embedding_provider_cls, + ) + return {"dim": int(provider.get_dim())} + + async def _provider_rerank_rerank( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, _, rerank_provider_cls = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.rerank.rerank", + "rerank", + rerank_provider_cls, + ) + documents = payload.get("documents") + if not isinstance(documents, list): + raise AstrBotError.invalid_input( + "provider.rerank.rerank requires documents", + ) + normalized_documents = [str(item) for item in documents] + top_n = payload.get("top_n") + results = await provider.rerank( + str(payload.get("query", "")), + normalized_documents, + int(top_n) if top_n is not None else None, + ) + serialized = [] + for item in results: + index = int(getattr(item, "index", 0)) + serialized.append( + { + "index": index, + "score": float(getattr(item, "relevance_score", 0.0)), + "document": normalized_documents[index] + if 0 <= index < len(normalized_documents) + else "", + } + ) + return {"results": serialized} + + @staticmethod + def _normalize_provider_config_payload( + payload: Any, + capability_name: str, + field_name: str, + ) -> dict[str, Any]: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + f"{capability_name} requires {field_name} object" + ) + return dict(payload) + + @staticmethod + def _core_provider_type(value: Any, capability_name: str): + from astrbot.core.provider.entities import ProviderType as CoreProviderType + + normalized = str(value).strip() + try: + return CoreProviderType(normalized) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires a valid provider_type" + ) from exc + + async def _provider_manager_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.set") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.set requires provider_id" + ) + await self._star_context.provider_manager.set_provider( + provider_id=provider_id, + provider_type=self._core_provider_type( + payload.get("provider_type"), + "provider.manager.set", + ), + umo=( + str(payload.get("umo")) + if payload.get("umo") is not None and str(payload.get("umo")).strip() + else None + ), + ) + return {} + + async def _provider_manager_get_by_id( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.get_by_id") + provider_id = str(payload.get("provider_id", "")).strip() + return {"provider": self._managed_provider_payload_by_id(provider_id)} + + async def _provider_manager_get_merged_provider_config( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "provider.manager.get_merged_provider_config", + ) + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config requires provider_id" + ) + provider_manager = getattr(self._star_context, "provider_manager", None) + get_merged_provider_config = getattr( + provider_manager, + "get_merged_provider_config", + None, + ) + if provider_manager is None or not callable(get_merged_provider_config): + raise AstrBotError.invalid_input( + "Provider manager does not support merged config lookup" + ) + provider_config = self._find_provider_config_by_id(provider_id) + if provider_config is None: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config unknown provider_id" + ) + merged_config = cast( + dict[str, Any], get_merged_provider_config(provider_config) + ) + return {"config": dict(merged_config)} + + async def _provider_manager_load( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.load") + provider_config = self._normalize_provider_config_payload( + payload.get("provider_config"), + "provider.manager.load", + "provider_config", + ) + await self._star_context.provider_manager.load_provider(provider_config) + provider_id = str(provider_config.get("id", "")).strip() + return { + "provider": self._managed_provider_payload_by_id( + provider_id, + fallback_config=provider_config, + ) + } + + async def _provider_manager_terminate( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.terminate") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.terminate requires provider_id" + ) + await self._star_context.provider_manager.terminate_provider(provider_id) + return {} + + async def _provider_manager_create( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.create") + provider_config = self._normalize_provider_config_payload( + payload.get("provider_config"), + "provider.manager.create", + "provider_config", + ) + await self._star_context.provider_manager.create_provider(provider_config) + provider_id = str(provider_config.get("id", "")).strip() + return {"provider": self._managed_provider_payload_by_id(provider_id)} + + async def _provider_manager_update( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.update") + origin_provider_id = str(payload.get("origin_provider_id", "")).strip() + if not origin_provider_id: + raise AstrBotError.invalid_input( + "provider.manager.update requires origin_provider_id" + ) + new_config = self._normalize_provider_config_payload( + payload.get("new_config"), + "provider.manager.update", + "new_config", + ) + await self._star_context.provider_manager.update_provider( + origin_provider_id, + new_config, + ) + target_provider_id = str(new_config.get("id") or origin_provider_id).strip() + return {"provider": self._managed_provider_payload_by_id(target_provider_id)} + + async def _provider_manager_delete( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.delete") + provider_id = ( + str(payload.get("provider_id")).strip() + if payload.get("provider_id") is not None + else None + ) + provider_source_id = ( + str(payload.get("provider_source_id")).strip() + if payload.get("provider_source_id") is not None + else None + ) + if not provider_id and not provider_source_id: + raise AstrBotError.invalid_input( + "provider.manager.delete requires provider_id or provider_source_id" + ) + await self._star_context.provider_manager.delete_provider( + provider_id=provider_id or None, + provider_source_id=provider_source_id or None, + ) + return {} + + async def _provider_manager_get_insts( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.get_insts") + provider_manager = getattr(self._star_context, "provider_manager", None) + if provider_manager is None or not hasattr(provider_manager, "get_insts"): + return {"providers": []} + return { + "providers": [ + payload + for payload in ( + self._managed_provider_to_payload(provider) + for provider in list(provider_manager.get_insts()) + ) + if payload is not None + ] + } + + async def _provider_manager_watch_changes( + self, + request_id: str, + _payload: dict[str, Any], + token, + ) -> StreamExecution: + self._require_reserved_plugin(request_id, "provider.manager.watch_changes") + provider_manager = getattr(self._star_context, "provider_manager", None) + if provider_manager is None or not hasattr( + provider_manager, "register_provider_change_hook" + ): + raise AstrBotError.invalid_input("Provider manager does not support hooks") + unregister_hook = getattr( + provider_manager, + "unregister_provider_change_hook", + None, + ) + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + loop = asyncio.get_running_loop() + + def hook(provider_id: str, provider_type: Any, umo: str | None) -> None: + event = { + "provider_id": str(provider_id), + "provider_type": self._normalize_sdk_provider_type(provider_type).value, + "umo": str(umo) if umo is not None else None, + } + loop.call_soon_threadsafe(queue.put_nowait, event) + + provider_manager.register_provider_change_hook(hook) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + token.raise_if_cancelled() + yield await queue.get() + finally: + if callable(unregister_hook): + unregister_hook(hook) + + return StreamExecution( + iterator=iterator(), + finalize=lambda _chunks: {}, + collect_chunks=False, + ) + + async def _llm_tool_manager_get( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "registered": [ + item.to_payload() + for item in self._plugin_bridge.get_registered_llm_tools(plugin_id) + ], + "active": [ + item.to_payload() + for item in self._plugin_bridge.get_active_llm_tools(plugin_id) + ], + } + + async def _llm_tool_manager_activate( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "activated": self._plugin_bridge.activate_llm_tool( + plugin_id, str(payload.get("name", "")) + ) + } + + async def _llm_tool_manager_deactivate( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "deactivated": self._plugin_bridge.deactivate_llm_tool( + plugin_id, str(payload.get("name", "")) + ) + } + + async def _llm_tool_manager_add( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + tools_payload = payload.get("tools") + if not isinstance(tools_payload, list): + raise AstrBotError.invalid_input("llm_tool.manager.add requires tools list") + tools = [ + LLMToolSpec.from_payload(item) + for item in tools_payload + if isinstance(item, dict) + ] + return {"names": self._plugin_bridge.add_llm_tools(plugin_id, tools)} + + async def _llm_tool_manager_remove( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "removed": self._plugin_bridge.remove_llm_tool( + plugin_id, + str(payload.get("name", "")), + ) + } + + async def _agent_registry_list( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "agents": [ + item.to_payload() + for item in self._plugin_bridge.get_registered_agents(plugin_id) + ] + } + + async def _agent_registry_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + agent = self._plugin_bridge.get_registered_agent( + plugin_id, str(payload.get("name", "")) + ) + return {"agent": agent.to_payload() if agent is not None else None} + + def _select_llm_tools_for_request( + self, + plugin_id: str, + payload: dict[str, Any], + ) -> list[LLMToolSpec]: + active_specs = { + item.name: item + for item in self._plugin_bridge.get_request_tool_specs(plugin_id) + } + requested = payload.get("tool_names") + if not isinstance(requested, list) or not requested: + return list(active_specs.values()) + names = [str(item) for item in requested if str(item).strip()] + return [active_specs[name] for name in names if name in active_specs] + + def _make_sdk_tool_handler( + self, + *, + plugin_id: str, + tool_spec: LLMToolSpec, + tool_call_timeout: int, + ): + async def _handler(event: AstrMessageEvent, **tool_args: Any) -> str | None: + get_plugin_session = getattr( + self._plugin_bridge, "get_plugin_session", None + ) + if callable(get_plugin_session): + session = get_plugin_session(plugin_id) + else: + record = getattr(self._plugin_bridge, "_records", {}).get(plugin_id) + session = None if record is None else getattr(record, "session", None) + if session is None: + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content="SDK plugin worker is unavailable", + success=False, + ).to_payload(), + ensure_ascii=False, + ) + request_id = f"sdk_tool_{plugin_id}_{uuid.uuid4().hex}" + get_or_bind_dispatch_token = getattr( + self._plugin_bridge, + "get_or_bind_dispatch_token", + None, + ) + if callable(get_or_bind_dispatch_token): + dispatch_token = get_or_bind_dispatch_token(event) + else: + dispatch_token = ( + getattr( + self._plugin_bridge, "_get_dispatch_token", lambda _event: None + )(event) + or uuid.uuid4().hex + ) + get_overlay = getattr( + self._plugin_bridge, + "get_request_overlay_by_token", + lambda _dispatch_token: None, + ) + build_sdk_event_payload = getattr( + self._plugin_bridge, + "build_sdk_event_payload", + None, + ) + if callable(build_sdk_event_payload): + event_payload = build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=plugin_id, + request_id=request_id, + overlay=get_overlay(dispatch_token), + ) + else: + event_payload = self._plugin_bridge._build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=plugin_id, + request_id=request_id, + overlay=get_overlay(dispatch_token), + ) + call_payload = { + "plugin_id": plugin_id, + "tool_name": tool_spec.name, + "handler_ref": tool_spec.handler_ref, + "tool_args": json.loads( + json.dumps(tool_args, ensure_ascii=False, default=str) + ), + "event": event_payload, + } + try: + if tool_spec.handler_capability == "internal.mcp.local.execute": + handler_ref = json.loads(tool_spec.handler_ref or "{}") + output = await asyncio.wait_for( + self.execute( + "internal.mcp.local.execute", + { + "plugin_id": plugin_id, + "server_name": str( + handler_ref.get("server_name", "") + ).strip(), + "tool_name": str( + handler_ref.get("tool_name", "") + ).strip(), + "tool_args": call_payload["tool_args"], + }, + stream=False, + cancel_token=None, + request_id=request_id, + ), + timeout=tool_call_timeout, + ) + elif tool_spec.handler_capability: + output = await asyncio.wait_for( + record.session.invoke_capability( + tool_spec.handler_capability, + call_payload, + request_id=request_id, + ), + timeout=tool_call_timeout, + ) + else: + output = await asyncio.wait_for( + record.session.invoke_capability( + "internal.llm_tool.execute", + call_payload, + request_id=request_id, + ), + timeout=tool_call_timeout, + ) + except TimeoutError: + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content=( + f"Tool execution timeout after {tool_call_timeout} seconds" + ), + success=False, + ).to_payload(), + ensure_ascii=False, + ) + except Exception as exc: + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content=f"Tool execution failed: {exc}", + success=False, + ).to_payload(), + ensure_ascii=False, + ) + if not isinstance(output, dict): + return str(output) + content = output.get("content") + if output.get("success", True): + # Keep None distinct from an empty string so tools can signal + # "no content" without fabricating a textual result. + return None if content is None else str(content) + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content=str(content or ""), + success=False, + ).to_payload(), + ensure_ascii=False, + ) + + return _handler + + def _build_sdk_toolset( + self, + *, + plugin_id: str, + payload: dict[str, Any], + tool_call_timeout: int, + ) -> Any | None: + tool_specs = self._select_llm_tools_for_request(plugin_id, payload) + if not tool_specs: + return None + function_tool_cls, tool_set_cls = _get_runtime_tool_types() + tool_set = tool_set_cls() + for tool_spec in tool_specs: + tool_set.add_tool( + function_tool_cls( + name=tool_spec.name, + description=tool_spec.description, + parameters=tool_spec.parameters_schema, + handler=self._make_sdk_tool_handler( + plugin_id=plugin_id, + tool_spec=tool_spec, + tool_call_timeout=tool_call_timeout, + ), + ) + ) + return tool_set + + def _llm_response_to_payload(self, response: Any) -> dict[str, Any]: + usage = None + if response.usage is not None: + usage = { + "input_tokens": response.usage.input, + "output_tokens": response.usage.output, + "total_tokens": response.usage.total, + } + return { + "text": response.completion_text, + "usage": usage, + "finish_reason": "tool_calls" if response.tools_call_ids else "stop", + "tool_calls": response.to_openai_tool_calls(), + "role": response.role, + "reasoning_content": response.reasoning_content or None, + "reasoning_signature": response.reasoning_signature, + } + + async def _agent_tool_loop_run( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None: + raise AstrBotError.invalid_input( + "tool_loop_agent currently requires a message-bound SDK request" + ) + provider_id = str( + payload.get("provider_id") or "" + ).strip() or self._resolve_current_chat_provider_id(request_context) + if not provider_id: + raise AstrBotError.invalid_input("No active chat provider is available") + tool_call_timeout = int(payload.get("tool_call_timeout") or 60) + llm_resp = await self._star_context.tool_loop_agent( + event=request_context.event, + chat_provider_id=provider_id, + prompt=( + str(payload.get("prompt")) + if payload.get("prompt") is not None + else None + ), + image_urls=[ + str(item) + for item in payload.get("image_urls", []) + if isinstance(item, str) + ], + tools=self._build_sdk_toolset( + plugin_id=plugin_id, + payload=payload, + tool_call_timeout=tool_call_timeout, + ), + system_prompt=str(payload.get("system_prompt") or ""), + contexts=[ + dict(item) + for item in payload.get("contexts", []) + if isinstance(item, dict) + ], + max_steps=int(payload.get("max_steps") or 30), + tool_call_timeout=tool_call_timeout, + ) + return self._llm_response_to_payload(llm_resp) diff --git a/astrbot/core/sdk_bridge/capabilities/session.py b/astrbot/core/sdk_bridge/capabilities/session.py new file mode 100644 index 0000000000..0f992ff757 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/session.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from ..bridge_base import _get_runtime_sp +from ._host import CapabilityMixinHost + + +class SessionCapabilityMixin(CapabilityMixinHost): + def _register_session_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "session.plugin.is_enabled", + "Get session plugin enabled state", + ), + call_handler=self._session_plugin_is_enabled, + ) + self.register( + self._builtin_descriptor( + "session.plugin.filter_handlers", + "Filter handler metadata by session plugin config", + ), + call_handler=self._session_plugin_filter_handlers, + ) + self.register( + self._builtin_descriptor( + "session.service.is_llm_enabled", + "Get session LLM enabled state", + ), + call_handler=self._session_service_is_llm_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_llm_status", + "Set session LLM enabled state", + ), + call_handler=self._session_service_set_llm_status, + ) + self.register( + self._builtin_descriptor( + "session.service.is_tts_enabled", + "Get session TTS enabled state", + ), + call_handler=self._session_service_is_tts_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_tts_status", + "Set session TTS enabled state", + ), + call_handler=self._session_service_set_tts_status, + ) + + async def _load_session_plugin_config(self, session_id: str) -> dict[str, Any]: + raw_config = await _get_runtime_sp().get_async( + scope="umo", + scope_id=session_id, + key="session_plugin_config", + default={}, + ) + return self._normalize_session_scoped_config(raw_config, session_id) + + async def _load_session_service_config(self, session_id: str) -> dict[str, Any]: + raw_config = await _get_runtime_sp().get_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + default={}, + ) + return self._normalize_session_scoped_config(raw_config, session_id) + + async def _session_plugin_is_enabled( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + plugin_name = str(payload.get("plugin_name", "")).strip() + config = await self._load_session_plugin_config(session_id) + enabled_plugins = { + str(item) for item in config.get("enabled_plugins", []) if str(item).strip() + } + disabled_plugins = { + str(item) + for item in config.get("disabled_plugins", []) + if str(item).strip() + } + if ( + plugin_name in disabled_plugins + and plugin_name not in self._reserved_plugin_names() + ): + return {"enabled": False} + if plugin_name in enabled_plugins: + return {"enabled": True} + return {"enabled": True} + + async def _session_plugin_filter_handlers( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + handlers = payload.get("handlers") + if not isinstance(handlers, list): + raise AstrBotError.invalid_input( + "session.plugin.filter_handlers requires a handlers array" + ) + config = await self._load_session_plugin_config(session_id) + disabled_plugins = { + str(item) + for item in config.get("disabled_plugins", []) + if str(item).strip() + } + reserved_plugins = self._reserved_plugin_names() + filtered = [] + for item in handlers: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_name", "")).strip() + if ( + plugin_name + and plugin_name in disabled_plugins + and plugin_name not in reserved_plugins + ): + continue + filtered.append(dict(item)) + return {"handlers": filtered} + + async def _session_service_is_llm_enabled( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + return {"enabled": bool(config.get("llm_enabled", True))} + + async def _session_service_set_llm_status( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + config["llm_enabled"] = bool(payload.get("enabled", False)) + await _get_runtime_sp().put_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + value=config, + ) + return {} + + async def _session_service_is_tts_enabled( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + return {"enabled": bool(config.get("tts_enabled", True))} + + async def _session_service_set_tts_status( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + config["tts_enabled"] = bool(payload.get("enabled", False)) + await _get_runtime_sp().put_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + value=config, + ) + return {} diff --git a/astrbot/core/sdk_bridge/capabilities/skill.py b/astrbot/core/sdk_bridge/capabilities/skill.py new file mode 100644 index 0000000000..73fcbab614 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/skill.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from astrbot.core import logger + +from ._host import CapabilityMixinHost + + +class SkillCapabilityMixin(CapabilityMixinHost): + def _register_skill_capabilities(self) -> None: + self.register( + self._builtin_descriptor("skill.register", "Register SDK skill"), + call_handler=self._skill_register, + ) + self.register( + self._builtin_descriptor("skill.unregister", "Unregister SDK skill"), + call_handler=self._skill_unregister, + ) + self.register( + self._builtin_descriptor("skill.list", "List SDK skills"), + call_handler=self._skill_list, + ) + + async def _skill_register( + self, + request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, str]: + plugin_id = self._resolve_plugin_id(request_id) + result = self._plugin_bridge.register_skill( + plugin_id=plugin_id, + name=str(payload.get("name", "")), + path=str(payload.get("path", "")), + description=str(payload.get("description", "")), + ) + await self._sync_registered_skills_to_sandboxes() + return result + + async def _skill_unregister( + self, + request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, bool]: + plugin_id = self._resolve_plugin_id(request_id) + removed = self._plugin_bridge.unregister_skill( + plugin_id=plugin_id, + name=str(payload.get("name", "")), + ) + if removed: + await self._sync_registered_skills_to_sandboxes() + return {"removed": removed} + + async def _skill_list( + self, + request_id: str, + _payload: dict[str, object], + _token, + ) -> dict[str, list[dict[str, str]]]: + plugin_id = self._resolve_plugin_id(request_id) + return {"skills": self._plugin_bridge.list_registered_skills(plugin_id)} + + async def _sync_registered_skills_to_sandboxes(self) -> None: + try: + from astrbot.core.computer.computer_client import ( + sync_skills_to_active_sandboxes, + ) + + await sync_skills_to_active_sandboxes() + except Exception as exc: + logger.warning( + "Failed to sync skills to active sandboxes after SDK skill update: %s", + exc, + ) diff --git a/astrbot/core/sdk_bridge/capabilities/system.py b/astrbot/core/sdk_bridge/capabilities/system.py new file mode 100644 index 0000000000..7321e56be4 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/system.py @@ -0,0 +1,596 @@ +from __future__ import annotations + +import asyncio +import uuid +from collections.abc import AsyncIterator +from pathlib import Path +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from ..bridge_base import ( + _EventStreamState, + _get_runtime_astrbot_config, + _get_runtime_file_token_service, + _get_runtime_html_renderer, +) +from ._host import CapabilityMixinHost + + +class SystemCapabilityMixin(CapabilityMixinHost): + @staticmethod + def _overlay_request_id(request_id: str, payload: dict[str, Any]) -> str: + scope_request_id = payload.get("_request_scope_id") + if isinstance(scope_request_id, str) and scope_request_id.strip(): + return scope_request_id + return request_id + + def _register_system_capabilities(self) -> None: + self.register( + self._builtin_descriptor("system.get_data_dir", "Get plugin data dir"), + call_handler=self._system_get_data_dir, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.text_to_image", "Render text to image"), + call_handler=self._system_text_to_image, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.html_render", "Render html template"), + call_handler=self._system_html_render, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.register", "Register file token"), + call_handler=self._system_file_register, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.handle", "Resolve file token"), + call_handler=self._system_file_handle, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.register", + "Register sdk session waiter", + ), + call_handler=self._system_session_waiter_register, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.unregister", + "Unregister sdk session waiter", + ), + call_handler=self._system_session_waiter_unregister, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.react", "Send sdk event reaction"), + call_handler=self._system_event_react, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_typing", + "Send sdk event typing state", + ), + call_handler=self._system_event_send_typing, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming", + "Send sdk event streaming chunks", + ), + call_handler=self._system_event_send_streaming, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_chunk", + "Push sdk event streaming chunk", + ), + call_handler=self._system_event_send_streaming_chunk, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_close", + "Close sdk event streaming session", + ), + call_handler=self._system_event_send_streaming_close, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.get_state", + "Read sdk request llm state", + ), + call_handler=self._system_event_llm_get_state, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.request", + "Request default llm for current sdk request", + ), + call_handler=self._system_event_llm_request, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.result.get", + "Read sdk request result", + ), + call_handler=self._system_event_result_get, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.result.set", + "Write sdk request result", + ), + call_handler=self._system_event_result_set, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.result.clear", + "Clear sdk request result", + ), + call_handler=self._system_event_result_clear, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.get", + "Read sdk request handler whitelist", + ), + call_handler=self._system_event_handler_whitelist_get, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.set", + "Write sdk request handler whitelist", + ), + call_handler=self._system_event_handler_whitelist_set, + exposed=False, + ) + + def _register_registry_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "registry.get_handlers_by_event_type", + "List SDK handlers by event type", + ), + call_handler=self._registry_get_handlers_by_event_type, + ) + self.register( + self._builtin_descriptor( + "registry.get_handler_by_full_name", + "Get SDK handler metadata by full name", + ), + call_handler=self._registry_get_handler_by_full_name, + ) + self.register( + self._builtin_descriptor( + "registry.command.register", + "Register dynamic command route", + ), + call_handler=self._registry_command_register, + ) + + async def _system_get_data_dir( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + data_dir = Path(get_astrbot_data_path()) / "plugin_data" / plugin_id + data_dir.mkdir(parents=True, exist_ok=True) + return {"path": str(data_dir.resolve())} + + async def _system_text_to_image( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + config_obj = self._star_context.get_config() + template_name = None + if hasattr(config_obj, "get"): + try: + template_name = config_obj.get("t2i_active_template") + except Exception: + template_name = None + result = await _get_runtime_html_renderer().render_t2i( + str(payload.get("text", "")), + return_url=bool(payload.get("return_url", True)), + template_name=template_name, + ) + return {"result": result} + + async def _system_html_render( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + data = payload.get("data") + if not isinstance(data, dict): + raise AstrBotError.invalid_input("system.html_render requires object data") + options = payload.get("options") + if options is not None and not isinstance(options, dict): + raise AstrBotError.invalid_input( + "system.html_render options must be an object or null" + ) + result = await _get_runtime_html_renderer().render_custom_template( + str(payload.get("tmpl", "")), + data, + return_url=bool(payload.get("return_url", True)), + options=options, + ) + return {"result": result} + + async def _system_file_register( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + path = str(payload.get("path", "")).strip() + if not path: + raise AstrBotError.invalid_input("system.file.register requires path") + raw_timeout = payload.get("timeout") + timeout: float | None + if raw_timeout is None: + timeout = None + else: + try: + timeout = float(raw_timeout) + except (TypeError, ValueError) as exc: + raise AstrBotError.invalid_input( + "system.file.register timeout must be a number or null" + ) from exc + file_token = await _get_runtime_file_token_service().register_file( + path, timeout + ) + callback_host = _get_runtime_astrbot_config().get("callback_api_base") + if not callback_host: + raise AstrBotError.invalid_input( + "callback_api_base is required for system.file.register" + ) + base_url = str(callback_host).rstrip("/") + return {"token": file_token, "url": f"{base_url}/api/file/{file_token}"} + + async def _system_file_handle( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + file_token = str(payload.get("token", "")).strip() + if not file_token: + raise AstrBotError.invalid_input("system.file.handle requires token") + path = await _get_runtime_file_token_service().handle_file(file_token) + return {"path": str(path)} + + async def _system_session_waiter_register( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + self._plugin_bridge.register_session_waiter( + plugin_id=plugin_id, + session_key=str(payload.get("session_key", "")), + ) + return {} + + async def _system_session_waiter_unregister( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + self._plugin_bridge.unregister_session_waiter( + plugin_id=plugin_id, + session_key=str(payload.get("session_key", "")), + ) + return {} + + async def _system_event_react( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or request_context.cancelled: + return {"supported": False} + self._plugin_bridge.before_platform_send(request_context.dispatch_token) + await request_context.event.react(str(payload.get("emoji", ""))) + return { + "supported": bool( + self._plugin_bridge.mark_platform_send(request_context.dispatch_token) + ) + } + + async def _system_event_send_typing( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or request_context.cancelled: + return {"supported": False} + if type(request_context.event).send_typing is AstrMessageEvent.send_typing: + return {"supported": False} + await request_context.event.send_typing() + return {"supported": True} + + async def _system_event_send_streaming( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or request_context.cancelled: + return {"supported": False} + if ( + type(request_context.event).send_streaming + is AstrMessageEvent.send_streaming + ): + return {"supported": False} + self._plugin_bridge.before_platform_send(request_context.dispatch_token) + queue: asyncio.Queue[MessageChain | None] = asyncio.Queue() + + async def iterator() -> AsyncIterator[MessageChain]: + while True: + chunk = await queue.get() + if chunk is None or request_context.cancelled: + return + yield chunk + await asyncio.sleep(0) + + stream_id = uuid.uuid4().hex + task = asyncio.create_task( + request_context.event.send_streaming( + iterator(), + use_fallback=bool(payload.get("use_fallback", False)), + ) + ) + self._event_streams[stream_id] = _EventStreamState( + request_context=request_context, + queue=queue, + task=task, + ) + return {"supported": True, "stream_id": stream_id} + + async def _system_event_send_streaming_chunk( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + stream_state = self._event_streams.get(str(payload.get("stream_id", ""))) + if stream_state is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + if stream_state.request_context.cancelled: + raise AstrBotError.cancelled("The SDK request has been cancelled") + chain_payload = payload.get("chain") + if not isinstance(chain_payload, list): + raise AstrBotError.invalid_input( + "system.event.send_streaming_chunk requires a chain array" + ) + await stream_state.queue.put(self._build_core_message_chain(chain_payload)) + return {} + + async def _system_event_send_streaming_close( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + stream_id = str(payload.get("stream_id", "")) + stream_state = self._event_streams.pop(stream_id, None) + if stream_state is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + await stream_state.queue.put(None) + try: + await stream_state.task + finally: + self._event_streams.pop(stream_id, None) + return { + "supported": bool( + self._plugin_bridge.mark_platform_send( + stream_state.request_context.dispatch_token + ) + ) + } + + async def _system_event_llm_get_state( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + overlay = self._plugin_bridge.get_request_overlay_by_request_id( + overlay_request_id + ) + should_call_llm = self._plugin_bridge.get_should_call_llm_for_request( + overlay_request_id + ) + return { + "should_call_llm": bool(should_call_llm), + "requested_llm": bool(overlay.requested_llm) + if overlay is not None + else False, + } + + async def _system_event_llm_request( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + self._plugin_bridge.request_llm_for_request(overlay_request_id) + return await self._system_event_llm_get_state( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _system_event_result_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + return { + "result": self._plugin_bridge.get_result_payload_for_request( + overlay_request_id + ) + } + + async def _system_event_result_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + result_payload = payload.get("result") + if not isinstance(result_payload, dict): + raise AstrBotError.invalid_input( + "system.event.result.set requires an object result payload" + ) + overlay_request_id = self._overlay_request_id(request_id, payload) + if not self._plugin_bridge.set_result_for_request( + overlay_request_id, + result_payload, + ): + raise AstrBotError.cancelled("The SDK request overlay has been closed") + return { + "result": self._plugin_bridge.get_result_payload_for_request( + overlay_request_id + ) + } + + async def _system_event_result_clear( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + self._plugin_bridge.clear_result_for_request(overlay_request_id) + return {} + + async def _system_event_handler_whitelist_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + plugin_names = self._plugin_bridge.get_handler_whitelist_for_request( + overlay_request_id + ) + if plugin_names is None: + return {"plugin_names": None} + return {"plugin_names": sorted(plugin_names)} + + async def _system_event_handler_whitelist_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_names_payload = payload.get("plugin_names") + plugin_names: set[str] | None + if plugin_names_payload is None: + plugin_names = None + elif isinstance(plugin_names_payload, list): + plugin_names = { + str(item) for item in plugin_names_payload if str(item).strip() + } + else: + raise AstrBotError.invalid_input( + "system.event.handler_whitelist.set requires a string array or null" + ) + overlay_request_id = self._overlay_request_id(request_id, payload) + if not self._plugin_bridge.set_handler_whitelist_for_request( + overlay_request_id, + plugin_names, + ): + raise AstrBotError.cancelled("The SDK request overlay has been closed") + return await self._system_event_handler_whitelist_get( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _registry_get_handlers_by_event_type( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + event_type = str(payload.get("event_type", "")).strip() + return {"handlers": self._plugin_bridge.get_handlers_by_event_type(event_type)} + + async def _registry_get_handler_by_full_name( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + full_name = str(payload.get("full_name", "")).strip() + return {"handler": self._plugin_bridge.get_handler_by_full_name(full_name)} + + async def _registry_command_register( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + source_event_type = str(payload.get("source_event_type", "")).strip() + if source_event_type not in {"astrbot_loaded", "platform_loaded"}: + raise AstrBotError.invalid_input( + "register_commands is only available in astrbot_loaded/platform_loaded events" + ) + if bool(payload.get("ignore_prefix", False)): + raise AstrBotError.invalid_input( + "register_commands(ignore_prefix=True) is unsupported in SDK runtime" + ) + priority_value = payload.get("priority", 0) + if isinstance(priority_value, bool) or not isinstance(priority_value, int): + raise AstrBotError.invalid_input( + "registry.command.register priority must be an integer" + ) + plugin_id = self._resolve_plugin_id(request_id) + self._plugin_bridge.register_dynamic_command_route( + plugin_id=plugin_id, + command_name=str(payload.get("command_name", "")), + handler_full_name=str(payload.get("handler_full_name", "")), + desc=str(payload.get("desc", "")), + priority=priority_value, + use_regex=bool(payload.get("use_regex", False)), + ) + return {} diff --git a/astrbot/core/sdk_bridge/capability_bridge.py b/astrbot/core/sdk_bridge/capability_bridge.py new file mode 100644 index 0000000000..7368134cd4 --- /dev/null +++ b/astrbot/core/sdk_bridge/capability_bridge.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from .bridge_base import CapabilityBridgeBase +from .capabilities import ( + BasicCapabilityMixin, + ConversationCapabilityMixin, + KnowledgeBaseCapabilityMixin, + LLMCapabilityMixin, + MCPCapabilityMixin, + MessageHistoryCapabilityMixin, + PermissionCapabilityMixin, + PersonaCapabilityMixin, + PlatformCapabilityMixin, + ProviderCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + SystemCapabilityMixin, +) + +if TYPE_CHECKING: + from astrbot.core.star.context import Context as StarContext + +__all__ = ["CoreCapabilityBridge"] + + +class CoreCapabilityBridge( + SystemCapabilityMixin, + ProviderCapabilityMixin, + MCPCapabilityMixin, + PlatformCapabilityMixin, + PermissionCapabilityMixin, + KnowledgeBaseCapabilityMixin, + MessageHistoryCapabilityMixin, + ConversationCapabilityMixin, + PersonaCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + LLMCapabilityMixin, + BasicCapabilityMixin, + CapabilityBridgeBase, +): + def __init__(self, *, star_context: StarContext, plugin_bridge) -> None: + self._star_context = star_context + self._plugin_bridge = plugin_bridge + self._event_streams: dict[str, Any] = {} + self._memory_backends_by_plugin: dict[str, Any] = {} + self._memory_index_by_plugin: dict[str, dict[str, dict[str, Any]]] = {} + self._memory_dirty_keys_by_plugin: dict[str, set[str]] = {} + self._memory_expires_at_by_plugin: dict[str, dict[str, Any]] = {} + # CapabilityRouter.__init__() registers the built-in capability groups + # declared by this bridge and its mixins before extended groups are added. + super().__init__() + self._register_provider_capabilities() + self._register_provider_manager_capabilities() + self._register_mcp_capabilities() + self._register_platform_manager_capabilities() + self._register_permission_capabilities() + self._register_persona_capabilities() + self._register_conversation_capabilities() + self._register_message_history_capabilities() + self._register_kb_capabilities() + self._register_skill_capabilities() + self._register_system_capabilities() + self._register_registry_capabilities() + self._register_db_capabilities() + self._register_memory_capabilities() + self._register_http_capabilities() + self._register_metadata_capabilities() diff --git a/astrbot/core/sdk_bridge/dispatch_engine.py b/astrbot/core/sdk_bridge/dispatch_engine.py new file mode 100644 index 0000000000..ced44ab532 --- /dev/null +++ b/astrbot/core/sdk_bridge/dispatch_engine.py @@ -0,0 +1,538 @@ +from __future__ import annotations + +import asyncio +import uuid +from typing import TYPE_CHECKING, Any + +from astrbot.core import logger +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.message.message_types import sdk_message_type +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import LLMResponse as CoreLLMResponse +from astrbot.core.provider.entities import ProviderRequest as CoreProviderRequest + +from .event_payload import extract_sdk_handler_result +from .runtime_store import ( + SdkDispatchResult, + SdkPluginRecord, + _DispatchState, + _InFlightRequest, + _RequestContext, +) + +if TYPE_CHECKING: + from .plugin_bridge import SdkPluginBridge + + +class SdkDispatchEngine: + def __init__(self, *, bridge: SdkPluginBridge) -> None: + self.bridge = bridge + + async def dispatch_message(self, event: AstrMessageEvent) -> SdkDispatchResult: + result = SdkDispatchResult() + if event.is_stopped(): + result.skipped_reason = self.bridge.SKIP_LEGACY_STOPPED + return result + if self.bridge._legacy_has_replied(event): + result.skipped_reason = self.bridge.SKIP_LEGACY_REPLIED + return result + + waiter_plugins = self.bridge._match_waiter_plugins(event.unified_msg_origin) + if waiter_plugins: + return await self.dispatch_waiter_event(event, waiter_plugins) + + dispatch_token = self.bridge.get_or_bind_dispatch_token(event) + overlay = self.bridge._ensure_request_overlay( + dispatch_token, + # 使用统一方法获取 LLM 意愿,避免到处重复 not event.call_llm 的反转逻辑 + should_call_llm=self.bridge.get_effective_should_call_llm(event), + ) + matches = self.bridge._match_handlers(event) + permission_denied = self.bridge._resolve_command_permission_denied(event) + if permission_denied is not None and not self.bridge._has_command_trigger_match( + matches + ): + dispatch_state = _DispatchState(event=event) + request_context = self.bridge._request_contexts.get(dispatch_token) + if request_context is None: + request_context = _RequestContext( + plugin_id=permission_denied["plugin_id"], + request_id="", + dispatch_token=dispatch_token, + dispatch_state=dispatch_state, + ) + self.bridge._request_contexts[dispatch_token] = request_context + else: + request_context.plugin_id = permission_denied["plugin_id"] + request_context.dispatch_state = dispatch_state + self.bridge._set_sdk_origin_plugin_id(event, permission_denied["plugin_id"]) + event.set_result(MessageEventResult().message(permission_denied["message"])) + event.stop_event() + self.bridge.request_runtime._set_event_default_llm_blocked( + event, + blocked=True, + ) + overlay.should_call_llm = False + result.stopped = True + return result + group_fallback = self.bridge._resolve_group_root_fallback(event) + if group_fallback is not None and not self.bridge._has_command_trigger_match( + matches + ): + dispatch_state = _DispatchState(event=event) + request_context = self.bridge._request_contexts.get(dispatch_token) + if request_context is None: + request_context = _RequestContext( + plugin_id=group_fallback["plugin_id"], + request_id="", + dispatch_token=dispatch_token, + dispatch_state=dispatch_state, + ) + self.bridge._request_contexts[dispatch_token] = request_context + else: + request_context.plugin_id = group_fallback["plugin_id"] + request_context.dispatch_state = dispatch_state + self.bridge._set_sdk_origin_plugin_id(event, group_fallback["plugin_id"]) + event.set_result(MessageEventResult().message(group_fallback["help_text"])) + event.stop_event() + # 群组 fallback(如帮助文本)不应触发 LLM,直接阻止 + self.bridge.request_runtime._set_event_default_llm_blocked( + event, + blocked=True, + ) + overlay.should_call_llm = False + result.stopped = True + return result + if not matches: + result.skipped_reason = self.bridge.SKIP_NO_MATCH + return result + result.matched_handlers = [ + {"plugin_id": match.plugin_id, "handler_id": match.handler_id} + for match in matches + ] + + dispatch_state = _DispatchState(event=event) + request_context = self.bridge._request_contexts.get(dispatch_token) + if request_context is None: + request_context = _RequestContext( + plugin_id="", + request_id="", + dispatch_token=dispatch_token, + dispatch_state=dispatch_state, + ) + self.bridge._request_contexts[dispatch_token] = request_context + else: + request_context.dispatch_state = dispatch_state + skipped_reason = None + for match in matches: + whitelist = ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + if whitelist is not None and match.plugin_id not in whitelist: + continue + record = self.bridge._records.get(match.plugin_id) + if record is None: + continue + if record.state == self.bridge.SDK_STATE_RELOADING: + skipped_reason = skipped_reason or self.bridge.SKIP_SDK_RELOADING + continue + if ( + record.state + in {self.bridge.SDK_STATE_FAILED, self.bridge.SDK_STATE_DISABLED} + or record.session is None + ): + skipped_reason = skipped_reason or self.bridge.SKIP_WORKER_FAILED + continue + + request_id = f"sdk_{record.plugin_id}_{uuid.uuid4().hex}" + request_context.plugin_id = record.plugin_id + request_context.request_id = request_id + request_context.cancelled = False + self.bridge._set_sdk_origin_plugin_id(event, record.plugin_id) + setattr(event, "_sdk_last_request_id", request_id) + payload = self.bridge.build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=record.plugin_id, + request_id=request_id, + overlay=overlay, + ) + task = asyncio.create_task( + record.session.invoke_handler( + match.handler_id, + payload, + request_id=request_id, + args=match.args, + ) + ) + self.bridge._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=record.plugin_id, + ) + self.bridge._plugin_requests.setdefault(record.plugin_id, {})[ + request_id + ] = _InFlightRequest( + request_id=request_id, + dispatch_token=dispatch_token, + task=task, + ) + try: + output = await task + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning( + "SDK handler failed: plugin=%s handler=%s error=%s", + record.plugin_id, + match.handler_id, + exc, + ) + skipped_reason = skipped_reason or self.bridge.SKIP_WORKER_FAILED + output = {} + finally: + inflight = self.bridge._plugin_requests.get(record.plugin_id, {}).pop( + request_id, + None, + ) + + if inflight is not None and inflight.logical_cancelled: + continue + + handler_result = extract_sdk_handler_result( + output if isinstance(output, dict) else {} + ) + if isinstance(output, dict) and "sdk_local_extras" in output: + self.bridge._persist_sdk_local_extras_from_handler( + overlay, + output.get("sdk_local_extras"), + plugin_id=record.plugin_id, + handler_id=match.handler_id, + ) + result.executed_handlers.append( + {"plugin_id": record.plugin_id, "handler_id": match.handler_id} + ) + dispatch_state.sent_message = ( + dispatch_state.sent_message or handler_result["sent_message"] + ) + dispatch_state.stopped = dispatch_state.stopped or handler_result["stop"] + if handler_result["call_llm"]: + overlay.requested_llm = True + overlay.should_call_llm = True + if handler_result["sent_message"] or handler_result["stop"]: + overlay.should_call_llm = False + if handler_result["stop"]: + break + + result.sent_message = dispatch_state.sent_message + result.stopped = dispatch_state.stopped + if not result.executed_handlers: + result.skipped_reason = skipped_reason or self.bridge.SKIP_NO_MATCH + if result.sent_message: + # 已发送消息:同步标记 event 和 overlay 的发送状态,防止 LLM 重复回复 + self.bridge.request_runtime._mark_event_send_operation(event) + overlay.should_call_llm = False + self.bridge.request_runtime._set_event_default_llm_blocked( + event, + blocked=True, + ) + if result.stopped: + event.stop_event() + # 事件被 stop 后 LLM 不应再处理,双重写入 overlay 和 event + overlay.should_call_llm = False + self.bridge.request_runtime._set_event_default_llm_blocked( + event, + blocked=True, + ) + return result + + async def dispatch_system_event( + self, + event_type: str, + payload: dict[str, Any] | None = None, + ) -> None: + normalized_platform = self.bridge._normalize_platform_name( + (payload or {}).get("platform") + ) + event_payload = { + "type": event_type, + "event_type": event_type, + "text": str((payload or {}).get("message_outline", "")), + "session_id": str((payload or {}).get("session_id", "")), + "platform": str((payload or {}).get("platform", "")), + "platform_id": str((payload or {}).get("platform_id", "")), + "message_type": sdk_message_type((payload or {}).get("message_type", "")), + "sender_name": str((payload or {}).get("sender_name", "")), + "self_id": str((payload or {}).get("self_id", "")), + "raw": {"event_type": event_type, **(payload or {})}, + } + for key, value in (payload or {}).items(): + event_payload[key] = value + matches = self.bridge._match_event_handlers( + event_type, + platform_name=normalized_platform, + ) + for record, descriptor in matches: + if record.session is None: + continue + try: + await record.session.invoke_handler( + descriptor.id, + event_payload, + request_id=f"sdk_event_{record.plugin_id}_{uuid.uuid4().hex}", + args={}, + ) + except Exception as exc: + logger.warning( + "SDK event handler failed: plugin=%s handler=%s error=%s", + record.plugin_id, + descriptor.id, + exc, + ) + + async def dispatch_message_event( + self, + event_type: str, + event: AstrMessageEvent, + payload: dict[str, Any] | None = None, + *, + provider_request: CoreProviderRequest | None = None, + llm_response: CoreLLMResponse | None = None, + event_result: MessageEventResult | None = None, + ) -> None: + dispatch_token = self.bridge._get_dispatch_token(event) + if not dispatch_token: + return + overlay = self.bridge.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return + normalized_platform = self.bridge._normalize_platform_name( + event.get_platform_name() + ) + matches = self.bridge._match_event_handlers( + event_type, + allowed_plugins=overlay.handler_whitelist, + platform_name=normalized_platform, + ) + for record, descriptor in matches: + if record.session is None: + continue + request_id = f"sdk_event_{record.plugin_id}_{uuid.uuid4().hex}" + request_context = self.bridge._request_contexts.get(dispatch_token) + if request_context is None: + request_context = _RequestContext( + plugin_id=record.plugin_id, + request_id=request_id, + dispatch_token=dispatch_token, + dispatch_state=_DispatchState(event=event), + ) + self.bridge._request_contexts[dispatch_token] = request_context + request_context.plugin_id = record.plugin_id + request_context.request_id = request_id + if request_context.dispatch_state is None: + request_context.dispatch_state = _DispatchState(event=event) + request_context.dispatch_state.event = event + request_context.cancelled = False + self.bridge._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=record.plugin_id, + ) + event_payload = self.bridge.build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=record.plugin_id, + request_id=request_id, + overlay=overlay, + raw_updates={"event_type": event_type, **(payload or {})}, + field_updates={ + "type": event_type, + "event_type": event_type, + **(payload or {}), + }, + ) + if provider_request is not None: + request_payload = self.bridge._core_provider_request_to_sdk_payload( + provider_request + ) + event_payload["provider_request"] = request_payload + if isinstance(event_payload["raw"], dict): + event_payload["raw"]["provider_request"] = request_payload + if llm_response is not None: + response_payload = self.bridge._core_llm_response_to_sdk_payload( + llm_response + ) + event_payload["llm_response"] = response_payload + if isinstance(event_payload["raw"], dict): + event_payload["raw"]["llm_response"] = response_payload + if event_result is not None: + result_payload = self.bridge._legacy_result_to_sdk_payload(event_result) + if result_payload is not None: + event_payload["event_result"] = result_payload + if isinstance(event_payload["raw"], dict): + event_payload["raw"]["event_result"] = result_payload + try: + output = await record.session.invoke_handler( + descriptor.id, + event_payload, + request_id=request_id, + args={}, + ) + if isinstance(output, dict): + handler_result = extract_sdk_handler_result(output) + if "sdk_local_extras" in output: + self.bridge._persist_sdk_local_extras_from_handler( + overlay, + output.get("sdk_local_extras"), + plugin_id=record.plugin_id, + handler_id=descriptor.id, + ) + request_payload = output.get("provider_request") + if provider_request is not None and isinstance( + request_payload, dict + ): + self.bridge._apply_sdk_provider_request_payload( + provider_request, + request_payload, + ) + result_payload = output.get("event_result") + if event_result is not None and isinstance(result_payload, dict): + if not self.bridge.set_result_for_request( + request_id, + result_payload, + ): + self.bridge._apply_sdk_result_payload( + event_result, + result_payload, + ) + if handler_result["stop"]: + event.stop_event() + if handler_result["call_llm"]: + overlay.requested_llm = True + overlay.should_call_llm = True + if handler_result["sent_message"]: + # 系统事件处理中发送了消息,标记到 event 供后续 pipeline 判断 + self.bridge.request_runtime._mark_event_send_operation(event) + if handler_result["sent_message"] or handler_result["stop"]: + overlay.should_call_llm = False + except Exception as exc: + logger.warning( + "SDK event handler failed: plugin=%s handler=%s error=%s", + record.plugin_id, + descriptor.id, + exc, + ) + + async def dispatch_waiter_event( + self, + event: AstrMessageEvent, + records: list[SdkPluginRecord], + ) -> SdkDispatchResult: + result = SdkDispatchResult() + dispatch_state = _DispatchState(event=event) + dispatch_token = self.bridge.get_or_bind_dispatch_token(event) + overlay = self.bridge._ensure_request_overlay( + dispatch_token, + should_call_llm=self.bridge.get_effective_should_call_llm(event), + ) + request_context = _RequestContext( + plugin_id="", + request_id="", + dispatch_token=dispatch_token, + dispatch_state=dispatch_state, + ) + self.bridge._request_contexts[dispatch_token] = request_context + for record in records: + if record.state in { + self.bridge.SDK_STATE_DISABLED, + self.bridge.SDK_STATE_FAILED, + self.bridge.SDK_STATE_RELOADING, + }: + continue + if record.session is None: + continue + whitelist = ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + if whitelist is not None and record.plugin_id not in whitelist: + continue + request_id = f"sdk_waiter_{record.plugin_id}_{uuid.uuid4().hex}" + request_context.plugin_id = record.plugin_id + request_context.request_id = request_id + request_context.cancelled = False + self.bridge._set_sdk_origin_plugin_id(event, record.plugin_id) + setattr(event, "_sdk_last_request_id", request_id) + payload = self.bridge.build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=record.plugin_id, + request_id=request_id, + overlay=overlay, + ) + self.bridge._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=record.plugin_id, + ) + try: + output = await record.session.invoke_handler( + "__sdk_session_waiter__", + payload, + request_id=request_id, + args={}, + ) + except Exception as exc: + logger.warning( + "SDK waiter dispatch failed: plugin=%s error=%s", + record.plugin_id, + exc, + ) + output = {} + handler_result = extract_sdk_handler_result( + output if isinstance(output, dict) else {} + ) + if isinstance(output, dict) and "sdk_local_extras" in output: + self.bridge._persist_sdk_local_extras_from_handler( + overlay, + output.get("sdk_local_extras"), + plugin_id=record.plugin_id, + handler_id="__sdk_session_waiter__", + ) + result.executed_handlers.append( + {"plugin_id": record.plugin_id, "handler_id": "__sdk_session_waiter__"} + ) + dispatch_state.sent_message = ( + dispatch_state.sent_message or handler_result["sent_message"] + ) + dispatch_state.stopped = dispatch_state.stopped or handler_result["stop"] + if handler_result["call_llm"]: + overlay.requested_llm = True + overlay.should_call_llm = True + if handler_result["sent_message"] or handler_result["stop"]: + overlay.should_call_llm = False + if handler_result["stop"]: + break + result.sent_message = dispatch_state.sent_message + result.stopped = dispatch_state.stopped + if not result.executed_handlers: + result.skipped_reason = self.bridge.SKIP_NO_MATCH + if result.sent_message: + # waiter dispatch 同样需要同步发送状态到 event,供后续 pipeline 阶段判断 + self.bridge.request_runtime._mark_event_send_operation(event) + overlay.should_call_llm = False + self.bridge.request_runtime._set_event_default_llm_blocked( + event, + blocked=True, + ) + if result.stopped: + event.stop_event() + overlay.should_call_llm = False + self.bridge.request_runtime._set_event_default_llm_blocked( + event, + blocked=True, + ) + return result diff --git a/astrbot/core/sdk_bridge/event_payload.py b/astrbot/core/sdk_bridge/event_payload.py new file mode 100644 index 0000000000..3d6db223eb --- /dev/null +++ b/astrbot/core/sdk_bridge/event_payload.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import copy +import json +from dataclasses import dataclass +from datetime import datetime +from types import MappingProxyType +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from astrbot_sdk.message.components import component_to_payload_sync + +from astrbot.core.message.message_types import sdk_message_type + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +DROP_VALUE = object() + + +@dataclass(frozen=True, slots=True) +class InboundEventSnapshot: + text: str + user_id: str + group_id: str | None + platform: str + platform_id: str + session_id: str + self_id: str + message_type: str + sender_name: str + is_admin: bool + is_wake: bool + is_at_or_wake_command: bool + message_outline: str + messages: tuple[dict[str, Any], ...] + target: MappingProxyType + + def to_payload( + self, + *, + dispatch_token: str, + plugin_id: str, + request_id: str, + host_extras: dict[str, Any], + sdk_local_extras: dict[str, Any], + raw_updates: dict[str, Any] | None = None, + field_updates: dict[str, Any] | None = None, + ) -> dict[str, Any]: + raw = { + "dispatch_token": dispatch_token, + "plugin_id": plugin_id, + "request_id": request_id, + "platform_id": self.platform_id, + } + if raw_updates: + raw.update(copy.deepcopy(raw_updates)) + + merged_extras = dict(host_extras) + merged_extras.update(sdk_local_extras) + payload: dict[str, Any] = { + "text": self.text, + "user_id": self.user_id, + "group_id": self.group_id, + "platform": self.platform, + "platform_id": self.platform_id, + "session_id": self.session_id, + "self_id": self.self_id, + "message_type": self.message_type, + "sender_name": self.sender_name, + "is_admin": self.is_admin, + "is_wake": self.is_wake, + "is_at_or_wake_command": self.is_at_or_wake_command, + "message_outline": self.message_outline, + "raw": raw, + "target": { + "conversation_id": self.target["conversation_id"], + "platform": self.target["platform"], + "raw": dict(raw), + }, + "host_extras": copy.deepcopy(host_extras), + "sdk_local_extras": copy.deepcopy(sdk_local_extras), + "extras": merged_extras, + } + if self.messages: + payload["messages"] = copy.deepcopy(list(self.messages)) + if field_updates: + payload.update(copy.deepcopy(field_updates)) + return payload + + +def sanitize_sdk_extra_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + if isinstance(value, UUID): + return str(value) + if isinstance(value, (list, tuple)): + items = [] + for item in value: + normalized = sanitize_sdk_extra_value(item) + if normalized is not DROP_VALUE: + items.append(normalized) + return items + if isinstance(value, dict): + normalized_dict: dict[str, Any] = {} + for key, item in value.items(): + normalized = sanitize_sdk_extra_value(item) + if normalized is not DROP_VALUE: + normalized_dict[str(key)] = normalized + return normalized_dict + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + try: + return sanitize_sdk_extra_value(model_dump()) + except Exception: + return DROP_VALUE + dict_view = getattr(value, "__dict__", None) + if isinstance(dict_view, dict) and dict_view: + return sanitize_sdk_extra_value(dict_view) + try: + json.dumps(value) + except (TypeError, ValueError): + return DROP_VALUE + return value + + +def sanitize_sdk_extras(extras: dict[str, Any]) -> dict[str, Any]: + sanitized: dict[str, Any] = {} + for key, value in extras.items(): + normalized = sanitize_sdk_extra_value(value) + if normalized is not DROP_VALUE: + sanitized[str(key)] = normalized + return sanitized + + +def normalize_sdk_local_extras( + payload: Any, +) -> tuple[dict[str, Any], list[str]]: + if not isinstance(payload, dict): + return {}, [] + normalized: dict[str, Any] = {} + dropped_keys: list[str] = [] + for key, value in payload.items(): + normalized_value = sanitize_sdk_extra_value(value) + if normalized_value is DROP_VALUE: + dropped_keys.append(str(key)) + continue + normalized[str(key)] = normalized_value + return normalized, dropped_keys + + +def extract_sdk_handler_result(sdk_result: dict[str, Any] | None) -> dict[str, bool]: + if not sdk_result: + return {"sent_message": False, "stop": False, "call_llm": False} + return { + "sent_message": bool(sdk_result.get("sent_message", False)), + "stop": bool(sdk_result.get("stop", False)), + "call_llm": bool(sdk_result.get("call_llm", False)), + } + + +def build_inbound_event_snapshot(event: AstrMessageEvent) -> InboundEventSnapshot: + group_id = event.get_group_id() or None + user_id = event.get_sender_id() or "" + messages: list[dict[str, Any]] = [] + for component in event.get_messages(): + try: + messages.append(component_to_payload_sync(component)) + except Exception: + messages.append( + { + "type": "unknown", + "data": {"value": str(component)}, + } + ) + return InboundEventSnapshot( + text=event.get_message_str(), + user_id=user_id, + group_id=group_id, + platform=event.get_platform_name(), + platform_id=event.get_platform_id(), + session_id=event.unified_msg_origin, + self_id=event.get_self_id(), + message_type=sdk_message_type( + event.get_message_type(), + group_id=group_id, + user_id=user_id or None, + ), + sender_name=event.get_sender_name(), + is_admin=event.is_admin(), + is_wake=bool(event.is_wake), + is_at_or_wake_command=bool(event.is_at_or_wake_command), + message_outline=event.get_message_outline(), + messages=tuple(messages), + target=MappingProxyType( + { + "conversation_id": event.unified_msg_origin, + "platform": event.get_platform_name(), + } + ), + ) diff --git a/astrbot/core/sdk_bridge/lifecycle_manager.py b/astrbot/core/sdk_bridge/lifecycle_manager.py new file mode 100644 index 0000000000..00ba31e8b8 --- /dev/null +++ b/astrbot/core/sdk_bridge/lifecycle_manager.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import asyncio +import contextlib +from typing import TYPE_CHECKING, Any + +from astrbot.core import logger + +if TYPE_CHECKING: + from .plugin_bridge import SdkPluginBridge + + +class SdkPluginLifecycleManager: + def __init__(self, *, bridge: SdkPluginBridge) -> None: + self.bridge = bridge + # Phase 1 lock: serialize discovery/planning so every operation builds its + # action plan from a coherent snapshot instead of racing on shared metadata. + self._plan_lock = asyncio.Lock() + # Phase 3 lock: serialize the short global refresh/commit tail after each + # plugin operation. This keeps command/native-platform refreshes ordered + # without holding a global lock during slow worker startup/shutdown. + self._commit_lock = asyncio.Lock() + # Phase 2 lock map: each plugin gets its own execution lock so unrelated + # plugins can load/teardown in parallel, while the same plugin remains + # strictly serialized across reload/enable/disable/worker-close flows. + self._plugin_locks: dict[str, asyncio.Lock] = {} + self._startup_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + if self.bridge._started: + return + self.bridge._sweep_stale_mcp_leases() + self.bridge._started = True + self._schedule_background_reload(reset_restart_budget=True) + + async def stop(self) -> None: + if not self.bridge._started and not self.bridge._records: + return + self.bridge._stopping = True + if self._startup_task is not None: + self._startup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._startup_task + self._startup_task = None + for plugin_id in list(self.bridge._records.keys()): + await self.bridge._cancel_plugin_requests(plugin_id) + await self.bridge._close_temporary_mcp_sessions(plugin_id) + for record in list(self.bridge._records.values()): + await self.bridge._shutdown_local_mcp_servers(record) + if record.session is not None: + await record.session.stop() + record.session = None + self.bridge._records.clear() + self.bridge._request_contexts.clear() + self.bridge._request_id_to_token.clear() + self.bridge._request_plugin_ids.clear() + for overlay in list(self.bridge._request_overlays.values()): + if overlay.cleanup_task is not None: + overlay.cleanup_task.cancel() + self.bridge._request_overlays.clear() + self.bridge._plugin_requests.clear() + self.bridge._http_routes.clear() + self.bridge._session_waiters.clear() + self.bridge._schedule_job_ids.clear() + self.bridge._temporary_mcp_sessions.clear() + self.bridge._started = False + self.bridge._stopping = False + + async def reload_all(self, *, reset_restart_budget: bool = False) -> None: + stale_plugin_ids, load_plan = await self._plan_reload_all() + + for plugin_id in stale_plugin_ids: + async with self._plugin_lock(plugin_id): + # The plugin may have been removed already by a concurrent operation. + if plugin_id not in self.bridge._records: + continue + await self.bridge._teardown_plugin(plugin_id) + self.bridge._records.pop(plugin_id, None) + + for load_order, plugin in load_plan: + async with self._plugin_lock(plugin.name): + await self.bridge._load_or_reload_plugin( + plugin, + load_order=load_order, + reset_restart_budget=reset_restart_budget, + ) + + await self._commit_runtime_refresh() + + async def reload_plugin(self, plugin_id: str) -> None: + load_order, plugin = await self._plan_single_plugin(plugin_id) + async with self._plugin_lock(plugin_id): + await self.bridge._load_or_reload_plugin( + plugin, + load_order=load_order, + reset_restart_budget=True, + ) + await self._commit_runtime_refresh() + + async def turn_off_plugin(self, plugin_id: str) -> None: + await self._plan_turn_off(plugin_id) + async with self._plugin_lock(plugin_id): + record = self.bridge._records.get(plugin_id) + if record is None: + raise ValueError(f"SDK plugin not found: {plugin_id}") + record.state = self.bridge.SDK_STATE_DISABLED + await self.bridge._cancel_plugin_requests(plugin_id) + await self.bridge._teardown_plugin(plugin_id) + record.failure_reason = "" + self.bridge._set_disabled_override(plugin_id, disabled=True) + await self._commit_runtime_refresh() + + async def turn_on_plugin(self, plugin_id: str) -> None: + load_order, plugin = await self._plan_single_plugin(plugin_id) + async with self._plugin_lock(plugin_id): + self.bridge._set_disabled_override(plugin_id, disabled=False) + await self.bridge._load_or_reload_plugin( + plugin, + load_order=load_order, + reset_restart_budget=True, + ) + record = self.bridge._records.get(plugin_id) + if record is not None and record.state == self.bridge.SDK_STATE_FAILED: + raise RuntimeError( + record.failure_reason or f"SDK plugin failed to start: {plugin_id}" + ) + await self._commit_runtime_refresh() + + async def handle_worker_closed(self, plugin_id: str) -> None: + async with self._plugin_lock(plugin_id): + if self.bridge._stopping: + return + await self.bridge._cancel_plugin_requests(plugin_id) + await self.bridge._close_temporary_mcp_sessions(plugin_id) + record = self.bridge._records.get(plugin_id) + if record is None: + return + await self.bridge._shutdown_local_mcp_servers(record) + record.session = None + if record.state in { + self.bridge.SDK_STATE_RELOADING, + self.bridge.SDK_STATE_DISABLED, + }: + await self._commit_runtime_refresh() + return + if not record.restart_attempted: + record.restart_attempted = True + logger.warning( + "SDK plugin worker closed unexpectedly, retrying once: %s", + plugin_id, + ) + await self.bridge._load_or_reload_plugin( + record.plugin, + load_order=record.load_order, + reset_restart_budget=False, + ) + await self._commit_runtime_refresh() + return + record.state = self.bridge.SDK_STATE_FAILED + self.bridge._http_routes.pop(plugin_id, None) + self.bridge._session_waiters.pop(plugin_id, None) + await self.bridge._unregister_schedule_jobs(plugin_id) + await self.bridge._clear_plugin_skills( + plugin_id=plugin_id, + record=record, + reason="worker failure cleanup", + ) + await self._commit_runtime_refresh() + + async def _plan_reload_all(self) -> tuple[list[str], list[tuple[int, Any]]]: + async with self._plan_lock: + discovered = self.bridge._discover_plugins() + self.bridge._set_discovery_issues(discovered.issues) + self.bridge.env_manager.plan(discovered.plugins) + known = {plugin.name for plugin in discovered.plugins} + self.bridge._make_skill_manager().prune_sdk_plugin_skills(known) + stale_plugin_ids = [ + plugin_id + for plugin_id in list(self.bridge._records.keys()) + if plugin_id not in known + ] + load_plan = list(enumerate(discovered.plugins)) + return stale_plugin_ids, load_plan + + async def _plan_single_plugin(self, plugin_id: str) -> tuple[int, Any]: + async with self._plan_lock: + discovered = self.bridge._discover_plugins() + self.bridge._set_discovery_issues(discovered.issues) + self.bridge.env_manager.plan(discovered.plugins) + for load_order, plugin in enumerate(discovered.plugins): + if plugin.name == plugin_id: + return load_order, plugin + raise ValueError(f"SDK plugin not found: {plugin_id}") + + async def _plan_turn_off(self, plugin_id: str) -> None: + async with self._plan_lock: + if self.bridge._records.get(plugin_id) is None: + raise ValueError(f"SDK plugin not found: {plugin_id}") + + async def _commit_runtime_refresh(self) -> None: + async with self._commit_lock: + self.bridge.refresh_command_compatibility_issues() + await self.bridge._refresh_native_platform_commands() + + def _plugin_lock(self, plugin_id: str) -> asyncio.Lock: + lock = self._plugin_locks.get(plugin_id) + if lock is None: + lock = asyncio.Lock() + self._plugin_locks[plugin_id] = lock + return lock + + def _schedule_background_reload(self, *, reset_restart_budget: bool) -> None: + if self._startup_task is not None and not self._startup_task.done(): + return + self._startup_task = asyncio.create_task( + self._background_reload(reset_restart_budget=reset_restart_budget), + name="sdk_plugin_bridge_startup", + ) + + async def _background_reload(self, *, reset_restart_budget: bool) -> None: + try: + await self.reload_all(reset_restart_budget=reset_restart_budget) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error("SDK plugin background startup failed: %s", exc, exc_info=True) + finally: + self._startup_task = None diff --git a/astrbot/core/sdk_bridge/mcp_manager.py b/astrbot/core/sdk_bridge/mcp_manager.py new file mode 100644 index 0000000000..b753a9c76b --- /dev/null +++ b/astrbot/core/sdk_bridge/mcp_manager.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +import asyncio +import contextlib +import uuid +from datetime import timedelta +from typing import TYPE_CHECKING, Any + +from astrbot_sdk.errors import AstrBotError + +from .runtime_store import ( + SdkPluginRecord, + _LocalMCPServerRuntime, + _TemporaryMCPSessionRuntime, +) + +if TYPE_CHECKING: + from .plugin_bridge import SdkPluginBridge + + +class SdkMcpManager: + def __init__(self, *, bridge: SdkPluginBridge) -> None: + self.bridge = bridge + + def get_local_mcp_server( + self, + plugin_id: str, + name: str, + ) -> dict[str, Any] | None: + runtime = self.bridge._local_mcp_record(plugin_id, name) + if runtime is None: + return None + return self.bridge._serialize_local_mcp_server(runtime) + + def list_local_mcp_servers(self, plugin_id: str) -> list[dict[str, Any]]: + record = self.bridge._records.get(plugin_id) + if record is None: + return [] + return [ + self.bridge._serialize_local_mcp_server(runtime) + for runtime in sorted( + record.local_mcp_servers.values(), + key=lambda item: item.name, + ) + ] + + async def connect_local_mcp_server( + self, + *, + plugin_id: str, + runtime: _LocalMCPServerRuntime, + timeout: float, + ) -> None: + runtime.ready_event.clear() + runtime.running = False + runtime.last_error = None + runtime.errlogs = [] + runtime.tools = [] + runtime.tool_specs = [] + self.bridge._remove_local_mcp_lease(runtime) + await self.bridge._cleanup_mcp_client(runtime.client) + runtime.client = None + + client = self.bridge._make_mcp_client() + client.name = runtime.name + try: + await asyncio.wait_for( + client.connect_to_server(dict(runtime.config), runtime.name), + timeout=timeout, + ) + await asyncio.wait_for(client.list_tools_and_save(), timeout=timeout) + except asyncio.CancelledError: + await self.bridge._cleanup_mcp_client(client) + raise + except TimeoutError: + runtime.last_error = ( + f"Local MCP server '{runtime.name}' did not become ready within " + f"{timeout:g} seconds" + ) + runtime.errlogs = [runtime.last_error] + await self.bridge._cleanup_mcp_client(client) + except Exception as exc: + runtime.last_error = str(exc) + runtime.errlogs = [runtime.last_error] + await self.bridge._cleanup_mcp_client(client) + else: + runtime.client = client + runtime.running = True + runtime.tools = [ + str(tool.name) for tool in client.tools if getattr(tool, "name", None) + ] + runtime.tool_specs = self.bridge._build_local_mcp_tool_specs( + runtime.name, + client, + ) + runtime.errlogs = list(client.server_errlogs) + if client.process_pid is not None: + runtime.lease_path = self.bridge._write_local_mcp_lease( + plugin_id=plugin_id, + server_name=runtime.name, + pid=client.process_pid, + ) + finally: + runtime.ready_event.set() + runtime.connect_task = None + + async def initialize_local_mcp_servers(self, record: SdkPluginRecord) -> None: + tasks: list[asyncio.Task[None]] = [] + for runtime in record.local_mcp_servers.values(): + if not runtime.active: + runtime.ready_event.set() + continue + task = asyncio.create_task( + self.connect_local_mcp_server( + plugin_id=record.plugin_id, + runtime=runtime, + timeout=30.0, + ) + ) + runtime.connect_task = task + tasks.append(task) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + async def shutdown_local_mcp_runtime( + self, + runtime: _LocalMCPServerRuntime, + ) -> None: + connect_task = runtime.connect_task + runtime.connect_task = None + if connect_task is not None and not connect_task.done(): + connect_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await connect_task + self.bridge._remove_local_mcp_lease(runtime) + await self.bridge._cleanup_mcp_client(runtime.client) + runtime.client = None + runtime.running = False + runtime.tools = [] + runtime.tool_specs = [] + runtime.ready_event.clear() + + async def shutdown_local_mcp_servers(self, record: SdkPluginRecord) -> None: + for runtime in record.local_mcp_servers.values(): + await self.shutdown_local_mcp_runtime(runtime) + + async def enable_local_mcp_server( + self, + plugin_id: str, + name: str, + *, + timeout: float = 30.0, + ) -> dict[str, Any]: + runtime = self.bridge._local_mcp_record(plugin_id, name) + if runtime is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if runtime.active and runtime.running and runtime.connect_task is None: + return self.bridge._serialize_local_mcp_server(runtime) + if runtime.connect_task is not None and not runtime.connect_task.done(): + runtime.active = True + await runtime.connect_task + return self.bridge._serialize_local_mcp_server(runtime) + runtime.active = True + task = asyncio.create_task( + self.connect_local_mcp_server( + plugin_id=plugin_id, + runtime=runtime, + timeout=timeout, + ) + ) + runtime.connect_task = task + await task + return self.bridge._serialize_local_mcp_server(runtime) + + async def disable_local_mcp_server( + self, + plugin_id: str, + name: str, + ) -> dict[str, Any]: + runtime = self.bridge._local_mcp_record(plugin_id, name) + if runtime is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if not runtime.active and not runtime.running and runtime.connect_task is None: + return self.bridge._serialize_local_mcp_server(runtime) + runtime.active = False + await self.shutdown_local_mcp_runtime(runtime) + return self.bridge._serialize_local_mcp_server(runtime) + + async def wait_for_local_mcp_server( + self, + plugin_id: str, + name: str, + *, + timeout: float, + ) -> dict[str, Any]: + runtime = self.bridge._local_mcp_record(plugin_id, name) + if runtime is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + await asyncio.wait_for(runtime.ready_event.wait(), timeout=timeout) + if not runtime.running: + raise TimeoutError( + f"Local MCP server '{name}' did not become ready in time" + ) + return self.bridge._serialize_local_mcp_server(runtime) + + async def open_temporary_mcp_session( + self, + plugin_id: str, + *, + name: str, + config: dict[str, Any], + timeout: float, + ) -> tuple[str, list[str]]: + client = self.bridge._make_mcp_client() + client.name = name + try: + await asyncio.wait_for( + client.connect_to_server(dict(config), name), + timeout=timeout, + ) + await asyncio.wait_for(client.list_tools_and_save(), timeout=timeout) + except Exception: + await self.bridge._cleanup_mcp_client(client) + raise + session_id = f"{plugin_id}:{uuid.uuid4().hex}" + tools = [str(tool.name) for tool in client.tools if getattr(tool, "name", None)] + self.bridge._temporary_mcp_sessions[session_id] = _TemporaryMCPSessionRuntime( + plugin_id=plugin_id, + name=name, + client=client, + tools=tools, + ) + return session_id, tools + + async def close_temporary_mcp_session( + self, + plugin_id: str, + session_id: str, + ) -> None: + runtime = self.bridge._temporary_mcp_sessions.get(session_id) + if runtime is None or runtime.plugin_id != plugin_id: + return + self.bridge._temporary_mcp_sessions.pop(session_id, None) + await self.bridge._cleanup_mcp_client(runtime.client) + + async def close_temporary_mcp_sessions(self, plugin_id: str) -> None: + session_ids = [ + session_id + for session_id, runtime in self.bridge._temporary_mcp_sessions.items() + if runtime.plugin_id == plugin_id + ] + for session_id in session_ids: + await self.close_temporary_mcp_session(plugin_id, session_id) + + def get_temporary_mcp_session_tools( + self, + plugin_id: str, + session_id: str, + ) -> list[str]: + runtime = self.bridge._temporary_mcp_sessions.get(session_id) + if runtime is None or runtime.plugin_id != plugin_id: + raise AstrBotError.invalid_input("Unknown MCP session") + return list(runtime.tools) + + async def call_temporary_mcp_tool( + self, + plugin_id: str, + *, + session_id: str, + tool_name: str, + arguments: dict[str, Any], + ) -> dict[str, Any]: + runtime = self.bridge._temporary_mcp_sessions.get(session_id) + if runtime is None or runtime.plugin_id != plugin_id: + raise AstrBotError.invalid_input("Unknown MCP session") + result = await runtime.client.call_tool_with_reconnect( + tool_name=tool_name, + arguments=arguments, + read_timeout_seconds=timedelta(seconds=60), + ) + text = self.bridge._mcp_call_result_to_text(result) + return {"content": text, "is_error": bool(getattr(result, "isError", False))} + + async def execute_local_mcp_tool( + self, + plugin_id: str, + *, + server_name: str, + tool_name: str, + tool_args: dict[str, Any], + timeout_seconds: int = 60, + ) -> dict[str, Any]: + runtime = self.bridge._local_mcp_record(plugin_id, server_name) + if ( + runtime is None + or not runtime.active + or not runtime.running + or runtime.client is None + ): + return { + "content": f"Local MCP server unavailable: {server_name}", + "success": False, + } + if tool_name not in runtime.tools: + return { + "content": f"Local MCP tool not found: {server_name}.{tool_name}", + "success": False, + } + try: + result = await runtime.client.call_tool_with_reconnect( + tool_name=tool_name, + arguments=tool_args, + read_timeout_seconds=timedelta(seconds=timeout_seconds), + ) + except Exception as exc: + return {"content": f"Tool execution failed: {exc}", "success": False} + text = self.bridge._mcp_call_result_to_text(result) + return { + "content": text, + "success": not bool(getattr(result, "isError", False)), + } diff --git a/astrbot/core/sdk_bridge/plugin_bridge.py b/astrbot/core/sdk_bridge/plugin_bridge.py new file mode 100644 index 0000000000..f75f2d30e9 --- /dev/null +++ b/astrbot/core/sdk_bridge/plugin_bridge.py @@ -0,0 +1,2923 @@ +from __future__ import annotations + +import asyncio +import contextlib +import json +import os +import re +import signal +import subprocess +import uuid +from collections.abc import Awaitable, Callable +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, cast + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.llm.agents import AgentSpec +from astrbot_sdk.llm.entities import LLMToolSpec +from astrbot_sdk.protocol.descriptors import ( + CommandTrigger, + CompositeFilterSpec, + EventTrigger, + HandlerDescriptor, + MessageTrigger, + PlatformFilterSpec, + ScheduleTrigger, +) +from astrbot_sdk.runtime._command_matching import command_root_name +from astrbot_sdk.runtime.loader import ( + PluginDiscoveryIssue, + PluginEnvironmentManager, + PluginSpec, + discover_plugins, + load_plugin_config, + load_plugin_config_schema, + save_plugin_config, +) +from astrbot_sdk.runtime.supervisor import WorkerSession + +from astrbot.core import astrbot_config, logger +from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.command_compatibility import ( + CommandRegistration, + CrossSystemCommandConflict, + build_cross_system_conflicts, + collect_legacy_command_registrations, + collect_sdk_command_registrations, + match_legacy_command_registrations, +) +from astrbot.core.message.message_event_result import MessageChain, MessageEventResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import LLMResponse as CoreLLMResponse +from astrbot.core.provider.entities import ProviderRequest as CoreProviderRequest +from astrbot.core.skills.skill_manager import ( + SkillManager, +) +from astrbot.core.utils.astrbot_path import ( + get_astrbot_data_path, + get_astrbot_plugin_data_path, +) + +from .capability_bridge import CoreCapabilityBridge +from .dispatch_engine import SdkDispatchEngine +from .event_payload import ( + InboundEventSnapshot, +) +from .lifecycle_manager import SdkPluginLifecycleManager +from .mcp_manager import SdkMcpManager +from .registry_manager import SdkRegistryManager +from .request_runtime import SdkRequestRuntime +from .runtime_store import ( + SdkDispatchResult, + SdkDynamicCommandRoute, + SdkHandlerRef, + SdkHttpRoute, + SdkPluginRecord, + SdkRuntimeStore, + _LocalMCPServerRuntime, + _RequestContext, + _RequestOverlayState, +) +from .trigger_converter import TriggerConverter, TriggerMatch + +SDK_STATE_ENABLED = "enabled" +SDK_STATE_DISABLED = "disabled" +SDK_STATE_RELOADING = "reloading" +SDK_STATE_FAILED = "failed" +SDK_STATE_UNSUPPORTED_PARTIAL = "unsupported_partial" + +SKIP_LEGACY_STOPPED = "legacy_stopped" +SKIP_LEGACY_REPLIED = "legacy_replied" +SKIP_SDK_RELOADING = "sdk_reloading" +SKIP_NO_MATCH = "no_match" +SKIP_WORKER_FAILED = "worker_failed" +OVERLAY_TIMEOUT_SECONDS = 300 +SDK_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") +SUPPORTED_SYSTEM_EVENTS = { + "astrbot_loaded", + "platform_loaded", + "after_message_sent", + "waiting_llm_request", + "agent_begin", + "llm_request", + "llm_response", + "agent_done", + "streaming_delta", + "decorating_result", + "calling_func_tool", + "llm_tool_start", + "llm_tool_end", + "plugin_error", + "plugin_loaded", + "plugin_unloaded", +} +COMMAND_OVERRIDE_WARNING_TYPE = "legacy_sdk_command_override" + + +class SdkPluginBridge: + SDK_STATE_ENABLED = SDK_STATE_ENABLED + SDK_STATE_DISABLED = SDK_STATE_DISABLED + SDK_STATE_RELOADING = SDK_STATE_RELOADING + SDK_STATE_FAILED = SDK_STATE_FAILED + SDK_STATE_UNSUPPORTED_PARTIAL = SDK_STATE_UNSUPPORTED_PARTIAL + SKIP_LEGACY_STOPPED = SKIP_LEGACY_STOPPED + SKIP_LEGACY_REPLIED = SKIP_LEGACY_REPLIED + SKIP_SDK_RELOADING = SKIP_SDK_RELOADING + SKIP_NO_MATCH = SKIP_NO_MATCH + SKIP_WORKER_FAILED = SKIP_WORKER_FAILED + COMMAND_OVERRIDE_WARNING_TYPE = COMMAND_OVERRIDE_WARNING_TYPE + SDK_SKILL_NAME_RE = SDK_SKILL_NAME_RE + + def __init__(self, star_context) -> None: + self.star_context = star_context + self.logger = logger + self.plugins_dir = Path(get_astrbot_data_path()) / "sdk_plugins" + self.state_path = Path(get_astrbot_data_path()) / "sdk_plugins_state.json" + self.plugins_dir.mkdir(parents=True, exist_ok=True) + self._started = False + self._stopping = False + self._state_overrides = self._load_state_overrides() + self.env_manager = PluginEnvironmentManager(Path(__file__).resolve().parents[3]) + self._store = SdkRuntimeStore() + self.capability_bridge = CoreCapabilityBridge( + star_context=star_context, + plugin_bridge=self, + ) + self._records = self._store.records + self._request_contexts = self._store.request_contexts + self._request_id_to_token = self._store.request_id_to_token + self._request_plugin_ids = self._store.request_plugin_ids + self._request_overlays = self._store.request_overlays + self._plugin_requests = self._store.plugin_requests + self._http_routes = self._store.http_routes + self._session_waiters = self._store.session_waiters + self._schedule_job_ids = self._store.schedule_job_ids + self._discovery_issues = self._store.discovery_issues + self._temporary_mcp_sessions = self._store.temporary_mcp_sessions + self.request_runtime = SdkRequestRuntime( + bridge=self, + store=self._store, + overlay_timeout_seconds=OVERLAY_TIMEOUT_SECONDS, + ) + self.dispatch_engine = SdkDispatchEngine(bridge=self) + self.lifecycle = SdkPluginLifecycleManager(bridge=self) + self.mcp = SdkMcpManager(bridge=self) + self.registry = SdkRegistryManager(bridge=self) + + async def start(self) -> None: + await self.lifecycle.start() + + async def stop(self) -> None: + await self.lifecycle.stop() + + async def reload_all(self, *, reset_restart_budget: bool = False) -> None: + await self.lifecycle.reload_all(reset_restart_budget=reset_restart_budget) + + async def reload_plugin(self, plugin_id: str) -> None: + await self.lifecycle.reload_plugin(plugin_id) + + async def turn_off_plugin(self, plugin_id: str) -> None: + await self.lifecycle.turn_off_plugin(plugin_id) + + async def turn_on_plugin(self, plugin_id: str) -> None: + await self.lifecycle.turn_on_plugin(plugin_id) + + def _snapshot_records(self) -> list[SdkPluginRecord]: + with self._store.mutation_lock: + return list(self._records.values()) + + def _snapshot_records_sorted(self) -> list[SdkPluginRecord]: + with self._store.mutation_lock: + return sorted(self._records.values(), key=lambda item: item.load_order) + + def _snapshot_http_routes(self, plugin_id: str | None = None) -> list[SdkHttpRoute]: + with self._store.mutation_lock: + if plugin_id is None: + routes: list[SdkHttpRoute] = [] + for entries in self._http_routes.values(): + routes.extend(list(entries)) + return routes + return list(self._http_routes.get(plugin_id, [])) + + def list_plugins(self) -> list[dict[str, Any]]: + return self.registry.list_plugins() + + def get_plugin_metadata(self, plugin_id: str) -> dict[str, Any] | None: + return self.registry.get_plugin_metadata(plugin_id) + + def list_plugin_metadata(self) -> list[dict[str, Any]]: + return self.registry.list_plugin_metadata() + + def get_plugin_config(self, plugin_id: str) -> dict[str, Any] | None: + record = self._records.get(plugin_id) + if record is None: + return None + return dict(record.config) + + def get_plugin_config_schema(self, plugin_id: str) -> dict[str, Any] | None: + record = self._records.get(plugin_id) + if record is None: + return None + return dict(record.config_schema) + + def save_plugin_config( + self, + plugin_id: str, + payload: dict[str, Any], + ) -> dict[str, Any]: + record = self._records.get(plugin_id) + if record is None: + raise ValueError(f"SDK plugin not found: {plugin_id}") + normalized = save_plugin_config( + record.plugin, + payload, + schema=record.config_schema, + ) + record.config = dict(normalized) + return dict(record.config) + + def get_registered_llm_tools(self, plugin_id: str) -> list[LLMToolSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [item.model_copy(deep=True) for item in record.llm_tools.values()] + + def get_active_llm_tools(self, plugin_id: str) -> list[LLMToolSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [ + item.model_copy(deep=True) + for name, item in record.llm_tools.items() + if name in record.active_llm_tools + ] + + def get_llm_tool(self, plugin_id: str, name: str) -> LLMToolSpec | None: + record = self._records.get(plugin_id) + if record is None: + return None + spec = record.llm_tools.get(name) + if spec is None: + return None + return spec.model_copy(deep=True) + + def add_llm_tools(self, plugin_id: str, tools: list[LLMToolSpec]) -> list[str]: + record = self._records.get(plugin_id) + if record is None: + return [] + names: list[str] = [] + for spec in tools: + record.llm_tools[spec.name] = spec.model_copy(deep=True) + if spec.active: + record.active_llm_tools.add(spec.name) + else: + record.active_llm_tools.discard(spec.name) + names.append(spec.name) + return names + + def remove_llm_tool(self, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + return False + removed = record.llm_tools.pop(name, None) is not None + record.active_llm_tools.discard(name) + return removed + + def activate_llm_tool(self, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + return False + spec = record.llm_tools.get(name) + if spec is None: + return False + spec.active = True + record.active_llm_tools.add(name) + return True + + def deactivate_llm_tool(self, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + return False + spec = record.llm_tools.get(name) + if spec is None: + return False + spec.active = False + record.active_llm_tools.discard(name) + return True + + def _local_mcp_record( + self, plugin_id: str, name: str + ) -> _LocalMCPServerRuntime | None: + record = self._records.get(plugin_id) + if record is None: + return None + return record.local_mcp_servers.get(name) + + @staticmethod + def _serialize_local_mcp_server( + runtime: _LocalMCPServerRuntime, + ) -> dict[str, Any]: + errlogs = list(runtime.errlogs) + if runtime.client is not None: + errlogs.extend(str(item) for item in runtime.client.server_errlogs) + return { + "name": runtime.name, + "scope": "local", + "active": runtime.active, + "running": runtime.running, + "config": dict(runtime.config), + "tools": list(runtime.tools), + "errlogs": errlogs, + "last_error": runtime.last_error, + } + + def get_local_mcp_server( + self, + plugin_id: str, + name: str, + ) -> dict[str, Any] | None: + return self.mcp.get_local_mcp_server(plugin_id, name) + + def list_local_mcp_servers(self, plugin_id: str) -> list[dict[str, Any]]: + return self.mcp.list_local_mcp_servers(plugin_id) + + def get_request_tool_specs(self, plugin_id: str) -> list[LLMToolSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + specs: dict[str, LLMToolSpec] = { + item.name: item.model_copy(deep=True) + for name, item in record.llm_tools.items() + if name in record.active_llm_tools + } + for runtime in record.local_mcp_servers.values(): + if not runtime.active or not runtime.running: + continue + for spec in runtime.tool_specs: + specs.setdefault(spec.name, spec.model_copy(deep=True)) + return list(specs.values()) + + def get_registered_agents(self, plugin_id: str) -> list[AgentSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [item.model_copy(deep=True) for item in record.agents.values()] + + def get_registered_agent(self, plugin_id: str, name: str) -> AgentSpec | None: + record = self._records.get(plugin_id) + if record is None: + return None + spec = record.agents.get(name) + if spec is None: + return None + return spec.model_copy(deep=True) + + def register_dynamic_command_route( + self, + *, + plugin_id: str, + command_name: str, + handler_full_name: str, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ) -> None: + record = self._records.get(plugin_id) + if record is None: + raise AstrBotError.invalid_input(f"Unknown SDK plugin: {plugin_id}") + if isinstance(priority, bool) or not isinstance(priority, int): + raise AstrBotError.invalid_input("priority must be an integer") + command_text = str(command_name).strip() + if not command_text: + raise AstrBotError.invalid_input("command_name must not be empty") + handler_text = str(handler_full_name).strip() + if not handler_text: + raise AstrBotError.invalid_input("handler_full_name must not be empty") + if not handler_text.startswith(f"{plugin_id}:"): + raise AstrBotError.invalid_input( + "handler_full_name must belong to the caller plugin" + ) + if self._find_handler_ref(record, handler_text) is None: + raise AstrBotError.invalid_input( + f"Unknown handler_full_name for plugin '{plugin_id}': {handler_text}" + ) + existing_order = next( + ( + route.declaration_order + for route in record.dynamic_command_routes + if route.command_name == command_text + and route.use_regex is bool(use_regex) + ), + len(record.dynamic_command_routes), + ) + updated = [ + route + for route in record.dynamic_command_routes + if not ( + route.command_name == command_text + and route.use_regex is bool(use_regex) + ) + ] + updated.append( + SdkDynamicCommandRoute( + command_name=command_text, + handler_full_name=handler_text, + desc=str(desc), + priority=priority, + use_regex=bool(use_regex), + declaration_order=existing_order, + ) + ) + updated.sort(key=lambda item: item.declaration_order) + record.dynamic_command_routes = updated + + def register_skill( + self, + *, + plugin_id: str, + name: str, + path: str, + description: str = "", + ) -> dict[str, str]: + return self.registry.register_skill( + plugin_id=plugin_id, + name=name, + path=path, + description=description, + ) + + def unregister_skill(self, *, plugin_id: str, name: str) -> bool: + return self.registry.unregister_skill(plugin_id=plugin_id, name=name) + + def list_registered_skills(self, plugin_id: str) -> list[dict[str, str]]: + return self.registry.list_registered_skills(plugin_id) + + def _publish_plugin_skills(self, plugin_id: str) -> None: + self.registry.publish_plugin_skills_impl(plugin_id) + + async def _clear_plugin_skills( + self, + *, + plugin_id: str, + record: SdkPluginRecord | Any | None, + reason: str, + ) -> None: + await self.registry.clear_plugin_skills( + plugin_id=plugin_id, + record=record, + reason=reason, + ) + + def register_http_api( + self, + *, + plugin_id: str, + route: str, + methods: list[str], + handler_capability: str, + description: str, + ) -> None: + self.registry.register_http_api( + plugin_id=plugin_id, + route=route, + methods=methods, + handler_capability=handler_capability, + description=description, + ) + + def unregister_http_api( + self, + *, + plugin_id: str, + route: str, + methods: list[str], + ) -> None: + self.registry.unregister_http_api( + plugin_id=plugin_id, + route=route, + methods=methods, + ) + + def list_http_apis(self, plugin_id: str) -> list[dict[str, Any]]: + return self.registry.list_http_apis(plugin_id) + + def _public_http_path(self, route: str) -> str: + normalized_route = self._normalize_http_route(route) + return f"/api/plug{normalized_route}" + + def _public_page_path(self, route: str) -> str: + normalized_route = self._normalize_http_route(route) + return f"/plug{normalized_route}" + + @staticmethod + def _parse_env_bool(value: str | None, default: bool) -> bool: + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + def _dashboard_public_base_url(self) -> str: + return self.registry.dashboard_public_base_url() + + def _public_http_url(self, route: str) -> str: + return f"{self._dashboard_public_base_url()}{self._public_http_path(route)}" + + def _public_page_url(self, route: str) -> str: + return f"{self._dashboard_public_base_url()}{self._public_page_path(route)}" + + def _plugin_entry_route(self, plugin_id: str) -> str | None: + plugin_root = f"/{plugin_id}" + for entry in self._http_routes.get(plugin_id, []): + if entry.route == plugin_root: + return entry.route + for entry in self._http_routes.get(plugin_id, []): + if "/api/" not in entry.route: + return entry.route + return None + + async def dispatch_http_request( + self, + route: str, + method: str, + ) -> dict[str, Any] | None: + return await self.registry.dispatch_http_request(route, method) + + def register_session_waiter(self, *, plugin_id: str, session_key: str) -> None: + if not session_key: + raise AstrBotError.invalid_input( + "session waiter registration requires session_key" + ) + self._session_waiters.setdefault(plugin_id, set()).add(session_key) + + def unregister_session_waiter(self, *, plugin_id: str, session_key: str) -> None: + plugin_waiters = self._session_waiters.get(plugin_id) + if plugin_waiters is None: + return + plugin_waiters.discard(session_key) + if not plugin_waiters: + self._session_waiters.pop(plugin_id, None) + + async def dispatch_message(self, event: AstrMessageEvent) -> SdkDispatchResult: + return await self.dispatch_engine.dispatch_message(event) + + def resolve_request_plugin_id(self, request_id: str) -> str: + return self.request_runtime.resolve_request_plugin_id(request_id) + + def resolve_request_session(self, request_id: str) -> _RequestContext | None: + return self.request_runtime.resolve_request_session(request_id) + + def get_request_context_by_token( + self, dispatch_token: str + ) -> _RequestContext | None: + return self.request_runtime.get_request_context_by_token(dispatch_token) + + def _bind_dispatch_token( + self, event: AstrMessageEvent, dispatch_token: str + ) -> None: + self.request_runtime.bind_dispatch_token(event, dispatch_token) + + def _get_dispatch_token(self, event: AstrMessageEvent) -> str | None: + return self.request_runtime.get_dispatch_token(event) + + def _schedule_overlay_cleanup( + self, dispatch_token: str + ) -> asyncio.Task[None] | None: + return self.request_runtime.schedule_overlay_cleanup(dispatch_token) + + def _ensure_request_overlay( + self, + dispatch_token: str, + *, + should_call_llm: bool, + ) -> _RequestOverlayState: + return self.request_runtime.ensure_request_overlay( + dispatch_token, + should_call_llm=should_call_llm, + ) + + def _track_request_scope( + self, + *, + dispatch_token: str, + request_id: str, + plugin_id: str, + ) -> None: + self.request_runtime.track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=plugin_id, + ) + + def _close_request_overlay(self, dispatch_token: str) -> None: + self.request_runtime.close_request_overlay(dispatch_token) + + def close_request_overlay_for_event(self, event: AstrMessageEvent) -> None: + self.request_runtime.close_request_overlay_for_event(event) + + def get_request_overlay_by_token( + self, dispatch_token: str + ) -> _RequestOverlayState | None: + return self.request_runtime.get_request_overlay_by_token(dispatch_token) + + def get_request_overlay_by_request_id( + self, request_id: str + ) -> _RequestOverlayState | None: + return self.request_runtime.get_request_overlay_by_request_id(request_id) + + def request_llm_for_request(self, request_id: str) -> bool: + return self.request_runtime.request_llm_for_request(request_id) + + def get_effective_should_call_llm(self, event: AstrMessageEvent) -> bool: + return self.request_runtime.get_effective_should_call_llm(event) + + def get_should_call_llm_for_request(self, request_id: str) -> bool | None: + return self.request_runtime.get_should_call_llm_for_request(request_id) + + def _set_overlay_stop_state( + self, + overlay: _RequestOverlayState, + *, + stopped: bool, + ) -> None: + self.request_runtime.set_overlay_stop_state(overlay, stopped=stopped) + + def _set_result_from_object( + self, + overlay: _RequestOverlayState, + result: MessageEventResult | None, + ) -> None: + self.request_runtime.set_result_from_object(overlay, result) + + def _bind_result_object( + self, + overlay: _RequestOverlayState, + result: MessageEventResult | None, + ) -> None: + self.request_runtime.bind_result_object(overlay, result) + + def _set_result_payload_on_overlay( + self, + overlay: _RequestOverlayState, + result_payload: dict[str, Any] | None, + ) -> None: + self.request_runtime.set_result_payload_on_overlay(overlay, result_payload) + + def _sync_overlay_payload_from_result_object( + self, + overlay: _RequestOverlayState, + ) -> None: + self.request_runtime.sync_overlay_payload_from_result_object(overlay) + + def _get_effective_result_for_token( + self, + dispatch_token: str, + ) -> MessageEventResult | None: + return self.request_runtime.get_effective_result_for_token(dispatch_token) + + def _set_result_for_dispatch_token( + self, + dispatch_token: str, + result: MessageEventResult | None, + ) -> None: + self.request_runtime.set_result_for_dispatch_token(dispatch_token, result) + + def _clear_result_for_dispatch_token(self, dispatch_token: str) -> None: + self.request_runtime.clear_result_for_dispatch_token(dispatch_token) + + def _stop_event_for_dispatch_token(self, dispatch_token: str) -> None: + self.request_runtime.stop_event_for_dispatch_token(dispatch_token) + + def _continue_event_for_dispatch_token(self, dispatch_token: str) -> None: + self.request_runtime.continue_event_for_dispatch_token(dispatch_token) + + def _is_stopped_for_dispatch_token(self, dispatch_token: str) -> bool: + return self.request_runtime.is_stopped_for_dispatch_token(dispatch_token) + + def set_result_for_request( + self, + request_id: str, + result_payload: dict[str, Any] | None, + ) -> bool: + return self.request_runtime.set_result_for_request(request_id, result_payload) + + def clear_result_for_request(self, request_id: str) -> bool: + return self.request_runtime.clear_result_for_request(request_id) + + def get_result_payload_for_request(self, request_id: str) -> dict[str, Any] | None: + return self.request_runtime.get_result_payload_for_request(request_id) + + def set_handler_whitelist_for_request( + self, + request_id: str, + plugin_names: set[str] | None, + ) -> bool: + return self.request_runtime.set_handler_whitelist_for_request( + request_id, + plugin_names, + ) + + def get_handler_whitelist_for_request(self, request_id: str) -> set[str] | None: + return self.request_runtime.get_handler_whitelist_for_request(request_id) + + def _get_handler_whitelist_for_event( + self, event: AstrMessageEvent + ) -> set[str] | None: + return self.request_runtime.get_handler_whitelist_for_event(event) + + @staticmethod + def _build_core_message_chain_from_payload( + chain_payload: list[dict[str, Any]], + ) -> MessageChain: + return SdkRequestRuntime.build_core_message_chain_from_payload(chain_payload) + + @classmethod + def _build_core_result_from_chain_payload( + cls, + chain_payload: list[dict[str, Any]], + ) -> MessageEventResult: + return SdkRequestRuntime.build_core_result_from_chain_payload(chain_payload) + + @staticmethod + def _legacy_result_to_sdk_payload( + result: MessageEventResult | None, + ) -> dict[str, Any] | None: + return SdkRequestRuntime.legacy_result_to_sdk_payload(result) + + @staticmethod + def _components_to_sdk_payload( + components: list[Any] | tuple[Any, ...] | None, + ) -> list[dict[str, Any]]: + return SdkRequestRuntime.components_to_sdk_payload(components) + + def _persist_sdk_local_extras_from_handler( + self, + overlay: _RequestOverlayState, + payload: Any, + *, + plugin_id: str, + handler_id: str, + ) -> None: + self.request_runtime.persist_sdk_local_extras_from_handler( + overlay, + payload, + plugin_id=plugin_id, + handler_id=handler_id, + ) + + @staticmethod + def _sanitize_host_extras(event: AstrMessageEvent) -> dict[str, Any]: + return SdkRequestRuntime.sanitize_host_extras(event) + + @staticmethod + def _set_sdk_origin_plugin_id( + event: AstrMessageEvent, + plugin_id: str, + ) -> None: + SdkRequestRuntime.set_sdk_origin_plugin_id(event, plugin_id) + + def _get_or_build_inbound_snapshot( + self, + event: AstrMessageEvent, + overlay: _RequestOverlayState | None, + ) -> InboundEventSnapshot: + return self.request_runtime.get_or_build_inbound_snapshot(event, overlay) + + def _build_sdk_event_payload( + self, + event: AstrMessageEvent, + *, + dispatch_token: str, + plugin_id: str, + request_id: str, + overlay: _RequestOverlayState | None, + raw_updates: dict[str, Any] | None = None, + field_updates: dict[str, Any] | None = None, + ) -> dict[str, Any]: + return self.request_runtime.build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=plugin_id, + request_id=request_id, + overlay=overlay, + raw_updates=raw_updates, + field_updates=field_updates, + ) + + def build_sdk_event_payload( + self, + event: AstrMessageEvent, + *, + dispatch_token: str, + plugin_id: str, + request_id: str, + overlay: _RequestOverlayState | None, + raw_updates: dict[str, Any] | None = None, + field_updates: dict[str, Any] | None = None, + ) -> dict[str, Any]: + return self.request_runtime.build_sdk_event_payload( + event, + dispatch_token=dispatch_token, + plugin_id=plugin_id, + request_id=request_id, + overlay=overlay, + raw_updates=raw_updates, + field_updates=field_updates, + ) + + @staticmethod + def _core_provider_request_to_sdk_payload( + request: CoreProviderRequest, + ) -> dict[str, Any]: + return SdkRequestRuntime.core_provider_request_to_sdk_payload(request) + + @staticmethod + def _apply_sdk_provider_request_payload( + request: CoreProviderRequest, + payload: dict[str, Any], + ) -> None: + SdkRequestRuntime.apply_sdk_provider_request_payload(request, payload) + + @staticmethod + def _core_llm_response_to_sdk_payload( + response: CoreLLMResponse, + ) -> dict[str, Any]: + return SdkRequestRuntime.core_llm_response_to_sdk_payload(response) + + @classmethod + def _apply_sdk_result_payload( + cls, + result: MessageEventResult, + payload: dict[str, Any], + ) -> MessageEventResult: + return SdkRequestRuntime.apply_sdk_result_payload(result, payload) + + def get_effective_result( + self, event: AstrMessageEvent + ) -> MessageEventResult | None: + return self.request_runtime.get_effective_result(event) + + def before_platform_send(self, dispatch_token: str) -> None: + self.request_runtime.before_platform_send(dispatch_token) + + def mark_platform_send(self, dispatch_token: str) -> str: + return self.request_runtime.mark_platform_send(dispatch_token) + + def get_or_bind_dispatch_token(self, event: AstrMessageEvent) -> str: + return self.request_runtime.get_or_bind_dispatch_token(event) + + def get_plugin_session(self, plugin_id: str) -> WorkerSession | None: + record = self._records.get(plugin_id) + return None if record is None else record.session + + @staticmethod + def _legacy_has_replied(event: AstrMessageEvent) -> bool: + # 按优先级尝试新版方法 → 兼容方法 → 直接读内部字段, + # 确保 AstrMessageEvent 的 API 演进不会破坏旧版 bridge 逻辑 + has_send = getattr(event, "has_send_operation", None) + if callable(has_send): + return bool(has_send()) + has_send = getattr(event, "get_send_operation_state", None) + if callable(has_send): + return bool(has_send()) + return bool(getattr(event, "_has_send_oper", False)) + + def _match_handlers(self, event: AstrMessageEvent) -> list[TriggerMatch]: + matches: list[TriggerMatch] = [] + normalized_platform = self._normalize_platform_name(event.get_platform_name()) + for record in self._records.values(): + if record.state in {SDK_STATE_DISABLED, SDK_STATE_FAILED}: + continue + if not self._record_supports_platform(record, normalized_platform): + continue + for handler in record.handlers: + match = TriggerConverter.match_handler( + plugin_id=record.plugin_id, + descriptor=handler.descriptor, + event=event, + load_order=record.load_order, + declaration_order=handler.declaration_order, + ) + if match is not None: + matches.append(match) + dynamic_base_order = len(record.handlers) + for route in getattr(record, "dynamic_command_routes", []): + match = self._match_dynamic_command_route( + record=record, + route=route, + event=event, + declaration_order=dynamic_base_order + route.declaration_order, + ) + if match is not None: + matches.append(match) + matches.sort(key=TriggerConverter.sort_key) + return matches + + def list_cross_system_command_conflicts( + self, + ) -> list[CrossSystemCommandConflict]: + return build_cross_system_conflicts( + collect_legacy_command_registrations(), + self._collect_sdk_command_registrations(), + ) + + def has_active_sdk_command_handlers(self) -> bool: + if not self._records: + return False + for record in self._snapshot_records(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if any( + isinstance(handler.descriptor.trigger, CommandTrigger) + for handler in record.handlers + ): + return True + if any( + not route.use_regex + for route in getattr(record, "dynamic_command_routes", []) + ): + return True + return False + + def refresh_command_compatibility_issues(self) -> None: + conflicts = self.list_cross_system_command_conflicts() + conflict_map: dict[str, list[CrossSystemCommandConflict]] = {} + for conflict in conflicts: + conflict_map.setdefault(conflict.sdk.plugin_name, []).append(conflict) + + for record in self._snapshot_records(): + record.issues = [ + issue + for issue in record.issues + if issue.get("warning_type") != self.COMMAND_OVERRIDE_WARNING_TYPE + ] + record_conflicts = conflict_map.get(record.plugin_id, []) + if record_conflicts: + for issue in self._build_command_compatibility_issues( + record.plugin_id, + record_conflicts, + ): + record.issues.append(issue) + logger.warning( + "SDK plugin command overrides legacy handlers: plugin=%s commands=%s", + record.plugin_id, + ", ".join( + sorted({conflict.command_name for conflict in record_conflicts}) + ), + ) + + def detect_legacy_command_conflict( + self, + event: AstrMessageEvent, + legacy_handlers: list[Any], + ) -> CrossSystemCommandConflict | None: + if not legacy_handlers or not self.has_active_sdk_command_handlers(): + return None + sdk_matches = self._match_handlers(event) + if not sdk_matches: + return None + legacy_registrations = match_legacy_command_registrations( + legacy_handlers, + event.get_message_str(), + ) + if not legacy_registrations: + return None + sdk_registrations = self._matched_sdk_command_registrations(sdk_matches) + if not sdk_registrations: + return None + conflicts = build_cross_system_conflicts( + legacy_registrations, + sdk_registrations, + ) + if not conflicts: + return None + conflicts.sort( + key=lambda item: ( + item.command_name, + item.legacy.plugin_name, + item.sdk.plugin_name, + item.sdk.handler_full_name, + ) + ) + return conflicts[0] + + def format_legacy_command_conflict_message( + self, + conflict: CrossSystemCommandConflict, + ) -> str: + legacy_name = conflict.legacy.plugin_display_name or conflict.legacy.plugin_name + sdk_name = conflict.sdk.plugin_display_name or conflict.sdk.plugin_name + if conflict.legacy.command_name == conflict.sdk.command_name: + command_detail = f"`/{conflict.legacy.command_name}`" + else: + command_detail = ( + f"`/{conflict.legacy.command_name}` 与 `/{conflict.sdk.command_name}`" + ) + return ( + "检测到旧插件与 SDK 插件存在命令冲突,当前不兼容:" + f"{command_detail} 分别来自 {legacy_name} 和 {sdk_name}。" + "请停用、卸载或重命名其中一个插件后再使用。" + ) + + def _collect_sdk_command_registrations(self) -> list[Any]: + registrations: list[Any] = [] + for record in self._snapshot_records_sorted(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + registrations.extend(self._sdk_record_command_registrations(record)) + return registrations + + def _sdk_record_command_registrations(self, record: SdkPluginRecord) -> list[Any]: + registrations: list[Any] = [] + plugin_display_name = str( + record.plugin.manifest_data.get("display_name") or record.plugin_id + ) + for handler in record.handlers: + registrations.extend( + collect_sdk_command_registrations( + plugin_name=record.plugin_id, + plugin_display_name=plugin_display_name, + handler_full_name=handler.descriptor.id, + descriptor=handler.descriptor, + ) + ) + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + registrations.extend( + collect_sdk_command_registrations( + plugin_name=record.plugin_id, + plugin_display_name=plugin_display_name, + handler_full_name=descriptor.id, + descriptor=descriptor, + ) + ) + return registrations + + def _matched_sdk_command_registrations( + self, + matches: list[TriggerMatch], + ) -> list[CommandRegistration]: + registrations: list[CommandRegistration] = [] + for match in matches: + if not match.matched_command_name: + continue + record = self._records.get(match.plugin_id) + if record is None: + continue + descriptor = self._descriptor_from_match(record, match) + if descriptor is None: + continue + registrations.append( + CommandRegistration( + runtime_kind="sdk", + plugin_name=record.plugin_id, + plugin_display_name=str( + record.plugin.manifest_data.get("display_name") + or record.plugin_id + ), + handler_full_name=descriptor.id, + command_name=match.matched_command_name, + ) + ) + return registrations + + def _descriptor_from_match( + self, + record: SdkPluginRecord, + match: TriggerMatch, + ) -> HandlerDescriptor | None: + for handler in record.handlers: + if ( + handler.descriptor.id == match.handler_id + and handler.declaration_order == match.declaration_order + ): + return handler.descriptor + + dynamic_order = match.declaration_order - len(record.handlers) + if dynamic_order < 0: + return None + for route in getattr(record, "dynamic_command_routes", []): + if route.declaration_order != dynamic_order: + continue + return self._build_dynamic_route_descriptor(record, route) + return None + + def _build_command_compatibility_issues( + self, + plugin_id: str, + conflicts: list[CrossSystemCommandConflict], + ) -> list[dict[str, Any]]: + issues: list[dict[str, Any]] = [] + for conflict in conflicts: + legacy_name = ( + conflict.legacy.plugin_display_name or conflict.legacy.plugin_name + ) + if conflict.legacy.command_name == conflict.sdk.command_name: + conflict_detail = f"Command '/{conflict.legacy.command_name}'" + else: + conflict_detail = ( + f"Commands '/{conflict.legacy.command_name}' and " + f"'/{conflict.sdk.command_name}'" + ) + issues.append( + { + "severity": "warning", + "phase": "compatibility", + "plugin_id": plugin_id, + "message": "SDK plugin command overrides a legacy plugin command", + "details": ( + f"{conflict_detail} are registered by both systems. " + f"The SDK plugin overrides legacy plugin '{legacy_name}' at runtime." + ), + "warning_type": self.COMMAND_OVERRIDE_WARNING_TYPE, + "command_name": conflict.command_name, + "legacy_command_name": conflict.legacy.command_name, + "sdk_command_name": conflict.sdk.command_name, + "legacy_plugin_name": conflict.legacy.plugin_name, + "legacy_plugin_display_name": conflict.legacy.plugin_display_name, + "legacy_handler_full_name": conflict.legacy.handler_full_name, + "sdk_handler_full_name": conflict.sdk.handler_full_name, + } + ) + return issues + + @staticmethod + def _descriptor_root_candidates(descriptor: HandlerDescriptor) -> list[str]: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return [] + candidates: list[str] = [] + route = descriptor.command_route + if route is not None and route.group_path: + root_name = str(route.group_path[0]).strip() + if root_name: + candidates.append(root_name) + for name in [trigger.command, *trigger.aliases]: + normalized = str(name).strip() + if " " not in normalized: + continue + root_name = normalized.split()[0].strip() + if root_name: + candidates.append(root_name) + return list(dict.fromkeys(candidates)) + + @classmethod + def _descriptor_help_entry( + cls, + descriptor: HandlerDescriptor, + ) -> tuple[str, str | None] | None: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return None + route = descriptor.command_route + display_command = ( + str(route.display_command).strip() + if route is not None and str(route.display_command).strip() + else str(trigger.command).strip() + ) + if not display_command: + return None + return display_command, cls._descriptor_description(descriptor) + + def _resolve_group_root_fallback( + self, + event: AstrMessageEvent, + ) -> dict[str, str] | None: + root_name = command_root_name(event.get_message_str()) + if not root_name: + return None + normalized_platform = self._normalize_platform_name(event.get_platform_name()) + for record in self._snapshot_records_sorted(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if not self._record_supports_platform(record, normalized_platform): + continue + help_text = self._build_group_root_help(record, event, root_name) + if help_text is None: + continue + return {"plugin_id": record.plugin_id, "help_text": help_text} + return None + + def _resolve_command_permission_denied( + self, + event: AstrMessageEvent, + ) -> dict[str, str] | None: + text = event.get_message_str().strip() + if not text: + return None + normalized_platform = self._normalize_platform_name(event.get_platform_name()) + for record in sorted(self._records.values(), key=lambda item: item.load_order): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if not self._record_supports_platform(record, normalized_platform): + continue + for handler in record.handlers: + descriptor = handler.descriptor + if not self._descriptor_requires_admin(descriptor): + continue + if not TriggerConverter._match_filters(descriptor, event): + continue + if not self._descriptor_matches_command_text(descriptor, text): + continue + help_entry = self._descriptor_help_entry(descriptor) + display_command = ( + help_entry[0] + if help_entry is not None + else str(getattr(descriptor.trigger, "command", "")).strip() + ) + if not display_command: + continue + return { + "plugin_id": record.plugin_id, + "message": (f"权限不足:`/{display_command}` 需要管理员权限。"), + } + return None + + def _has_command_trigger_match(self, matches: list[TriggerMatch]) -> bool: + for match in matches: + record = self._records.get(match.plugin_id) + if record is None: + continue + handler_ref = self._find_handler_ref(record, match.handler_id) + if handler_ref is not None and isinstance( + handler_ref.descriptor.trigger, CommandTrigger + ): + return True + return False + + def _build_group_root_help( + self, + record: SdkPluginRecord, + event: AstrMessageEvent, + root_name: str, + ) -> str | None: + entries: list[tuple[str, str | None]] = [] + seen_commands: set[str] = set() + for handler in record.handlers: + descriptor = handler.descriptor + if root_name not in self._descriptor_root_candidates(descriptor): + continue + if not TriggerConverter._match_filters(descriptor, event): + continue + if not self._descriptor_is_visible_to_event(descriptor, event): + continue + help_entry = self._descriptor_help_entry(descriptor) + if help_entry is None: + continue + command_name, description = help_entry + if command_name in seen_commands: + continue + seen_commands.add(command_name) + entries.append((command_name, description)) + if not entries: + return None + lines = [f"{root_name}命令:"] + for command_name, description in entries: + line = f"- /{command_name}" + if description: + line += f": {description}" + lines.append(line) + return "\n".join(lines) + + @staticmethod + def _descriptor_requires_admin(descriptor: HandlerDescriptor) -> bool: + required_role = descriptor.permissions.required_role + if required_role is None and descriptor.permissions.require_admin: + required_role = "admin" + return required_role == "admin" + + @classmethod + def _descriptor_is_visible_to_event( + cls, + descriptor: HandlerDescriptor, + event: AstrMessageEvent, + ) -> bool: + if cls._descriptor_requires_admin(descriptor) and not event.is_admin(): + return False + return True + + @staticmethod + def _descriptor_matches_command_text( + descriptor: HandlerDescriptor, + text: str, + ) -> bool: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return False + for command_name in [trigger.command, *trigger.aliases]: + if not command_name: + continue + if TriggerConverter._match_command_name(text, command_name) is not None: + return True + return False + + def _match_dynamic_command_route( + self, + *, + record: SdkPluginRecord, + route: SdkDynamicCommandRoute, + event: AstrMessageEvent, + declaration_order: int, + ) -> TriggerMatch | None: + handler_ref = self._find_handler_ref(record, route.handler_full_name) + if handler_ref is None: + return None + descriptor = handler_ref.descriptor.model_copy(deep=True) + descriptor.priority = route.priority + if route.use_regex: + descriptor.trigger = MessageTrigger(regex=route.command_name) + else: + descriptor.trigger = CommandTrigger( + command=route.command_name, + description=route.desc or None, + ) + return TriggerConverter.match_handler( + plugin_id=record.plugin_id, + descriptor=descriptor, + event=event, + load_order=record.load_order, + declaration_order=declaration_order, + ) + + @staticmethod + def _find_handler_ref( + record: SdkPluginRecord, + handler_full_name: str, + ) -> SdkHandlerRef | None: + for handler in record.handlers: + if handler.descriptor.id == handler_full_name: + return handler + return None + + async def dispatch_system_event( + self, + event_type: str, + payload: dict[str, Any] | None = None, + ) -> None: + await self.dispatch_engine.dispatch_system_event(event_type, payload) + + async def dispatch_message_event( + self, + event_type: str, + event: AstrMessageEvent, + payload: dict[str, Any] | None = None, + *, + provider_request: CoreProviderRequest | None = None, + llm_response: CoreLLMResponse | None = None, + event_result: MessageEventResult | None = None, + ) -> None: + await self.dispatch_engine.dispatch_message_event( + event_type, + event, + payload, + provider_request=provider_request, + llm_response=llm_response, + event_result=event_result, + ) + + def _match_event_handlers( + self, + event_type: str, + *, + allowed_plugins: set[str] | None = None, + platform_name: str = "", + ) -> list[tuple[SdkPluginRecord, HandlerDescriptor]]: + matches: list[tuple[int, int, int, SdkPluginRecord, HandlerDescriptor]] = [] + for record in self._snapshot_records(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if allowed_plugins is not None and record.plugin_id not in allowed_plugins: + continue + if not self._record_supports_platform(record, platform_name): + continue + for handler in record.handlers: + trigger = handler.descriptor.trigger + if not isinstance(trigger, EventTrigger): + continue + if trigger.event_type != event_type: + continue + if not self._descriptor_supports_platform( + handler.descriptor, + platform_name, + ): + continue + matches.append( + ( + -handler.descriptor.priority, + record.load_order, + handler.declaration_order, + record, + handler.descriptor, + ) + ) + matches.sort(key=lambda item: (item[0], item[1], item[2])) + return [(record, descriptor) for _, _, _, record, descriptor in matches] + + @staticmethod + def _descriptor_event_types(descriptor: HandlerDescriptor) -> list[str]: + trigger = descriptor.trigger + if isinstance(trigger, EventTrigger): + return [trigger.event_type] + return [] + + @staticmethod + def _descriptor_group_path(descriptor: HandlerDescriptor) -> list[str]: + route = getattr(descriptor, "command_route", None) + if route is None: + return [] + return list(route.group_path) + + @staticmethod + def _descriptor_description(descriptor: HandlerDescriptor) -> str | None: + description = str(descriptor.description or "").strip() + if description: + return description + trigger = descriptor.trigger + if isinstance(trigger, CommandTrigger): + command_description = str(trigger.description or "").strip() + if command_description: + return command_description + return None + + def _descriptor_metadata( + self, + *, + plugin_id: str, + descriptor: HandlerDescriptor, + ) -> dict[str, Any]: + return { + "plugin_name": plugin_id, + "handler_full_name": descriptor.id, + "trigger_type": getattr(descriptor.trigger, "type", ""), + "description": self._descriptor_description(descriptor), + "event_types": self._descriptor_event_types(descriptor), + "enabled": True, + "group_path": self._descriptor_group_path(descriptor), + "priority": descriptor.priority, + "kind": descriptor.kind, + "require_admin": descriptor.permissions.require_admin, + "required_role": descriptor.permissions.required_role, + } + + def get_handlers_by_event_type(self, event_type: str) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + for record in self._snapshot_records_sorted(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + for handler in record.handlers: + trigger = handler.descriptor.trigger + if ( + isinstance(trigger, EventTrigger) + and trigger.event_type == event_type + ): + entries.append( + self._descriptor_metadata( + plugin_id=record.plugin_id, + descriptor=handler.descriptor, + ) + ) + if event_type == "message": + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + entries.append( + self._descriptor_metadata( + plugin_id=record.plugin_id, + descriptor=descriptor, + ) + ) + return entries + + def list_native_command_candidates( + self, + platform_name: str, + ) -> list[dict[str, Any]]: + """Expose SDK commands that can be surfaced in native platform menus. + + Native platform command menus are top-level and single-token, so grouped + SDK commands are exported as their root command (for example ``gf`` for + ``gf chat`` / ``gf affection``). + """ + normalized_platform = str(platform_name).strip().lower() + if not normalized_platform: + return [] + + entries: list[dict[str, Any]] = [] + seen_names: set[str] = set() + + for record in self._snapshot_records_sorted(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if not self._record_supports_platform(record, normalized_platform): + continue + + for handler in record.handlers: + for entry in self._descriptor_native_command_candidates( + handler.descriptor, + platform_name=normalized_platform, + ): + name = str(entry.get("name", "")).strip().lower() + if not name or name in seen_names: + continue + seen_names.add(name) + entries.append(entry) + + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + for entry in self._descriptor_native_command_candidates( + descriptor, + platform_name=normalized_platform, + ): + name = str(entry.get("name", "")).strip().lower() + if not name or name in seen_names: + continue + seen_names.add(name) + entries.append(entry) + + return entries + + def get_handler_by_full_name(self, full_name: str) -> dict[str, Any] | None: + for record in self._snapshot_records(): + for handler in record.handlers: + if handler.descriptor.id == full_name: + return self._descriptor_metadata( + plugin_id=record.plugin_id, + descriptor=handler.descriptor, + ) + return None + + def list_dashboard_commands(self) -> list[dict[str, Any]]: + items: list[dict[str, Any]] = [] + for record in self._snapshot_records_sorted(): + items.extend(self._build_dashboard_command_items(record)) + items.sort(key=lambda item: str(item.get("effective_command", "")).lower()) + return items + + def list_dashboard_tools(self) -> list[dict[str, Any]]: + tools: list[dict[str, Any]] = [] + for record in self._snapshot_records_sorted(): + display_name = str( + record.plugin.manifest_data.get("display_name") or record.plugin_id + ) + plugin_enabled = record.state not in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + } + for spec in sorted(record.llm_tools.values(), key=lambda item: item.name): + tools.append( + { + "tool_key": (f"sdk:{record.plugin_id}:{spec.name}"), + "name": spec.name, + "description": spec.description, + "parameters": dict(spec.parameters_schema), + "active": bool(spec.active) and plugin_enabled, + "origin": "sdk_plugin", + "origin_name": display_name, + "runtime_kind": "sdk", + "plugin_id": record.plugin_id, + } + ) + return tools + + def _build_dashboard_command_items( + self, + record: SdkPluginRecord, + ) -> list[dict[str, Any]]: + flat_commands: list[dict[str, Any]] = [] + for handler in record.handlers: + entry = self._build_dashboard_command_entry( + record=record, + descriptor=handler.descriptor, + ) + if entry is not None: + flat_commands.append(entry) + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + entry = self._build_dashboard_command_entry( + record=record, + descriptor=descriptor, + route=route, + ) + if entry is not None: + flat_commands.append(entry) + + groups: dict[str, dict[str, Any]] = {} + root_items: list[dict[str, Any]] = [] + for entry in flat_commands: + parent_signature = str(entry.get("parent_signature", "")).strip() + if not parent_signature: + root_items.append(entry) + continue + group_key = self._dashboard_group_key(record.plugin_id, parent_signature) + group = groups.get(group_key) + if group is None: + group = { + "command_key": group_key, + "handler_full_name": group_key, + "handler_name": parent_signature.split()[-1] or record.plugin_id, + "plugin": record.plugin_id, + "plugin_display_name": str( + record.plugin.manifest_data.get("display_name") + or record.plugin_id + ), + "module_path": str(record.plugin.plugin_dir), + "description": entry.pop("_group_help", "") or "", + "type": "group", + "parent_signature": "", + "parent_group_handler": "", + "original_command": parent_signature, + "current_fragment": parent_signature.split()[-1] + if parent_signature + else "", + "effective_command": parent_signature, + "aliases": [], + "permission": "everyone", + "enabled": bool(entry.get("enabled", False)), + "is_group": True, + "has_conflict": False, + "reserved": False, + "runtime_kind": "sdk", + "supports_toggle": False, + "supports_rename": False, + "supports_permission": False, + "sub_commands": [], + } + groups[group_key] = group + root_items.append(group) + elif not group.get("description") and entry.get("_group_help"): + group["description"] = entry["_group_help"] + + if entry.get("permission") == "admin": + group["permission"] = "admin" + group["enabled"] = bool(group["enabled"]) or bool( + entry.get("enabled", False) + ) + entry["parent_group_handler"] = group["handler_full_name"] + entry.pop("_group_help", None) + group["sub_commands"].append(entry) + + for group in groups.values(): + group["sub_commands"].sort( + key=lambda item: str(item.get("effective_command", "")).lower() + ) + for item in root_items: + item.pop("_group_help", None) + return root_items + + def _build_dashboard_command_entry( + self, + *, + record: SdkPluginRecord, + descriptor: HandlerDescriptor, + route: SdkDynamicCommandRoute | None = None, + ) -> dict[str, Any] | None: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return None + + route_meta = descriptor.command_route + effective_command = ( + str(route_meta.display_command).strip() + if route_meta is not None and str(route_meta.display_command).strip() + else str(trigger.command).strip() + ) + parent_signature = "" + group_help = "" + if route_meta is not None and route_meta.group_path: + parent_signature = " ".join( + str(item).strip() for item in route_meta.group_path if str(item).strip() + ).strip() + group_help = str(route_meta.group_help or "").strip() + + current_fragment = effective_command + if parent_signature and effective_command.startswith(f"{parent_signature} "): + current_fragment = effective_command[len(parent_signature) + 1 :].strip() + + enabled = record.state not in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + } + return { + "command_key": self._dashboard_command_key( + plugin_id=record.plugin_id, + handler_full_name=descriptor.id, + route=route, + ), + "handler_full_name": descriptor.id, + "handler_name": descriptor.id.rsplit(".", 1)[-1], + "plugin": record.plugin_id, + "plugin_display_name": str( + record.plugin.manifest_data.get("display_name") or record.plugin_id + ), + "module_path": descriptor.id.rsplit(".", 1)[0], + "description": self._descriptor_description(descriptor) or "", + "type": "sub_command" if parent_signature else "command", + "parent_signature": parent_signature, + "parent_group_handler": "", + "original_command": effective_command, + "current_fragment": current_fragment, + "effective_command": effective_command, + "aliases": list(trigger.aliases), + "permission": ( + "admin" if descriptor.permissions.require_admin else "everyone" + ), + "enabled": enabled, + "is_group": False, + "has_conflict": False, + "reserved": False, + "runtime_kind": "sdk", + "supports_toggle": False, + "supports_rename": False, + "supports_permission": False, + "sub_commands": [], + "_group_help": group_help, + } + + @staticmethod + def _dashboard_command_key( + *, + plugin_id: str, + handler_full_name: str, + route: SdkDynamicCommandRoute | None, + ) -> str: + if route is None: + return f"sdk:command:{plugin_id}:{handler_full_name}" + route_kind = "regex" if route.use_regex else "command" + return f"sdk:route:{plugin_id}:{handler_full_name}:{route_kind}:{route.command_name}" + + @staticmethod + def _dashboard_group_key(plugin_id: str, parent_signature: str) -> str: + return f"sdk:group:{plugin_id}:{parent_signature}" + + def _build_dynamic_route_descriptor( + self, + record: SdkPluginRecord, + route: SdkDynamicCommandRoute, + ) -> HandlerDescriptor | None: + handler_ref = self._find_handler_ref(record, route.handler_full_name) + if handler_ref is None: + return None + descriptor = handler_ref.descriptor.model_copy(deep=True) + descriptor.priority = route.priority + if route.use_regex: + descriptor.trigger = MessageTrigger(regex=route.command_name) + else: + descriptor.trigger = CommandTrigger( + command=route.command_name, + description=route.desc or None, + ) + return descriptor + + @staticmethod + def _normalize_platform_name(value: Any) -> str: + return str(value or "").strip().lower() + + @classmethod + def _normalized_platform_names(cls, values: Any) -> set[str]: + if not isinstance(values, list): + return set() + return { + cls._normalize_platform_name(item) + for item in values + if cls._normalize_platform_name(item) + } + + @classmethod + def _manifest_supported_platforms(cls, manifest_data: Any) -> set[str]: + if not isinstance(manifest_data, dict): + return set() + return cls._normalized_platform_names(manifest_data.get("support_platforms")) + + def plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + normalized_platform = self._normalize_platform_name(platform_name) + if not normalized_platform: + return True + record = self._records.get(str(plugin_id)) + if record is None: + return True + return self._record_supports_platform(record, normalized_platform) + + @staticmethod + def _record_supports_platform( + record: SdkPluginRecord, + platform_name: str, + ) -> bool: + normalized_platform = SdkPluginBridge._normalize_platform_name(platform_name) + if not normalized_platform: + return True + plugin = getattr(record, "plugin", None) + manifest_data = getattr(plugin, "manifest_data", None) + normalized = SdkPluginBridge._manifest_supported_platforms(manifest_data) + if not normalized: + return True + return normalized_platform in normalized + + @staticmethod + def _local_mcp_tool_name(server_name: str, tool_name: str) -> str: + return f"mcp.{server_name}.{tool_name}" + + @staticmethod + def _local_mcp_tool_ref(server_name: str, tool_name: str) -> str: + return json.dumps( + {"server_name": server_name, "tool_name": tool_name}, + ensure_ascii=True, + separators=(",", ":"), + ) + + @staticmethod + def _plugin_data_dir(plugin_id: str) -> Path: + return Path(get_astrbot_plugin_data_path()) / plugin_id + + @classmethod + def _plugin_mcp_lease_dir(cls, plugin_id: str) -> Path: + return cls._plugin_data_dir(plugin_id) / ".mcp_leases" + + def acknowledges_global_mcp_risk(self, plugin_id: str) -> bool: + record = self._records.get(plugin_id) + return bool(record and record.acknowledge_global_mcp_risk) + + def _load_local_mcp_configs(self, plugin: PluginSpec) -> dict[str, dict[str, Any]]: + config_path = plugin.plugin_dir / "mcp.json" + if not config_path.exists(): + return {} + try: + payload = json.loads(config_path.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning( + "Failed to read SDK plugin mcp.json %s: %s", config_path, exc + ) + return {} + if not isinstance(payload, dict): + logger.warning("Ignoring invalid SDK plugin mcp.json root: %s", config_path) + return {} + servers = payload.get("mcpServers") + if not isinstance(servers, dict): + logger.warning( + "Ignoring SDK plugin mcp.json without mcpServers: %s", config_path + ) + return {} + return { + str(name): dict(config) + for name, config in servers.items() + if str(name).strip() and isinstance(config, dict) + } + + @classmethod + def _build_local_mcp_tool_specs( + cls, + server_name: str, + client: MCPClient, + ) -> list[LLMToolSpec]: + specs: list[LLMToolSpec] = [] + for tool in client.tools: + raw_tool_name = str(getattr(tool, "name", "")).strip() + if not raw_tool_name: + continue + parameters_schema = getattr(tool, "inputSchema", None) + if not isinstance(parameters_schema, dict): + parameters_schema = {"type": "object", "properties": {}} + specs.append( + LLMToolSpec.create( + name=cls._local_mcp_tool_name(server_name, raw_tool_name), + description=str(getattr(tool, "description", "") or ""), + parameters_schema=dict(parameters_schema), + handler_ref=cls._local_mcp_tool_ref(server_name, raw_tool_name), + handler_capability="internal.mcp.local.execute", + active=True, + ) + ) + return specs + + @staticmethod + def _mcp_call_result_to_text(result: Any) -> str | None: + content_items = getattr(result, "content", None) + if not isinstance(content_items, list): + return None + chunks: list[str] = [] + for item in content_items: + text = getattr(item, "text", None) + if isinstance(text, str): + chunks.append(text) + continue + model_dump = getattr(item, "model_dump", None) + if callable(model_dump): + chunks.append(json.dumps(model_dump(), ensure_ascii=False)) + continue + if item is not None: + chunks.append(str(item)) + return "\n".join(part for part in chunks if part).strip() or None + + async def _cleanup_mcp_client(self, client: MCPClient | None) -> None: + if client is None: + return + with contextlib.suppress(Exception): + await client.cleanup() + + def _write_local_mcp_lease( + self, + *, + plugin_id: str, + server_name: str, + pid: int, + ) -> Path: + lease_dir = self._plugin_mcp_lease_dir(plugin_id) + lease_dir.mkdir(parents=True, exist_ok=True) + lease_path = lease_dir / f"{server_name}.json" + lease_path.write_text( + json.dumps( + { + "pid": int(pid), + "plugin_id": plugin_id, + "server_name": server_name, + }, + ensure_ascii=True, + indent=2, + ), + encoding="utf-8", + ) + return lease_path + + @staticmethod + def _remove_local_mcp_lease(runtime: _LocalMCPServerRuntime) -> None: + lease_path = runtime.lease_path + runtime.lease_path = None + if lease_path is None: + return + with contextlib.suppress(OSError): + lease_path.unlink() + + def _terminate_stale_mcp_pid(self, pid: int) -> None: + if pid <= 0: + return + if os.name == "nt": + # Windows 没有 SIGTERM,os.kill 在 Windows 上行为不稳定; + # 使用 taskkill /T /F 可以递归终止整个进程树,更可靠 + creation_flags = int(getattr(subprocess, "CREATE_NO_WINDOW", 0)) + completed = subprocess.run( + ["taskkill", "/PID", str(pid), "/T", "/F"], + capture_output=True, + text=True, + check=False, + creationflags=creation_flags, + ) + combined_output = " ".join( + item.strip() + for item in (completed.stdout, completed.stderr) + if isinstance(item, str) and item.strip() + ).lower() + # 进程已不存在("not found")也视为成功终止,避免误报 + if completed.returncode == 0 or "not found" in combined_output: + return + logger.warning( + "Failed to terminate stale MCP pid %s on Windows: rc=%s output=%s", + pid, + completed.returncode, + combined_output or "", + ) + return + # 非 Windows 平台使用 SIGTERM,简洁且可移植 + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return + except PermissionError: + logger.warning("Permission denied while terminating stale MCP pid %s", pid) + return + except OSError as exc: + logger.warning("Failed to terminate stale MCP pid %s: %s", pid, exc) + + def _sweep_stale_mcp_leases(self) -> None: + plugin_data_root = Path(get_astrbot_plugin_data_path()) + if not plugin_data_root.exists(): + return + for lease_path in plugin_data_root.glob("*/.mcp_leases/*.json"): + try: + payload = json.loads(lease_path.read_text(encoding="utf-8")) + except Exception: + payload = {} + pid = payload.get("pid") + if pid is not None: + with contextlib.suppress(TypeError, ValueError): + self._terminate_stale_mcp_pid(int(pid)) + with contextlib.suppress(OSError): + lease_path.unlink() + + async def _connect_local_mcp_server( + self, + *, + plugin_id: str, + runtime: _LocalMCPServerRuntime, + timeout: float, + ) -> None: + await self.mcp.connect_local_mcp_server( + plugin_id=plugin_id, + runtime=runtime, + timeout=timeout, + ) + + async def _initialize_local_mcp_servers(self, record: SdkPluginRecord) -> None: + await self.mcp.initialize_local_mcp_servers(record) + + async def _shutdown_local_mcp_runtime( + self, + runtime: _LocalMCPServerRuntime, + ) -> None: + await self.mcp.shutdown_local_mcp_runtime(runtime) + + async def _shutdown_local_mcp_servers(self, record: SdkPluginRecord) -> None: + await self.mcp.shutdown_local_mcp_servers(record) + + async def enable_local_mcp_server( + self, + plugin_id: str, + name: str, + *, + timeout: float = 30.0, + ) -> dict[str, Any]: + return await self.mcp.enable_local_mcp_server( + plugin_id, + name, + timeout=timeout, + ) + + async def disable_local_mcp_server( + self, + plugin_id: str, + name: str, + ) -> dict[str, Any]: + return await self.mcp.disable_local_mcp_server(plugin_id, name) + + async def wait_for_local_mcp_server( + self, + plugin_id: str, + name: str, + *, + timeout: float, + ) -> dict[str, Any]: + return await self.mcp.wait_for_local_mcp_server( + plugin_id, + name, + timeout=timeout, + ) + + async def open_temporary_mcp_session( + self, + plugin_id: str, + *, + name: str, + config: dict[str, Any], + timeout: float, + ) -> tuple[str, list[str]]: + return await self.mcp.open_temporary_mcp_session( + plugin_id, + name=name, + config=config, + timeout=timeout, + ) + + async def close_temporary_mcp_session( + self, + plugin_id: str, + session_id: str, + ) -> None: + await self.mcp.close_temporary_mcp_session(plugin_id, session_id) + + async def _close_temporary_mcp_sessions(self, plugin_id: str) -> None: + await self.mcp.close_temporary_mcp_sessions(plugin_id) + + def get_temporary_mcp_session_tools( + self, + plugin_id: str, + session_id: str, + ) -> list[str]: + return self.mcp.get_temporary_mcp_session_tools(plugin_id, session_id) + + async def call_temporary_mcp_tool( + self, + plugin_id: str, + *, + session_id: str, + tool_name: str, + arguments: dict[str, Any], + ) -> dict[str, Any]: + return await self.mcp.call_temporary_mcp_tool( + plugin_id, + session_id=session_id, + tool_name=tool_name, + arguments=arguments, + ) + + async def execute_local_mcp_tool( + self, + plugin_id: str, + *, + server_name: str, + tool_name: str, + tool_args: dict[str, Any], + timeout_seconds: int = 60, + ) -> dict[str, Any]: + return await self.mcp.execute_local_mcp_tool( + plugin_id, + server_name=server_name, + tool_name=tool_name, + tool_args=tool_args, + timeout_seconds=timeout_seconds, + ) + + @classmethod + def _descriptor_native_command_candidates( + cls, + descriptor: HandlerDescriptor, + *, + platform_name: str, + ) -> list[dict[str, Any]]: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return [] + if not cls._descriptor_supports_platform(descriptor, platform_name): + return [] + + names = [trigger.command, *trigger.aliases] + route = descriptor.command_route + root_candidates: list[str] = [] + + if route is not None and route.group_path: + root_candidates.append(str(route.group_path[0]).strip()) + + for name in names: + normalized = str(name).strip() + if " " not in normalized: + continue + root_candidates.append(normalized.split()[0].strip()) + + if root_candidates: + description = ( + str(route.group_help).strip() + if route is not None and route.group_help + else str(trigger.description or "").strip() + ) + root_name = next((item for item in root_candidates if item), "") + if not description and root_name: + description = f"Command group: {root_name}" + unique_roots = [ + item + for item in dict.fromkeys(root_candidates) + if isinstance(item, str) and item.strip() + ] + return [ + { + "name": item.strip(), + "description": description, + "is_group": True, + } + for item in unique_roots + ] + + description = str(trigger.description or "").strip() + if not description and trigger.command.strip(): + description = f"Command: {trigger.command.strip()}" + unique_names = [ + item for item in dict.fromkeys(str(name).strip() for name in names) if item + ] + return [ + { + "name": item, + "description": description, + "is_group": False, + } + for item in unique_names + ] + + @classmethod + def _descriptor_supports_platform( + cls, + descriptor: HandlerDescriptor, + platform_name: str, + ) -> bool: + normalized_platform = cls._normalize_platform_name(platform_name) + if not normalized_platform: + return True + trigger_platforms = getattr(descriptor.trigger, "platforms", []) + if isinstance(trigger_platforms, list): + normalized = cls._normalized_platform_names(trigger_platforms) + if normalized and normalized_platform not in normalized: + return False + for filter_spec in descriptor.filters: + if not cls._filter_supports_platform(filter_spec, normalized_platform): + return False + return True + + @classmethod + def _filter_supports_platform(cls, filter_spec, platform_name: str) -> bool: + if isinstance(filter_spec, PlatformFilterSpec): + normalized = { + str(item).strip().lower() + for item in filter_spec.platforms + if str(item).strip() + } + return not normalized or platform_name in normalized + if isinstance(filter_spec, CompositeFilterSpec): + platform_children = [ + child + for child in filter_spec.children + if isinstance(child, PlatformFilterSpec | CompositeFilterSpec) + ] + if not platform_children: + return True + results = [ + cls._filter_supports_platform(child, platform_name) + for child in platform_children + ] + if filter_spec.kind == "and": + return all(results) + return any(results) + return True + + async def _load_or_reload_plugin( + self, + plugin: PluginSpec, + *, + load_order: int, + reset_restart_budget: bool, + ) -> None: + current = self._records.get(plugin.name) + if current is not None: + current.state = SDK_STATE_RELOADING + await self._cancel_plugin_requests(plugin.name) + await self._teardown_plugin(plugin.name) + + disabled = bool( + self._state_overrides.get(plugin.name, {}).get("disabled", False) + ) + config_schema = load_plugin_config_schema(plugin) + local_mcp_configs = self._load_local_mcp_configs(plugin) + local_mcp_servers: dict[str, _LocalMCPServerRuntime] = {} + for server_name, server_config in local_mcp_configs.items(): + local_mcp_servers[server_name] = _LocalMCPServerRuntime( + name=server_name, + config=dict(server_config), + active=bool(server_config.get("active", True)), + ) + + record = SdkPluginRecord( + plugin=plugin, + load_order=load_order, + state=SDK_STATE_DISABLED if disabled else SDK_STATE_ENABLED, + unsupported_features=[], + config_schema=config_schema, + config=load_plugin_config(plugin, schema=config_schema), + handlers=[], + llm_tools={}, + active_llm_tools=set(), + agents={}, + restart_attempted=False + if reset_restart_budget + else (current.restart_attempted if current is not None else False), + issues=[dict(item) for item in self._discovery_issues.get(plugin.name, [])], + local_mcp_servers=local_mcp_servers, + ) + self._records[plugin.name] = record + self._publish_plugin_skills(plugin.name) + if disabled: + self._persist_state_overrides() + return + + try: + + def _schedule_closed(plugin_id: str = plugin.name) -> None: + asyncio.create_task(self._handle_worker_closed(plugin_id)) + + session = WorkerSession( + plugin=plugin, + repo_root=Path(__file__).resolve().parents[3], + env_manager=self.env_manager, + capability_router=self.capability_bridge, + on_closed=_schedule_closed, + ) + await session.start() + session.start_close_watch() + record.session = session + remote_metadata = ( + dict(session.peer.remote_metadata) + if session.peer is not None + and isinstance(session.peer.remote_metadata, dict) + else {} + ) + record.acknowledge_global_mcp_risk = bool( + remote_metadata.get("acknowledge_global_mcp_risk", False) + ) + unsupported_features: set[str] = set() + for index, descriptor in enumerate(session.handlers): + if ( + isinstance(descriptor.trigger, EventTrigger) + and descriptor.trigger.event_type not in SUPPORTED_SYSTEM_EVENTS + ): + unsupported_features.add("event_trigger") + record.handlers.append( + SdkHandlerRef( + descriptor=descriptor, + declaration_order=index, + ) + ) + for item in session.llm_tools: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_id") or plugin.name) + if plugin_name != plugin.name: + continue + normalized = dict(item) + normalized.pop("plugin_id", None) + spec = LLMToolSpec.from_payload(normalized) + record.llm_tools[spec.name] = spec + if spec.active: + record.active_llm_tools.add(spec.name) + for item in session.agents: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_id") or plugin.name) + if plugin_name != plugin.name: + continue + normalized = dict(item) + normalized.pop("plugin_id", None) + spec = AgentSpec.from_payload(normalized) + record.agents[spec.name] = spec + await self._register_schedule_handlers(record) + await self._initialize_local_mcp_servers(record) + record.issues.extend(issue.to_payload() for issue in session.issues) + record.unsupported_features = sorted(unsupported_features) + record.state = ( + SDK_STATE_UNSUPPORTED_PARTIAL + if record.unsupported_features + else SDK_STATE_ENABLED + ) + record.failure_reason = "" + registered_http_apis = self.list_http_apis(plugin.name) + if registered_http_apis: + api_base_url = self._public_http_url(f"/{plugin.name}") + entry_route = self._plugin_entry_route(plugin.name) + if entry_route is not None: + logger.info( + "SDK plugin HTTP routes ready: plugin=%s total=%s page=%s api_base=%s", + plugin.name, + len(registered_http_apis), + self._public_page_url(entry_route), + api_base_url, + ) + else: + logger.info( + "SDK plugin HTTP routes ready: plugin=%s total=%s api_base=%s", + plugin.name, + len(registered_http_apis), + api_base_url, + ) + except Exception as exc: + record.session = None + record.state = SDK_STATE_FAILED + record.failure_reason = str(exc) + record.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin.name, + message="插件 worker 启动失败", + details=str(exc), + ).to_payload() + ) + logger.warning("Failed to start SDK plugin %s: %s", plugin.name, exc) + finally: + self._persist_state_overrides() + + async def _teardown_plugin(self, plugin_id: str) -> None: + record = self._records.get(plugin_id) + self._http_routes.pop(plugin_id, None) + self._session_waiters.pop(plugin_id, None) + await self._unregister_schedule_jobs(plugin_id) + await self._close_temporary_mcp_sessions(plugin_id) + await self._clear_plugin_skills( + plugin_id=plugin_id, + record=record, + reason="teardown", + ) + if record is None or record.session is None: + if record is not None: + await self._shutdown_local_mcp_servers(record) + return + try: + await self._shutdown_local_mcp_servers(record) + await record.session.stop() + finally: + record.session = None + + async def _register_schedule_handlers(self, record: SdkPluginRecord) -> None: + cron_manager = getattr(self.star_context, "cron_manager", None) + if cron_manager is None: + return + for handler in record.handlers: + trigger = handler.descriptor.trigger + if not isinstance(trigger, ScheduleTrigger): + continue + schedule_key = f"{record.plugin_id}:{handler.handler_id}" + job_ref: dict[str, Any] = {"job": None} + job = await cron_manager.add_basic_job( + name=trigger.name or schedule_key, + cron_expression=trigger.cron, + interval_seconds=trigger.interval_seconds, + handler=self._build_schedule_runner( + plugin_id=record.plugin_id, + handler_id=handler.handler_id, + trigger=trigger, + job_ref=job_ref, + ), + description=handler.descriptor.description + or f"SDK schedule handler {handler.handler_id}", + timezone=trigger.timezone, + enabled=True, + persistent=False, + ) + job_ref["job"] = job + self._schedule_job_ids.setdefault(record.plugin_id, set()).add(job.job_id) + + async def _unregister_schedule_jobs(self, plugin_id: str) -> None: + cron_manager = getattr(self.star_context, "cron_manager", None) + if cron_manager is None: + return + for job_id in list(self._schedule_job_ids.pop(plugin_id, set())): + try: + await cron_manager.delete_job(job_id) + except Exception: + logger.debug("Failed to remove SDK schedule job {}", job_id) + + def _build_schedule_runner( + self, + *, + plugin_id: str, + handler_id: str, + trigger: ScheduleTrigger, + job_ref: dict[str, Any] | None = None, + ): + async def _run(**_scheduler_payload: Any) -> None: + # CronJobManager stores scheduler metadata such as interval_seconds in the + # job payload and replays that payload into basic handlers. SDK schedule + # handlers do not consume those transport-level kwargs, so the bridge + # must swallow them here and only forward the synthesized schedule event. + invoke_kwargs = { + "plugin_id": plugin_id, + "handler_id": handler_id, + "trigger": trigger, + } + job = (job_ref or {}).get("job") + if job is not None: + invoke_kwargs["job"] = job + await self._invoke_schedule_handler( + **invoke_kwargs, + ) + + return _run + + def _set_discovery_issues(self, issues: list[PluginDiscoveryIssue]) -> None: + grouped: dict[str, list[dict[str, Any]]] = {} + for issue in issues: + grouped.setdefault(issue.plugin_id, []).append(issue.to_payload()) + self._discovery_issues = grouped + + # TODO: 平台适配器目前仍用 legacy 的 @register_platform_adapter,不走 SDK 协议。 + # 长期来看可以把平台适配器也纳入 SDK 的 capability 体系,实现完全统一的插件/平台注册机制。 + # 但是目前先保持现状,等平台适配器的 SDK 能力稳定后再做迁移,以避免不必要的重复开发和潜在风险。 + async def _refresh_native_platform_commands( + self, platforms: set[str] | None = None + ) -> None: + platform_manager = getattr(self.star_context, "platform_manager", None) + if platform_manager is None: + return + refresh_commands = getattr(platform_manager, "refresh_native_commands", None) + if not callable(refresh_commands): + return + refresh_commands_async = cast( + Callable[..., Awaitable[Any]], + refresh_commands, + ) + try: + await refresh_commands_async(platforms=platforms) + except Exception as exc: + logger.warning("Failed to refresh native platform commands: %s", exc) + + async def _invoke_schedule_handler( + self, + *, + plugin_id: str, + handler_id: str, + trigger: ScheduleTrigger, + job: Any | None = None, + ) -> None: + record = self._records.get(plugin_id) + if ( + record is None + or record.session is None + or record.state + in {SDK_STATE_DISABLED, SDK_STATE_FAILED, SDK_STATE_RELOADING} + ): + return + dispatch_token = uuid.uuid4().hex + request_id = f"sdk_schedule_{plugin_id}_{uuid.uuid4().hex}" + self._ensure_request_overlay(dispatch_token, should_call_llm=False) + self._request_contexts[dispatch_token] = _RequestContext( + plugin_id=plugin_id, + request_id=request_id, + dispatch_token=dispatch_token, + dispatch_state=None, + ) + self._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=plugin_id, + ) + payload = self._build_schedule_payload( + plugin_id=plugin_id, + handler_id=handler_id, + trigger=trigger, + job=job, + ) + try: + await record.session.invoke_handler( + handler_id, + payload, + request_id=request_id, + args={}, + ) + except Exception as exc: + logger.warning( + "SDK schedule handler failed: plugin=%s handler=%s error=%s", + plugin_id, + handler_id, + exc, + ) + finally: + # 无论调度 handler 成功与否,都要关闭 overlay, + # 防止已结束的调度任务一直占用 overlay 槽位导致内存泄漏 + self._close_request_overlay(dispatch_token) + + @staticmethod + def _build_schedule_payload( + *, + plugin_id: str, + handler_id: str, + trigger: ScheduleTrigger, + job: Any | None = None, + ) -> dict[str, Any]: + scheduled_at = datetime.now(timezone.utc).isoformat() + job_name = str(getattr(job, "name", "")).strip() or f"{plugin_id}:{handler_id}" + job_id = str(getattr(job, "job_id", "")).strip() or None + description = getattr(job, "description", None) + if description is not None: + description = str(description).strip() or None + job_type = str(getattr(job, "job_type", "")).strip() or "basic" + timezone_name = getattr(job, "timezone", None) + if isinstance(timezone_name, str): + timezone_name = timezone_name.strip() or None + else: + timezone_name = None + if timezone_name is None: + timezone_name = trigger.timezone + return { + "type": "schedule", + "event_type": "schedule", + "text": "", + "session_id": "", + "platform": "", + "platform_id": "", + "message_type": "other", + "sender_name": "", + "self_id": "", + "raw": {"event_type": "schedule"}, + "schedule": { + "schedule_id": f"{plugin_id}:{handler_id}", + "job_id": job_id, + "plugin_id": plugin_id, + "handler_id": handler_id, + "name": job_name, + "description": description, + "job_type": job_type, + "trigger_kind": "cron" if trigger.cron is not None else "interval", + "cron": trigger.cron, + "interval_seconds": trigger.interval_seconds, + "timezone": timezone_name, + "scheduled_at": scheduled_at, + }, + } + + async def _cancel_plugin_requests(self, plugin_id: str) -> None: + requests = list(self._plugin_requests.get(plugin_id, {}).values()) + for inflight in requests: + request_context = self._request_contexts.get(inflight.dispatch_token) + if request_context is not None: + request_context.cancelled = True + self._close_request_overlay(inflight.dispatch_token) + record = self._records.get(plugin_id) + if ( + record is not None + and record.session is not None + and record.session.peer is not None + and not inflight.task.done() + ): + try: + await record.session.cancel(inflight.request_id) + except Exception: + logger.debug( + "Failed to forward SDK cancel for %s", inflight.request_id + ) + inflight.task.cancel() + else: + inflight.logical_cancelled = True + self._plugin_requests.pop(plugin_id, None) + + async def _handle_worker_closed(self, plugin_id: str) -> None: + await self.lifecycle.handle_worker_closed(plugin_id) + + def _record_to_dashboard_item(self, record: SdkPluginRecord) -> dict[str, Any]: + manifest = record.plugin.manifest_data + support_platforms = manifest.get("support_platforms") + installed_at = None + try: + installed_at = datetime.fromtimestamp( + record.plugin.plugin_dir.stat().st_mtime, + timezone.utc, + ).isoformat() + except OSError: + installed_at = None + handlers = [ + self._handler_to_dashboard_item(handler) for handler in record.handlers + ] + return { + "name": record.plugin_id, + "repo": str(manifest.get("repo") or ""), + "author": str(manifest.get("author") or ""), + "desc": str(manifest.get("desc") or manifest.get("description") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "reserved": False, + "activated": record.state not in {SDK_STATE_DISABLED, SDK_STATE_FAILED}, + "online_vesion": "", + "handlers": handlers, + "display_name": str(manifest.get("display_name") or record.plugin_id), + "logo": None, + "support_platforms": [ + str(item) for item in support_platforms if isinstance(item, str) + ] + if isinstance(support_platforms, list) + else [], + "astrbot_version": ( + str(manifest.get("astrbot_version")) + if manifest.get("astrbot_version") is not None + else "" + ), + "installed_at": installed_at, + "runtime_kind": "sdk", + "source_kind": "local_dir", + "managed_by": "sdk_bridge", + "state": record.state, + "trigger_summary": [item["cmd"] for item in handlers], + "unsupported_features": list(record.unsupported_features), + "failure_reason": record.failure_reason, + "issues": [dict(item) for item in record.issues], + } + + def _failed_issue_to_dashboard_item( + self, + plugin_id: str, + issues: list[dict[str, Any]], + ) -> dict[str, Any]: + issue = issues[0] if issues else {} + failure_reason = str(issue.get("details") or issue.get("message") or "") + return { + "name": plugin_id, + "repo": "", + "author": "", + "desc": str(issue.get("message", "")), + "version": "0.0.0", + "reserved": False, + "activated": False, + "online_vesion": "", + "handlers": [], + "display_name": plugin_id, + "logo": None, + "support_platforms": [], + "astrbot_version": "", + "installed_at": None, + "runtime_kind": "sdk", + "source_kind": "local_dir", + "managed_by": "sdk_bridge", + "state": SDK_STATE_FAILED, + "trigger_summary": [], + "unsupported_features": [], + "failure_reason": failure_reason, + "issues": [dict(item) for item in issues], + } + + def _handler_to_dashboard_item(self, handler: SdkHandlerRef) -> dict[str, Any]: + trigger = handler.descriptor.trigger + description = self._descriptor_description(handler.descriptor) + if not description and isinstance(trigger, CommandTrigger): + description = f"Command: {trigger.command}" + if not description: + description = "无描述" + if isinstance(trigger, CommandTrigger): + event_type = "SDKCommandEvent" + event_type_h = "SDK 指令触发" + elif isinstance(trigger, MessageTrigger): + event_type = "SDKMessageEvent" + event_type_h = "SDK 消息触发" + elif isinstance(trigger, EventTrigger): + event_type = "SDKEventTrigger" + event_type_h = "SDK 事件触发" + elif isinstance(trigger, ScheduleTrigger): + event_type = "SDKScheduleEvent" + event_type_h = "SDK 定时触发" + else: + event_type = "SDKHandler" + event_type_h = "SDK 行为触发" + + base = { + "event_type": event_type, + "event_type_h": event_type_h, + "handler_full_name": handler.handler_id, + "desc": description, + "handler_name": handler.handler_name, + "has_admin": handler.descriptor.permissions.require_admin, + } + if isinstance(trigger, CommandTrigger): + return {**base, "type": "指令", "cmd": trigger.command} + if isinstance(trigger, MessageTrigger): + if trigger.regex: + return {**base, "type": "正则匹配", "cmd": trigger.regex} + if trigger.keywords: + return {**base, "type": "关键词", "cmd": ", ".join(trigger.keywords)} + return {**base, "type": "消息", "cmd": "任意消息"} + if isinstance(trigger, EventTrigger): + return {**base, "type": "事件", "cmd": trigger.event_type} + if isinstance(trigger, ScheduleTrigger): + return { + **base, + "type": "定时", + "cmd": trigger.cron or str(trigger.interval_seconds), + } + return {**base, "type": "未知", "cmd": "未知"} + + def _load_state_overrides(self) -> dict[str, dict[str, Any]]: + if not self.state_path.exists(): + return {} + try: + data = json.loads(self.state_path.read_text(encoding="utf-8")) + except Exception: + return {} + plugins = data.get("plugins") + return dict(plugins) if isinstance(plugins, dict) else {} + + def _persist_state_overrides(self) -> None: + self.state_path.write_text( + json.dumps( + {"plugins": self._state_overrides}, ensure_ascii=False, indent=2 + ), + encoding="utf-8", + ) + + def _set_disabled_override(self, plugin_id: str, *, disabled: bool) -> None: + plugin_state = dict(self._state_overrides.get(plugin_id, {})) + if disabled: + plugin_state["disabled"] = True + self._state_overrides[plugin_id] = plugin_state + else: + plugin_state.pop("disabled", None) + if plugin_state: + self._state_overrides[plugin_id] = plugin_state + else: + self._state_overrides.pop(plugin_id, None) + self._persist_state_overrides() + + def _discover_plugins(self): + return discover_plugins(self.plugins_dir) + + @staticmethod + def _make_mcp_client() -> MCPClient: + return MCPClient() + + @staticmethod + def _make_skill_manager() -> SkillManager: + return SkillManager() + + @staticmethod + def _get_dashboard_config(): + return astrbot_config + + @staticmethod + def _normalize_http_route(route: str) -> str: + route_text = str(route).strip() + if not route_text: + raise AstrBotError.invalid_input("http route must not be empty") + if not route_text.startswith("/"): + route_text = f"/{route_text}" + return route_text + + @staticmethod + def _normalize_http_methods(methods: list[str]) -> tuple[str, ...]: + normalized = tuple( + sorted({str(method).upper() for method in methods if method}) + ) + if not normalized: + raise AstrBotError.invalid_input("http methods must not be empty") + return normalized + + def _ensure_http_route_available( + self, + *, + plugin_id: str, + route: str, + methods: tuple[str, ...], + ) -> None: + for legacy_route, _view_handler, legacy_methods, _desc in getattr( + self.star_context, "registered_web_apis", [] + ): + if route != legacy_route: + continue + if set(methods) & {str(method).upper() for method in legacy_methods}: + raise AstrBotError.invalid_input( + f"HTTP route conflict with legacy plugin route: {route}" + ) + for owner, entries in self._http_routes.items(): + for entry in entries: + if ( + owner == plugin_id + and entry.route == route + and entry.methods == methods + ): + continue + if entry.route != route: + continue + if set(entry.methods) & set(methods): + raise AstrBotError.invalid_input( + f"HTTP route conflict with SDK plugin route: {route}" + ) + + def _resolve_http_route( + self, + route: str, + method: str, + ) -> tuple[SdkPluginRecord, SdkHttpRoute] | None: + normalized_route = self._normalize_http_route(route) + normalized_method = str(method).upper() + for record in sorted(self._records.values(), key=lambda item: item.load_order): + for entry in self._http_routes.get(record.plugin_id, []): + if ( + entry.route == normalized_route + and normalized_method in entry.methods + ): + return record, entry + return None + + def _match_waiter_plugins(self, session_key: str) -> list[SdkPluginRecord]: + matches: list[SdkPluginRecord] = [] + for record in sorted(self._records.values(), key=lambda item: item.load_order): + if session_key in self._session_waiters.get(record.plugin_id, set()): + matches.append(record) + return matches + + async def _dispatch_waiter_event( + self, + event: AstrMessageEvent, + records: list[SdkPluginRecord], + ) -> SdkDispatchResult: + return await self.dispatch_engine.dispatch_waiter_event(event, records) diff --git a/astrbot/core/sdk_bridge/registry_manager.py b/astrbot/core/sdk_bridge/registry_manager.py new file mode 100644 index 0000000000..e08fd0ccd9 --- /dev/null +++ b/astrbot/core/sdk_bridge/registry_manager.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +import os +import tempfile +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from astrbot_sdk._internal.plugin_ids import ( + capability_belongs_to_plugin, + http_route_belongs_to_plugin, + plugin_capability_prefix, + plugin_http_route_root, +) +from astrbot_sdk.errors import AstrBotError + +from astrbot.core import logger +from astrbot.core.skills.skill_manager import ( + _parse_frontmatter_description, +) +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from .runtime_store import ( + SdkHttpRoute, + SdkPluginRecord, + SdkRegisteredSkill, +) + +if TYPE_CHECKING: + from .plugin_bridge import SdkPluginBridge + + +class SdkRegistryManager: + def __init__(self, *, bridge: SdkPluginBridge) -> None: + self.bridge = bridge + + def list_plugins(self) -> list[dict[str, Any]]: + records = sorted( + self.bridge._records.values(), key=lambda item: item.load_order + ) + items = [self.bridge._record_to_dashboard_item(record) for record in records] + for plugin_id, issues in sorted(self.bridge._discovery_issues.items()): + if plugin_id in self.bridge._records: + continue + items.append(self.bridge._failed_issue_to_dashboard_item(plugin_id, issues)) + return items + + def get_plugin_metadata(self, plugin_id: str) -> dict[str, Any] | None: + record = self.bridge._records.get(plugin_id) + if record is not None: + manifest = record.plugin.manifest_data + support_platforms = manifest.get("support_platforms") + return { + "name": plugin_id, + "display_name": str(manifest.get("display_name") or plugin_id), + "description": str( + manifest.get("desc") or manifest.get("description") or "" + ), + "repo": str(manifest.get("repo") or ""), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": record.state not in {"disabled", "failed"}, + "support_platforms": [ + str(item) for item in support_platforms if isinstance(item, str) + ] + if isinstance(support_platforms, list) + else [], + "astrbot_version": ( + str(manifest.get("astrbot_version")) + if manifest.get("astrbot_version") is not None + else None + ), + "runtime_kind": "sdk", + "issues": [dict(item) for item in record.issues], + } + for plugin in self.bridge.star_context.get_all_stars(): + if plugin.name == plugin_id: + return { + "name": plugin.name, + "display_name": plugin.display_name, + "description": plugin.desc, + "repo": plugin.repo, + "author": plugin.author, + "version": plugin.version, + "enabled": plugin.activated, + "support_platforms": list(plugin.support_platforms), + "astrbot_version": plugin.astrbot_version, + "runtime_kind": "legacy", + } + if plugin_id in self.bridge._discovery_issues: + issue = self.bridge._discovery_issues[plugin_id][0] + return { + "name": plugin_id, + "display_name": plugin_id, + "description": str(issue.get("message", "")), + "repo": "", + "author": "", + "version": "0.0.0", + "enabled": False, + "support_platforms": [], + "astrbot_version": None, + "runtime_kind": "sdk", + "issues": [ + dict(item) for item in self.bridge._discovery_issues[plugin_id] + ], + } + return None + + def list_plugin_metadata(self) -> list[dict[str, Any]]: + metadata = [] + for plugin in self.bridge.star_context.get_all_stars(): + metadata.append( + { + "name": plugin.name, + "display_name": plugin.display_name, + "description": plugin.desc, + "repo": plugin.repo, + "author": plugin.author, + "version": plugin.version, + "enabled": plugin.activated, + "support_platforms": list(plugin.support_platforms), + "astrbot_version": plugin.astrbot_version, + "runtime_kind": "legacy", + } + ) + for plugin_id in sorted(self.bridge._records.keys()): + plugin_metadata = self.get_plugin_metadata(plugin_id) + if plugin_metadata is not None: + metadata.append(plugin_metadata) + for plugin_id in sorted(self.bridge._discovery_issues.keys()): + if plugin_id in self.bridge._records: + continue + plugin_metadata = self.get_plugin_metadata(plugin_id) + if plugin_metadata is not None: + metadata.append(plugin_metadata) + return metadata + + def register_skill( + self, + *, + plugin_id: str, + name: str, + path: str, + description: str = "", + ) -> dict[str, str]: + record = self.bridge._records.get(plugin_id) + if record is None: + raise AstrBotError.invalid_input(f"Unknown SDK plugin: {plugin_id}") + + skill_name = str(name).strip() + if not skill_name or not self.bridge.SDK_SKILL_NAME_RE.fullmatch(skill_name): + raise AstrBotError.invalid_input( + "skill.register requires a name matching [A-Za-z0-9._-]+" + ) + + path_text = str(path).strip() + if not path_text: + raise AstrBotError.invalid_input("skill.register requires path") + + plugin_root = record.plugin.plugin_dir.resolve() + requested_path = Path(path_text) + resolved_path = ( + requested_path.resolve() + if requested_path.is_absolute() + else (plugin_root / requested_path).resolve() + ) + + skill_dir = resolved_path if resolved_path.is_dir() else resolved_path.parent + skill_md_path = ( + resolved_path / "SKILL.md" if resolved_path.is_dir() else resolved_path + ) + if skill_md_path.name != "SKILL.md" or not skill_md_path.is_file(): + raise AstrBotError.invalid_input( + "skill.register path must point to a skill directory containing SKILL.md or to SKILL.md itself" + ) + if not skill_dir.is_dir(): + raise AstrBotError.invalid_input( + "skill.register resolved skill_dir is not a directory" + ) + if not skill_md_path.is_relative_to(plugin_root): + raise AstrBotError.invalid_input( + "skill.register path must stay inside the plugin directory" + ) + + normalized_description = str(description).strip() + if not normalized_description: + try: + normalized_description = _parse_frontmatter_description( + skill_md_path.read_text(encoding="utf-8") + ) + except Exception: + normalized_description = "" + + record.skills[skill_name] = SdkRegisteredSkill( + name=skill_name, + description=normalized_description, + skill_dir=skill_dir, + skill_md_path=skill_md_path, + ) + self.bridge._publish_plugin_skills(plugin_id) + return { + "name": skill_name, + "description": normalized_description, + "path": str(skill_md_path), + "skill_dir": str(skill_dir), + } + + def unregister_skill(self, *, plugin_id: str, name: str) -> bool: + record = self.bridge._records.get(plugin_id) + if record is None: + raise AstrBotError.invalid_input(f"Unknown SDK plugin: {plugin_id}") + removed = record.skills.pop(str(name).strip(), None) is not None + if removed: + self.bridge._publish_plugin_skills(plugin_id) + return removed + + def list_registered_skills(self, plugin_id: str) -> list[dict[str, str]]: + record = self.bridge._records.get(plugin_id) + if record is None: + return [] + return [ + record.skills[name].to_registry_payload() + for name in sorted(record.skills.keys()) + ] + + def publish_plugin_skills_impl(self, plugin_id: str) -> None: + record = self.bridge._records.get(plugin_id) + manager = self.bridge._make_skill_manager() + if record is None or not record.skills: + manager.remove_sdk_plugin_skills(plugin_id) + return + manager.replace_sdk_plugin_skills( + plugin_id, + [skill.to_registry_payload() for skill in record.skills.values()], + ) + + async def clear_plugin_skills( + self, + *, + plugin_id: str, + record: SdkPluginRecord | Any | None, + reason: str, + ) -> None: + if record is None or not getattr(record, "skills", None): + return + record.skills.clear() + self.bridge._publish_plugin_skills(plugin_id) + try: + from astrbot.core.computer.computer_client import ( + sync_skills_to_active_sandboxes, + ) + + await sync_skills_to_active_sandboxes() + except Exception as exc: + logger.warning( + "Failed to sync skills after SDK plugin %s %s: %s", + plugin_id, + reason, + exc, + ) + + def register_http_api( + self, + *, + plugin_id: str, + route: str, + methods: list[str], + handler_capability: str, + description: str, + ) -> None: + normalized_route = self.bridge._normalize_http_route(route) + normalized_methods = self.bridge._normalize_http_methods(methods) + if not handler_capability: + raise AstrBotError.invalid_input( + "http.register_api requires handler_capability" + ) + self._validate_http_route_namespace(normalized_route, plugin_id) + self._validate_http_handler_namespace(handler_capability, plugin_id) + self.bridge._ensure_http_route_available( + plugin_id=plugin_id, + route=normalized_route, + methods=normalized_methods, + ) + route_entry = SdkHttpRoute( + plugin_id=plugin_id, + route=normalized_route, + methods=normalized_methods, + handler_capability=handler_capability, + description=description, + ) + plugin_routes = [ + entry + for entry in self.bridge._http_routes.get(plugin_id, []) + if not ( + entry.route == normalized_route and entry.methods == normalized_methods + ) + ] + plugin_routes.append(route_entry) + self.bridge._http_routes[plugin_id] = plugin_routes + logger.info( + "SDK HTTP route registered: plugin=%s route=%s methods=%s handler=%s", + plugin_id, + route_entry.route, + ",".join(route_entry.methods), + handler_capability, + ) + + @staticmethod + def _validate_http_route_namespace(route: str, plugin_id: str) -> None: + if http_route_belongs_to_plugin(route, plugin_id): + return + route_root = plugin_http_route_root(plugin_id) + raise AstrBotError.invalid_input( + "http.register_api requires route to use the current plugin namespace: " + f"route={route!r}, plugin_id={plugin_id!r}, expected={route_root!r} " + f"or {route_root + '/...'}" + ) + + @staticmethod + def _validate_http_handler_namespace( + handler_capability: str, + plugin_id: str, + ) -> None: + if capability_belongs_to_plugin(handler_capability, plugin_id): + return + expected_prefix = plugin_capability_prefix(plugin_id) + raise AstrBotError.invalid_input( + "http.register_api requires handler_capability to belong to the current " + "plugin: " + f"capability={handler_capability!r}, plugin_id={plugin_id!r}, " + f"expected_prefix={expected_prefix!r}" + ) + + def unregister_http_api( + self, + *, + plugin_id: str, + route: str, + methods: list[str], + ) -> None: + normalized_route = self.bridge._normalize_http_route(route) + normalized_methods = {method.upper() for method in methods if method} + updated: list[SdkHttpRoute] = [] + for entry in self.bridge._http_routes.get(plugin_id, []): + if entry.route != normalized_route: + updated.append(entry) + continue + if not normalized_methods: + continue + remaining = tuple( + method for method in entry.methods if method not in normalized_methods + ) + if remaining: + updated.append( + SdkHttpRoute( + plugin_id=entry.plugin_id, + route=entry.route, + methods=remaining, + handler_capability=entry.handler_capability, + description=entry.description, + ) + ) + if updated: + self.bridge._http_routes[plugin_id] = updated + else: + self.bridge._http_routes.pop(plugin_id, None) + + def list_http_apis(self, plugin_id: str) -> list[dict[str, Any]]: + return [ + { + "route": entry.route, + "methods": list(entry.methods), + "handler_capability": entry.handler_capability, + "description": entry.description, + } + for entry in self.bridge._http_routes.get(plugin_id, []) + ] + + def dashboard_public_base_url(self) -> str: + dashboard_config_source = self.bridge._get_dashboard_config() + dashboard_config = dashboard_config_source.get("dashboard", {}) + if not isinstance(dashboard_config, dict): + dashboard_config = {} + ssl_config = dashboard_config.get("ssl", {}) + if not isinstance(ssl_config, dict): + ssl_config = {} + + port = ( + os.environ.get("DASHBOARD_PORT") + or os.environ.get("ASTRBOT_DASHBOARD_PORT") + or dashboard_config.get("port", 6185) + ) + host = ( + os.environ.get("DASHBOARD_HOST") + or os.environ.get("ASTRBOT_DASHBOARD_HOST") + or dashboard_config.get("host", "0.0.0.0") + ) + ssl_enabled = self.bridge._parse_env_bool( + os.environ.get("DASHBOARD_SSL_ENABLE") + or os.environ.get("ASTRBOT_DASHBOARD_SSL_ENABLE"), + bool(ssl_config.get("enable", False)), + ) + scheme = "https" if ssl_enabled else "http" + host_text = str(host).strip() or "localhost" + if host_text in {"0.0.0.0", "::", "[::]"}: + host_text = "localhost" + if ":" in host_text and not host_text.startswith("["): + host_text = f"[{host_text}]" + return f"{scheme}://{host_text}:{int(port)}" + + async def dispatch_http_request( + self, + route: str, + method: str, + ) -> dict[str, Any] | None: + resolved = self.bridge._resolve_http_route(route, method) + if resolved is None: + return None + record, route_entry = resolved + if record.session is None: + raise AstrBotError.invalid_input("SDK HTTP route worker is unavailable") + from quart import request as quart_request + + text_body = await quart_request.get_data(as_text=True) + form_payload = (await quart_request.form).to_dict(flat=False) + upload_dir = Path(get_astrbot_data_path()) / "temp" / "sdk_http_uploads" + upload_dir.mkdir(parents=True, exist_ok=True) + file_payloads: list[dict[str, Any]] = [] + request_files = await quart_request.files + for field_name in request_files: + for storage in request_files.getlist(field_name): + original_name = str(storage.filename or "").strip() + suffix = Path(original_name).suffix + temp_file = tempfile.NamedTemporaryFile( + delete=False, + dir=upload_dir, + suffix=suffix, + ) + temp_path = Path(temp_file.name) + temp_file.close() + storage.save(temp_path) + file_payloads.append( + { + "field_name": str(field_name), + "filename": original_name, + "content_type": str(storage.content_type or ""), + "path": str(temp_path), + "size": temp_path.stat().st_size, + } + ) + payload = { + "method": method.upper(), + "route": route_entry.route, + "path": quart_request.path, + "query": quart_request.args.to_dict(flat=False), + "headers": dict(quart_request.headers), + "form": form_payload, + "files": file_payloads, + "json_body": await quart_request.get_json(silent=True), + "text_body": text_body, + } + output = await record.session.invoke_capability( + route_entry.handler_capability, + payload, + request_id=f"sdk_http_{record.plugin_id}_{uuid.uuid4().hex}", + ) + if not isinstance(output, dict): + raise AstrBotError.invalid_input("SDK HTTP handler must return an object") + return output diff --git a/astrbot/core/sdk_bridge/request_runtime.py b/astrbot/core/sdk_bridge/request_runtime.py new file mode 100644 index 0000000000..d020e460d7 --- /dev/null +++ b/astrbot/core/sdk_bridge/request_runtime.py @@ -0,0 +1,897 @@ +from __future__ import annotations + +import asyncio +import copy +import json +import uuid +from typing import TYPE_CHECKING, Any + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.message.components import component_to_payload_sync + +from astrbot.core import logger +from astrbot.core.message.message_event_result import MessageChain, MessageEventResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import LLMResponse as CoreLLMResponse +from astrbot.core.provider.entities import ProviderRequest as CoreProviderRequest + +from .bridge_base import _build_message_chain_from_payload +from .event_payload import ( + InboundEventSnapshot, + build_inbound_event_snapshot, + normalize_sdk_local_extras, + sanitize_sdk_extras, +) +from .runtime_store import ( + SdkRuntimeStore, + _RequestContext, + _RequestOverlayState, +) + +if TYPE_CHECKING: + from .plugin_bridge import SdkPluginBridge + + +class _EventResultBinding: + def __init__(self, *, runtime: SdkRequestRuntime, dispatch_token: str) -> None: + self.runtime = runtime + self.dispatch_token = dispatch_token + + def is_active(self) -> bool: + return ( + self.runtime.get_request_overlay_by_token(self.dispatch_token) is not None + ) + + def has_result_state(self) -> bool: + overlay = self.runtime.get_request_overlay_by_token(self.dispatch_token) + return bool(overlay is not None and overlay.result_is_set) + + def get_result(self) -> MessageEventResult | None: + return self.runtime.get_effective_result_for_token(self.dispatch_token) + + def set_result(self, result: MessageEventResult) -> None: + self.runtime.set_result_for_dispatch_token(self.dispatch_token, result) + + def clear_result(self) -> None: + self.runtime.clear_result_for_dispatch_token(self.dispatch_token) + + def stop_event(self) -> None: + self.runtime.stop_event_for_dispatch_token(self.dispatch_token) + + def continue_event(self) -> None: + self.runtime.continue_event_for_dispatch_token(self.dispatch_token) + + def is_stopped(self) -> bool: + return self.runtime.is_stopped_for_dispatch_token(self.dispatch_token) + + +class SdkRequestRuntime: + def __init__( + self, + *, + bridge: SdkPluginBridge, + store: SdkRuntimeStore, + overlay_timeout_seconds: int, + ) -> None: + self.bridge = bridge + self.store = store + self.overlay_timeout_seconds = overlay_timeout_seconds + + def get_or_bind_dispatch_token(self, event: AstrMessageEvent) -> str: + dispatch_token = self.get_dispatch_token(event) or uuid.uuid4().hex + self.bind_dispatch_token(event, dispatch_token) + return dispatch_token + + def bind_dispatch_token(self, event: AstrMessageEvent, dispatch_token: str) -> None: + setattr(event, "_sdk_dispatch_token", dispatch_token) + setattr( + event, + "_sdk_result_binding", + _EventResultBinding(runtime=self, dispatch_token=dispatch_token), + ) + + def get_dispatch_token(self, event: AstrMessageEvent) -> str | None: + token = getattr(event, "_sdk_dispatch_token", None) + return str(token) if token else None + + def schedule_overlay_cleanup( + self, dispatch_token: str + ) -> asyncio.Task[None] | None: + async def _cleanup_later() -> None: + try: + await asyncio.sleep(self.overlay_timeout_seconds) + except asyncio.CancelledError: + return + self.close_request_overlay(dispatch_token) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return None + return loop.create_task(_cleanup_later()) + + def ensure_request_overlay( + self, + dispatch_token: str, + *, + should_call_llm: bool, + ) -> _RequestOverlayState: + # 整个方法加锁,防止并发调度为同一 token 创建多个 overlay + with self.store.mutation_lock: + overlay = self.store.request_overlays.get(dispatch_token) + if overlay is not None: + if overlay.closed: + overlay.closed = False + if overlay.cleanup_task is None or overlay.cleanup_task.done(): + overlay.cleanup_task = self.schedule_overlay_cleanup(dispatch_token) + return overlay + overlay = _RequestOverlayState( + dispatch_token=dispatch_token, + should_call_llm=should_call_llm, + cleanup_task=self.schedule_overlay_cleanup(dispatch_token), + ) + self.store.request_overlays[dispatch_token] = overlay + return overlay + + def track_request_scope( + self, + *, + dispatch_token: str, + request_id: str, + plugin_id: str, + ) -> None: + with self.store.mutation_lock: + self.store.request_id_to_token[request_id] = dispatch_token + self.store.request_plugin_ids[request_id] = plugin_id + overlay = self.store.request_overlays.get(dispatch_token) + if overlay is not None: + overlay.request_scope_ids.add(request_id) + + def close_request_overlay(self, dispatch_token: str) -> None: + # 第一阶段(加锁):从 store 中原子性地移除 overlay 和 context, + # 确保其他线程在锁释放后无法再读到已关闭的状态 + with self.store.mutation_lock: + request_context = self.store.request_contexts.get(dispatch_token) + dispatch_state = ( + getattr(request_context, "dispatch_state", None) + if request_context is not None + else None + ) + bound_event = None + # 在锁内快照结果和 LLM 状态,锁外再写回 event,避免长耗时操作阻塞其他请求 + persisted_result: MessageEventResult | None = None + default_llm_allowed: bool | None = None + if dispatch_state is not None: + bound_event = dispatch_state.event + persisted_result = self.get_effective_result_for_token(dispatch_token) + default_llm_allowed = self.get_effective_should_call_llm(bound_event) + + overlay = self.store.request_overlays.pop(dispatch_token, None) + if overlay is not None: + overlay.closed = True + if overlay.cleanup_task is not None: + overlay.cleanup_task.cancel() + for request_id in overlay.request_scope_ids: + self.store.request_id_to_token.pop(request_id, None) + self.store.request_plugin_ids.pop(request_id, None) + request_context = self.store.request_contexts.pop(dispatch_token, None) + if request_context is not None: + request_context.cancelled = True + + # 第二阶段(无锁):将快照的结果状态写回原始 event 对象。 + # event 本身不属于 store 共享状态,这里通过鸭子类型适配新老 API, + # 保证即使 AstrMessageEvent 接口变更也不会崩溃 + if bound_event is not None: + if hasattr(bound_event, "_sdk_result_binding"): + delattr(bound_event, "_sdk_result_binding") + if persisted_result is None: + clear_result = getattr(bound_event, "clear_result", None) + if callable(clear_result): + clear_result() + else: + setattr(bound_event, "_result", None) + else: + set_result = getattr(bound_event, "set_result", None) + if callable(set_result): + set_result(persisted_result) + else: + setattr(bound_event, "_result", persisted_result) + if default_llm_allowed is not None: + self._set_event_default_llm_blocked( + bound_event, + blocked=not default_llm_allowed, + ) + + def close_request_overlay_for_event(self, event: AstrMessageEvent) -> None: + dispatch_token = self.get_dispatch_token(event) + if dispatch_token: + self.close_request_overlay(dispatch_token) + + def resolve_request_plugin_id(self, request_id: str) -> str: + with self.store.mutation_lock: + plugin_id = self.store.request_plugin_ids.get(request_id) + if plugin_id is not None: + return plugin_id + token = self.store.request_id_to_token.get(request_id) + if token is not None and token in self.store.request_contexts: + return self.store.request_contexts[token].plugin_id + raise AstrBotError.invalid_input(f"Unknown SDK request id: {request_id}") + + def resolve_request_session(self, request_id: str) -> _RequestContext | None: + with self.store.mutation_lock: + token = self.store.request_id_to_token.get(request_id) + if token is None: + return None + return self.store.request_contexts.get(token) + + def get_request_context_by_token( + self, dispatch_token: str + ) -> _RequestContext | None: + with self.store.mutation_lock: + return self.store.request_contexts.get(dispatch_token) + + def get_request_overlay_by_token( + self, dispatch_token: str + ) -> _RequestOverlayState | None: + with self.store.mutation_lock: + overlay = self.store.request_overlays.get(dispatch_token) + if overlay is None or overlay.closed: + return None + return overlay + + def get_request_overlay_by_request_id( + self, request_id: str + ) -> _RequestOverlayState | None: + token = self.store.request_id_to_token.get(request_id) + if not token: + return None + return self.get_request_overlay_by_token(token) + + def request_llm_for_request(self, request_id: str) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + overlay.requested_llm = True + if not overlay.result_stopped: + overlay.should_call_llm = True + return True + + def get_effective_should_call_llm(self, event: AstrMessageEvent) -> bool: + dispatch_token = self.get_dispatch_token(event) + if dispatch_token: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is not None: + return overlay.should_call_llm + return self._event_should_call_default_llm(event) + + def get_should_call_llm_for_request(self, request_id: str) -> bool | None: + # 读操作也加锁,确保与 close_request_overlay 的写操作互斥 + with self.store.mutation_lock: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return None + return overlay.should_call_llm + + @staticmethod + def set_overlay_stop_state( + overlay: _RequestOverlayState, + *, + stopped: bool, + ) -> None: + overlay.result_stopped = stopped + if stopped: + overlay.should_call_llm = False + + def set_result_from_object( + self, + overlay: _RequestOverlayState, + result: MessageEventResult | None, + ) -> None: + overlay.result_object = result + overlay.result_is_set = True + self.set_overlay_stop_state( + overlay, + stopped=bool(result is not None and result.is_stopped()), + ) + self.sync_overlay_payload_from_result_object(overlay) + + def bind_result_object( + self, + overlay: _RequestOverlayState, + result: MessageEventResult | None, + ) -> None: + overlay.result_object = result + overlay.result_is_set = True + self.set_overlay_stop_state( + overlay, + stopped=bool(result is not None and result.is_stopped()), + ) + + def set_result_payload_on_overlay( + self, + overlay: _RequestOverlayState, + result_payload: dict[str, Any] | None, + ) -> None: + if result_payload is None: + overlay.result_payload = None + overlay.result_object = None + overlay.result_is_set = True + self.set_overlay_stop_state(overlay, stopped=False) + return + normalized_payload = json.loads(json.dumps(result_payload)) + overlay.result_payload = normalized_payload + chain_payload = normalized_payload.get("chain") + overlay.result_object = ( + self.build_core_result_from_chain_payload(chain_payload) + if isinstance(chain_payload, list) + else None + ) + overlay.result_is_set = True + self.set_overlay_stop_state( + overlay, + stopped=bool(normalized_payload.get("stop", False)), + ) + + def sync_overlay_payload_from_result_object( + self, + overlay: _RequestOverlayState, + ) -> None: + overlay.result_payload = self.bridge._legacy_result_to_sdk_payload( + overlay.result_object + ) + self.set_overlay_stop_state( + overlay, + stopped=bool( + overlay.result_object is not None and overlay.result_object.is_stopped() + ), + ) + + def get_effective_result_for_token( + self, + dispatch_token: str, + ) -> MessageEventResult | None: + # 整个读取 + 延迟构建过程放在锁内,避免 overlay 在读取过程中被另一个线程关闭 + with self.store.mutation_lock: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None or not overlay.result_is_set: + # 没有显式设置结果时,从原始 event 的 get_result() 取, + # 兼容老插件直接操作 event._result 的路径 + request_context = self.store.request_contexts.get(dispatch_token) + if ( + request_context is not None + and request_context.dispatch_state is not None + ): + return request_context.dispatch_state.event.get_result() + return None + # 延迟反序列化:只在首次访问时从 payload 构建结果对象 + if overlay.result_object is None and overlay.result_payload is not None: + chain_payload = overlay.result_payload.get("chain") + if isinstance(chain_payload, list): + overlay.result_object = self.build_core_result_from_chain_payload( + chain_payload + ) + if overlay.result_object is None: + if overlay.result_stopped: + stopped_result = MessageEventResult() + stopped_result.stop_event() + overlay.result_object = stopped_result + else: + return None + if overlay.result_stopped and not overlay.result_object.is_stopped(): + overlay.result_object.stop_event() + elif not overlay.result_stopped and overlay.result_object.is_stopped(): + overlay.result_object.continue_event() + return overlay.result_object + + def set_result_for_dispatch_token( + self, + dispatch_token: str, + result: MessageEventResult | None, + ) -> None: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is not None: + self.set_result_from_object(overlay, result) + + def clear_result_for_dispatch_token(self, dispatch_token: str) -> None: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return + overlay.result_payload = None + overlay.result_object = None + overlay.result_is_set = True + self.set_overlay_stop_state(overlay, stopped=False) + + def stop_event_for_dispatch_token(self, dispatch_token: str) -> None: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return + self.set_overlay_stop_state(overlay, stopped=True) + overlay.result_is_set = True + if overlay.result_object is not None and not overlay.result_object.is_stopped(): + overlay.result_object.stop_event() + + def continue_event_for_dispatch_token(self, dispatch_token: str) -> None: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return + overlay.result_is_set = True + self.set_overlay_stop_state(overlay, stopped=False) + if overlay.result_object is not None and overlay.result_object.is_stopped(): + overlay.result_object.continue_event() + + def is_stopped_for_dispatch_token(self, dispatch_token: str) -> bool: + with self.store.mutation_lock: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is not None and overlay.result_is_set: + return overlay.result_stopped + # 回退到 event 的原始结果,使用 get_result() 而非直接访问 _result, + # 以兼容 SDK result binding 机制 + request_context = self.store.request_contexts.get(dispatch_token) + if ( + request_context is not None + and request_context.dispatch_state is not None + ): + result = request_context.dispatch_state.event.get_result() + return bool(result is not None and result.is_stopped()) + return False + + def set_result_for_request( + self, + request_id: str, + result_payload: dict[str, Any] | None, + ) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + self.set_result_payload_on_overlay(overlay, result_payload) + return True + + def clear_result_for_request(self, request_id: str) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + overlay.result_payload = None + overlay.result_object = None + overlay.result_is_set = True + self.set_overlay_stop_state(overlay, stopped=False) + return True + + def get_result_payload_for_request(self, request_id: str) -> dict[str, Any] | None: + overlay = self.get_request_overlay_by_request_id(request_id) + request_context = self.resolve_request_session(request_id) + request_context_has_event = False + if request_context is not None: + has_event = getattr(request_context, "has_event", None) + request_context_has_event = ( + bool(has_event) + if has_event is not None + else hasattr(request_context, "event") + ) + if overlay is not None and overlay.result_is_set: + if overlay.result_object is not None: + self.sync_overlay_payload_from_result_object(overlay) + return ( + copy.deepcopy(overlay.result_payload) + if overlay.result_payload is not None + else None + ) + if request_context is None or not request_context_has_event: + return None + return self.bridge._legacy_result_to_sdk_payload( + request_context.event.get_result() + ) + + def set_handler_whitelist_for_request( + self, + request_id: str, + plugin_names: set[str] | None, + ) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + overlay.handler_whitelist = None if plugin_names is None else set(plugin_names) + return True + + def get_handler_whitelist_for_request(self, request_id: str) -> set[str] | None: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return None + return ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + + def get_handler_whitelist_for_event( + self, event: AstrMessageEvent + ) -> set[str] | None: + dispatch_token = self.get_dispatch_token(event) + if not dispatch_token: + return None + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return None + return ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + + @staticmethod + def build_core_message_chain_from_payload( + chain_payload: list[dict[str, Any]], + ) -> MessageChain: + return _build_message_chain_from_payload(chain_payload) + + @classmethod + def build_core_result_from_chain_payload( + cls, + chain_payload: list[dict[str, Any]], + ) -> MessageEventResult: + chain = cls.build_core_message_chain_from_payload(chain_payload) + result = MessageEventResult() + setattr(result, "chain", chain) + result.use_t2i_ = chain.use_t2i_ + result.type = chain.type + return result + + @staticmethod + def legacy_result_to_sdk_payload( + result: MessageEventResult | None, + ) -> dict[str, Any] | None: + if result is None: + return None + chain = ( + result.chain.chain + if isinstance(result.chain, MessageChain) + else result.chain + ) + payload = { + "type": "chain" if chain else "empty", + "chain": SdkRequestRuntime.components_to_sdk_payload(chain), + } + if result.is_stopped(): + payload["stop"] = True + return payload + + @staticmethod + def components_to_sdk_payload( + components: list[Any] | tuple[Any, ...] | None, + ) -> list[dict[str, Any]]: + return [ + component_to_payload_sync(component) for component in (components or []) + ] + + def persist_sdk_local_extras_from_handler( + self, + overlay: _RequestOverlayState, + payload: Any, + *, + plugin_id: str, + handler_id: str, + ) -> None: + if payload is None: + overlay.sdk_local_extras = {} + return + if not isinstance(payload, dict): + logger.warning( + "SDK event handler returned invalid sdk_local_extras: plugin=%s handler=%s payload_type=%s", + plugin_id, + handler_id, + type(payload).__name__, + ) + return + normalized, dropped_keys = normalize_sdk_local_extras(payload) + overlay.sdk_local_extras = normalized + for key in dropped_keys: + value = payload.get(key) + logger.warning( + "Dropped sdk_local_extras entry during SDK bridge serialization: " + "plugin=%s handler=%s key=%s value_type=%s reason=%s " + "recommended_fix=%s", + plugin_id, + handler_id, + key, + type(value).__name__, + "sdk_local_extras only preserves JSON-serializable values across " + "handler and lifecycle boundaries", + "store plain dict/list/scalar payloads, or serialize framework " + "objects such as message components before calling set_extra()", + ) + + @staticmethod + def sanitize_host_extras(event: AstrMessageEvent) -> dict[str, Any]: + extras = event.get_extra() + if not isinstance(extras, dict) or not extras: + return {} + return sanitize_sdk_extras(extras) + + @staticmethod + def set_sdk_origin_plugin_id( + event: AstrMessageEvent, + plugin_id: str, + ) -> None: + setter = getattr(event, "set_extra", None) + if callable(setter): + setter("_sdk_origin_plugin_id", plugin_id) + return + setattr(event, "_sdk_origin_plugin_id", plugin_id) + + def get_or_build_inbound_snapshot( + self, + event: AstrMessageEvent, + overlay: _RequestOverlayState | None, + ) -> InboundEventSnapshot: + if overlay is not None and overlay.inbound_snapshot is not None: + return overlay.inbound_snapshot + snapshot = build_inbound_event_snapshot(event) + if overlay is not None: + overlay.inbound_snapshot = snapshot + return snapshot + + def build_sdk_event_payload( + self, + event: AstrMessageEvent, + *, + dispatch_token: str, + plugin_id: str, + request_id: str, + overlay: _RequestOverlayState | None, + raw_updates: dict[str, Any] | None = None, + field_updates: dict[str, Any] | None = None, + ) -> dict[str, Any]: + snapshot = self.get_or_build_inbound_snapshot(event, overlay) + sdk_local_extras = dict(overlay.sdk_local_extras) if overlay is not None else {} + return snapshot.to_payload( + dispatch_token=dispatch_token, + plugin_id=plugin_id, + request_id=request_id, + host_extras=self.sanitize_host_extras(event), + sdk_local_extras=sdk_local_extras, + raw_updates=raw_updates, + field_updates=field_updates, + ) + + @staticmethod + def core_provider_request_to_sdk_payload( + request: CoreProviderRequest, + ) -> dict[str, Any]: + tool_calls_result: list[dict[str, Any]] = [] + raw_results = request.tool_calls_result + if raw_results is not None: + if not isinstance(raw_results, list): + raw_results = [raw_results] + for item in raw_results: + if not getattr(item, "tool_calls_result", None): + continue + tool_name_by_id: dict[str, str] = {} + tool_calls_info = getattr(item, "tool_calls_info", None) + raw_tool_calls = getattr(tool_calls_info, "tool_calls", None) + if isinstance(raw_tool_calls, list): + for tool_call in raw_tool_calls: + if isinstance(tool_call, dict): + tool_call_id = tool_call.get("id") + function_payload = tool_call.get("function") + tool_name = ( + function_payload.get("name") + if isinstance(function_payload, dict) + else None + ) + else: + tool_call_id = getattr(tool_call, "id", None) + function_payload = getattr(tool_call, "function", None) + tool_name = getattr(function_payload, "name", None) + if tool_call_id is None or tool_name is None: + continue + tool_name_by_id[str(tool_call_id)] = str(tool_name) + for tool_result in item.tool_calls_result: + tool_call_id = getattr(tool_result, "tool_call_id", None) + content = getattr(tool_result, "content", "") + tool_calls_result.append( + { + "tool_call_id": str(tool_call_id) + if tool_call_id is not None + else None, + "tool_name": tool_name_by_id.get(str(tool_call_id), "") + if tool_call_id is not None + else "", + "content": str(content or ""), + "success": True, + } + ) + return { + "prompt": request.prompt, + "system_prompt": request.system_prompt or None, + "session_id": request.session_id or None, + "contexts": copy.deepcopy(request.contexts or []), + "image_urls": list(request.image_urls or []), + "tool_calls_result": tool_calls_result, + "model": request.model, + } + + @staticmethod + def apply_sdk_provider_request_payload( + request: CoreProviderRequest, + payload: dict[str, Any], + ) -> None: + prompt = payload.get("prompt") + request.prompt = None if prompt is None else str(prompt) + system_prompt = payload.get("system_prompt") + request.system_prompt = "" if system_prompt is None else str(system_prompt) + session_id = payload.get("session_id") + request.session_id = None if session_id is None else str(session_id) + + contexts = payload.get("contexts") + if isinstance(contexts, list): + request.contexts = copy.deepcopy(contexts) + + image_urls = payload.get("image_urls") + if isinstance(image_urls, list): + request.image_urls = [str(item) for item in image_urls] + + model = payload.get("model") + request.model = None if model is None else str(model) + + @staticmethod + def core_llm_response_to_sdk_payload( + response: CoreLLMResponse, + ) -> dict[str, Any]: + usage_payload = None + if response.usage is not None: + usage_payload = { + "input_tokens": response.usage.input, + "output_tokens": response.usage.output, + "total_tokens": response.usage.total, + "input_cached_tokens": response.usage.input_cached, + } + tool_calls: list[dict[str, Any]] = [] + for idx, tool_name in enumerate(response.tools_call_name): + tool_calls.append( + { + "id": ( + response.tools_call_ids[idx] + if idx < len(response.tools_call_ids) + else None + ), + "name": tool_name, + "arguments": ( + response.tools_call_args[idx] + if idx < len(response.tools_call_args) + else {} + ), + "extra_content": ( + response.tools_call_extra_content.get( + response.tools_call_ids[idx] + ) + if idx < len(response.tools_call_ids) + else None + ), + } + ) + return { + "text": response.completion_text or "", + "usage": usage_payload, + "finish_reason": "tool_calls" if tool_calls else "stop", + "tool_calls": tool_calls, + "role": response.role, + "reasoning_content": response.reasoning_content or None, + "reasoning_signature": response.reasoning_signature, + } + + @classmethod + def apply_sdk_result_payload( + cls, + result: MessageEventResult, + payload: dict[str, Any], + ) -> MessageEventResult: + chain_payload = payload.get("chain") + updated = ( + cls.build_core_result_from_chain_payload(chain_payload) + if isinstance(chain_payload, list) + else MessageEventResult() + ) + result.chain = updated.chain + result.use_t2i_ = updated.use_t2i_ + result.type = updated.type + if bool(payload.get("stop", False)): + result.stop_event() + else: + result.continue_event() + return result + + def get_effective_result( + self, event: AstrMessageEvent + ) -> MessageEventResult | None: + dispatch_token = self.get_dispatch_token(event) + if dispatch_token: + return self.get_effective_result_for_token(dispatch_token) + return event.get_result() + + def before_platform_send(self, dispatch_token: str) -> None: + # 发送前置校验加锁,防止 overlay 在校验过程中被并发关闭 + with self.store.mutation_lock: + request_context = self.store.request_contexts.get(dispatch_token) + if request_context is None: + raise AstrBotError.invalid_input( + "Unknown SDK dispatch token for platform send" + ) + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + raise AstrBotError.cancelled("The SDK request overlay has been closed") + if request_context.cancelled: + raise AstrBotError.cancelled("The SDK request has been cancelled") + + def mark_platform_send(self, dispatch_token: str) -> str: + with self.store.mutation_lock: + request_context = self.store.request_contexts.get(dispatch_token) + if request_context is None: + raise AstrBotError.invalid_input( + "Unknown SDK dispatch token for platform send" + ) + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + raise AstrBotError.cancelled("The SDK request overlay has been closed") + if request_context.cancelled: + raise AstrBotError.cancelled("The SDK request has been cancelled") + if request_context.dispatch_state is not None: + request_context.dispatch_state.sent_message = True + # 发送消息后默认不再调用 LLM——消息已经发出去了,LLM 调用多余 + overlay.should_call_llm = False + if request_context.has_event: + self._mark_event_send_operation(request_context.event) + return f"sdk_{dispatch_token}" + + @staticmethod + def _event_should_call_default_llm(event: AstrMessageEvent) -> bool: + """读取 event 的 LLM 调用意愿,按新 API → 兼容 API → 直接读字段的优先级适配。""" + getter = getattr(event, "should_call_default_llm", None) + if callable(getter): + return bool(getter()) + # 旧版 event 只有 call_llm 布尔字段,语义反转:True = 阻止 LLM + return not bool(getattr(event, "call_llm", False)) + + @staticmethod + def _set_event_default_llm_blocked( + event: AstrMessageEvent, + *, + blocked: bool, + ) -> None: + """将 LLM 阻塞状态写回 event,按新 API → 兼容 API → 直接写字段的优先级适配。""" + setter = getattr(event, "set_default_llm_blocked", None) + if callable(setter): + setter(blocked) + return + setter = getattr(event, "set_default_llm_allowed", None) + if callable(setter): + setter(not blocked) + return + setter = getattr(event, "disable_default_llm", None) + if callable(setter): + setter(blocked) + return + legacy = getattr(event, "should_call_llm", None) + if callable(legacy): + legacy(blocked) + return + setattr(event, "call_llm", bool(blocked)) + + @staticmethod + def _mark_event_send_operation(event: AstrMessageEvent) -> None: + """标记 event 已发送消息,按新 API → 兼容 API → 直接写字段的优先级适配。""" + setter = getattr(event, "set_send_operation_state", None) + if callable(setter): + setter(True) + return + marker = getattr(event, "mark_send_operation", None) + if callable(marker): + marker() + return + setattr(event, "_has_send_oper", True) + + @staticmethod + def event_has_send_operation(event: AstrMessageEvent) -> bool: + """读取 event 是否已发送消息,按新 API → 直接读字段的优先级适配。""" + getter = getattr(event, "has_send_operation", None) + if callable(getter): + return bool(getter()) + return bool(getattr(event, "_has_send_oper", False)) diff --git a/astrbot/core/sdk_bridge/runtime_store.py b/astrbot/core/sdk_bridge/runtime_store.py new file mode 100644 index 0000000000..7df5734131 --- /dev/null +++ b/astrbot/core/sdk_bridge/runtime_store.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import asyncio +import threading +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.llm.agents import AgentSpec +from astrbot_sdk.llm.entities import LLMToolSpec +from astrbot_sdk.protocol.descriptors import HandlerDescriptor +from astrbot_sdk.runtime.loader import PluginSpec +from astrbot_sdk.runtime.supervisor import WorkerSession + +from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.message.message_event_result import MessageEventResult + +from .event_payload import InboundEventSnapshot + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +@dataclass(slots=True) +class SdkHandlerRef: + descriptor: HandlerDescriptor + declaration_order: int + + @property + def handler_id(self) -> str: + return self.descriptor.id + + @property + def handler_name(self) -> str: + return self.descriptor.id.rsplit(".", 1)[-1] + + +@dataclass(slots=True) +class SdkDispatchResult: + matched_handlers: list[dict[str, str]] = field(default_factory=list) + executed_handlers: list[dict[str, str]] = field(default_factory=list) + sent_message: bool = False + stopped: bool = False + skipped_reason: str | None = None + + +@dataclass(slots=True) +class _DispatchState: + event: AstrMessageEvent + sent_message: bool = False + stopped: bool = False + + +@dataclass(slots=True) +class _RequestContext: + plugin_id: str + request_id: str + dispatch_token: str + dispatch_state: _DispatchState | None + cancelled: bool = False + + @property + def has_event(self) -> bool: + return self.dispatch_state is not None + + @property + def event(self) -> AstrMessageEvent: + if self.dispatch_state is None: + raise AstrBotError.invalid_input( + "The current SDK request is not bound to a message event" + ) + return self.dispatch_state.event + + +@dataclass(slots=True) +class _InFlightRequest: + request_id: str + dispatch_token: str + task: asyncio.Task[dict[str, Any]] + logical_cancelled: bool = False + + +@dataclass(slots=True) +class _LocalMCPServerRuntime: + name: str + config: dict[str, Any] + active: bool + running: bool = False + client: MCPClient | None = None + tools: list[str] = field(default_factory=list) + tool_specs: list[LLMToolSpec] = field(default_factory=list) + errlogs: list[str] = field(default_factory=list) + last_error: str | None = None + ready_event: asyncio.Event = field(default_factory=asyncio.Event) + connect_task: asyncio.Task[None] | None = None + lease_path: Path | None = None + + +@dataclass(slots=True) +class _TemporaryMCPSessionRuntime: + plugin_id: str + name: str + client: MCPClient + tools: list[str] + + +@dataclass(slots=True) +class _RequestOverlayState: + dispatch_token: str + should_call_llm: bool + requested_llm: bool = False + sdk_local_extras: dict[str, Any] = field(default_factory=dict) + inbound_snapshot: InboundEventSnapshot | None = None + result_payload: dict[str, Any] | None = None + result_object: MessageEventResult | None = None + result_is_set: bool = False + result_stopped: bool = False + handler_whitelist: set[str] | None = None + request_scope_ids: set[str] = field(default_factory=set) + closed: bool = False + cleanup_task: asyncio.Task[None] | None = None + + +@dataclass(slots=True) +class SdkRegisteredSkill: + name: str + description: str + skill_dir: Path + skill_md_path: Path + + def to_registry_payload(self) -> dict[str, str]: + return { + "name": self.name, + "description": self.description, + "path": str(self.skill_md_path), + "skill_dir": str(self.skill_dir), + } + + +@dataclass(slots=True) +class SdkDynamicCommandRoute: + command_name: str + handler_full_name: str + desc: str + priority: int + use_regex: bool + declaration_order: int + + +@dataclass(slots=True) +class SdkPluginRecord: + plugin: PluginSpec + load_order: int + state: str + unsupported_features: list[str] + config_schema: dict[str, Any] + config: dict[str, Any] + handlers: list[SdkHandlerRef] + llm_tools: dict[str, LLMToolSpec] = field(default_factory=dict) + active_llm_tools: set[str] = field(default_factory=set) + agents: dict[str, AgentSpec] = field(default_factory=dict) + skills: dict[str, SdkRegisteredSkill] = field(default_factory=dict) + dynamic_command_routes: list[SdkDynamicCommandRoute] = field(default_factory=list) + session: WorkerSession | None = None + restart_attempted: bool = False + failure_reason: str = "" + issues: list[dict[str, Any]] = field(default_factory=list) + local_mcp_servers: dict[str, _LocalMCPServerRuntime] = field(default_factory=dict) + acknowledge_global_mcp_risk: bool = False + + @property + def plugin_id(self) -> str: + return self.plugin.name + + +@dataclass(slots=True) +class SdkHttpRoute: + plugin_id: str + route: str + methods: tuple[str, ...] + handler_capability: str + description: str + + +@dataclass(slots=True) +class SdkRuntimeStore: + # 可重入锁:保护所有 request_overlays / request_contexts 等字典的并发读写。 + # 使用 RLock 而非 Lock 是因为同一线程内可能嵌套调用(如 close_request_overlay + # 内部调用 get_effective_result_for_token),RLock 允许同线程重入不死锁。 + mutation_lock: threading.RLock = field(default_factory=threading.RLock) + records: dict[str, SdkPluginRecord] = field(default_factory=dict) + request_contexts: dict[str, _RequestContext] = field(default_factory=dict) + request_id_to_token: dict[str, str] = field(default_factory=dict) + request_plugin_ids: dict[str, str] = field(default_factory=dict) + request_overlays: dict[str, _RequestOverlayState] = field(default_factory=dict) + plugin_requests: dict[str, dict[str, _InFlightRequest]] = field( + default_factory=dict + ) + http_routes: dict[str, list[SdkHttpRoute]] = field(default_factory=dict) + session_waiters: dict[str, set[str]] = field(default_factory=dict) + schedule_job_ids: dict[str, set[str]] = field(default_factory=dict) + discovery_issues: dict[str, list[dict[str, Any]]] = field(default_factory=dict) + temporary_mcp_sessions: dict[str, _TemporaryMCPSessionRuntime] = field( + default_factory=dict + ) + + def snapshot_records(self) -> list[SdkPluginRecord]: + with self.mutation_lock: + return list(self.records.values()) + + def snapshot_records_sorted(self) -> list[SdkPluginRecord]: + with self.mutation_lock: + return sorted(self.records.values(), key=lambda item: item.load_order) + + def snapshot_http_routes(self, plugin_id: str | None = None) -> list[SdkHttpRoute]: + with self.mutation_lock: + if plugin_id is None: + routes: list[SdkHttpRoute] = [] + for entries in self.http_routes.values(): + routes.extend(list(entries)) + return routes + return list(self.http_routes.get(plugin_id, [])) diff --git a/astrbot/core/sdk_bridge/trigger_converter.py b/astrbot/core/sdk_bridge/trigger_converter.py new file mode 100644 index 0000000000..eca9dc2581 --- /dev/null +++ b/astrbot/core/sdk_bridge/trigger_converter.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import inspect +import re +import shlex +import typing +from dataclasses import dataclass +from typing import Any, get_type_hints + +from astrbot_sdk._message_types import normalize_message_type +from astrbot_sdk.events import MessageEvent as SdkMessageEvent +from astrbot_sdk.protocol.descriptors import ( + CommandTrigger, + CompositeFilterSpec, + HandlerDescriptor, + LocalFilterRefSpec, + MessageTrigger, + MessageTypeFilterSpec, + ParamSpec, + PlatformFilterSpec, +) +from astrbot_sdk.runtime._command_matching import match_command_name + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +@dataclass(slots=True) +class TriggerMatch: + plugin_id: str + handler_id: str + args: dict[str, Any] + priority: int + load_order: int + declaration_order: int + matched_command_name: str | None = None + + +class TriggerConverter: + @staticmethod + def _message_type_name(event: AstrMessageEvent) -> str: + return normalize_message_type( + event.get_message_type(), + group_id=event.get_group_id() or None, + user_id=event.get_sender_id() or None, + empty_default="other", + ) + + @staticmethod + def _match_command_name(text: str, command_name: str) -> str | None: + return match_command_name(text, command_name) + + @staticmethod + def _split_command_remainder(remainder: str) -> list[str]: + try: + return shlex.split(remainder) + except ValueError: + return remainder.split() + + @classmethod + def _build_command_args(cls, handler, remainder: str) -> dict[str, Any]: + param_specs = getattr(handler, "param_specs", None) + if not isinstance(param_specs, list): + names = cls._legacy_arg_parameter_names(handler) + if not names or not remainder: + return {} + if len(names) == 1: + return {names[0]: remainder} + parts = cls._split_command_remainder(remainder) + return { + name: parts[index] + for index, name in enumerate(names) + if index < len(parts) + } + if not param_specs or not remainder: + return {} + if len(param_specs) == 1: + return {param_specs[0].name: remainder} + parts = cls._split_command_remainder(remainder) + args: dict[str, Any] = {} + for index, spec in enumerate(param_specs): + if index >= len(parts): + break + if spec.type == "greedy_str": + args[spec.name] = " ".join(parts[index:]) + break + args[spec.name] = parts[index] + return args + + @classmethod + def _build_regex_args(cls, handler, match: re.Match[str]) -> dict[str, Any]: + named = { + key: value for key, value in match.groupdict().items() if value is not None + } + param_specs = getattr(handler, "param_specs", None) + if isinstance(param_specs, list): + names = [spec.name for spec in param_specs if spec.name not in named] + else: + names = [ + name + for name in cls._legacy_arg_parameter_names(handler) + if name not in named + ] + positional = [value for value in match.groups() if value is not None] + for index, value in enumerate(positional): + if index >= len(names): + break + named[names[index]] = value + return named + + @classmethod + def _build_descriptor_command_args( + cls, + param_specs: list[ParamSpec], + remainder: str, + ) -> dict[str, Any]: + if not param_specs or not remainder: + return {} + if len(param_specs) == 1: + return {param_specs[0].name: remainder} + parts = cls._split_command_remainder(remainder) + args: dict[str, Any] = {} + for index, spec in enumerate(param_specs): + if index >= len(parts): + break + if spec.type == "greedy_str": + args[spec.name] = " ".join(parts[index:]) + break + args[spec.name] = parts[index] + return args + + @classmethod + def _build_descriptor_regex_args( + cls, + param_specs: list[ParamSpec], + match: re.Match[str], + ) -> dict[str, Any]: + named = { + key: value for key, value in match.groupdict().items() if value is not None + } + names = [spec.name for spec in param_specs if spec.name not in named] + positional = [value for value in match.groups() if value is not None] + for index, value in enumerate(positional): + if index >= len(names): + break + named[names[index]] = value + return named + + @classmethod + def _match_filters( + cls, + descriptor: HandlerDescriptor, + event: AstrMessageEvent, + ) -> bool: + for filter_spec in descriptor.filters: + if not cls._match_filter_spec(filter_spec, event): + return False + return True + + @classmethod + def _match_filter_spec(cls, filter_spec, event: AstrMessageEvent) -> bool: + if isinstance(filter_spec, PlatformFilterSpec): + return event.get_platform_name() in filter_spec.platforms + if isinstance(filter_spec, MessageTypeFilterSpec): + return cls._message_type_name(event) in filter_spec.message_types + if isinstance(filter_spec, LocalFilterRefSpec): + # Local filter refs point at plugin-process callables. The host bridge + # cannot execute them, so trigger matching must stay fail-open here. + return True + if isinstance(filter_spec, CompositeFilterSpec): + results = [ + cls._match_filter_spec(child, event) for child in filter_spec.children + ] + if filter_spec.kind == "and": + return all(results) + return any(results) + return True + + @classmethod + def _legacy_arg_parameter_names(cls, handler) -> list[str]: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + type_hints = get_type_hints(handler) + except Exception: + type_hints = {} + names: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + if cls._is_injected_parameter( + parameter.name, type_hints.get(parameter.name) + ): + continue + names.append(parameter.name) + return names + + @classmethod + def _is_injected_parameter(cls, name: str, annotation: Any) -> bool: + if name in {"event", "ctx", "context"}: + return True + normalized = cls._unwrap_optional(annotation) + if normalized is None: + return False + if normalized in {AstrMessageEvent, SdkMessageEvent}: + return True + if isinstance(normalized, type) and issubclass( + normalized, + (AstrMessageEvent, SdkMessageEvent), + ): + return True + return False + + @staticmethod + def _unwrap_optional(annotation: Any) -> Any: + if annotation is None: + return None + origin = typing.get_origin(annotation) + if origin is typing.Union: + options = [ + item for item in typing.get_args(annotation) if item is not type(None) + ] + if len(options) == 1: + return options[0] + return annotation + + @classmethod + def match_handler( + cls, + *, + plugin_id: str, + handler=None, + descriptor: HandlerDescriptor, + event: AstrMessageEvent, + load_order: int, + declaration_order: int, + ) -> TriggerMatch | None: + trigger = descriptor.trigger + + required_role = descriptor.permissions.required_role + if required_role is None and descriptor.permissions.require_admin: + required_role = "admin" + if required_role == "admin" and not event.is_admin(): + return None + if not cls._match_filters(descriptor, event): + return None + + if isinstance(trigger, CommandTrigger): + text = event.get_message_str().strip() + for command_name in [trigger.command, *trigger.aliases]: + if not command_name: + continue + remainder = cls._match_command_name(text, command_name) + if remainder is None: + continue + return TriggerMatch( + plugin_id=plugin_id, + handler_id=descriptor.id, + args=( + cls._build_command_args(handler, remainder) + if handler is not None + else cls._build_descriptor_command_args( + descriptor.param_specs, + remainder, + ) + ), + priority=descriptor.priority, + load_order=load_order, + declaration_order=declaration_order, + matched_command_name=str(command_name).strip() or None, + ) + return None + + if isinstance(trigger, MessageTrigger): + text = event.get_message_str() + if trigger.regex: + match = re.search(trigger.regex, text) + if match is None: + return None + args = ( + cls._build_regex_args(handler, match) if handler is not None else {} + ) + if handler is None: + args = cls._build_descriptor_regex_args( + descriptor.param_specs, match + ) + else: + if trigger.keywords and not any( + keyword in text for keyword in trigger.keywords + ): + return None + args = {} + return TriggerMatch( + plugin_id=plugin_id, + handler_id=descriptor.id, + args=args, + priority=descriptor.priority, + load_order=load_order, + declaration_order=declaration_order, + ) + + return None + + @staticmethod + def sort_key(match: TriggerMatch) -> tuple[int, int, int]: + return (-match.priority, match.load_order, match.declaration_order) diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index a8121c42a4..686a127748 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -22,10 +22,12 @@ SKILLS_CONFIG_FILENAME = "skills.json" SANDBOX_SKILLS_CACHE_FILENAME = "sandbox_skills_cache.json" +SDK_PLUGIN_SKILLS_FILENAME = "sdk_plugin_skills.json" DEFAULT_SKILLS_CONFIG: dict[str, dict] = {"skills": {}} SANDBOX_SKILLS_ROOT = "skills" SANDBOX_WORKSPACE_ROOT = "/workspace" _SANDBOX_SKILLS_CACHE_VERSION = 1 +_SDK_PLUGIN_SKILLS_VERSION = 1 _SKILL_NAME_RE = re.compile(r"^[\w.-]+$") @@ -99,6 +101,16 @@ class SkillInfo: sandbox_exists: bool = False +@dataclass(frozen=True, slots=True) +class LocalSkillSource: + name: str + skill_dir: Path + skill_md_path: Path + owner_type: str = "standalone" + description_override: str = "" + plugin_id: str | None = None + + def _parse_frontmatter_description(text: str) -> str: """Extract the ``description`` value from YAML frontmatter. @@ -279,8 +291,221 @@ def __init__(self, skills_root: str | None = None) -> None: data_path = Path(get_astrbot_data_path()) self.config_path = str(data_path / SKILLS_CONFIG_FILENAME) self.sandbox_skills_cache_path = str(data_path / SANDBOX_SKILLS_CACHE_FILENAME) + self.sdk_plugin_skills_path = str(data_path / SDK_PLUGIN_SKILLS_FILENAME) os.makedirs(self.skills_root, exist_ok=True) + def _read_skill_description(self, skill_md_path: Path) -> str: + try: + content = skill_md_path.read_text(encoding="utf-8") + except Exception: + return "" + return _parse_frontmatter_description(content) + + def _discover_standalone_skill_sources(self) -> dict[str, LocalSkillSource]: + sources: dict[str, LocalSkillSource] = {} + skills_root = Path(self.skills_root) + if not skills_root.exists(): + return sources + + for entry in sorted(skills_root.iterdir()): + if not entry.is_dir(): + continue + skill_md_path = _normalize_skill_markdown_path(entry) + if skill_md_path is None: + continue + sources[entry.name] = LocalSkillSource( + name=entry.name, + skill_dir=entry, + skill_md_path=skill_md_path, + owner_type="standalone", + ) + return sources + + def _load_sdk_plugin_skills_registry(self) -> dict[str, object]: + if not os.path.exists(self.sdk_plugin_skills_path): + return {"version": _SDK_PLUGIN_SKILLS_VERSION, "plugins": {}} + try: + with open(self.sdk_plugin_skills_path, encoding="utf-8") as f: + data = json.load(f) + except Exception: + return {"version": _SDK_PLUGIN_SKILLS_VERSION, "plugins": {}} + if not isinstance(data, dict): + return {"version": _SDK_PLUGIN_SKILLS_VERSION, "plugins": {}} + plugins = data.get("plugins", {}) + if not isinstance(plugins, dict): + plugins = {} + return { + "version": int(data.get("version", _SDK_PLUGIN_SKILLS_VERSION)), + "plugins": plugins, + } + + def _save_sdk_plugin_skills_registry(self, registry: dict[str, object]) -> None: + registry["version"] = _SDK_PLUGIN_SKILLS_VERSION + with open(self.sdk_plugin_skills_path, "w", encoding="utf-8") as f: + json.dump(registry, f, ensure_ascii=False, indent=2) + + def replace_sdk_plugin_skills( + self, + plugin_id: str, + skills: list[dict[str, str]], + ) -> None: + plugin_name = str(plugin_id).strip() + if not plugin_name: + raise ValueError("plugin_id must not be empty") + + normalized_skills: list[dict[str, str]] = [] + for item in skills: + if not isinstance(item, dict): + continue + skill_name = str(item.get("name", "")).strip() + skill_dir_text = str(item.get("skill_dir", "")).strip() + if not skill_name or not _SKILL_NAME_RE.fullmatch(skill_name): + continue + if not skill_dir_text: + continue + skill_dir = Path(skill_dir_text).resolve() + skill_md_path = Path( + str(item.get("path", "")).strip() or str(skill_dir / "SKILL.md") + ).resolve() + normalized_skills.append( + { + "name": skill_name, + "description": str(item.get("description", "") or ""), + "path": str(skill_md_path), + "skill_dir": str(skill_dir), + } + ) + + registry = self._load_sdk_plugin_skills_registry() + plugins = registry.get("plugins", {}) + if not isinstance(plugins, dict): + plugins = {} + previous_items = plugins.get(plugin_name, []) + previous_names = { + str(item.get("name", "")).strip() + for item in previous_items + if isinstance(item, dict) + } + if normalized_skills: + plugins[plugin_name] = sorted( + normalized_skills, + key=lambda item: str(item.get("name", "")), + ) + else: + plugins.pop(plugin_name, None) + registry["plugins"] = plugins + self._save_sdk_plugin_skills_registry(registry) + + current_names = {item["name"] for item in normalized_skills} + for removed_name in sorted(previous_names - current_names): + self._remove_skill_from_sandbox_cache(removed_name) + + def remove_sdk_plugin_skills(self, plugin_id: str) -> None: + self.replace_sdk_plugin_skills(plugin_id, []) + + def prune_sdk_plugin_skills(self, active_plugin_ids: set[str]) -> None: + normalized_ids = { + str(item).strip() for item in active_plugin_ids if str(item).strip() + } + registry = self._load_sdk_plugin_skills_registry() + plugins = registry.get("plugins", {}) + if not isinstance(plugins, dict): + return + + removed_skill_names: set[str] = set() + updated_plugins: dict[str, object] = {} + for plugin_id, items in plugins.items(): + plugin_name = str(plugin_id).strip() + if not plugin_name: + continue + if plugin_name in normalized_ids: + updated_plugins[plugin_name] = items + continue + if isinstance(items, list): + removed_skill_names.update( + str(item.get("name", "")).strip() + for item in items + if isinstance(item, dict) + ) + + registry["plugins"] = updated_plugins + self._save_sdk_plugin_skills_registry(registry) + for removed_name in sorted(name for name in removed_skill_names if name): + self._remove_skill_from_sandbox_cache(removed_name) + + def _discover_sdk_plugin_skill_sources(self) -> dict[str, LocalSkillSource]: + sources: dict[str, LocalSkillSource] = {} + registry = self._load_sdk_plugin_skills_registry() + plugins = registry.get("plugins", {}) + if not isinstance(plugins, dict): + return sources + for plugin_id, items in plugins.items(): + if not isinstance(items, list): + continue + for item in items: + if not isinstance(item, dict): + continue + skill_name = str(item.get("name", "")).strip() + skill_dir_text = str(item.get("skill_dir", "")).strip() + path_text = str(item.get("path", "")).strip() + if not skill_name or not _SKILL_NAME_RE.fullmatch(skill_name): + continue + if not skill_dir_text: + continue + skill_dir = Path(skill_dir_text) + skill_md_path = Path(path_text or str(skill_dir / "SKILL.md")) + if not skill_dir.is_dir() or not skill_md_path.is_file(): + continue + sources.setdefault( + skill_name, + LocalSkillSource( + name=skill_name, + skill_dir=skill_dir, + skill_md_path=skill_md_path, + owner_type="sdk_registered", + description_override=str(item.get("description", "") or ""), + plugin_id=str(plugin_id), + ), + ) + return sources + + def list_local_skill_sources(self) -> list[LocalSkillSource]: + sources = self._discover_standalone_skill_sources() + for name, source in self._discover_sdk_plugin_skill_sources().items(): + sources.setdefault(name, source) + return [sources[name] for name in sorted(sources)] + + def get_local_skill_source(self, name: str) -> LocalSkillSource | None: + for source in self.list_local_skill_sources(): + if source.name == name: + return source + return None + + def materialize_local_skill_bundle( + self, + bundle_root: Path, + *, + skill_names: list[str] | None = None, + ) -> list[LocalSkillSource]: + selected_names = ( + {name for name in skill_names if name} if skill_names is not None else None + ) + bundle_root.mkdir(parents=True, exist_ok=True) + + copied_sources: list[LocalSkillSource] = [] + for source in self.list_local_skill_sources(): + if selected_names is not None and source.name not in selected_names: + continue + target_dir = bundle_root / source.name + if target_dir.exists(): + shutil.rmtree(target_dir) + # SDK-registered skills may live inside plugin packages, so bundle + # them under the public skill id to give sandbox/runtime a stable + # path that is independent from the plugin's internal layout. + shutil.copytree(source.skill_dir, target_dir) + copied_sources.append(source) + return copied_sources + def _load_config(self) -> dict: if not os.path.exists(self.config_path): self._save_config(DEFAULT_SKILLS_CONFIG.copy()) @@ -388,25 +613,17 @@ def list_skills( sandbox_cached_descriptions[name] = str(item.get("description", "") or "") sandbox_cached_paths[name] = path - for entry in sorted(Path(self.skills_root).iterdir()): - if not entry.is_dir(): - continue - skill_name = entry.name - skill_md = _normalize_skill_markdown_path(entry) - if skill_md is None: - continue + for source in self.list_local_skill_sources(): + skill_name = source.name active = skill_configs.get(skill_name, {}).get("active", True) if skill_name not in skill_configs: skill_configs[skill_name] = {"active": active} modified = True if active_only and not active: continue - description = "" - try: - content = skill_md.read_text(encoding="utf-8") - description = _parse_frontmatter_description(content) - except Exception: - description = "" + description = source.description_override or self._read_skill_description( + source.skill_md_path + ) sandbox_exists = ( runtime == "sandbox" and skill_name in sandbox_cached_descriptions ) @@ -417,7 +634,7 @@ def list_skills( skill_name ) or _default_sandbox_skill_path(skill_name) else: - path_str = str(skill_md) + path_str = str(source.skill_md_path) path_str = path_str.replace("\\", "/") skills_by_name[skill_name] = SkillInfo( name=skill_name, @@ -473,9 +690,7 @@ def list_skills( return [skills_by_name[name] for name in sorted(skills_by_name)] def is_sandbox_only_skill(self, name: str) -> bool: - skill_dir = Path(self.skills_root) / name - skill_md_exists = _normalize_skill_markdown_path(skill_dir) is not None - if skill_md_exists: + if self.get_local_skill_source(name) is not None: return False cache = self._load_sandbox_skills_cache() skills = cache.get("skills", []) @@ -522,9 +737,14 @@ def delete_skill(self, name: str) -> None: "Sandbox preset skill cannot be deleted from local skill management." ) - skill_dir = Path(self.skills_root) / name - if skill_dir.exists(): - shutil.rmtree(skill_dir) + source = self.get_local_skill_source(name) + if source is not None and source.owner_type != "standalone": + raise PermissionError( + "SDK-registered skill cannot be deleted here. Disable or update the owning plugin instead." + ) + + if source is not None and source.skill_dir.exists(): + shutil.rmtree(source.skill_dir) # Ensure UI consistency even when there is no active sandbox session # to refresh cache from runtime side. diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 796e0bd683..f9a7417c21 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -1,11 +1,23 @@ -# 兼容导出: Provider 从 provider 模块重新导出 -from astrbot.core.provider import Provider +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any -from .base import Star -from .context import Context from .star import StarMetadata, star_map, star_registry -from .star_manager import PluginManager -from .star_tools import StarTools + +if TYPE_CHECKING: + from astrbot.core.provider import Provider + + from .base import Star + from .context import Context + from .star_manager import PluginManager + from .star_tools import StarTools +else: + Provider: Any + Star: Any + Context: Any + PluginManager: Any + StarTools: Any __all__ = [ "Context", @@ -17,3 +29,17 @@ "star_map", "star_registry", ] + + +def __getattr__(name: str) -> Any: + if name == "Provider": + return import_module("astrbot.core.provider").Provider + if name == "Star": + return import_module(".base", __name__).Star + if name == "Context": + return import_module(".context", __name__).Context + if name == "PluginManager": + return import_module(".star_manager", __name__).PluginManager + if name == "StarTools": + return import_module(".star_tools", __name__).StarTools + raise AttributeError(name) diff --git a/astrbot/core/star/command_management.py b/astrbot/core/star/command_management.py index c60af9ea26..f73ed65600 100644 --- a/astrbot/core/star/command_management.py +++ b/astrbot/core/star/command_management.py @@ -4,8 +4,7 @@ from dataclasses import dataclass, field from typing import Any -from astrbot.api import sp -from astrbot.core import db_helper, logger +from astrbot.core import db_helper, logger, sp from astrbot.core.db.po import CommandConfig from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 606f46dd73..64adaa7645 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -5,25 +5,18 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any, Protocol +from astrbot_sdk.message.components import component_to_payload_sync from deprecated import deprecated from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import ToolSet -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.conversation_mgr import ConversationManager -from astrbot.core.db import BaseDatabase -from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.persona_mgr import PersonaManager -from astrbot.core.platform import Platform -from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion -from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager +from astrbot.core.message.message_types import sdk_message_type +from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager -from astrbot.core.provider.manager import ProviderManager from astrbot.core.provider.provider import ( EmbeddingProvider, Provider, @@ -35,7 +28,6 @@ ADAPTER_NAME_2_TYPE, PlatformAdapterType, ) -from astrbot.core.subagent_orchestrator import SubAgentOrchestrator from ..exceptions import ProviderNotFoundError from .filter.command import CommandFilter @@ -46,7 +38,19 @@ logger = logging.getLogger("astrbot") if TYPE_CHECKING: + from astrbot.core.astrbot_config_mgr import AstrBotConfigManager + from astrbot.core.config.astrbot_config import AstrBotConfig + from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron.manager import CronJobManager + from astrbot.core.db import BaseDatabase + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + from astrbot.core.persona_mgr import PersonaManager + from astrbot.core.platform import Platform + from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager + from astrbot.core.provider.manager import ProviderManager + from astrbot.core.sdk_bridge.plugin_bridge import SdkPluginBridge + from astrbot.core.subagent_orchestrator import SubAgentOrchestrator class PlatformManagerProtocol(Protocol): @@ -100,6 +104,8 @@ def __init__( self.cron_manager = cron_manager """Cron job manager, initialized by core lifecycle.""" self.subagent_orchestrator = subagent_orchestrator + self.sdk_plugin_bridge: SdkPluginBridge | None = None + """SDK plugin bridge, initialized by core lifecycle when available.""" async def llm_generate( self, @@ -151,7 +157,7 @@ async def tool_loop_agent( image_urls: list[str] | None = None, tools: ToolSet | None = None, system_prompt: str | None = None, - contexts: list[Message] | None = None, + contexts: list[Message | dict[str, Any]] | None = None, max_steps: int = 30, tool_call_timeout: int = 120, **kwargs: Any, @@ -342,6 +348,10 @@ def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts + def get_all_rerank_providers(self) -> list[RerankProvider]: + """获取所有用于 Rerank 任务的 Provider。""" + return self.provider_manager.rerank_provider_insts + def get_using_provider(self, umo: str | None = None) -> Provider | None: """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 @@ -454,6 +464,32 @@ async def send_message( for platform in self.platform_manager.platform_insts: if platform.meta().id == session.platform_name: await platform.send_by_session(session, message_chain) + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_system_event( + "after_message_sent", + { + "session_id": str(session), + "platform": platform.meta().name, + "platform_id": platform.meta().id, + "message_type": sdk_message_type(session.message_type), + "message_outline": message_chain.get_plain_text( + with_other_comps_mark=True + ), + "sent_message_outline": message_chain.get_plain_text( + with_other_comps_mark=True + ), + "sent_messages": [ + component_to_payload_sync(component) + for component in message_chain.chain + ], + }, + ) + except Exception as exc: + logger.warning( + "SDK after_message_sent dispatch failed for proactive send: %s", + exc, + ) return True logger.warning( f"cannot find platform for session {str(session)}, message not sent" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 25df73f642..591f7c0bb6 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -11,9 +11,11 @@ import sys import tempfile import traceback +from pathlib import Path from types import ModuleType import yaml +from astrbot_sdk.runtime.loader import load_plugin_spec, validate_plugin_spec from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import InvalidVersion, Version @@ -30,6 +32,7 @@ from astrbot.core.provider.register import llm_tools from astrbot.core.utils.astrbot_path import ( get_astrbot_config_path, + get_astrbot_data_path, get_astrbot_path, get_astrbot_plugin_path, get_astrbot_temp_path, @@ -459,6 +462,156 @@ def _get_plugin_dir_name_from_metadata(plugin_path: str) -> str: PluginManager._validate_importable_name(plugin_dir_name) return plugin_dir_name + @staticmethod + def _detect_plugin_type(plugin_path: str) -> tuple[str, str]: + """根据插件清单文件识别安装目标。 + + Why: + 旧版插件和 SDK 插件分别由不同加载器管理,安装阶段必须先按 + `metadata.yaml` / `plugin.yaml` 分流,否则 SDK 插件会被误送到 + `data/plugins`,后续无法被 SDK 桥接层发现。 + """ + plugin_dir = Path(plugin_path) + plugin_manifest_path = plugin_dir / "plugin.yaml" + legacy_metadata_path = plugin_dir / "metadata.yaml" + + if plugin_manifest_path.exists(): + plugin_spec = load_plugin_spec(plugin_dir) + validate_plugin_spec(plugin_spec) + return "sdk", plugin_spec.name + + if legacy_metadata_path.exists(): + return "legacy", PluginManager._get_plugin_dir_name_from_metadata( + plugin_path + ) + + raise Exception( + "无法识别插件类型:插件目录中既没有 plugin.yaml,也没有 metadata.yaml。" + ) + + @staticmethod + def _read_plugin_readme(plugin_path: str, plugin_label: str) -> str | None: + plugin_dir = Path(plugin_path) + + for readme_name in ("README.md", "readme.md"): + readme_path = plugin_dir / readme_name + if not readme_path.exists(): + continue + try: + return readme_path.read_text(encoding="utf-8") + except Exception as exc: + logger.warning( + "读取插件 %s 的 %s 文件失败: %s", + plugin_label, + readme_name, + exc, + ) + return None + + return None + + @staticmethod + def _build_plugin_install_result( + *, + name: str, + repo: str | None, + readme: str | None, + plugin_type: str, + ) -> dict[str, str | None]: + return { + "repo": repo, + "readme": readme, + "name": name, + "type": plugin_type, + } + + async def _install_sdk_plugin( + self, + *, + temp_plugin_path: str, + plugin_name: str, + repo_url: str | None, + ) -> dict[str, str | None]: + """安装 SDK 插件到 data/sdk_plugins 并触发桥接层重新发现。""" + sdk_plugins_dir = Path(get_astrbot_data_path()) / "sdk_plugins" + target_plugin_path = sdk_plugins_dir / plugin_name + + if target_plugin_path.exists(): + raise Exception(f"安装失败:SDK 插件 {plugin_name} 已存在。") + + sdk_plugins_dir.mkdir(parents=True, exist_ok=True) + Path(temp_plugin_path).rename(target_plugin_path) + + sdk_plugin_bridge = getattr(self.context, "sdk_plugin_bridge", None) + if sdk_plugin_bridge is not None: + await sdk_plugin_bridge.reload_all(reset_restart_budget=True) + else: + logger.warning( + "SDK 插件 %s 已写入 %s,但当前未找到 sdk_plugin_bridge," + "需等待后续生命周期重载。", + plugin_name, + target_plugin_path, + ) + + return self._build_plugin_install_result( + name=plugin_name, + repo=repo_url, + readme=self._read_plugin_readme(str(target_plugin_path), plugin_name), + plugin_type="sdk", + ) + + async def _migrate_legacy_plugin_to_sdk_runtime( + self, + *, + legacy_plugin: StarMetadata, + legacy_plugin_path: Path, + sdk_plugin_name: str, + ) -> None: + """将已更新为 SDK 清单的 legacy 插件迁移到 SDK 运行时目录。""" + if legacy_plugin.root_dir_name is None or legacy_plugin.module_path is None: + raise Exception( + f"插件 {legacy_plugin.name} 缺少 root_dir_name 或 module_path,无法迁移到 SDK 运行时。" + ) + + logger.info( + "检测到 legacy 插件 %s 已切换为 SDK 清单,开始迁移到 data/sdk_plugins/%s", + legacy_plugin.name, + sdk_plugin_name, + ) + + try: + await self._terminate_plugin(legacy_plugin) + except Exception as exc: + logger.warning(traceback.format_exc()) + logger.warning( + "插件 %s 在迁移到 SDK 运行时前未被正常终止: %s", + legacy_plugin.name, + exc, + ) + + await self._unbind_plugin(legacy_plugin.name, legacy_plugin.module_path) + + sdk_plugins_dir = Path(get_astrbot_data_path()) / "sdk_plugins" + target_plugin_path = sdk_plugins_dir / sdk_plugin_name + if target_plugin_path.exists(): + raise Exception(f"迁移失败:SDK 插件 {sdk_plugin_name} 已存在。") + + sdk_plugins_dir.mkdir(parents=True, exist_ok=True) + legacy_plugin_path.rename(target_plugin_path) + + sdk_plugin_bridge = getattr(self.context, "sdk_plugin_bridge", None) + if sdk_plugin_bridge is not None: + await sdk_plugin_bridge.reload_all(reset_restart_budget=True) + if not legacy_plugin.activated: + await sdk_plugin_bridge.turn_off_plugin(sdk_plugin_name) + else: + logger.warning( + "SDK 插件 %s 已迁移到 %s,但当前未找到 sdk_plugin_bridge," + "需等待后续生命周期重载。", + sdk_plugin_name, + target_plugin_path, + ) + @staticmethod def _validate_astrbot_version_specifier( version_spec: str | None, @@ -1061,6 +1214,19 @@ async def load( await handler.handler(metadata) except Exception: logger.error(traceback.format_exc()) + sdk_plugin_bridge = getattr(self.context, "sdk_plugin_bridge", None) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_system_event( + "plugin_loaded", + { + "plugin_name": metadata.name, + "display_name": metadata.display_name or metadata.name, + "version": metadata.version, + }, + ) + except Exception as exc: + logger.warning("SDK plugin_loaded dispatch failed: %s", exc) except BaseException as e: logger.error(f"----- 插件 {root_dir_name} 载入失败 -----") @@ -1238,6 +1404,7 @@ async def install_plugin( async with self._pm_lock: plugin_path = "" dir_name = "" + should_track_failed_install_dir = True try: _, repo_name, _ = self.updator.parse_github_url(repo_url) repo_name = self.updator.format_name(repo_name) @@ -1248,21 +1415,36 @@ async def install_plugin( ) plugin_path = await self.updator.install(repo_url, proxy) - # reload the plugin - dir_name = os.path.basename(plugin_path) - metadata_dir_name = self._get_plugin_dir_name_from_metadata(plugin_path) + plugin_type, plugin_name = self._detect_plugin_type(plugin_path) + logger.info( + "插件安装类型识别完成:repo=%s, type=%s, name=%s", + repo_url, + plugin_type, + plugin_name, + ) + dir_name = plugin_name + if plugin_type == "sdk": + should_track_failed_install_dir = False + return await self._install_sdk_plugin( + temp_plugin_path=plugin_path, + plugin_name=plugin_name, + repo_url=repo_url, + ) + + # Why: + # 旧版插件的导入路径依赖目录名与 metadata.yaml 中的 name 一致, + # 因此在加载前必须完成重命名;SDK 插件则已在前面的分支单独处理。 target_plugin_path = os.path.join( self.plugin_store_path, - metadata_dir_name, + plugin_name, ) if target_plugin_path != plugin_path and os.path.exists( target_plugin_path ): - raise Exception(f"安装失败:目录 {metadata_dir_name} 已存在。") + raise Exception(f"安装失败:目录 {plugin_name} 已存在。") if target_plugin_path != plugin_path: os.rename(plugin_path, target_plugin_path) plugin_path = target_plugin_path - dir_name = metadata_dir_name await self._ensure_plugin_requirements( plugin_path, dir_name, @@ -1286,36 +1468,25 @@ async def install_plugin( plugin = star break - # Extract README.md content if exists - readme_content = None - readme_path = os.path.join(plugin_path, "README.md") - if not os.path.exists(readme_path): - readme_path = os.path.join(plugin_path, "readme.md") - - if os.path.exists(readme_path): - try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() - except Exception as e: - logger.warning( - f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}", - ) + readme_content = self._read_plugin_readme(plugin_path, dir_name) plugin_info = None if plugin: - plugin_info = { - "repo": plugin.repo, - "readme": readme_content, - "name": plugin.name, - } + plugin_info = self._build_plugin_install_result( + name=plugin.name, + repo=plugin.repo, + readme=readme_content, + plugin_type="legacy", + ) return plugin_info except Exception as e: - self._track_failed_install_dir( - dir_name=dir_name, - plugin_path=plugin_path, - error=e, - ) + if should_track_failed_install_dir: + self._track_failed_install_dir( + dir_name=dir_name, + plugin_path=plugin_path, + error=e, + ) if dir_name and plugin_path: logger.warning( f"安装插件 {dir_name} 失败,插件安装目录:{plugin_path}", @@ -1507,9 +1678,17 @@ async def update_plugin(self, plugin_name: str, proxy="") -> None: await self.updator.update(plugin, proxy=proxy) if plugin.root_dir_name: - plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) + plugin_dir_path = Path(self.plugin_store_path) / plugin.root_dir_name + plugin_type, detected_name = self._detect_plugin_type(str(plugin_dir_path)) + if plugin_type == "sdk": + await self._migrate_legacy_plugin_to_sdk_runtime( + legacy_plugin=plugin, + legacy_plugin_path=plugin_dir_path, + sdk_plugin_name=detected_name, + ) + return await self._ensure_plugin_requirements( - plugin_dir_path, + str(plugin_dir_path), plugin_name, ) await self.reload(plugin_name) @@ -1601,6 +1780,24 @@ def _log_del_exception(fut: asyncio.Future) -> None: await handler.handler(star_metadata) except Exception: logger.error(traceback.format_exc()) + sdk_plugin_bridge = ( + getattr(star_metadata.star_cls.context, "sdk_plugin_bridge", None) + if getattr(star_metadata, "star_cls", None) + else None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_system_event( + "plugin_unloaded", + { + "plugin_name": star_metadata.name, + "display_name": star_metadata.display_name + or star_metadata.name, + "version": star_metadata.version, + }, + ) + except Exception as exc: + logger.warning("SDK plugin_unloaded dispatch failed: %s", exc) async def turn_on_plugin(self, plugin_name: str) -> None: plugin = self.context.get_registered_star(plugin_name) @@ -1636,26 +1833,41 @@ async def install_plugin_from_file( dir=self.plugin_store_path, prefix="plugin_upload_" ) temp_desti_dir = desti_dir + should_track_failed_install_dir = True try: self.updator.unzip_file(zip_file_path, desti_dir) - metadata_dir_name = self._get_plugin_dir_name_from_metadata(desti_dir) + try: + os.remove(zip_file_path) + except BaseException as e: + logger.warning(f"删除插件压缩包失败: {e!s}") + + plugin_type, plugin_name = self._detect_plugin_type(desti_dir) + logger.info( + "上传插件安装类型识别完成:type=%s, name=%s, file=%s", + plugin_type, + plugin_name, + zip_file_path, + ) + dir_name = plugin_name + if plugin_type == "sdk": + should_track_failed_install_dir = False + return await self._install_sdk_plugin( + temp_plugin_path=desti_dir, + plugin_name=plugin_name, + repo_url=None, + ) + target_plugin_path = os.path.join( self.plugin_store_path, - metadata_dir_name, + plugin_name, ) if target_plugin_path != desti_dir and os.path.exists(target_plugin_path): - raise Exception(f"安装失败:目录 {metadata_dir_name} 已存在。") + raise Exception(f"安装失败:目录 {plugin_name} 已存在。") if target_plugin_path != desti_dir: os.rename(desti_dir, target_plugin_path) - dir_name = metadata_dir_name desti_dir = target_plugin_path - # remove the zip - try: - os.remove(zip_file_path) - except BaseException as e: - logger.warning(f"删除插件压缩包失败: {e!s}") await self._ensure_plugin_requirements(desti_dir, dir_name) # await self.reload() success, error_message = await self.load( @@ -1677,26 +1889,16 @@ async def install_plugin_from_file( plugin = star break - # Extract README.md content if exists - readme_content = None - readme_path = os.path.join(desti_dir, "README.md") - if not os.path.exists(readme_path): - readme_path = os.path.join(desti_dir, "readme.md") - - if os.path.exists(readme_path): - try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() - except Exception as e: - logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}") + readme_content = self._read_plugin_readme(desti_dir, dir_name) plugin_info = None if plugin: - plugin_info = { - "repo": plugin.repo, - "readme": readme_content, - "name": plugin.name, - } + plugin_info = self._build_plugin_install_result( + name=plugin.name, + repo=plugin.repo, + readme=readme_content, + plugin_type="legacy", + ) if plugin.repo: asyncio.create_task( @@ -1708,14 +1910,13 @@ async def install_plugin_from_file( return plugin_info except Exception as e: - self._track_failed_install_dir( - dir_name=dir_name, - plugin_path=desti_dir, - error=e, - ) - logger.warning( - f"安装插件 {dir_name} 失败,插件安装目录:{desti_dir}", - ) + if should_track_failed_install_dir: + self._track_failed_install_dir( + dir_name=dir_name, + plugin_path=desti_dir, + error=e, + ) + logger.warning(f"安装插件 {dir_name} 失败,插件安装目录:{desti_dir}") raise finally: if temp_desti_dir != desti_dir and os.path.isdir(temp_desti_dir): diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 4d85131fc6..94237620d7 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -28,12 +28,6 @@ from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( - AiocqhttpMessageEvent, -) -from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( - AiocqhttpAdapter, -) from astrbot.core.star.context import Context from astrbot.core.star.star import star_map from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -103,6 +97,13 @@ async def send_message_by_id( raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + adapter = next( (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None, @@ -183,6 +184,13 @@ async def create_event( raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + adapter = next( (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None, diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py index 987ce110a5..05d22bc22c 100644 --- a/astrbot/core/utils/astrbot_path.py +++ b/astrbot/core/utils/astrbot_path.py @@ -5,6 +5,7 @@ 数据目录路径:固定为根目录下的 data 目录 配置文件路径:固定为数据目录下的 config 目录 插件目录路径:固定为数据目录下的 plugins 目录 +SDK 插件目录路径:固定为数据目录下的 sdk_plugins 目录 插件数据目录路径:固定为数据目录下的 plugin_data 目录 T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录 WebChat 数据目录路径:固定为数据目录下的 webchat 目录 @@ -49,6 +50,11 @@ def get_astrbot_plugin_path() -> str: return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) +def get_astrbot_sdk_plugins_path() -> str: + """获取Astrbot SDK 插件目录路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "sdk_plugins")) + + def get_astrbot_plugin_data_path() -> str: """获取Astrbot插件数据目录路径""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data")) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b565926749..82e4ea0744 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -9,7 +9,6 @@ import zipfile from pathlib import Path -import aiohttp import certifi import psutil from PIL import Image @@ -19,6 +18,12 @@ logger = logging.getLogger("astrbot") +def _get_aiohttp(): + import aiohttp + + return aiohttp + + def on_error(func, path, exc_info) -> None: """A callback of the rmtree function.""" import stat @@ -70,6 +75,7 @@ async def download_image_by_url( path: str | None = None, ) -> str: """下载图片, 返回 path""" + aiohttp = _get_aiohttp() try: ssl_context = ssl.create_default_context( cafile=certifi.where(), @@ -125,6 +131,7 @@ async def download_image_by_url( async def download_file(url: str, path: str, show_progress: bool = False) -> None: """从指定 url 下载文件到指定路径 path""" + aiohttp = _get_aiohttp() try: ssl_context = ssl.create_default_context( cafile=certifi.where(), diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 8fb1464284..a3ebd40e7e 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -3,12 +3,21 @@ import sys import uuid -import aiohttp - -from astrbot.core import db_helper, logger from astrbot.core.config import VERSION +def _get_aiohttp(): + import aiohttp + + return aiohttp + + +def _get_runtime_dependencies(): + from astrbot.core import db_helper, logger + + return db_helper, logger + + class Metric: _iid_cache = None @@ -45,6 +54,7 @@ async def upload(**kwargs) -> None: Powered by TickStats. """ + db_helper, logger = _get_runtime_dependencies() if os.environ.get("ASTRBOT_DISABLE_METRICS", "0") == "1": return base_url = "https://tickstats.soulter.top/api/metric/90a6c2a1" @@ -69,6 +79,7 @@ async def upload(**kwargs) -> None: logger.error(f"保存指标到数据库失败: {e}") try: + aiohttp = _get_aiohttp() async with aiohttp.ClientSession(trust_env=True) as session: async with session.post(base_url, json=payload, timeout=3) as response: if response.status != 200: diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 2fa2351291..c50c3b08a2 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -1,17 +1,23 @@ -import re import os -import aiohttp +import re import ssl -import certifi -from io import BytesIO -from typing import List, Tuple from abc import ABC, abstractmethod +from io import BytesIO + +import certifi +from PIL import Image, ImageDraw, ImageFont + from astrbot.core.config import VERSION +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import save_temp_img from . import RenderStrategy -from PIL import ImageFont, Image, ImageDraw -from astrbot.core.utils.io import save_temp_img -from astrbot.core.utils.astrbot_path import get_astrbot_data_path + + +def _get_aiohttp(): + import aiohttp + + return aiohttp class FontManager: @@ -20,7 +26,7 @@ class FontManager: _font_cache = {} @classmethod - def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont: + def get_font(cls, size: int) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: """获取指定大小的字体,优先从缓存获取""" if size in cls._font_cache: return cls._font_cache[size] @@ -66,7 +72,9 @@ class TextMeasurer: """测量文本尺寸的工具类""" @staticmethod - def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]: + def get_text_size( + text: str, font: ImageFont.FreeTypeFont | ImageFont.ImageFont + ) -> tuple[int, int]: """获取文本的尺寸""" # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0 @@ -75,7 +83,7 @@ def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) - @staticmethod def split_text_to_fit_width( - text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int + text: str, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, max_width: int ) -> list[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] @@ -293,7 +301,10 @@ def render( # 倾斜变换,使用仿射变换实现斜体效果 # 变换矩阵: [1, 0.2, 0, 0, 1, 0] italic_img = text_img.transform( - text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC + text_img.size, + Image.Transform.AFFINE, + (1, 0.2, 0, 0, 1, 0), + Image.Resampling.BICUBIC, ) # 粘贴到原图像 @@ -629,6 +640,7 @@ def __init__(self, content: str, image_url: str): async def load_image(self): """加载图片""" try: + aiohttp = _get_aiohttp() ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 53d9441fab..828fa597a7 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -2,8 +2,6 @@ import logging import random -import aiohttp - from astrbot.core.config import VERSION from astrbot.core.utils.http_ssl import build_tls_connector from astrbot.core.utils.io import download_image_by_url @@ -16,6 +14,12 @@ logger = logging.getLogger("astrbot") +def _get_aiohttp(): + import aiohttp + + return aiohttp + + class NetworkRenderStrategy(RenderStrategy): def __init__(self, base_url: str | None = None) -> None: super().__init__() @@ -38,6 +42,7 @@ async def get_template(self, name: str = "base") -> str: async def get_official_endpoints(self) -> None: """获取官方的 t2i 端点列表。""" try: + aiohttp = _get_aiohttp() async with aiohttp.ClientSession( trust_env=True, connector=build_tls_connector(), @@ -89,6 +94,7 @@ async def render_custom_template( last_exception = None for endpoint in endpoints: try: + aiohttp = _get_aiohttp() if return_url: async with ( aiohttp.ClientSession( diff --git a/astrbot/dashboard/routes/command.py b/astrbot/dashboard/routes/command.py index cbc565c476..8222a90bf5 100644 --- a/astrbot/dashboard/routes/command.py +++ b/astrbot/dashboard/routes/command.py @@ -1,5 +1,6 @@ from quart import request +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star.command_management import ( list_command_conflicts, list_commands, @@ -18,8 +19,13 @@ class CommandRoute(Route): - def __init__(self, context: RouteContext) -> None: + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: super().__init__(context) + self.core_lifecycle = core_lifecycle self.routes = { "/commands": ("GET", self.get_commands), "/commands/conflicts": ("GET", self.get_conflicts), @@ -30,7 +36,7 @@ def __init__(self, context: RouteContext) -> None: self.register_routes() async def get_commands(self): - commands = await list_commands() + commands = await _list_dashboard_commands(self.core_lifecycle) summary = { "total": len(commands), "disabled": len([cmd for cmd in commands if not cmd["enabled"]]), @@ -39,67 +45,174 @@ async def get_commands(self): return Response().ok({"items": commands, "summary": summary}).__dict__ async def get_conflicts(self): - conflicts = await list_command_conflicts() + conflicts = await _list_dashboard_conflicts(self.core_lifecycle) return Response().ok(conflicts).__dict__ async def toggle_command(self): data = await request.get_json() - handler_full_name = data.get("handler_full_name") + command_key = _resolve_command_key(data) enabled = data.get("enabled") - if handler_full_name is None or enabled is None: - return Response().error("handler_full_name 与 enabled 均为必填。").__dict__ + if command_key is None or enabled is None: + return Response().error("command_key 与 enabled 均为必填。").__dict__ if isinstance(enabled, str): enabled = enabled.lower() in ("1", "true", "yes", "on") + item = await _get_command_payload(self.core_lifecycle, command_key) + if item.get("runtime_kind") == "sdk": + return ( + Response() + .error("SDK commands are read-only in the dashboard.") + .__dict__ + ) + try: - await toggle_command_service(handler_full_name, bool(enabled)) + await toggle_command_service(command_key, bool(enabled)) except ValueError as exc: return Response().error(str(exc)).__dict__ - payload = await _get_command_payload(handler_full_name) + payload = await _get_command_payload(self.core_lifecycle, command_key) return Response().ok(payload).__dict__ async def rename_command(self): data = await request.get_json() - handler_full_name = data.get("handler_full_name") + command_key = _resolve_command_key(data) new_name = data.get("new_name") aliases = data.get("aliases") - if not handler_full_name or not new_name: - return Response().error("handler_full_name 与 new_name 均为必填。").__dict__ + if not command_key or not new_name: + return Response().error("command_key 与 new_name 均为必填。").__dict__ + + item = await _get_command_payload(self.core_lifecycle, command_key) + if item.get("runtime_kind") == "sdk": + return ( + Response() + .error("SDK commands are read-only in the dashboard.") + .__dict__ + ) try: - await rename_command_service(handler_full_name, new_name, aliases=aliases) + await rename_command_service(command_key, new_name, aliases=aliases) except ValueError as exc: return Response().error(str(exc)).__dict__ - payload = await _get_command_payload(handler_full_name) + payload = await _get_command_payload(self.core_lifecycle, command_key) return Response().ok(payload).__dict__ async def update_permission(self): data = await request.get_json() - handler_full_name = data.get("handler_full_name") + command_key = _resolve_command_key(data) permission = data.get("permission") - if not handler_full_name or not permission: + if not command_key or not permission: + return Response().error("command_key 与 permission 均为必填。").__dict__ + + item = await _get_command_payload(self.core_lifecycle, command_key) + if item.get("runtime_kind") == "sdk": return ( - Response().error("handler_full_name 与 permission 均为必填。").__dict__ + Response() + .error("SDK commands are read-only in the dashboard.") + .__dict__ ) try: - await update_command_permission_service(handler_full_name, permission) + await update_command_permission_service(command_key, permission) except ValueError as exc: return Response().error(str(exc)).__dict__ - payload = await _get_command_payload(handler_full_name) + payload = await _get_command_payload(self.core_lifecycle, command_key) return Response().ok(payload).__dict__ -async def _get_command_payload(handler_full_name: str): - commands = await list_commands() - for cmd in commands: - if cmd["handler_full_name"] == handler_full_name: +def _resolve_command_key(data: dict | None) -> str | None: + if not isinstance(data, dict): + return None + command_key = data.get("command_key") + if command_key: + return str(command_key) + handler_full_name = data.get("handler_full_name") + if handler_full_name: + return str(handler_full_name) + return None + + +async def _list_dashboard_commands( + core_lifecycle: AstrBotCoreLifecycle, +) -> list[dict]: + commands = _decorate_legacy_commands(await list_commands()) + sdk_bridge = getattr(core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + commands.extend(sdk_bridge.list_dashboard_commands()) + _apply_conflict_flags(commands) + commands.sort(key=lambda item: str(item.get("effective_command", "")).lower()) + return commands + + +async def _list_dashboard_conflicts( + core_lifecycle: AstrBotCoreLifecycle, +) -> list[dict]: + conflicts = list(await list_command_conflicts()) + sdk_bridge = getattr(core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is None or not hasattr( + sdk_bridge, "list_cross_system_command_conflicts" + ): + return conflicts + conflicts.extend( + conflict.to_dashboard_payload() + for conflict in sdk_bridge.list_cross_system_command_conflicts() + ) + return conflicts + + +def _decorate_legacy_commands(commands: list[dict]) -> list[dict]: + for item in commands: + _decorate_legacy_command_item(item) + return commands + + +def _decorate_legacy_command_item(item: dict) -> None: + item["command_key"] = str(item.get("handler_full_name", "")) + item["runtime_kind"] = "legacy" + item["supports_toggle"] = True + item["supports_rename"] = True + item["supports_permission"] = True + sub_commands = item.get("sub_commands") + if not isinstance(sub_commands, list): + return + for sub in sub_commands: + if isinstance(sub, dict): + _decorate_legacy_command_item(sub) + + +def _apply_conflict_flags(commands: list[dict]) -> None: + counts: dict[str, int] = {} + for item in _walk_command_items(commands): + command_name = str(item.get("effective_command", "")).strip() + if not command_name or not bool(item.get("enabled", False)): + continue + counts[command_name] = counts.get(command_name, 0) + 1 + + for item in _walk_command_items(commands): + command_name = str(item.get("effective_command", "")).strip() + item["has_conflict"] = bool(command_name and counts.get(command_name, 0) > 1) + + +def _walk_command_items(commands: list[dict]): + for item in commands: + yield item + sub_commands = item.get("sub_commands") + if not isinstance(sub_commands, list): + continue + yield from _walk_command_items(sub_commands) + + +async def _get_command_payload( + core_lifecycle: AstrBotCoreLifecycle, + command_key: str, +): + commands = await _list_dashboard_commands(core_lifecycle) + for cmd in _walk_command_items(commands): + if cmd.get("command_key") == command_key: return cmd return {} diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..72a45d27c6 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1043,7 +1043,7 @@ async def post_plugin_configs(self): plugin_name = request.args.get("plugin_name", "unknown") try: await self._save_plugin_configs(post_configs, plugin_name) - await self.core_lifecycle.plugin_manager.reload(plugin_name) + await self._reload_plugin_after_config_save(plugin_name) return ( Response() .ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在热重载插件。") @@ -1058,6 +1058,16 @@ def _get_plugin_metadata_by_name(self, plugin_name: str) -> StarMetadata | None: return plugin_md return None + def _sdk_bridge(self): + return getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + + async def _reload_plugin_after_config_save(self, plugin_name: str) -> None: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None and sdk_bridge.get_plugin_metadata(plugin_name): + await sdk_bridge.reload_plugin(plugin_name) + return + await self.core_lifecycle.plugin_manager.reload(plugin_name) + def _resolve_config_file_scope( self, ) -> tuple[str, str, str, StarMetadata, AstrBotConfig]: @@ -1516,6 +1526,26 @@ async def _get_plugin_config(self, plugin_name: str): } break + if ret["metadata"] is not None: + return ret + + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return ret + + schema = sdk_bridge.get_plugin_config_schema(plugin_name) + if schema is None or not schema: + return ret + config = sdk_bridge.get_plugin_config(plugin_name) or {} + ret["config"] = config + ret["metadata"] = { + plugin_name: { + "description": f"{plugin_name} 配置", + "type": "object", + "items": schema, + }, + } + return ret async def _save_astrbot_configs( @@ -1542,18 +1572,40 @@ async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> No if plugin_md.name == plugin_name: md = plugin_md - if not md: + if md: + if not md.config: + raise ValueError(f"插件 {plugin_name} 没有注册配置") + assert md.config is not None + + try: + errors, post_configs = validate_config( + post_configs, getattr(md.config, "schema", {}), is_core=False + ) + if errors: + raise ValueError(f"格式校验未通过: {errors}") + md.config.save_config(post_configs) + return + except Exception as e: + raise e + + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + raise ValueError(f"插件 {plugin_name} 不存在") + + schema = sdk_bridge.get_plugin_config_schema(plugin_name) + if schema is None: raise ValueError(f"插件 {plugin_name} 不存在") - if not md.config: + if not schema: raise ValueError(f"插件 {plugin_name} 没有注册配置") - assert md.config is not None try: errors, post_configs = validate_config( - post_configs, getattr(md.config, "schema", {}), is_core=False + post_configs, + schema, + is_core=False, ) if errors: raise ValueError(f"格式校验未通过: {errors}") - md.config.save_config(post_configs) + sdk_bridge.save_plugin_config(plugin_name, post_configs) except Exception as e: raise e diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index d151bbe6f6..50b7f37652 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,4 +1,5 @@ import asyncio +import base64 import hashlib import json import os @@ -14,6 +15,7 @@ from astrbot.api import sp from astrbot.core import DEMO_MODE, file_token_service, logger +from astrbot.core.config.default import VERSION from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter @@ -28,6 +30,7 @@ get_astrbot_data_path, get_astrbot_temp_path, ) +from astrbot.core.zip_updator import RepoZipUpdator from .route import Response, Route, RouteContext @@ -86,6 +89,19 @@ def __init__( } self._logo_cache = {} + self._remote_doc_cache: dict[tuple[str, str], str] = {} + self._repo_updator = RepoZipUpdator() + + def _sdk_bridge(self): + return getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + + def _is_sdk_plugin(self, plugin_name: str) -> bool: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return False + return any( + plugin["name"] == plugin_name for plugin in sdk_bridge.list_plugins() + ) async def check_plugin_compatibility(self): try: @@ -146,9 +162,19 @@ async def reload_plugins(self): data = await request.get_json() plugin_name = data.get("name", None) try: - success, message = await self.plugin_manager.reload(plugin_name) - if not success: - return Response().error(message or "插件重载失败").__dict__ + if plugin_name and self._is_sdk_plugin(plugin_name): + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return Response().error("SDK bridge 未初始化").__dict__ + await sdk_bridge.reload_plugin(plugin_name) + else: + success, message = await self.plugin_manager.reload(plugin_name) + if not success: + return Response().error(message or "插件重载失败").__dict__ + if plugin_name is None: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None: + await sdk_bridge.reload_all(reset_restart_budget=True) return Response().ok(None, "重载成功。").__dict__ except Exception as e: logger.error(f"/api/plugin/reload: {traceback.format_exc()}") @@ -367,6 +393,105 @@ def _resolve_plugin_dir(self, plugin) -> Path | None: return None return plugin_dir + def _resolve_sdk_plugin_dir(self, plugin_name: str) -> Path | None: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return None + records = getattr(sdk_bridge, "_records", None) + if not isinstance(records, dict): + return None + record = records.get(plugin_name) + plugin = getattr(record, "plugin", None) + plugin_dir = getattr(plugin, "plugin_dir", None) + if plugin_dir is None: + return None + resolved = Path(plugin_dir) + if not resolved.is_dir(): + return None + return resolved + + def _find_legacy_plugin(self, plugin_name: str): + for plugin in self.plugin_manager.context.get_all_stars(): + if plugin.name == plugin_name: + return plugin + return None + + def _resolve_plugin_content_dir(self, plugin_name: str) -> Path | None: + for plugin in self.plugin_manager.context.get_all_stars(): + if plugin.name != plugin_name: + continue + return self._resolve_plugin_dir(plugin) + return self._resolve_sdk_plugin_dir(plugin_name) + + def _resolve_plugin_repo_url(self, plugin_name: str) -> str | None: + for plugin in self.plugin_manager.context.get_all_stars(): + if plugin.name != plugin_name: + continue + repo = getattr(plugin, "repo", None) + if isinstance(repo, str) and repo.strip(): + return repo.strip() + + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None: + get_plugin_metadata = getattr(sdk_bridge, "get_plugin_metadata", None) + if callable(get_plugin_metadata): + metadata = get_plugin_metadata(plugin_name) + if isinstance(metadata, dict): + repo = metadata.get("repo") + if isinstance(repo, str) and repo.strip(): + return repo.strip() + records = getattr(sdk_bridge, "_records", None) + if isinstance(records, dict): + record = records.get(plugin_name) + plugin = getattr(record, "plugin", None) + manifest = getattr(plugin, "manifest_data", None) + if isinstance(manifest, dict): + repo = manifest.get("repo") + if isinstance(repo, str) and repo.strip(): + return repo.strip() + return None + + async def _fetch_github_repo_readme(self, repo_url: str) -> str: + cache_key = ("readme", repo_url) + cached = self._remote_doc_cache.get(cache_key) + if cached is not None: + return cached + + owner, repo, branch = self._repo_updator.parse_github_url(repo_url) + params = {"ref": branch} if branch else None + headers = { + "Accept": "application/vnd.github+json", + "User-Agent": f"AstrBot/{VERSION}", + "X-GitHub-Api-Version": "2022-11-28", + } + api_url = f"https://api.github.com/repos/{owner}/{repo}/readme" + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + + async with ( + aiohttp.ClientSession( + trust_env=True, + connector=connector, + headers=headers, + ) as session, + session.get(api_url, params=params) as response, + ): + if response.status != 200: + message = await response.text() + raise ValueError( + f"GitHub README 获取失败,状态码 {response.status}: {message}" + ) + payload = await response.json() + + encoding = str(payload.get("encoding") or "").lower() + content = payload.get("content") + if encoding != "base64" or not isinstance(content, str): + raise ValueError("GitHub README 返回格式不受支持。") + + decoded = base64.b64decode(content).decode("utf-8") + self._remote_doc_cache[cache_key] = decoded + return decoded + def _get_plugin_installed_at(self, plugin) -> str | None: plugin_dir = self._resolve_plugin_dir(plugin) if plugin_dir is None: @@ -420,6 +545,12 @@ async def get_plugins(self): ): continue _plugin_resp.append(_t) + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None: + for plugin in sdk_bridge.list_plugins(): + if plugin_name and plugin["name"] != plugin_name: + continue + _plugin_resp.append(plugin) return ( Response() .ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info) @@ -515,6 +646,8 @@ async def install_plugin(self): ignore_version_check=ignore_version_check, ) # self.core_lifecycle.restart() + if plugin_info and plugin_info.get("type") == "sdk": + logger.info("SDK 插件 %s 安装成功", plugin_info.get("name")) logger.info(f"安装插件 {repo_url} 成功。") return Response().ok(plugin_info, "安装成功。").__dict__ except PluginVersionIncompatibleError as e: @@ -556,6 +689,8 @@ async def install_plugin_upload(self): ignore_version_check=ignore_version_check, ) # self.core_lifecycle.restart() + if plugin_info and plugin_info.get("type") == "sdk": + logger.info("SDK 插件 %s 上传安装成功", plugin_info.get("name")) logger.info(f"安装插件 {file.filename} 成功") return Response().ok(plugin_info, "安装成功。").__dict__ except PluginVersionIncompatibleError as e: @@ -583,6 +718,10 @@ async def uninstall_plugin(self): plugin_name = post_data["name"] delete_config = post_data.get("delete_config", False) delete_data = post_data.get("delete_data", False) + if self._is_sdk_plugin(plugin_name): + return Response().error( + "SDK 插件在 MVP 中不支持卸载,请手动移除目录" + ).__dict__, 400 try: logger.info(f"正在卸载插件 {plugin_name}") await self.plugin_manager.uninstall_plugin( @@ -635,6 +774,8 @@ async def update_plugin(self): post_data = await request.get_json() plugin_name = post_data["name"] proxy: str = post_data.get("proxy", None) + if self._is_sdk_plugin(plugin_name): + return Response().error("SDK 插件在 MVP 中不支持更新").__dict__, 400 try: logger.info(f"正在更新插件 {plugin_name}") await self.plugin_manager.update_plugin(plugin_name, proxy) @@ -709,6 +850,19 @@ async def off_plugin(self): post_data = await request.get_json() plugin_name = post_data["name"] + if self._is_sdk_plugin(plugin_name): + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return Response().error("SDK bridge 未初始化").__dict__, 500 + try: + await sdk_bridge.turn_off_plugin(plugin_name) + except ValueError as exc: + return Response().error(str(exc)).__dict__, 404 + except Exception as exc: + logger.error(f"/api/plugin/off: {traceback.format_exc()}") + return Response().error(str(exc)).__dict__ + logger.info(f"停用 SDK 插件 {plugin_name} 。") + return Response().ok(None, "停用成功。").__dict__ try: await self.plugin_manager.turn_off_plugin(plugin_name) logger.info(f"停用插件 {plugin_name} 。") @@ -727,9 +881,22 @@ async def on_plugin(self): post_data = await request.get_json() plugin_name = post_data["name"] + if self._is_sdk_plugin(plugin_name): + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return Response().error("SDK bridge 未初始化").__dict__, 500 + try: + await sdk_bridge.turn_on_plugin(plugin_name) + except ValueError as exc: + return Response().error(str(exc)).__dict__, 404 + except Exception as exc: + logger.error(f"/api/plugin/on: {traceback.format_exc()}") + return Response().error(str(exc)).__dict__ + logger.info(f"启用 SDK 插件 {plugin_name}") + return Response().ok(None, "启用成功。").__dict__ try: await self.plugin_manager.turn_on_plugin(plugin_name) - logger.info(f"启用插件 {plugin_name} 。") + logger.info(f"启用插件 {plugin_name}") return Response().ok(None, "启用成功。").__dict__ except Exception as e: logger.error(f"/api/plugin/on: {traceback.format_exc()}") @@ -737,50 +904,83 @@ async def on_plugin(self): async def get_plugin_readme(self): plugin_name = request.args.get("name") + repo_url = str(request.args.get("repo_url") or "").strip() or None logger.debug(f"正在获取插件 {plugin_name} 的README文件内容") - if not plugin_name: - logger.warning("插件名称为空") - return Response().error("插件名称不能为空").__dict__ + if not plugin_name and not repo_url: + logger.warning("插件名称和仓库地址均为空") + return Response().error("插件名称或仓库地址不能为空").__dict__ - plugin_obj = None - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name == plugin_name: - plugin_obj = plugin - break + legacy_plugin = self._find_legacy_plugin(plugin_name) if plugin_name else None + if legacy_plugin is not None: + if not legacy_plugin.root_dir_name: + logger.warning(f"插件 {plugin_name} 目录不存在") + return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ - if not plugin_obj: - logger.warning(f"插件 {plugin_name} 不存在") - return Response().error(f"插件 {plugin_name} 不存在").__dict__ + if legacy_plugin.reserved: + plugin_dir = os.path.join( + self.plugin_manager.reserved_plugin_path, + legacy_plugin.root_dir_name, + ) + else: + plugin_dir = os.path.join( + self.plugin_manager.plugin_store_path, + legacy_plugin.root_dir_name, + ) - if not plugin_obj.root_dir_name: - logger.warning(f"插件 {plugin_name} 目录不存在") - return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ + if not os.path.isdir(plugin_dir): + logger.warning(f"无法找到插件目录: {plugin_dir}") + return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ - if plugin_obj.reserved: - plugin_dir = os.path.join( - self.plugin_manager.reserved_plugin_path, - plugin_obj.root_dir_name, - ) - else: - plugin_dir = os.path.join( - self.plugin_manager.plugin_store_path, - plugin_obj.root_dir_name, - ) + readme_path = os.path.join(plugin_dir, "README.md") + if not os.path.isfile(readme_path): + logger.warning(f"插件 {plugin_name} 没有README文件") + return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ - if not os.path.isdir(plugin_dir): - logger.warning(f"无法找到插件目录: {plugin_dir}") - return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ + try: + with open(readme_path, encoding="utf-8") as f: + readme_content = f.read() - readme_path = os.path.join(plugin_dir, "README.md") + return ( + Response() + .ok({"content": readme_content}, "成功获取README内容") + .__dict__ + ) + except Exception as e: + logger.error(f"/api/plugin/readme: {traceback.format_exc()}") + return Response().error(f"读取README文件失败: {e!s}").__dict__ - if not os.path.isfile(readme_path): + if repo_url is None and plugin_name: + repo_url = self._resolve_plugin_repo_url(plugin_name) + + if repo_url is not None: + try: + readme_content = await self._fetch_github_repo_readme(repo_url) + return ( + Response() + .ok({"content": readme_content}, "成功获取README内容") + .__dict__ + ) + except Exception as exc: + if not plugin_name: + logger.error(f"/api/plugin/readme: {traceback.format_exc()}") + return Response().error(f"读取README文件失败: {exc!s}").__dict__ + logger.warning( + "从 GitHub 获取 SDK 插件 %s README 失败: %s", plugin_name, exc + ) + + plugin_dir = self._resolve_sdk_plugin_dir(plugin_name) if plugin_name else None + if plugin_dir is None: + logger.warning(f"插件 {plugin_name or repo_url} 不存在") + return Response().error(f"插件 {plugin_name or repo_url} 不存在").__dict__ + + readme_path = plugin_dir / "README.md" + if not readme_path.is_file(): logger.warning(f"插件 {plugin_name} 没有README文件") return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = readme_path.read_text(encoding="utf-8") return ( Response() @@ -803,44 +1003,58 @@ async def get_plugin_changelog(self): logger.warning("插件名称为空") return Response().error("插件名称不能为空").__dict__ - # 查找插件 - plugin_obj = None - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name == plugin_name: - plugin_obj = plugin - break + legacy_plugin = self._find_legacy_plugin(plugin_name) + if legacy_plugin is not None: + if not legacy_plugin.root_dir_name: + logger.warning(f"插件 {plugin_name} 目录不存在") + return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ - if not plugin_obj: - logger.warning(f"插件 {plugin_name} 不存在") - return Response().error(f"插件 {plugin_name} 不存在").__dict__ + if legacy_plugin.reserved: + plugin_dir = os.path.join( + self.plugin_manager.reserved_plugin_path, + legacy_plugin.root_dir_name, + ) + else: + plugin_dir = os.path.join( + self.plugin_manager.plugin_store_path, + legacy_plugin.root_dir_name, + ) - if not plugin_obj.root_dir_name: - logger.warning(f"插件 {plugin_name} 目录不存在") - return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ + if not os.path.isdir(plugin_dir): + logger.warning(f"无法找到插件目录: {plugin_dir}") + return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ + + changelog_names = ["CHANGELOG.md", "changelog.md", "CHANGELOG", "changelog"] + for name in changelog_names: + changelog_path = os.path.join(plugin_dir, name) + if os.path.isfile(changelog_path): + try: + with open(changelog_path, encoding="utf-8") as f: + changelog_content = f.read() + return ( + Response() + .ok({"content": changelog_content}, "成功获取更新日志") + .__dict__ + ) + except Exception as e: + logger.error(f"/api/plugin/changelog: {traceback.format_exc()}") + return Response().error(f"读取更新日志失败: {e!s}").__dict__ - if plugin_obj.reserved: - plugin_dir = os.path.join( - self.plugin_manager.reserved_plugin_path, - plugin_obj.root_dir_name, - ) - else: - plugin_dir = os.path.join( - self.plugin_manager.plugin_store_path, - plugin_obj.root_dir_name, - ) + logger.warning(f"插件 {plugin_name} 没有更新日志文件") + return Response().ok({"content": None}, "该插件没有更新日志文件").__dict__ - if not os.path.isdir(plugin_dir): - logger.warning(f"无法找到插件目录: {plugin_dir}") - return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ + plugin_dir = self._resolve_sdk_plugin_dir(plugin_name) + if plugin_dir is None: + logger.warning(f"插件 {plugin_name} 不存在") + return Response().error(f"插件 {plugin_name} 不存在").__dict__ # 尝试多种可能的文件名 changelog_names = ["CHANGELOG.md", "changelog.md", "CHANGELOG", "changelog"] for name in changelog_names: - changelog_path = os.path.join(plugin_dir, name) - if os.path.isfile(changelog_path): + changelog_path = plugin_dir / name + if changelog_path.is_file(): try: - with open(changelog_path, encoding="utf-8") as f: - changelog_content = f.read() + changelog_content = changelog_path.read_text(encoding="utf-8") return ( Response() .ok({"content": changelog_content}, "成功获取更新日志") diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index abae13e33b..77bcf40698 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -2,6 +2,7 @@ import re import shutil import traceback +import uuid from collections.abc import Awaitable, Callable from pathlib import Path from typing import Any @@ -388,24 +389,28 @@ async def download_skill(self): .__dict__ ) - skill_dir = Path(skill_mgr.skills_root) / name - skill_md = skill_dir / "SKILL.md" - if not skill_dir.is_dir() or not skill_md.exists(): + if skill_mgr.get_local_skill_source(name) is None: return Response().error("Local skill not found").__dict__ export_dir = Path(get_astrbot_temp_path()) / "skill_exports" export_dir.mkdir(parents=True, exist_ok=True) zip_base = export_dir / name zip_path = zip_base.with_suffix(".zip") + bundle_dir = export_dir / f"{name}_{uuid.uuid4().hex}" if zip_path.exists(): zip_path.unlink() - shutil.make_archive( - str(zip_base), - "zip", - root_dir=str(skill_mgr.skills_root), - base_dir=name, - ) + try: + skill_mgr.materialize_local_skill_bundle(bundle_dir, skill_names=[name]) + shutil.make_archive( + str(zip_base), + "zip", + root_dir=str(bundle_dir), + base_dir=name, + ) + finally: + if bundle_dir.exists(): + shutil.rmtree(bundle_dir, ignore_errors=True) return await send_file( str(zip_path), diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 84f8dcc6d7..825abc005f 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -445,14 +445,20 @@ async def get_tool_list(self): origin_name = "unknown" tool_info = { + "tool_key": _build_legacy_tool_key(tool, origin, origin_name), "name": tool.name, "description": tool.description, "parameters": tool.parameters, "active": tool.active, "origin": origin, "origin_name": origin_name, + "runtime_kind": "legacy", + "plugin_id": None, } tools_dict.append(tool_info) + sdk_bridge = getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + tools_dict.extend(sdk_bridge.list_dashboard_tools()) return Response().ok(data=tools_dict).__dict__ except Exception as e: logger.error(traceback.format_exc()) @@ -463,28 +469,65 @@ async def toggle_tool(self): try: data = await request.json tool_name = data.get("name") + tool_key = data.get("tool_key") action = data.get("activate") # True or False + runtime_kind = str(data.get("runtime_kind", "legacy") or "legacy") + plugin_id = data.get("plugin_id") - if not tool_name or action is None: + if (not tool_name and not tool_key) or action is None: return ( Response() - .error("Missing required parameters: name or activate") + .error("Missing required parameters: tool_key/name or activate") .__dict__ ) - if action: - try: - ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) - except ValueError as e: - return Response().error(f"Failed to activate tool: {e!s}").__dict__ + if runtime_kind == "sdk": + sdk_bridge = getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is None: + return Response().error("SDK bridge is unavailable.").__dict__ + if not plugin_id or not tool_name: + return ( + Response() + .error("SDK tool toggle requires plugin_id and name") + .__dict__ + ) + plugin_metadata = sdk_bridge.get_plugin_metadata(str(plugin_id)) + if ( + action + and plugin_metadata is not None + and not plugin_metadata.get("enabled", False) + ): + return ( + Response() + .error( + "The SDK plugin is disabled. Enable the plugin before activating its tool." + ) + .__dict__ + ) + if action: + ok = sdk_bridge.activate_llm_tool(str(plugin_id), str(tool_name)) + else: + ok = sdk_bridge.deactivate_llm_tool(str(plugin_id), str(tool_name)) else: - ok = self.tool_mgr.deactivate_llm_tool(tool_name) + if action: + try: + ok = self.tool_mgr.activate_llm_tool( + str(tool_name), star_map=star_map + ) + except ValueError as e: + return ( + Response().error(f"Failed to activate tool: {e!s}").__dict__ + ) + else: + ok = self.tool_mgr.deactivate_llm_tool(str(tool_name)) if ok: return Response().ok(None, "Operation successful.").__dict__ return ( Response() - .error(f"Tool {tool_name} does not exist or the operation failed.") + .error( + f"Tool {tool_key or tool_name} does not exist or the operation failed." + ) .__dict__ ) @@ -510,3 +553,11 @@ async def sync_provider(self): except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"Sync failed: {e!s}").__dict__ + + +def _build_legacy_tool_key(tool, origin: str, origin_name: str) -> str: + if origin == "mcp" and origin_name: + return f"mcp:{origin_name}:{tool.name}" + if origin == "plugin" and getattr(tool, "handler_module_path", None): + return f"plugin:{tool.handler_module_path}:{tool.name}" + return f"tool:{tool.name}" diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index cbb7296bd0..053130dc27 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -13,6 +13,7 @@ from hypercorn.asyncio import serve from hypercorn.config import Config as HyperConfig from quart import Quart, g, jsonify, request +from quart import Response as QuartResponse from quart.logging import default_handler from astrbot.core import logger @@ -108,7 +109,7 @@ def __init__( core_lifecycle, core_lifecycle.plugin_manager, ) - self.command_route = CommandRoute(self.context) + self.command_route = CommandRoute(self.context, core_lifecycle) self.cr = ConfigRoute(self.context, core_lifecycle) self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.sfr = StaticFileRoute(self.context) @@ -145,23 +146,128 @@ def __init__( view_func=self.srv_plug_route, methods=["GET", "POST"], ) + self.app.add_url_rule( + "/plug/", + view_func=self.srv_public_plug_route, + methods=["GET"], + ) self.shutdown_event = shutdown_event self._init_jwt_secret() async def srv_plug_route(self, subpath, *args, **kwargs): - """插件路由""" + """插件路由(需要认证)""" + auth_error = self._require_bearer_auth() + if auth_error is not None: + return auth_error + output = await self._dispatch_plugin_route(subpath, *args, **kwargs) + if output is not None: + return self._build_sdk_plugin_response(output) + return jsonify(Response().error("未找到该路由").__dict__) + + async def srv_public_plug_route(self, subpath, *args, **kwargs): + """公开插件页面路由""" + output = await self._dispatch_plugin_route(subpath, *args, **kwargs) + if output is None: + return jsonify(Response().error("未找到该路由").__dict__) + if not self._is_public_plugin_page_response(output): + r = jsonify(Response().error("该路由需要通过 /api/plug 访问").__dict__) + r.status_code = 403 + return r + return self._build_sdk_plugin_response(output) + + async def _dispatch_plugin_route(self, subpath, *args, **kwargs): registered_web_apis = self.core_lifecycle.star_context.registered_web_apis for api in registered_web_apis: route, view_handler, methods, _ = api if route == f"/{subpath}" and request.method in methods: return await view_handler(*args, **kwargs) - return jsonify(Response().error("未找到该路由").__dict__) + sdk_bridge = getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + return await sdk_bridge.dispatch_http_request(f"/{subpath}", request.method) + return None + + @staticmethod + def _is_public_plugin_page_response(output: dict[str, object]) -> bool: + headers = output.get("headers") + if not isinstance(headers, dict): + headers = {} + content_type = str( + headers.get("Content-Type", headers.get("content-type", "")) + ).lower() + body = output.get("body") + if isinstance(body, str) and "text/html" in content_type: + return True + return isinstance(body, (bytes, bytearray)) and content_type.startswith( + "image/" + ) + + @staticmethod + def _build_sdk_plugin_response(output: dict) -> QuartResponse: + status = int(output.get("status", 200)) + headers = output.get("headers") + if headers is None: + headers = {} + if not isinstance(headers, dict): + raise ValueError("SDK HTTP handler headers must be an object") + + body = output.get("body") + if isinstance(body, (dict, list)): + response = jsonify(body) + response.status_code = status + response.headers.setdefault("Content-Type", "application/json") + elif isinstance(body, str): + response = QuartResponse( + body, + status=status, + content_type="text/plain; charset=utf-8", + ) + elif isinstance(body, (bytes, bytearray)): + response = QuartResponse( + bytes(body), + status=status, + content_type=str( + headers.get("Content-Type") + or headers.get("content-type") + or "application/octet-stream" + ), + ) + elif body is None: + response = QuartResponse("", status=status) + else: + raise ValueError( + "SDK HTTP handler body must be object, array, string, bytes or null" + ) + + for key, value in headers.items(): + response.headers[str(key)] = str(value) + return response + + def _require_bearer_auth(self): + """检查 Bearer token,无效时返回 401 响应,有效时返回 None。""" + token = request.headers.get("Authorization") + if not token: + r = jsonify(Response().error("未授权").__dict__) + r.status_code = 401 + return r + token = token.removeprefix("Bearer ") + try: + payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + g.username = payload["username"] + except (jwt.InvalidTokenError, KeyError): + r = jsonify(Response().error("未授权").__dict__) + r.status_code = 401 + return r + return None async def auth_middleware(self): if not request.path.startswith("/api"): return None + # SDK plugin HTTP routes are proxied under /api/plug and must be able to + # implement their own authentication flow, including public login pages. + if request.path.startswith("/api/plug/"): + return None if request.path.startswith("/api/v1"): raw_key = self._extract_raw_api_key() if not raw_key: diff --git a/dashboard/src/components/extension/componentPanel/components/CommandTable.vue b/dashboard/src/components/extension/componentPanel/components/CommandTable.vue index 32eebb746b..d9d281e971 100644 --- a/dashboard/src/components/extension/componentPanel/components/CommandTable.vue +++ b/dashboard/src/components/extension/componentPanel/components/CommandTable.vue @@ -90,6 +90,10 @@ const getRowProps = ({ item }: { item: CommandItem }) => { } return classes.length > 0 ? { class: classes.join(' ') } : {}; }; + +const canToggle = (cmd: CommandItem): boolean => cmd.supports_toggle !== false; +const canRename = (cmd: CommandItem): boolean => cmd.supports_rename !== false; +const canEditPermission = (cmd: CommandItem): boolean => cmd.supports_permission !== false;