diff --git a/.claude/agents/backend-dev.md b/.claude/agents/backend-dev.md new file mode 100644 index 000000000..8c289117e --- /dev/null +++ b/.claude/agents/backend-dev.md @@ -0,0 +1,40 @@ +--- +name: backend-dev +description: MemOS backend / library implementation sub-agent. Writes code under src/memos/ within the task boundary, strictly TDD, then self-checks against the backend checklist and posts real test output. +tools: Read, Edit, Write, Bash, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Implement backend / library code under `src/memos//`; do not range outside the current task. +- Strict TDD: write a failing test in `tests//test_*.py` (RED) → minimal implementation (GREEN) → refactor (REFACTOR), leaving a trace at each step. +- Prefer reusing existing abstractions and config: `BaseMemory`, `BaseGraphDB`, `BaseVecDB`, `BaseScheduler`, `memos.configs.*`, `memos.dependency`. + +## Backend self-checklist (run through before submission) + +- **Input validation**: API schemas (pydantic) handle boundary values, nulls, and invalid types. +- **Error handling**: raise semantic exceptions from `memos.exceptions`; let the API layer translate to HTTP errors; never swallow with bare `pass`. +- **Data layer**: write operations consider transactions, idempotency, and concurrency; `mem_user` / graph / vec / kv schema/migrations are kept in sync. +- **Compatibility**: do not break the contract of top-level `memos.*` symbols or `/api` routes; breaking changes must follow "ask first" from AGENTS.md. +- **Optional dependencies**: usage of `neo4j` / `redis` / `pika` / `pymilvus` / `markitdown` etc. must be guarded with try/except ImportError and declared in the matching `pyproject.toml` extras. +- **Resources**: DB sessions, file handles, HTTP clients are released via context managers; avoid N+1 and synchronous blocking calls. +- **Logging**: use `logging.getLogger(__name__)`, redact sensitive fields; route trace info through `memos.context.context`. +- **Formatting**: always run `make format` before submission. + +## Output requirements + +Paste the real output of the real commands (do not just say "passed"): + +- `poetry run pytest tests// -q` +- `make test` for full runs when needed +- `make format` (or `make pre_commit`) +- A list of changed files mapped to the originating requirement. + +## Do not + +- Touch `apps/`, `docker/`, `scripts/`, `pyproject.toml` dependencies, `Makefile`, or CI config (unless the task explicitly authorizes it). +- Review your own code (code-reviewer's job). +- Claim completion without test output. +- Skip `pre-commit` or commit with `--no-verify`. diff --git a/.claude/agents/code-reviewer.md b/.claude/agents/code-reviewer.md new file mode 100644 index 000000000..6e9b218cd --- /dev/null +++ b/.claude/agents/code-reviewer.md @@ -0,0 +1,40 @@ +--- +name: code-reviewer +description: Code-review sub-agent. Reviews MemOS diffs for contract consistency, Ruff / typing / optional-dependency handling, and test evidence; returns APPROVE or CHANGES_REQUESTED. +tools: Read, Bash, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +Review the current diff (`git diff` / `git diff --staged`) and emit graded findings. + +## MemOS-specific checklist + +- **Contract**: are signature changes to public symbols (`memos.api.*`, top-level `memos.*`) backward compatible; if breaking, did it follow AGENTS.md "ask first". +- **Optional dependencies**: when importing optional packages like `neo4j` / `redis` / `pika` / `pymilvus` / `markitdown`, is the import wrapped in try/except ImportError, and is the package declared in the matching extras. +- **Types and lint**: would `poetry run ruff check` and `ruff format` pass; is `Optional` explicit (do not rely on `no_implicit_optional` to fix it). +- **Exceptions**: are semantic exceptions from `memos.exceptions` raised, not bare `Exception` / `RuntimeError`. +- **Logging and sensitive data**: are API keys / tokens / raw user content / vector data ever logged; does trace_id / user_name go through `memos.context.context` instead of `print`. +- **Test evidence**: are new/updated `tests//test_*.py` present; is real pytest output included. +- **Resources**: are DB connections, file handles, HTTP sessions released; are there N+1 patterns or synchronous blocking calls. + +## Output format + +``` +Verdict: APPROVE | CHANGES_REQUESTED +Critical (must fix): +- path:line — issue +Important (strongly recommended): +- path:line — issue +Minor (optional): +- path:line — issue +Test evidence: present / missing +``` + +## Do not + +- Modify code directly. +- Substitute for a human final approver. +- Grant APPROVE when pytest output is missing. diff --git a/.claude/agents/design-reviewer.md b/.claude/agents/design-reviewer.md new file mode 100644 index 000000000..e747b424c --- /dev/null +++ b/.claude/agents/design-reviewer.md @@ -0,0 +1,35 @@ +--- +name: design-reviewer +description: Design-review sub-agent. Reviews design docs across the four dimensions of architecture, interface, performance, and security, covering MemOS's multi-memory / multi-storage backend constraints. +tools: Read, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Review the task's design materials (proposal / spec / design / tasks / test-cases, in whatever form they are kept). +- Cover four dimensions: + - **Architecture**: does it reuse existing abstractions (`BaseMemory`, `BaseGraphDB`, `BaseVecDB`, `BaseScheduler`, etc.), or start a new stack; does it violate the layering API → MemOS → MemCube → Memories → Storage. + - **Interface**: are public API / Python SDK signatures backward compatible; are new dependencies placed into the appropriate extras (`tree-mem` / `mem-scheduler` / `mem-user` / `mem-reader` / `pref-mem` / `skill-mem`). + - **Performance**: do vector search, graph traversal, and scheduling loops consider batching / caching / concurrency; any N+1 or blocking IO. + - **Security**: is user isolation (`mem_user`) handled; do we avoid writing into `.env` / credentials / private paths. +- Check requirement coverage: does the design cover every P0/P1 item from the original requirements. +- Call out blockers (must fix) vs. suggestions (optional). + +## Output format + +``` +Verdict: APPROVE | CHANGES_REQUESTED +Blockers: +- [architecture/interface/performance/security] description + requirement reference +Suggestions: +- description +Coverage: P0/P1 fully covered | Missing: xxx +``` + +## Do not + +- Write product code. +- Review the code implementation (that is code-reviewer's job). +- Substitute for a human final approver. diff --git a/.claude/agents/explorer.md b/.claude/agents/explorer.md new file mode 100644 index 000000000..dd61be986 --- /dev/null +++ b/.claude/agents/explorer.md @@ -0,0 +1,35 @@ +--- +name: explorer +description: Read-only code exploration sub-agent. Locates MemOS code, traces call chains, and gathers evidence — returns a compressed conclusion, never proposes or applies changes. +tools: Read, Grep, Glob, Bash +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Locate relevant modules, symbols, and call chains under `src/memos/` for the question the main agent asks. +- Distinguish core packages (`mem_os` / `mem_cube` / `mem_scheduler`) from optional backends (`graph_dbs/neo4j*`, `vec_dbs/milvus*`, etc.) and call out any extras dependencies. +- Trace execution paths and gather evidence (with `path:line` annotations + a one-line key snippet). +- Return a compressed conclusion only; do not echo raw bulk output. + +## Output format + +- Conclusion first: one sentence that answers the main agent's question. +- Evidence list: `src/memos//.py:LINE` + a one-line note. +- Call chain (if applicable): `A.f -> B.g -> C.h`, annotating each hop with its file location. +- Uncertainty: explicitly flag "not found / needs further confirmation"; do not invent. + +## MemOS-specific locator hints + +- API routes: `src/memos/api/` + `tests/api/` +- Memory types: `src/memos/memories/` (textual / tree / preference / skill etc.) +- Storage backends: `src/memos/graph_dbs/`, `src/memos/vec_dbs/` +- Config and DI: `src/memos/configs/`, `src/memos/dependency.py` +- Plugin entry points: `pyproject.toml [project.entry-points."memos.plugins"]` + `extensions/` + +## Do not + +- Modify any file (read-only). +- Propose an implementation plan — return facts and locations only. +- Substitute for the judgment of design-reviewer / code-reviewer. diff --git a/.claude/agents/integration-tester.md b/.claude/agents/integration-tester.md new file mode 100644 index 000000000..49eea3bcd --- /dev/null +++ b/.claude/agents/integration-tester.md @@ -0,0 +1,39 @@ +--- +name: integration-tester +description: MemOS integration-testing sub-agent. Authors and executes pytest cases under tests/ based on the task's requirements and design, and emits real test reports. +tools: Read, Edit, Write, Bash, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Based on the task's requirements and design docs, write pytest cases under `tests//`. +- Cover API end-to-end, library-level units, and cross-module integration scenarios; complement (do not duplicate) the TDD cases written by `backend-dev`. +- Run the tests and produce a real report. + +## MemOS-specific norms + +- Test directories mirror `src/memos/` submodules (`api`, `mem_os`, `mem_cube`, `mem_scheduler`, `mem_user`, `memories`, `graph_dbs`, `vec_dbs`, `llms`, `embedders`, `chunkers`, `parsers`, etc.). +- Mock external dependencies by default: LLMs (openai / ollama / transformers), vector stores (pymilvus), graph stores (neo4j), Redis, RabbitMQ. +- Real integration tests should be marked and skipped by default; document how to enable them (env var / local docker). +- Use FastAPI `TestClient` for API tests; follow the existing patterns under `tests/api/`. +- Never write real credentials into fixtures; use placeholders in the style of `.env.example`. + +## Output format + +``` +Test file: tests//test_.py +Coverage map: +- Requirement 1.1 → test_xxx +Command: poetry run pytest tests//test_.py -q +Output: + +Result: N passed, M failed +``` + +## Do not + +- Modify product code under `src/memos/` (backend-dev's job). +- Substitute for code-reviewer. +- Claim completion without real pytest output. diff --git a/.codex/agents/backend-dev.toml b/.codex/agents/backend-dev.toml new file mode 100644 index 000000000..510de8a0e --- /dev/null +++ b/.codex/agents/backend-dev.toml @@ -0,0 +1,33 @@ +name = "backend-dev" +description = "MemOS backend / library implementation sub-agent. Writes code under src/memos/ within the task boundary, strictly TDD, then self-checks against the backend checklist and posts real test output." +sandbox_mode = "workspace-write" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Implement backend / library code under src/memos//; do not range outside the current task. +- Strict TDD: write a failing test in tests//test_*.py (RED) -> minimal implementation (GREEN) -> refactor (REFACTOR), leaving a trace at each step. +- Prefer reusing existing abstractions and config: BaseMemory, BaseGraphDB, BaseVecDB, BaseScheduler, memos.configs.*, memos.dependency. + +Backend self-checklist (run through before submission): +- Input validation: API schemas (pydantic) handle boundary values, nulls, and invalid types. +- Error handling: raise semantic exceptions from memos.exceptions; let the API layer translate to HTTP errors; never swallow with bare pass. +- Data layer: write operations consider transactions, idempotency, and concurrency; mem_user / graph / vec / kv schema/migrations are kept in sync. +- Compatibility: do not break the contract of top-level memos.* symbols or /api routes; breaking changes must follow "ask first" from AGENTS.md. +- Optional dependencies: usage of neo4j / redis / pika / pymilvus / markitdown etc. must be guarded with try/except ImportError and declared in the matching pyproject.toml extras. +- Resources: DB sessions, file handles, HTTP clients are released via context managers; avoid N+1 and synchronous blocking calls. +- Logging: use logging.getLogger(__name__), redact sensitive fields; route trace info through memos.context.context. +- Formatting: always run make format before submission. + +Output requirements (paste the real output of the real commands): +- poetry run pytest tests// -q +- make test for full runs when needed +- make format (or make pre_commit) +- A list of changed files mapped to the originating requirement. + +Do not: +- Touch apps/, docker/, scripts/, pyproject.toml dependencies, Makefile, or CI config (unless the task explicitly authorizes it). +- Review your own code (code-reviewer's job). +- Claim completion without test output. +- Skip pre-commit or commit with --no-verify. +""" diff --git a/.codex/agents/code-reviewer.toml b/.codex/agents/code-reviewer.toml new file mode 100644 index 000000000..8a713b4e9 --- /dev/null +++ b/.codex/agents/code-reviewer.toml @@ -0,0 +1,29 @@ +name = "code-reviewer" +description = "Code-review sub-agent. Reviews MemOS diffs for contract consistency, Ruff / typing / optional-dependency handling, and test evidence; returns APPROVE or CHANGES_REQUESTED." +sandbox_mode = "read-only" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: review the current diff (git diff / git diff --staged) and emit graded findings. + +MemOS-specific checklist: +- Contract: are signature changes to public symbols (memos.api.*, top-level memos.*) backward compatible; if breaking, did it follow AGENTS.md "ask first". +- Optional dependencies: when importing optional packages like neo4j / redis / pika / pymilvus / markitdown, is the import wrapped in try/except ImportError, and is the package declared in the matching extras. +- Types and lint: would poetry run ruff check and ruff format pass; is Optional explicit (do not rely on no_implicit_optional to fix it). +- Exceptions: are semantic exceptions from memos.exceptions raised, not bare Exception / RuntimeError. +- Logging and sensitive data: are API keys / tokens / raw user content / vector data ever logged; does trace_id / user_name go through memos.context.context instead of print. +- Test evidence: are new/updated tests//test_*.py present; is real pytest output included. +- Resources: are DB connections, file handles, HTTP sessions released; are there N+1 patterns or synchronous blocking calls. + +Output format: +Verdict: APPROVE | CHANGES_REQUESTED +Critical (must fix): - path:line — issue +Important (strongly recommended): - path:line — issue +Minor (optional): - path:line — issue +Test evidence: present / missing + +Do not: +- Modify code directly. +- Substitute for a human final approver. +- Grant APPROVE when pytest output is missing. +""" diff --git a/.codex/agents/design-reviewer.toml b/.codex/agents/design-reviewer.toml new file mode 100644 index 000000000..49c9b7be7 --- /dev/null +++ b/.codex/agents/design-reviewer.toml @@ -0,0 +1,27 @@ +name = "design-reviewer" +description = "Design-review sub-agent. Reviews design docs across the four dimensions of architecture, interface, performance, and security, covering MemOS's multi-memory / multi-storage backend constraints." +sandbox_mode = "read-only" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Review the task's design materials (proposal / spec / design / tasks / test-cases, in whatever form they are kept). +- Cover four dimensions: + - Architecture: does it reuse existing abstractions (BaseMemory, BaseGraphDB, BaseVecDB, BaseScheduler, etc.), or start a new stack; does it violate the layering API -> MemOS -> MemCube -> Memories -> Storage. + - Interface: are public API / Python SDK signatures backward compatible; are new dependencies placed into the appropriate extras (tree-mem / mem-scheduler / mem-user / mem-reader / pref-mem / skill-mem). + - Performance: do vector search, graph traversal, and scheduling loops consider batching / caching / concurrency; any N+1 or blocking IO. + - Security: is user isolation (mem_user) handled; do we avoid writing into .env / credentials / private paths. +- Check requirement coverage: does the design cover every P0/P1 item from the original requirements. +- Call out blockers (must fix) vs. suggestions (optional). + +Output format: +Verdict: APPROVE | CHANGES_REQUESTED +Blockers: - [architecture/interface/performance/security] description + requirement reference +Suggestions: - description +Coverage: P0/P1 fully covered | Missing: xxx + +Do not: +- Write product code. +- Review the code implementation (that is code-reviewer's job). +- Substitute for a human final approver. +""" diff --git a/.codex/agents/explorer.toml b/.codex/agents/explorer.toml new file mode 100644 index 000000000..b8a94a3b1 --- /dev/null +++ b/.codex/agents/explorer.toml @@ -0,0 +1,30 @@ +name = "explorer" +description = "Read-only code exploration sub-agent. Locates MemOS code, traces call chains, gathers evidence, and returns a compressed conclusion — never proposes or applies changes." +sandbox_mode = "read-only" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Locate relevant modules, symbols, and call chains under src/memos/ for the question the main agent asks. +- Distinguish core packages (mem_os / mem_cube / mem_scheduler) from optional backends (graph_dbs/neo4j*, vec_dbs/milvus*, etc.) and call out any extras dependencies. +- Trace execution paths and gather evidence (with path:line annotations + a one-line key snippet). +- Return a compressed conclusion only; do not echo raw bulk output. + +Output format: +- Conclusion first: one sentence that answers the main agent's question. +- Evidence list: src/memos//.py:LINE + a one-line note. +- Call chain (if applicable): A.f -> B.g -> C.h, annotating each hop with its file location. +- Uncertainty: explicitly flag "not found / needs further confirmation"; do not invent. + +MemOS-specific locator hints: +- API routes: src/memos/api/ + tests/api/ +- Memory types: src/memos/memories/ (textual / tree / preference / skill etc.) +- Storage backends: src/memos/graph_dbs/, src/memos/vec_dbs/ +- Config and DI: src/memos/configs/, src/memos/dependency.py +- Plugin entry points: pyproject.toml [project.entry-points."memos.plugins"] + extensions/ + +Do not: +- Modify any file (read-only). +- Propose an implementation plan — return facts and locations only. +- Substitute for the judgment of design-reviewer / code-reviewer. +""" diff --git a/.codex/agents/integration-tester.toml b/.codex/agents/integration-tester.toml new file mode 100644 index 000000000..5baa4621c --- /dev/null +++ b/.codex/agents/integration-tester.toml @@ -0,0 +1,30 @@ +name = "integration-tester" +description = "MemOS integration-testing sub-agent. Authors and executes pytest cases under tests/ based on the task's requirements and design, and emits real test reports." +sandbox_mode = "workspace-write" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Based on the task's requirements and design docs, write pytest cases under tests//. +- Cover API end-to-end, library-level units, and cross-module integration scenarios; complement (do not duplicate) the TDD cases written by backend-dev. +- Run the tests and produce a real report. + +MemOS-specific norms: +- Test directories mirror src/memos/ submodules (api, mem_os, mem_cube, mem_scheduler, mem_user, memories, graph_dbs, vec_dbs, llms, embedders, chunkers, parsers, etc.). +- Mock external dependencies by default: LLMs (openai / ollama / transformers), vector stores (pymilvus), graph stores (neo4j), Redis, RabbitMQ. +- Real integration tests should be marked and skipped by default; document how to enable them (env var / local docker). +- Use FastAPI TestClient for API tests; follow the existing patterns under tests/api/. +- Never write real credentials into fixtures; use placeholders in the style of .env.example. + +Output format: +Test file: tests//test_.py +Coverage map: Requirement 1.1 -> test_xxx +Command: poetry run pytest tests//test_.py -q +Output: +Result: N passed, M failed + +Do not: +- Modify product code under src/memos/ (backend-dev's job). +- Substitute for code-reviewer. +- Claim completion without real pytest output. +""" diff --git a/.private-paths b/.private-paths new file mode 100644 index 000000000..1df5fa57d --- /dev/null +++ b/.private-paths @@ -0,0 +1,11 @@ +# Paths exclusive to the enterprise repo (one per line). +# These will NOT be synced to the public repository. +# This file itself is also excluded from the public repo. + +extensions/ +pyproject.toml +poetry.lock +.private-paths +scripts/sync-public.sh +scripts/check-public-push.sh +Makefile diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..cd885b3c4 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,155 @@ +# AGENTS.md + +> Single source of truth for the project across AI runtimes. Claude Code, Codex, Cursor, Copilot, etc. all defer to this file. +> Runtime-specific adaptation belongs in each runtime's own file (Claude reads `CLAUDE.md`); do not mix it in here. + +## Project Overview + +**MemOS / MemoryOS**: a memory operating system for LLM agents. Python library plus a FastAPI service, providing multiple memory types (textual / tree / preference / skill / KV cache / LoRA parametric) plus scheduling, version management, and vector & graph storage. + +- **Repository**: https://github.com/MemTensor/MemOS +- **Documentation**: https://memos-docs.openmem.net/home/overview/ +- **PyPI**: https://pypi.org/project/MemoryOS/ +- **License**: Apache-2.0 +- **Top-level package**: `src/memos/`. Distribution name `MemoryOS`; import name `memos`. +- **CLI**: `memos` (entry `memos.cli:main`) +- **API service**: `memos.api.start_api:app` + +## Repository Layout + +| Path | Purpose | +|------|---------| +| `src/memos/mem_os/` | `MOS` / `MOSCore` — top-level Memory OS entry | +| `src/memos/mem_cube/` | `GeneralMemCube` — memory container aggregating multiple memory types | +| `src/memos/memories/` | Memory implementations: `textual/`, `activation/`, `parametric/` | +| `src/memos/mem_scheduler/` | Memory scheduler + monitors + ORM + task scheduling | +| `src/memos/mem_user/` | User / multi-tenant management (MySQL / Redis backends) | +| `src/memos/mem_chat/` `mem_reader/` `mem_agent/` `mem_feedback/` `multi_mem_cube/` | Chat sessions, ingest pipeline, agent integration, feedback channel, multi-cube routing | +| `src/memos/llms/` `embedders/` `vec_dbs/` `graph_dbs/` `chunkers/` `parsers/` `reranker/` | Provider implementations (`base.py` + `factory.py` + each backend) | +| `src/memos/api/` | FastAPI service (routers / handlers / middleware / MCP server) | +| `src/memos/configs/` | All pydantic configuration classes (one-to-one with the modules above) | +| `src/memos/context/` | Cross-thread context (trace_id / user / env) | +| `tests/` | pytest cases, subdirectories mirror `src/memos/` | +| `apps/` | Independent sub-projects, each with its own README; not part of the main Harness flow | +| `extensions/` | Official plugin examples | +| `docker/` `docs/` `evaluation/` `scripts/` | Deployment, documentation, evaluation, helper scripts | +| `.claude/agents/`, `.codex/agents/` | Project-recommended AI sub-agent definitions | + +## Command Cheatsheet + +- Install: `make install` (= `poetry install --extras all --with dev --with test` + pre-commit + push hook) +- Start API: `make serve` +- Export OpenAPI: `make openapi` (writes to `docs/openapi.json`) +- Run full tests: `make test` +- Run a single test: `poetry run pytest tests//test_xxx.py -q` +- Lint + format: `make format` +- Full pre-commit: `make pre_commit` +- Build: `poetry build` (publishing is automated by `python-release.yml` on GitHub release) + +## Core API + +### Python top-level entries (`from memos import ...`) + +| Symbol | Purpose | Source | +|--------|---------|--------| +| `MOS` | Memory OS top-level entry (inherits `MOSCore`) | `memos.mem_os.main` | +| `GeneralMemCube` | General memory container | `memos.mem_cube.general` | +| `MOSConfig` / `GeneralMemCubeConfig` | Primary configs | `memos.configs.mem_os` / `memos.configs.mem_cube` | +| `GeneralScheduler` / `SchedulerFactory` / `SchedulerConfigFactory` | Scheduler and factories | `memos.mem_scheduler.*` | + +Common `MOS` methods: `MOS.simple()` (auto-configure from env), `register_mem_cube(cube)`, `add(...)`, `search(...)`, `chat(...)`, `create_user(...)` / `list_users()`. + +### API entry + +- ASGI app: `memos.api.start_api:app` +- Routers: `src/memos/api/routers/` (`admin_router`, `product_router`, `server_router`) +- OpenAPI contract: `docs/openapi.json` (must run `make openapi` after touching the API) + +## Import Patterns + +| Use | Import | +|-----|--------| +| Top-level entries | `from memos import MOS, GeneralMemCube, MOSConfig` | +| Config classes | `from memos.configs. import ` | +| Any provider factory | `from memos..factory import Factory` | +| Logger | `from memos.log import get_logger`; `logger = get_logger(__name__)` | +| Context (trace) | `from memos.context.context import get_current_trace_id, get_current_user_name` | +| Exceptions | `from memos.exceptions import ` | + +## Provider Matrix + +Every provider follows the same three-piece pattern: `base.py` abstract class + `factory.py` registry + `configs/.py` config. The authoritative list of registered backends is the factory's `backend_to_class`; the snapshot below is provided for quick reference: + +| Category | Base class | Factory | Registered backends | +|----------|-----------|---------|---------------------| +| LLM | `BaseLLM` | `LLMFactory` | `openai` / `openai_new` / `azure` / `ollama` / `huggingface` / `huggingface_singleton` / `vllm` / `qwen` / `deepseek` | +| Embedder | `BaseEmbedder` | `EmbedderFactory` | `ollama` / `sentence_transformer` / `ark` / `universal_api` | +| Vector DB | `BaseVecDB` | `VecDBFactory` | `qdrant` / `milvus` | +| Graph DB | `BaseGraphDB` | `GraphStoreFactory` | `neo4j` / `neo4j_community` / `nebular` / `polardb` / `postgres` | +| Chunker | `BaseChunker` | `ChunkerFactory` | `sentence` / `markdown` / `simple` / `charactertext` | +| Parser | `BaseParser` | `ParserFactory` | `markitdown` | +| Reranker | `BaseReranker` | `RerankerFactory` | `cosine_local` / `http_bge` / `http_bge_strategy` / `concat` / `noop` | +| Memory | `BaseMemory` (+ `BaseTextMemory` / `BaseActMemory` / `BaseParaMemory`) | `MemoryFactory` | `naive_text` / `general_text` / `tree_text` / `simple_tree_text` / `pref_text` / `simple_pref_text` / `kv_cache` / `vllm_kv_cache` / `lora` | +| Scheduler | `BaseScheduler` | `SchedulerFactory` | `general` / `optimized` | + +## Adding a New Provider + +Mirror any existing provider in the same category: + +1. Implement `src/memos//.py`, inheriting the `base.py` abstract class and matching the signatures of existing providers. +2. Add a pydantic config in `src/memos/configs/.py` and register it in `ConfigFactory.backend_to_class`. +3. Register the implementation in `Factory.backend_to_class` in `src/memos//factory.py`. +4. Third-party dependencies **must** go into an optional extras group in `pyproject.toml` (`tree-mem` / `mem-scheduler` / `mem-user` / `mem-reader` / `pref-mem` / `skill-mem`) and be added to `all`; guard the import with try/except ImportError and raise a clear "install extras X" message on failure. +5. Add tests under `tests//test_.py`; external HTTP / model loading must be mocked. + +## Behavior Boundaries + +### Always do + +- Write a failing test first (TDD), placed under `tests//test_*.py`. +- Before claiming a task is done, run verification commands and paste the real output (at minimum `make format` plus the relevant pytest run). +- Keep changes within the directories the current task authorizes; cross-module edits need to be called out and approved first. +- Use `memos.log.get_logger(__name__)` for logging; route trace info through `memos.context.context` — do not `print`. +- Optional third-party dependencies (neo4j / redis / pika / pymilvus / markitdown, etc.) must be guarded with try/except ImportError and declared in the matching extras group. +- After touching `src/memos/api/`, run `make openapi` to refresh `docs/openapi.json`. + +### Ask first + +- Modifying `pyproject.toml` dependencies or the Python version constraint. +- Touching public routes, request/response models, or the OpenAPI contract under `src/memos/api/`. +- Changing DB schema, migrations, `mem_user` tables, or `graph_dbs` graph models. +- Deleting files or doing wide-scope renames of public APIs (`memos.*` top-level symbols). +- Editing `Makefile`, `.pre-commit-config.yaml`, `pyproject.toml [tool.*]`, or `.github/workflows/`. + +### Never do (IMPORTANT) + +- **Never** commit `.env`, `private/`, `.private-paths`, `tmp/`, `*.log`, secrets, tokens, or model credentials. +- Do not log or include real API keys, raw user data, or vector contents in tests/fixtures. +- Do not skip `pre-commit` or push with `--no-verify` (the `scripts/check-public-push.sh` pre-push hook is enforced). +- Do not claim tests pass without real pytest output as evidence. +- Do not add third-party dependencies to core `dependencies` — they must go into optional extras. +- Do not run wide-scope `rm -rf` outside `src/`; do not `git push --force` or `git reset --hard origin/*`. + +## Code Style + +- Format and lint with Ruff (configured in `pyproject.toml [tool.ruff]`); `make format` must pass before commit. +- Type annotations are required on public functions, API schemas, and config classes; implicit `Optional` is not allowed (enforced via pre-commit). +- All configs and API schemas use Pydantic v2. +- Logging: `logger.info("... %s", x)` form — do not pre-format with f-strings before passing to the logger. +- Exceptions: library code raises semantic exceptions from `memos.exceptions`, never bare `Exception` / `RuntimeError`; the API layer translates them to HTTP errors in `memos.api.exceptions`. +- File naming: source `snake_case.py`, tests `test_.py`. + +## Change → Test Mapping + +- Edit `src/memos//`: at minimum run `pytest tests// -q`; run `make test` once more before merging. +- Edit `src/memos/api/`: run `tests/api/` and `make openapi` to confirm the OpenAPI spec did not change unexpectedly. +- Edit `pyproject.toml` dependencies: `poetry lock --no-update`, then `make test`. +- Edit `Makefile` / pre-commit / Ruff config: run `make pre_commit` locally over the whole tree. + +## Git Conventions + +- Commits: Conventional Commits (`feat:` / `fix:` / `chore:` / `refactor:` / `docs:`), subject line ≤ 72 chars. +- Branches: `feat/` / `fix/` / `dev-YYYYMMDD-v`. +- `main` is protected — all changes go through PRs; never force-push to `main`; do not skip git hooks. +- Do not commit paths listed in `.private-paths`. +- The PR template lives at `.github/PULL_REQUEST_TEMPLATE.md` — its checklist must be fully ticked. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..c2402f7c7 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,23 @@ +# CLAUDE.md + +## Claude Code Entry + +Project facts live in `AGENTS.md`. This file only covers Claude Code runtime adaptation. + +## Sub-agents + +Five project-recommended sub-agents live under `.claude/agents/*.md`. Claude Code loads them automatically; the main agent should dispatch by task boundary: + +| Agent | Permissions | When to use | +|-------|-------------|-------------| +| `explorer` | Read-only | Locate code, trace call chains, gather evidence | +| `design-reviewer` | Read-only | Review design docs (architecture / interface / performance / security / requirement coverage) | +| `code-reviewer` | Read-only | Review diffs and return APPROVE or CHANGES_REQUESTED | +| `backend-dev` | Read-write | Implement backend / library code under `src/memos/` (TDD) | +| `integration-tester` | Read-write | Author and run integration / end-to-end cases under `tests/` | + +The main repo has no frontend stack, so no `frontend-dev` is provided; TypeScript sub-projects under `apps/` use their own AI configuration. + +## Project knowledge + +Before starting a task, run `ls docs/`. `docs/openapi.json` is the source of truth for the API contract; after touching `src/memos/api/`, run `make openapi` to regenerate it. diff --git a/Makefile b/Makefile index 57ede5838..178a4c19a 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ install: poetry install --extras all --with dev --with test poetry run pre-commit install --install-hooks + cp scripts/check-public-push.sh .git/hooks/pre-push + chmod +x .git/hooks/pre-push clean: rm -rf .memos @@ -25,3 +27,6 @@ serve: openapi: poetry run memos export_openapi --output docs/openapi.json + +sync-public: + @bash scripts/sync-public.sh "$(msg)" $(commit) diff --git a/extensions/README.md b/extensions/README.md new file mode 100644 index 000000000..84c7de004 --- /dev/null +++ b/extensions/README.md @@ -0,0 +1,615 @@ +# MemOS 插件开发 README + +这是一份面向新人的标准操作流程。按顺序执行即可完成以下全流程: + +`环境搭建 -> 插件开发 -> 测试 -> 提交 -> 部署` + +如需了解设计原理,请参考《MemOS 插件系统设计与开发指南》。 + +## 快速导航 + +1. [环境搭建](#环境搭建) +2. [创建插件](#创建插件) +3. [注册插件](#注册插件) +4. [编写测试](#编写测试) +5. [代码提交](#代码提交) +6. [部署与验证](#部署与验证) +7. [开发 Checklist](#开发-checklist) +8. [常用命令速查](#常用命令速查) +9. [文件模板清单](#文件模板清单) + +## 环境搭建 + +这是一次性操作,首次接入时完成即可。 + +### 1. 克隆企业仓库 + +```bash +git clone git@github.com:MemTensor/MemOS-Enterprise.git +cd MemOS-Enterprise +``` + +### 2. 添加公开仓库 remote + +```bash +git remote add public git@github.com:MemTensor/MemOS.git +``` + +### 3. 安装依赖与 Git Hooks + +`make install` 会同时完成依赖安装,以及以下 Git Hooks 配置: + +- `pre-commit`:代码检查 +- `pre-push`:私有代码拦截 + +```bash +make install +``` + +### 4. 配置同步 alias(可选,推荐) + +```bash +git config alias.sync-public '!bash scripts/sync-public.sh' +``` + +### 5. 环境验证 + +#### 插件框架测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest tests/plugins/ -v +``` + +#### Demo 插件测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest extensions/memos_demo_plugin/tests/ -v +``` + +#### 启动服务并验证插件加载 + +```bash +uvicorn memos.api.server_api:app --port 8001 +curl http://127.0.0.1:8001/demo/health +``` + +预期返回: + +```json +{"status":"ok","plugin":"demo","version":"0.1.0"} +``` + +## 创建插件 + +以下以开发 `memos_foo_plugin` 为例,实际使用时将 `foo` 替换为你的插件名。 + +### 1. 创建目录 + +```bash +mkdir -p extensions/memos_foo_plugin/tests +touch extensions/memos_foo_plugin/__init__.py +touch extensions/memos_foo_plugin/tests/__init__.py +``` + +### 2. 包入口 + +文件:`extensions/memos_foo_plugin/__init__.py` + +```python +from memos_foo_plugin.plugin import FooPlugin + +__all__ = ["FooPlugin"] +``` + +### 3. 编写 Plugin 主类 + +文件:`extensions/memos_foo_plugin/plugin.py` + +```python +import logging +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H + +logger = logging.getLogger(__name__) + + +class FooPlugin(MemOSPlugin): + name = "foo" + version = "0.1.0" + description = "Foo plugin - brief description" + + def on_load(self) -> None: + self.counter: dict[str, int] = {} + logger.info("[Foo] plugin loaded") + + def init_app(self) -> None: + from memos_foo_plugin.hooks import on_add_after + from memos_foo_plugin.routes import create_router + + self.register_router(create_router(self)) + self.register_hook(H.ADD_AFTER, partial(on_add_after, self)) + + # from memos_foo_plugin.middleware import FooMiddleware + # self.register_middleware(FooMiddleware) + + logger.info("[Foo] plugin initialized") + + def on_shutdown(self) -> None: + logger.info("[Foo] plugin shutdown") +``` + +### 4. 编写路由 + +文件:`extensions/memos_foo_plugin/routes.py` + +```python +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter + +if TYPE_CHECKING: + from memos_foo_plugin.plugin import FooPlugin + + +def create_router(plugin: FooPlugin) -> APIRouter: + router = APIRouter(prefix="/foo", tags=["foo"]) + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name} + + @router.get("/stats") + async def stats(): + return {"counter": plugin.counter} + + return router +``` + +### 5. 编写 Hook 回调 + +文件:`extensions/memos_foo_plugin/hooks.py` + +```python +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from memos_foo_plugin.plugin import FooPlugin + +logger = logging.getLogger(__name__) + + +def on_add_after(plugin: FooPlugin, *, request, result, **kw) -> None: + """[add.after] Count add operations per user.""" + uid = getattr(request, "user_id", "unknown") + plugin.counter[uid] = plugin.counter.get(uid, 0) + 1 + logger.info("[Foo] add counted user=%s total=%d", uid, plugin.counter[uid]) +``` + +### 6. 编写中间件(可选) + +文件:`extensions/memos_foo_plugin/middleware.py` + +```python +import logging +import time + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +logger = logging.getLogger(__name__) + + +class FooMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + start = time.time() + response = await call_next(request) + elapsed_ms = (time.time() - start) * 1000 + logger.info( + "[Foo] %s %s -> %s (%.1fms)", + request.method, + request.url.path, + response.status_code, + elapsed_ms, + ) + return response +``` + +### 7. 自定义 Hook(可选) + +如果插件需要自己定义并触发 Hook,而不只是监听 CE 提供的 Hook,可以新增 `hook_defs.py`。 + +文件:`extensions/memos_foo_plugin/hook_defs.py` + +```python +from memos.plugins.hook_defs import define_hook + + +class FooH: + """Foo plugin hook name constants.""" + + RESULT_ENRICH = "foo.result.enrich" + + +define_hook( + FooH.RESULT_ENRICH, + description="Enrich result data after processing", + params=["user_id", "result"], + pipe_key="result", +) +``` + +在路由或业务逻辑中触发: + +```python +from memos.plugins.hooks import trigger_hook +from memos_foo_plugin.hook_defs import FooH + +rv = trigger_hook(FooH.RESULT_ENRICH, user_id="alice", result=data) +data = rv if rv is not None else data +``` + +在 `plugin.py` 中注册回调: + +```python +from memos_foo_plugin.hook_defs import FooH + +self.register_hook(FooH.RESULT_ENRICH, partial(enrich_result, self)) +``` + +## 注册插件 + +需要在 `pyproject.toml` 中添加两处配置。 + +### 1. 声明包路径 + +```toml +[tool.poetry] +packages = [ + {include = "memos", from = "src"}, + {include = "memos_foo_plugin", from = "extensions"}, +] +``` + +### 2. 注册 entry point + +```toml +[project.entry-points."memos.plugins"] +demo = "memos_demo_plugin:DemoPlugin" +foo = "memos_foo_plugin:FooPlugin" +``` + +### 3. 重新安装使 entry point 生效 + +```bash +pip install -e . +``` + +> 注意: +> 仅修改已安装插件的代码时,在 editable 模式下通常重启服务即可。 +> 如果是新增插件,或修改了 `pyproject.toml`,则必须重新安装。 + +## 编写测试 + +### 1. `conftest.py` + +文件:`extensions/memos_foo_plugin/tests/conftest.py` + +```python +"""Ensure hooks used by FooPlugin are declared for testing.""" + +from memos.plugins.hooks import hookable + +# Declare CE hooks (normally declared at import time of handler modules) +hookable("add") +hookable("search") + +# If plugin has custom hook_defs, import to trigger declarations: +# import memos_foo_plugin.hook_defs # noqa: F401 +``` + +### 2. 生命周期测试 + +文件:`extensions/memos_foo_plugin/tests/test_lifecycle.py` + +```python +from fastapi import FastAPI + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +class TestFooPluginLifecycle: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_metadata(self): + from memos_foo_plugin.plugin import FooPlugin + + plugin = FooPlugin() + assert plugin.name == "foo" + assert plugin.version == "0.1.0" + + def test_on_load_state(self): + from memos_foo_plugin.plugin import FooPlugin + + plugin = FooPlugin() + plugin.on_load() + assert plugin.counter == {} + + def test_full_lifecycle(self): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + _init_plugin(plugin, app) + + paths = [r.path for r in app.routes] + assert "/foo/health" in paths + assert "/foo/stats" in paths + + plugin.on_shutdown() +``` + +### 3. Hook 回调测试 + +文件:`extensions/memos_foo_plugin/tests/test_hooks.py` + +```python +from fastapi import FastAPI + + +def _make_plugin(): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + return plugin + + +class TestHookCallbacks: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_add_after_counts(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "alice" + + trigger_hook("add.after", request=Req(), result={}) + trigger_hook("add.after", request=Req(), result={}) + assert plugin.counter["alice"] == 2 +``` + +### 4. 路由测试 + +文件:`extensions/memos_foo_plugin/tests/test_routes.py` + +```python +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def _make_app(): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + return app, plugin + + +class TestRoutes: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_health(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/foo/health") + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + def test_stats_empty(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/foo/stats") + assert resp.json()["counter"] == {} +``` + +### 5. 运行测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest extensions/memos_foo_plugin/tests/ -v +``` + +## 代码提交 + +### 1. 提交到企业仓库 + +```bash +git add -A +git commit -m "feat: add foo plugin" +git push origin feature/foo +``` + +这是标准 Git 流程,完整代码推送到企业仓库。 + +### 2. 同步 CE 代码到公开仓库 + +如果本次改动包含 CE 代码,例如 `src/memos/`、`tests/plugins/` 等,需要执行同步。 + +同步最近一次 commit 的 CE 改动: + +```bash +git sync-public "feat: add plugin framework enhancement" +``` + +同步指定 commit: + +```bash +git sync-public "fix: hook trigger" abc1234 +``` + +或使用 `make`: + +```bash +make sync-public msg="feat: add plugin framework enhancement" +``` + +推送后,在 GitHub 创建 PR 合入 `public/main`。 + +### 3. 判断是否需要 `sync-public` + +| 改动内容 | 需要 `sync-public` | +| --- | --- | +| `extensions/` 下的插件代码 | ❌ | +| `pyproject.toml` / `poetry.lock` | ❌ | +| `scripts/` / `Makefile` / `.private-paths` | ❌ | +| `src/memos/plugins/` 框架代码 | ✅ | +| `src/memos/api/` 中新增 `@hookable` | ✅ | +| `tests/plugins/` 框架测试 | ✅ | + +### 4. 新增私有路径 + +如果新增了不应同步到公开仓库的文件或目录,请编辑 `.private-paths`,每行添加一个路径。 + +```text +extensions/ +pyproject.toml +poetry.lock +.private-paths +scripts/sync-public.sh +scripts/check-public-push.sh +Makefile +docs/internal/ +``` + +## 部署与验证 + +### 1. 启动服务 + +```bash +uvicorn memos.api.server_api:app --port 8001 +``` + +启动日志中应看到: + +```text +INFO: Plugin discovered: foo v0.1.0 +INFO: Plugin initialized: foo +``` + +### 2. 验证接口 + +插件健康检查: + +```bash +curl http://127.0.0.1:8001/foo/health +``` + +插件业务接口: + +```bash +curl http://127.0.0.1:8001/foo/stats +``` + +### 3. 验证 Hook 生效 + +通过调用 CE 接口触发 Hook,再检查插件状态。 + +触发 `add` 接口,使插件的 `add.after` hook 被调用: + +```bash +curl -X POST http://127.0.0.1:8001/product/add \ + -H "Content-Type: application/json" \ + -d '{"user_id": "test_user", ...}' +``` + +查看插件统计: + +```bash +curl http://127.0.0.1:8001/foo/stats +``` + +预期返回: + +```json +{"counter": {"test_user": 1}} +``` + +## 开发 Checklist + +开发完成后,逐项确认: + +- [ ] `Plugin` 类继承 `MemOSPlugin` +- [ ] 已实现 `name`、`version`、`description` +- [ ] 已在 `init_app()` 中注册路由 +- [ ] Hook 回调使用 `self.register_hook(...)` 正确注册 +- [ ] 如有中间件,已使用 `self.register_middleware(...)` 注册 +- [ ] 已在 `pyproject.toml` 中声明包路径 +- [ ] 已在 entry points 中注册插件 +- [ ] 测试通过:插件测试可完整运行 +- [ ] 服务启动日志出现插件发现与初始化信息 +- [ ] 插件接口返回预期结果 +- [ ] 代码已推送到企业仓库 +- [ ] 如涉及 CE 代码,已完成 `sync-public` + +## 常用命令速查 + +| 操作 | 命令 | +| --- | --- | +| 安装依赖 + hooks | `make install` | +| 运行全部测试 | `PYTHONPATH="src:extensions" python -m pytest tests/plugins/ extensions/ -v` | +| 运行单个插件测试 | `PYTHONPATH="src:extensions" python -m pytest extensions/memos_foo_plugin/tests/ -v` | +| 启动服务 | `uvicorn memos.api.server_api:app --port 8001` | +| 代码格式化 | `make format` | +| 代码检查 | `make pre_commit` | +| 提交到企业仓库 | `git commit + git push origin ` | +| 同步 CE 到公开仓库 | `git sync-public "message"` | +| 同步指定 commit | `git sync-public "message" ` | + +## 文件模板清单 + +新建插件时,通常需要创建如下文件: + +```text +extensions/memos_foo_plugin/ +├── __init__.py # 必须:包入口,re-export Plugin 类 +├── plugin.py # 必须:继承 MemOSPlugin,注册能力 +├── routes.py # 按需:FastAPI 路由 +├── hooks.py # 按需:Hook 回调函数 +├── middleware.py # 按需:Starlette 中间件 +├── hook_defs.py # 按需:插件自有 Hook 声明(有自定义 Hook 时需要) +└── tests/ + ├── __init__.py # 必须 + ├── conftest.py # 必须:声明测试中用到的 Hook + ├── test_lifecycle.py # 推荐:生命周期测试 + ├── test_hooks.py # 推荐:Hook 回调测试 + └── test_routes.py # 推荐:路由端点测试 +``` + +也可以直接复制 `extensions/memos_demo_plugin/` 作为模板,然后全局替换 `demo -> foo`、`Demo -> Foo`。 diff --git a/extensions/memos_demo_plugin/__init__.py b/extensions/memos_demo_plugin/__init__.py new file mode 100644 index 000000000..3e8d0ff78 --- /dev/null +++ b/extensions/memos_demo_plugin/__init__.py @@ -0,0 +1,6 @@ +"""memos-demo-plugin — a complete example plugin demonstrating the MemOS plugin system.""" + +from memos_demo_plugin.plugin import DemoPlugin + + +__all__ = ["DemoPlugin"] diff --git a/extensions/memos_demo_plugin/hook_defs.py b/extensions/memos_demo_plugin/hook_defs.py new file mode 100644 index 000000000..7ae920f48 --- /dev/null +++ b/extensions/memos_demo_plugin/hook_defs.py @@ -0,0 +1,36 @@ +"""Demo plugin-owned hook declarations. + +Hooks that the plugin declares, triggers, and registers callbacks for are defined here. +CE-exposed hooks (e.g. add.before/after) are managed by CE's hook_defs.py; the plugin only needs to reference them. +""" + +from memos.plugins.hook_defs import define_hook + + +class DemoH: + """Demo plugin hook name constants.""" + + # @hookable("demo.test") — auto-generates before/after + TEST_BEFORE = "demo.test.before" + TEST_AFTER = "demo.test.after" + + # Manually triggered via trigger_hook + TEST_POST_PROCESS = "demo.test.post_process" + REPORT_ENRICH = "demo.report.enrich" + + +# ── Custom hook declarations (@hookable-generated before/after need not be declared here) ── + +define_hook( + DemoH.TEST_POST_PROCESS, + description="post-process result after demo test endpoint business logic runs", + params=["request", "result"], + pipe_key="result", +) + +define_hook( + DemoH.REPORT_ENRICH, + description="after user activity report is generated, allows callbacks to extend report data", + params=["user_id", "report"], + pipe_key="report", +) diff --git a/extensions/memos_demo_plugin/hooks.py b/extensions/memos_demo_plugin/hooks.py new file mode 100644 index 000000000..776f9e877 --- /dev/null +++ b/extensions/memos_demo_plugin/hooks.py @@ -0,0 +1,87 @@ +"""Demo plugin hook callbacks. + +Two groups: + 1. CE hook responders — plugin listens to CE-exposed extension points (add/search etc.) + 2. Plugin-owned hooks — extension points the plugin declares and triggers (demo.test / demo.report) + +All callbacks are bound to the plugin instance via functools.partial(callback, plugin_instance). +""" + +from __future__ import annotations + +import logging + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from memos_demo_plugin.plugin import DemoPlugin + +logger = logging.getLogger(__name__) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# 1. CE hook responders — listen to CE-exposed @hookable / trigger_hook extension points +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +def log_operation(plugin: DemoPlugin, *, request, **kw) -> None: + """[add.before / search.before] Log operation (notification-style).""" + uid = getattr(request, "user_id", "unknown") + plugin.request_log.append({"user_id": uid}) + logger.info("[Demo] operation logged user=%s", uid) + + +def count_add(plugin: DemoPlugin, *, request, result, **kw) -> None: + """[add.after] Count add calls per user (notification-style).""" + uid = getattr(request, "user_id", "unknown") + plugin.add_counter[uid] = plugin.add_counter.get(uid, 0) + 1 + logger.info("[Demo] add count user=%s total=%d", uid, plugin.add_counter[uid]) + + +def post_process_add(plugin: DemoPlugin, *, request, result, **kw): + """[add.memories.post_process] Post-process add_memories result (pipe-style, returns result).""" + uid = getattr(request, "user_id", "unknown") + plugin.post_process_log.append({"user_id": uid, "result_count": len(result)}) + logger.info("[Demo] post_process_add user=%s count=%d", uid, len(result)) + return result + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# 2. Plugin-owned hooks — declared by plugin and triggered in routes.py +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +def on_test_before(plugin: DemoPlugin, *, request, **kw): + """[demo.test.before] @hookable auto-triggered, can modify request (pipe-style).""" + uid = getattr(request, "user_id", "anonymous") + plugin.hook_test_log.append({"phase": "before", "user_id": uid}) + logger.info("[Demo] test.before user=%s", uid) + return request + + +def on_test_after(plugin: DemoPlugin, *, request, result, **kw): + """[demo.test.after] @hookable auto-triggered, can modify result (pipe-style).""" + uid = getattr(request, "user_id", "anonymous") + plugin.hook_test_log.append({"phase": "after", "user_id": uid}) + result["hook_after_injected"] = True + logger.info("[Demo] test.after user=%s", uid) + return result + + +def on_test_post_process(plugin: DemoPlugin, *, request, result, **kw): + """[demo.test.post_process] trigger_hook manual trigger, can modify result (pipe-style).""" + uid = getattr(request, "user_id", "anonymous") + plugin.hook_test_log.append({"phase": "post_process", "user_id": uid}) + result["hook_post_process_injected"] = True + logger.info("[Demo] test.post_process user=%s", uid) + return result + + +def enrich_report(plugin: DemoPlugin, *, user_id, report, **kw): + """[demo.report.enrich] trigger_hook manual trigger, extend user activity report (pipe-style).""" + report["total_users_tracked"] = len(plugin.add_counter) + report["is_active_user"] = plugin.add_counter.get(user_id, 0) > 0 + report["enriched_by"] = plugin.name + logger.info("[Demo] enrich_report user=%s", user_id) + return report diff --git a/extensions/memos_demo_plugin/middleware.py b/extensions/memos_demo_plugin/middleware.py new file mode 100644 index 000000000..f2d4a73cd --- /dev/null +++ b/extensions/memos_demo_plugin/middleware.py @@ -0,0 +1,25 @@ +"""Demo plugin middleware — global audit logging for each request's method, path, status code, and duration.""" + +import logging +import time + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + + +logger = logging.getLogger(__name__) + + +class DemoAuditMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + start = time.time() + response = await call_next(request) + elapsed_ms = (time.time() - start) * 1000 + logger.info( + "[Demo Audit] %s %s → %s (%.1fms)", + request.method, + request.url.path, + response.status_code, + elapsed_ms, + ) + return response diff --git a/extensions/memos_demo_plugin/plugin.py b/extensions/memos_demo_plugin/plugin.py new file mode 100644 index 000000000..d3ef11c8d --- /dev/null +++ b/extensions/memos_demo_plugin/plugin.py @@ -0,0 +1,78 @@ +""" +Demo plugin main logic — complete demonstration of MemOS plugin's three extension capabilities. + +Scope: + 1. Register routes — self.register_router() + 2. Register middleware — self.register_middleware() + 3. Register hooks — self.register_hook() / self.register_hooks() + +Both community developers and enterprise self-hosted deployments can reference this plugin structure. +Package naming convention: memos-xx-plugin / memos_xx_plugin. +""" + +import logging + +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H +from memos_demo_plugin.hook_defs import DemoH + + +logger = logging.getLogger(__name__) + + +class DemoPlugin(MemOSPlugin): + name = "demo" + version = "0.1.0" + description = "Demo plugin — showcases routes, middleware, and hooks" + + # ── Lifecycle ──────────────────────────────────────────────────── + + def on_load(self) -> None: + self.add_counter: dict[str, int] = {} + self.request_log: list[dict] = [] + self.post_process_log: list[dict] = [] + self.hook_test_log: list[dict] = [] + logger.info("[Demo] plugin loaded") + + def init_app(self) -> None: + from memos_demo_plugin.hooks import ( + count_add, + enrich_report, + log_operation, + on_test_after, + on_test_before, + on_test_post_process, + post_process_add, + ) + from memos_demo_plugin.middleware import DemoAuditMiddleware + from memos_demo_plugin.routes import create_router + + # 1) Routes + self.register_router(create_router(self)) + + # 2) Middleware + self.register_middleware(DemoAuditMiddleware) + + # 3) Hooks — respond to CE @hookable extension points + self.register_hook(H.ADD_AFTER, partial(count_add, self)) + self.register_hooks([H.ADD_BEFORE, H.SEARCH_BEFORE], partial(log_operation, self)) + self.register_hook(H.ADD_MEMORIES_POST_PROCESS, partial(post_process_add, self)) + + # 4) Hooks — plugin-owned extension points (constants from DemoH) + self.register_hook(DemoH.TEST_BEFORE, partial(on_test_before, self)) + self.register_hook(DemoH.TEST_AFTER, partial(on_test_after, self)) + self.register_hook(DemoH.TEST_POST_PROCESS, partial(on_test_post_process, self)) + self.register_hook(DemoH.REPORT_ENRICH, partial(enrich_report, self)) + + logger.info("[Demo] plugin initialized") + + def on_shutdown(self) -> None: + logger.info( + "[Demo] plugin shutdown — users=%d, ops=%d, post_process=%d, hook_tests=%d", + len(self.add_counter), + len(self.request_log), + len(self.post_process_log), + len(self.hook_test_log), + ) diff --git a/extensions/memos_demo_plugin/routes.py b/extensions/memos_demo_plugin/routes.py new file mode 100644 index 000000000..af6a18983 --- /dev/null +++ b/extensions/memos_demo_plugin/routes.py @@ -0,0 +1,92 @@ +"""Demo plugin routes — demonstrates full usage of plugin-registered routes + both hook trigger styles.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter +from pydantic import BaseModel + +from memos.plugins.hooks import hookable, trigger_hook +from memos_demo_plugin.hook_defs import DemoH + + +if TYPE_CHECKING: + from memos_demo_plugin.plugin import DemoPlugin + + +# ── Request models ──────────────────────────────────────────────────────── + + +class TestHookRequest(BaseModel): + user_id: str = "anonymous" + message: str = "hello" + + +# ── Router factory ──────────────────────────────────────────────────────── + + +def create_router(plugin: DemoPlugin) -> APIRouter: + router = APIRouter(prefix="/demo", tags=["demo"]) + + # ── Basic routes ── + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name, "version": plugin.version} + + @router.get("/stats") + async def stats(): + return { + "add_counter": plugin.add_counter, + "total_adds": sum(plugin.add_counter.values()), + "recent_requests": plugin.request_log[-20:], + } + + # ── Hook demo routes ── + + class _HookDemoHandler: + """Demonstrates @hookable decorator: auto-triggers demo.test.before / demo.test.after.""" + + @hookable("demo.test") + def handle(self, request: TestHookRequest): + result = { + "user_id": request.user_id, + "echo": request.message, + "processed": True, + } + rv = trigger_hook(DemoH.TEST_POST_PROCESS, request=request, result=result) + return rv if rv is not None else result + + handler = _HookDemoHandler() + + @router.post("/test-hook") + async def test_hook(req: TestHookRequest): + """Full hook demo endpoint. + + Call chain: + 1. demo.test.before — @hookable auto, pipe-style, can modify request + 2. handler business logic + 3. demo.test.post_process — trigger_hook manual, pipe-style, can modify result + 4. demo.test.after — @hookable auto, pipe-style, can modify result + 5. demo.report.enrich — trigger_hook manual, pipe-style, can modify report + """ + result = handler.handle(req) + + report = { + "user_id": req.user_id, + "add_count": plugin.add_counter.get(req.user_id, 0), + "operation_count": sum( + 1 for r in plugin.request_log if r.get("user_id") == req.user_id + ), + } + rv = trigger_hook(DemoH.REPORT_ENRICH, user_id=req.user_id, report=report) + report = rv if rv is not None else report + + return { + "hook_test": result, + "user_report": report, + "plugin_state": {"hook_test_log": plugin.hook_test_log[-10:]}, + } + + return router diff --git a/extensions/memos_demo_plugin/tests/__init__.py b/extensions/memos_demo_plugin/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/extensions/memos_demo_plugin/tests/conftest.py b/extensions/memos_demo_plugin/tests/conftest.py new file mode 100644 index 000000000..91aaf61af --- /dev/null +++ b/extensions/memos_demo_plugin/tests/conftest.py @@ -0,0 +1,14 @@ +"""memos_demo_plugin tests — ensure hooks used by the plugin are declared. + +CE @hookable declarations are triggered by manually calling hookable(). +Plugin-owned hook declarations are triggered by importing the hook_defs module (module-level define_hook calls). +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("demo.test") + +import memos_demo_plugin.hook_defs # noqa: E402, F401 — triggers plugin-owned hook declarations diff --git a/extensions/memos_demo_plugin/tests/test_hooks.py b/extensions/memos_demo_plugin/tests/test_hooks.py new file mode 100644 index 000000000..40b2197a4 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_hooks.py @@ -0,0 +1,110 @@ +"""DemoPlugin hook callback verification — including @hookable before/after and custom trigger_hook.""" + +import logging + +from fastapi import FastAPI + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +def _make_plugin(): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + return plugin + + +class TestHookCallbacks: + """Verify business logic of each hook callback in the Demo plugin.""" + + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_add_after_counts(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "alice" + + trigger_hook("add.after", request=Req(), result={}) + trigger_hook("add.after", request=Req(), result={}) + + assert plugin.add_counter["alice"] == 2 + + def test_add_before_logs(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "bob" + + trigger_hook("add.before", request=Req()) + + assert len(plugin.request_log) == 1 + assert plugin.request_log[0]["user_id"] == "bob" + + def test_search_before_logs(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "charlie" + + trigger_hook("search.before", request=Req()) + + assert len(plugin.request_log) == 1 + assert plugin.request_log[0]["user_id"] == "charlie" + + def test_multiple_users(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class ReqA: + user_id = "alice" + + class ReqB: + user_id = "bob" + + trigger_hook("add.before", request=ReqA()) + trigger_hook("add.after", request=ReqA(), result={}) + trigger_hook("add.before", request=ReqB()) + trigger_hook("add.after", request=ReqB(), result={}) + trigger_hook("add.before", request=ReqA()) + trigger_hook("add.after", request=ReqA(), result={}) + + assert plugin.add_counter == {"alice": 2, "bob": 1} + assert len(plugin.request_log) == 3 + + def test_post_process_hook_is_pipeline(self): + """add.memories.post_process is a pipeline-style hook; callbacks can modify and return result.""" + from memos.plugins.hook_defs import H + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "dave" + + original = [{"id": 1}, {"id": 2}] + rv = trigger_hook(H.ADD_MEMORIES_POST_PROCESS, request=Req(), result=original) + + assert rv is original + assert len(plugin.post_process_log) == 1 + assert plugin.post_process_log[0]["user_id"] == "dave" + assert plugin.post_process_log[0]["result_count"] == 2 diff --git a/extensions/memos_demo_plugin/tests/test_lifecycle.py b/extensions/memos_demo_plugin/tests/test_lifecycle.py new file mode 100644 index 000000000..4c6156980 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_lifecycle.py @@ -0,0 +1,99 @@ +"""DemoPlugin lifecycle & PluginManager integration tests.""" + +import logging + +from fastapi import FastAPI + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + """Simulate the PluginManager initialization flow.""" + plugin._bind_app(app) + plugin.init_app() + + +class TestDemoPluginLifecycle: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_metadata(self): + from memos_demo_plugin.plugin import DemoPlugin + + plugin = DemoPlugin() + assert plugin.name == "demo" + assert plugin.version == "0.1.0" + + def test_on_load_initializes_state(self): + from memos_demo_plugin.plugin import DemoPlugin + + plugin = DemoPlugin() + plugin.on_load() + + assert plugin.add_counter == {} + assert plugin.request_log == [] + assert plugin.post_process_log == [] + assert plugin.hook_test_log == [] + + def test_on_shutdown_no_error(self): + from memos_demo_plugin.plugin import DemoPlugin + + plugin = DemoPlugin() + plugin.on_load() + plugin.on_shutdown() + + def test_full_lifecycle(self): + """Full lifecycle: on_load → init_app → normal operation → on_shutdown.""" + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + + paths = [r.path for r in app.routes] + assert "/demo/health" in paths + assert "/demo/stats" in paths + assert "/demo/test-hook" in paths + + plugin.on_shutdown() + + +class TestPluginManager: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_manual_registration_and_init(self): + from memos.plugins.manager import PluginManager + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + manager = PluginManager() + + plugin = DemoPlugin() + plugin.on_load() + manager._plugins[plugin.name] = plugin + + assert "demo" in manager.plugins + + manager.init_app(app) + + paths = [r.path for r in app.routes] + assert "/demo/health" in paths + assert "/demo/stats" in paths + + def test_shutdown(self): + from memos.plugins.manager import PluginManager + from memos_demo_plugin.plugin import DemoPlugin + + manager = PluginManager() + plugin = DemoPlugin() + plugin.on_load() + manager._plugins[plugin.name] = plugin + + manager.shutdown() diff --git a/extensions/memos_demo_plugin/tests/test_middleware.py b/extensions/memos_demo_plugin/tests/test_middleware.py new file mode 100644 index 000000000..45082fdb9 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_middleware.py @@ -0,0 +1,62 @@ +"""DemoPlugin middleware integration tests.""" + +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +def _make_app(): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + return app, plugin + + +class TestMiddlewareRegistration: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_audit_middleware_logs(self, caplog): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + return {"ok": True} + + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + + client = TestClient(app) + with caplog.at_level(logging.INFO): + resp = client.get("/test") + + assert resp.status_code == 200 + assert any("[Demo Audit]" in r.message for r in caplog.records) + + def test_audit_middleware_on_plugin_routes(self, caplog): + app, _ = _make_app() + + client = TestClient(app) + with caplog.at_level(logging.INFO): + client.get("/demo/health") + + assert any( + "[Demo Audit]" in r.message and "/demo/health" in r.message for r in caplog.records + ) diff --git a/extensions/memos_demo_plugin/tests/test_routes.py b/extensions/memos_demo_plugin/tests/test_routes.py new file mode 100644 index 000000000..82e83d558 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_routes.py @@ -0,0 +1,194 @@ +"""DemoPlugin routes and /demo/test-hook endpoint tests.""" + +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +def _make_app(): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + return app, plugin + + +# ========================================================================= # +# Route registration verification +# ========================================================================= # + + +class TestRouteRegistration: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_routes_exist(self): + app, _ = _make_app() + paths = [r.path for r in app.routes] + assert "/demo/health" in paths + assert "/demo/stats" in paths + assert "/demo/test-hook" in paths + + def test_health_endpoint(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/demo/health") + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["plugin"] == "demo" + assert data["version"] == "0.1.0" + + def test_stats_endpoint_empty(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/demo/stats") + + assert resp.status_code == 200 + data = resp.json() + assert data["add_counter"] == {} + assert data["total_adds"] == 0 + assert data["recent_requests"] == [] + + def test_stats_endpoint_after_hooks(self): + from memos.plugins.hooks import trigger_hook + + app, plugin = _make_app() + + class FakeRequest: + user_id = "user_42" + + trigger_hook("add.before", request=FakeRequest()) + trigger_hook("add.after", request=FakeRequest(), result={"ok": True}) + trigger_hook("add.before", request=FakeRequest()) + trigger_hook("add.after", request=FakeRequest(), result={"ok": True}) + + client = TestClient(app) + resp = client.get("/demo/stats") + data = resp.json() + + assert data["add_counter"]["user_42"] == 2 + assert data["total_adds"] == 2 + assert len(data["recent_requests"]) == 2 + + +# ========================================================================= # +# /demo/test-hook endpoint — @hookable + custom trigger_hook full chain +# ========================================================================= # + + +class TestHookEndpoint: + """Verify full hook call chain of the test endpoint: + demo.test.before → business logic → demo.test.post_process → demo.test.after + """ + + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_basic_response(self): + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "tester", "message": "ping"}) + assert resp.status_code == 200 + + data = resp.json() + hook_result = data["hook_test"] + assert hook_result["user_id"] == "tester" + assert hook_result["echo"] == "ping" + assert hook_result["processed"] is True + + def test_after_hook_injects_field(self): + """demo.test.after callback injects hook_after_injected=True.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "u1", "message": "hi"}) + assert resp.json()["hook_test"]["hook_after_injected"] is True + + def test_post_process_hook_injects_field(self): + """demo.test.post_process custom hook injects hook_post_process_injected=True.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "u2", "message": "world"}) + assert resp.json()["hook_test"]["hook_post_process_injected"] is True + + def test_records_all_three_phases(self): + """plugin.hook_test_log should record all three phases: before / post_process / after.""" + app, plugin = _make_app() + client = TestClient(app) + + client.post("/demo/test-hook", json={"user_id": "u3", "message": "test"}) + + phases = [entry["phase"] for entry in plugin.hook_test_log] + assert "before" in phases + assert "post_process" in phases + assert "after" in phases + + def test_state_in_response(self): + """plugin_state.hook_test_log in response should contain records for all three phases.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "u4", "message": "check"}) + data = resp.json() + + log = data["plugin_state"]["hook_test_log"] + assert len(log) >= 3 + assert any(e["phase"] == "before" for e in log) + assert any(e["phase"] == "post_process" for e in log) + assert any(e["phase"] == "after" for e in log) + + def test_multiple_calls_accumulate(self): + """hook_test_log should accumulate after multiple calls.""" + app, plugin = _make_app() + client = TestClient(app) + + client.post("/demo/test-hook", json={"user_id": "a"}) + client.post("/demo/test-hook", json={"user_id": "b"}) + + assert len(plugin.hook_test_log) >= 6 + + def test_default_values(self): + """Call with default parameters.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={}) + data = resp.json() + + assert data["hook_test"]["user_id"] == "anonymous" + assert data["hook_test"]["echo"] == "hello" + + def test_custom_hook_enrich_report(self): + """demo.report.enrich custom hook example — response contains user_report extended by callback.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "alice"}) + data = resp.json() + + report = data["user_report"] + assert report["user_id"] == "alice" + assert "add_count" in report + assert "operation_count" in report + assert report["enriched_by"] == "demo" + assert "total_users_tracked" in report + assert "is_active_user" in report diff --git a/extensions/memos_prompt_strategy_plugin/README.md b/extensions/memos_prompt_strategy_plugin/README.md new file mode 100644 index 000000000..862d082fe --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/README.md @@ -0,0 +1,102 @@ +# Prompt Strategy Plugin + +检测对话中的身份/亲属关系命名模式(如"我叫xxx"、"我的儿子叫xxx"),自动切换为专用的身份关系提取 prompt。 + +## 解决什么问题 + +默认的 mem-reader prompt 是通用型的,在遇到"我叫王沐辰,我的儿子叫王明泽"这类包含姓名和关系信息的对话时,可能无法精确提取出所有人名和关系。本插件增加一条专用规则:一旦检测到身份/关系命名句式,就替换为专门强调"不遗漏任何人名和关系"的提取 prompt。 + +如果对话不包含这类句式,插件不做任何改动,走 CE 默认流程。 + +## 工作原理 + +``` +消息进入 mem-reader + ↓ +CE: _get_llm_response() 构建默认 prompt + ↓ +CE: trigger_hook("mem_reader.pre_extract") ← 通用扩展点 + ↓ +插件回调 on_pre_extract(): + 1. MessageClassifier 检查是否命中身份/关系命名规则 + 2. 命中 → 返回专用 identity_relation prompt + 3. 不命中 → 返回 None,CE 使用默认 prompt + ↓ +CE: LLM.generate(prompt) +``` + +## 命中规则 + +| 模式 | 示例 | +|------|------| +| 自我命名(中文) | 我叫xxx、我是xxx、我的名字是xxx | +| 亲属/社交关系命名(中文) | 我的儿子叫xxx、我老婆是xxx、我妈妈叫xxx、我朋友叫xxx | +| 自我命名(英文) | My name is xxx、I'm xxx、Call me xxx | +| 关系命名(英文) | My son is called xxx、My wife's name is xxx | + +支持的关系词:儿子、女儿、老婆、老公、爸爸、妈妈、哥哥、姐姐、弟弟、妹妹、爷爷、奶奶、朋友、同事、同学、宠物等。 + +## 文件结构 + +``` +extensions/memos_prompt_strategy_plugin/ +├── __init__.py # 包入口,导出 PromptStrategyPlugin +├── plugin.py # 插件主类:生命周期 + 注册 +├── hooks.py # Hook 回调:on_pre_extract +├── classifier.py # 身份/关系命名规则检测 +├── strategies.py # identity_relation prompt 模板(中英文) +├── routes.py # 管理接口 +├── example.py # 可直接运行的检测演示 +└── tests/ + ├── conftest.py + ├── test_classifier.py + ├── test_strategies.py + └── test_lifecycle.py +``` + +## 快速体验 + +```bash +PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/example.py +``` + +## 安装与注册 + +`pyproject.toml` 中已包含以下配置: + +```toml +[tool.poetry] +packages = [ + {include = "memos_prompt_strategy_plugin", from = "extensions"}, +] + +[project.entry-points."memos.plugins"] +prompt_strategy = "memos_prompt_strategy_plugin:PromptStrategyPlugin" +``` + +首次安装需执行: + +```bash +pip install -e . +``` + +## 运行测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest extensions/memos_prompt_strategy_plugin/tests/ -v +``` + +## 管理接口 + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/prompt_strategy/health` | GET | 插件健康检查 | +| `/prompt_strategy/stats` | GET | 查看 identity_relation 命中次数 | + +## CE 依赖 + +本插件依赖 CE 侧的一个扩展点: + +- `mem_reader.pre_extract`:在 `MultiModalStructMemReader._get_llm_response()` 中,LLM 调用前触发 + +该扩展点声明在 `src/memos/plugins/hook_defs.py`。 diff --git a/extensions/memos_prompt_strategy_plugin/__init__.py b/extensions/memos_prompt_strategy_plugin/__init__.py new file mode 100644 index 000000000..f4bd07412 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/__init__.py @@ -0,0 +1,4 @@ +from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + +__all__ = ["PromptStrategyPlugin"] diff --git a/extensions/memos_prompt_strategy_plugin/classifier.py b/extensions/memos_prompt_strategy_plugin/classifier.py new file mode 100644 index 000000000..2f7d0cba9 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/classifier.py @@ -0,0 +1,122 @@ +"""Message classifier — rule-chain architecture with extensible rules. + +Currently only one rule is registered (identity/relation naming detection). +To add a new rule, write a static method that returns a category string on +match or ``None`` on miss, then append it to ``self._rules`` in ``__init__``. +""" + +from __future__ import annotations + +import logging +import re + +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + +logger = logging.getLogger(__name__) + +# ── Category constants ────────────────────────────────────────────── +IDENTITY_RELATION = "identity_relation" + +# ── Regex patterns for identity / relation detection ──────────────── +_SELF_NAME_RE = re.compile( + r"我(?:的名字)?(?:是|叫)\s*(?P\S+)", +) + +_RELATION_WORDS = ( + "儿子|女儿|孩子|小孩" + "|老婆|妻子|老公|丈夫|爱人|伴侣|对象" + "|爸爸|妈妈|父亲|母亲|爸|妈" + "|哥哥|姐姐|弟弟|妹妹|哥|姐|弟|妹" + "|爷爷|奶奶|外公|外婆|姥姥|姥爷" + "|叔叔|阿姨|舅舅|舅妈|姑姑|姑父" + "|朋友|同事|同学|室友|闺蜜|兄弟" + "|男朋友|女朋友|前任" + "|宠物|狗|猫" +) +_RELATION_NAME_RE = re.compile( + rf"我(?:的)?(?:{_RELATION_WORDS})(?:的名字)?(?:是|叫)\s*(?P\S+)", +) + +_MY_NAME_IS_EN = re.compile( + r"(?:my name is|i'?m|call me)\s+(?P[A-Z]\w+)", + re.IGNORECASE, +) +_MY_RELATION_IS_EN = re.compile( + r"my\s+(?:son|daughter|wife|husband|father|mother|brother|sister|friend)" + r"(?:'s name)?\s+is\s+(?P[A-Z]\w+)", + re.IGNORECASE, +) + + +def _extract_text(sources: list) -> str: + """Extract raw text content from sources for rule matching.""" + parts: list[str] = [] + for src in sources: + if isinstance(src, str): + parts.append(src) + elif hasattr(src, "content"): + parts.append(str(src.content)) + elif isinstance(src, dict): + parts.append(str(src.get("content", ""))) + return "\n".join(parts) + + +class MessageClassifier: + """Rule-chain classifier. + + Rules are evaluated in registration order; the first match wins. + If no rule matches, ``classify()`` returns ``None`` so the caller + keeps the default prompt unchanged. + + To add a new rule: + 1. Define a static/class method ``_check_xxx(sources, text) -> str | None`` + 2. Append ``("category_name", self._check_xxx)`` to ``self._rules`` + """ + + def __init__(self) -> None: + self._rules: list[tuple[str, Callable[[list, str], str | None]]] = [ + ("identity_relation", self._check_identity_relation), + ] + + def classify( + self, + sources: list, + mem_str: str, + default_prompt_type: str, + info: dict[str, Any], + ) -> str | None: + """Walk the rule chain; return the first matching category or ``None``.""" + text = _extract_text(sources) if sources else mem_str + if not text: + return None + + for _name, rule_fn in self._rules: + result = rule_fn(sources, text) + if result is not None: + return result + + return None + + # ── Rules ─────────────────────────────────────────────────────── + + @staticmethod + def _check_identity_relation(sources: list, text: str) -> str | None: + self_names = [m.group("name") for m in _SELF_NAME_RE.finditer(text)] + self_names += [m.group("name") for m in _MY_NAME_IS_EN.finditer(text)] + relation_names = [m.group("name") for m in _RELATION_NAME_RE.finditer(text)] + relation_names += [m.group("name") for m in _MY_RELATION_IS_EN.finditer(text)] + + if self_names or relation_names: + logger.info( + "[PromptStrategy] Identity/relation pattern detected — " + "self_names=%s, relation_names=%s", + self_names, + relation_names, + ) + return IDENTITY_RELATION + + return None diff --git a/extensions/memos_prompt_strategy_plugin/example.py b/extensions/memos_prompt_strategy_plugin/example.py new file mode 100644 index 000000000..9e948fcdb --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/example.py @@ -0,0 +1,90 @@ +"""Quick demo — run directly to see identity/relation detection in action. + +Usage: + PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/example.py +""" + +from memos_prompt_strategy_plugin.classifier import MessageClassifier +from memos_prompt_strategy_plugin.strategies import build_identity_relation_prompt + + +def _src(role: str, content: str): + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + +DEMO_CONVERSATIONS = [ + { + "label": "自我介绍 + 亲属关系", + "sources": [_src("user", "你好,我叫王沐辰,我的儿子叫王明泽")], + "mem_str": "你好,我叫王沐辰,我的儿子叫王明泽", + }, + { + "label": "仅自我介绍", + "sources": [_src("user", "我是李明,今年30岁")], + "mem_str": "我是李明,今年30岁", + }, + { + "label": "英文自我介绍 + 关系", + "sources": [_src("user", "Hi, my name is Alice. My son is called Bob.")], + "mem_str": "Hi, my name is Alice. My son is called Bob.", + }, + { + "label": "多种关系", + "sources": [_src("user", "我叫张三,我老婆叫李四,我女儿叫张小花,我妈妈叫王秀英")], + "mem_str": "我叫张三,我老婆叫李四,我女儿叫张小花,我妈妈叫王秀英", + }, + { + "label": "普通闲聊(不应命中)", + "sources": [_src("user", "今天天气不错,出去走走吧")], + "mem_str": "今天天气不错,出去走走吧", + }, + { + "label": "任务型(不应命中)", + "sources": [_src("user", "请帮我安排明天下午3点的会议")], + "mem_str": "请帮我安排明天下午3点的会议", + }, +] + +SEPARATOR = "=" * 72 + + +def main(): + clf = MessageClassifier() + + print(SEPARATOR) + print(" Prompt Strategy Plugin — Identity/Relation Detection Demo") + print(SEPARATOR) + + for data in DEMO_CONVERSATIONS: + label = data["label"] + sources = data["sources"] + mem_str = data["mem_str"] + + category = clf.classify(sources, mem_str, "chat", {}) + hit = category is not None + + print(f"\n{'—' * 72}") + print(f" Scenario: {label}") + print(f" Input : {mem_str[:80]}{'...' if len(mem_str) > 80 else ''}") + print(f" Hit : {'YES → identity_relation' if hit else 'NO → use default prompt'}") + + if hit: + lang = "zh" if any("\u4e00" <= c <= "\u9fff" for c in mem_str) else "en" + prompt = build_identity_relation_prompt(lang=lang, mem_str=mem_str) + print(f" Prompt : {prompt[:120]}...") + + print(f"{'—' * 72}") + + print(f"\n{SEPARATOR}") + print(" Done.") + print(SEPARATOR) + + +if __name__ == "__main__": + main() diff --git a/extensions/memos_prompt_strategy_plugin/hooks.py b/extensions/memos_prompt_strategy_plugin/hooks.py new file mode 100644 index 000000000..75f50cb6f --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/hooks.py @@ -0,0 +1,84 @@ +"""Prompt strategy plugin hook callbacks. + +All callbacks are bound to the plugin instance via functools.partial(callback, plugin_instance). +""" + +from __future__ import annotations + +import logging + +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + +logger = logging.getLogger(__name__) + +_IDENTITY_SUPPLEMENT_ZH = """ + +【特别注意 - 身份与关系提取】 +检测到当前对话包含用户的姓名或亲属/社交关系信息。请在完成上述所有处理步骤的同时, +**额外确保**以下内容被完整提取,绝对不能遗漏: +1. 用户本人的姓名 +2. 用户提及的所有关系人(关系类型 + 姓名) +3. 每个身份/关系信息需要作为独立的记忆条目 +4. tags 中必须包含 "identity" 或 "relationship" +""" + +_IDENTITY_SUPPLEMENT_EN = """ + +[IMPORTANT - Identity & Relationship Extraction] +The current conversation contains the user's name or family/social relationship information. +In addition to all the above processing steps, **make sure** to extract the following completely +— do NOT miss any: +1. The user's own name +2. All people mentioned with their relationship type and name +3. Each identity/relationship should be a separate memory item +4. Tags must include "identity" or "relationship" +""" + + +def on_pre_extract( + plugin: PromptStrategyPlugin, + *, + prompt: str, + prompt_type: str, + mem_str: str, + lang: str, + sources: list, + **_kw: Any, +) -> str | None: + """[mem_reader.pre_extract] If a classifier rule matches: + - For normal extraction: swap in the specialised identity/relation prompt. + - For version pipeline: append identity/relation emphasis to the existing prompt. + If no rule matches, return None to keep the default.""" + category = plugin.classifier.classify(sources, mem_str, prompt_type, info={}) + + if category is None: + return None + + plugin.stats[category] += 1 + logger.info( + "[PromptStrategy] Matched rule: %s | prompt_type=%s, lang=%s, text=%s", + category, + prompt_type, + lang, + mem_str[:120] + ("..." if len(mem_str) > 120 else ""), + ) + + if prompt_type == "version": + supplement = _IDENTITY_SUPPLEMENT_ZH if lang == "zh" else _IDENTITY_SUPPLEMENT_EN + logger.info("[PromptStrategy] Version pipeline — appending identity/relation supplement") + return prompt + supplement + + custom_prompt = plugin.registry.build_prompt( + category=category, + lang=lang, + mem_str=mem_str, + ) + if custom_prompt is not None: + logger.info("[PromptStrategy] Prompt swapped to strategy: %s", category) + return custom_prompt + + return None diff --git a/extensions/memos_prompt_strategy_plugin/plugin.py b/extensions/memos_prompt_strategy_plugin/plugin.py new file mode 100644 index 000000000..ace8d4f4e --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/plugin.py @@ -0,0 +1,44 @@ +"""PromptStrategyPlugin — rule-chain classifier + strategy registry.""" + +from __future__ import annotations + +import logging + +from collections import defaultdict +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H + + +logger = logging.getLogger(__name__) + + +class PromptStrategyPlugin(MemOSPlugin): + name = "prompt_strategy" + version = "0.2.0" + description = "Rule-chain classifier with strategy registry for specialised extraction prompts" + + def on_load(self) -> None: + from memos_prompt_strategy_plugin.classifier import MessageClassifier + from memos_prompt_strategy_plugin.strategies import StrategyRegistry + + self.classifier = MessageClassifier() + self.registry = StrategyRegistry() + self.registry.register_defaults() + self.stats: dict[str, int] = defaultdict(int) + logger.info("[PromptStrategy] plugin loaded") + + def init_app(self) -> None: + from memos_prompt_strategy_plugin.hooks import on_pre_extract + from memos_prompt_strategy_plugin.routes import create_router + + self.register_router(create_router(self)) + self.register_hook(H.MEM_READER_PRE_EXTRACT, partial(on_pre_extract, self)) + logger.info("[PromptStrategy] plugin initialized") + + def on_shutdown(self) -> None: + logger.info( + "[PromptStrategy] plugin shutdown — stats: %s", + dict(self.stats), + ) diff --git a/extensions/memos_prompt_strategy_plugin/routes.py b/extensions/memos_prompt_strategy_plugin/routes.py new file mode 100644 index 000000000..649b08370 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/routes.py @@ -0,0 +1,34 @@ +"""Admin routes for the Prompt Strategy plugin.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter + + +if TYPE_CHECKING: + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + +def create_router(plugin: PromptStrategyPlugin) -> APIRouter: + router = APIRouter(prefix="/prompt_strategy", tags=["prompt_strategy"]) + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name, "version": plugin.version} + + @router.get("/strategies") + async def list_strategies(): + """Return all registered prompt strategies.""" + return { + name: {"description": s.description} + for name, s in plugin.registry.all_strategies().items() + } + + @router.get("/stats") + async def classification_stats(): + """Return per-category classification hit counts.""" + return {"stats": dict(plugin.stats)} + + return router diff --git a/extensions/memos_prompt_strategy_plugin/strategies.py b/extensions/memos_prompt_strategy_plugin/strategies.py new file mode 100644 index 000000000..00337cdd7 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/strategies.py @@ -0,0 +1,364 @@ +"""Prompt strategy registry — maps classifier categories to specialised prompts. + +Currently only one strategy is registered (identity_relation). To add a new +strategy, create a ``PromptStrategy`` and call ``register()`` or append it to +``_DEFAULT_STRATEGIES``. + +All strategy prompts produce the same JSON output format as the default +mem-reader (memory list + summary) so downstream processing stays unchanged. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PromptStrategy: + name: str + template_en: str + template_zh: str + description: str + + +class StrategyRegistry: + """Registry that maps category labels to prompt strategies.""" + + def __init__(self): + self._strategies: dict[str, PromptStrategy] = {} + + def register(self, strategy: PromptStrategy) -> None: + self._strategies[strategy.name] = strategy + logger.info("[PromptStrategy] Registered strategy: %s", strategy.name) + + def get(self, name: str) -> PromptStrategy | None: + return self._strategies.get(name) + + def all_strategies(self) -> dict[str, PromptStrategy]: + return dict(self._strategies) + + def build_prompt( + self, + category: str, + lang: str, + mem_str: str, + custom_tags: list[str] | None = None, + ) -> str | None: + """Build a prompt for *category*. Returns ``None`` when the category + has no registered strategy (caller should fall back to the default).""" + strategy = self._strategies.get(category) + if strategy is None: + return None + + template = strategy.template_zh if lang == "zh" else strategy.template_en + prompt = template.replace("${conversation}", mem_str) + prompt = prompt.replace("{chunk_text}", mem_str) + + if custom_tags: + tags_instruction = ( + f"\n额外要求:提取的记忆请尽量关联以下标签:{custom_tags}" + if lang == "zh" + else f"\nAdditional: associate extracted memories with these tags: {custom_tags}" + ) + else: + tags_instruction = "" + prompt = prompt.replace("${custom_tags_prompt}", tags_instruction) + prompt = prompt.replace("{custom_tags_prompt}", tags_instruction) + + return prompt + + def register_defaults(self) -> None: + """Register built-in strategies.""" + for strategy in _DEFAULT_STRATEGIES: + self.register(strategy) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Default prompt templates +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +_IDENTITY_RELATION_EN = """\ +You are a memory extraction expert. Extract all kinds of memories, including accurate identity and relationship information about people. + +Your task is to extract memories from the perspective of the user based on the conversation between the user and the assistant. This means identifying information that the user may remember — including the user’s own experiences, thoughts, plans, or relevant statements and actions made by others (such as the assistant) that affect the user or are acknowledged by the user. + +Please perform the following: + +1. If the current conversation contains the user’s self-reported name or information about family/social relationships, the extracted content must precisely include: + - The user’s own name (e.g., "My name is xxx", "I am xxx") + - All related persons mentioned by the user: relationship type + name (e.g., "My son’s name is xxx", "My wife is xxx") + - If there are further relationship descriptions among related persons, extract them as well + - If other content exists, extract it as usual + + Extraction requirements: + - **Absolutely do not omit any name or relationship** + - Use third person ("The user’s son is named Wang Mingze" rather than "My son is Wang Mingze") + - Each identity/relationship item must be extracted as a separate memory + +2. Identify information that reflects the user’s experiences, beliefs, concerns, decisions, plans, or reactions — including meaningful factual information from the assistant that the user acknowledges or responds to. +If the message is from the user, extract memories related to the user. If the message is from the assistant, only extract factual memories that are acknowledged or responded to by the user. + +3. Clearly resolve all references to time, people, and events: + - If possible, use message timestamps to convert relative temporal expressions (such as “yesterday” or “next Friday”) into absolute dates. + - Clearly distinguish between event time and message time. + - If uncertainty exists, explicitly state it (e.g., “around June 2025”, “exact date unclear”). + - If a specific location is mentioned, include it. + - Resolve all pronouns, aliases, and vague references into full names or explicit identities. + - If there are multiple people with the same name, distinguish them clearly. + +4. Always write from the third-person perspective, using “the user” or the user’s name to refer to the user, rather than first person (“I”, “we”, “my”). +For example, write “The user feels tired...” rather than “I feel tired...”. + +5. Do not omit any information that the user may remember. + - Include all key experiences, thoughts, emotional reactions, and plans — even if they seem minor. + - Prioritize completeness and fidelity over brevity. + - Do not generalize or skip details that may have personal significance to the user. + +6. Please avoid including any content in the extracted memories that violates laws or regulations or involves politically sensitive information. + +Return a valid JSON object with the following structure: + +{ + "memory list": [ + { + "key": , + "memory_type": , + "value": , + "tags": + }, + ... + ], + "summary": +} + +Language rules: +- The fields `key`, `value`, `tags`, and `summary` must match the main language of the input conversation. **If the input is Chinese, output in Chinese.** +- `memory_type` must remain in English. + +${custom_tags_prompt} + +Example: +Conversation: +user: [June 26, 2025, 3:00 PM]: Hi Jerry, my name is Tom! Yesterday at 3:00 PM I had a meeting with my team to discuss a new project. +assistant: Do you think the team can finish by December 15? +user: [June 26, 2025, 3:00 PM]: I’m a little worried. The backend won’t be finished until December 10, so testing time will be tight. +assistant: [June 26, 2025, 3:00 PM]: Maybe suggest postponing it? +user: [June 26, 2025, 4:21 PM]: Good idea. I’ll bring it up at tomorrow’s 9:30 AM meeting — maybe push the deadline to January 5. + +Output: +{ + "memory list": [ + { + "key": "User name", + "memory_type": "UserMemory", + "value": "The user’s name is Tom.", + "tags": ["identity information", "name"] + }, + { + "key": "Initial project meeting", + "memory_type": "LongTermMemory", + "value": "On June 25, 2025 at 3:00 PM, Tom met with his team to discuss a new project. The meeting involved the timeline and raised concerns about whether the December 15, 2025 deadline was feasible.", + "tags": ["project", "timeline", "meeting", "deadline"] + }, + { + "key": "Planned deadline adjustment", + "memory_type": "UserMemory", + "value": "Tom plans to suggest at the June 27, 2025 9:30 AM meeting that the team reprioritize work and postpone the project deadline to January 5, 2026.", + "tags": ["plan", "deadline change", "prioritization"] + } + ], + "summary": "Tom is currently focused on managing a new project with a tight schedule. After the team meeting on June 25, 2025, he realized that the original December 15, 2025 deadline might not be achievable because the backend is expected to be completed only by December 10, leaving very little time for testing. Because of this concern, Tom accepted Jerry’s suggestion to propose a delay. He plans to raise the idea of postponing the deadline to January 5, 2026 at the next morning’s meeting. His actions reflect concern about the timeline as well as a proactive, team-oriented approach to problem solving." +} + +Input: + user: [July 1, 2025, 10:00 AM]: My name is Li Ming. My wife’s name is Wang Ting, and my son’s name is Li Haoran. Next week we are planning to travel to Shanghai together. + assistant: That sounds great. How many days are you planning to stay? + user: [July 1, 2025, 10:05 AM]: About three days. + +Output: +{ + "memory list": [ + { + "key": "User name", + "memory_type": "UserMemory", + "value": "The user’s name is Li Ming.", + "tags": ["identity information", "name"] + }, + { + "key": "Spouse's name", + "memory_type": "UserMemory", + "value": "The user’s wife is named Wang Ting.", + "tags": ["relationship information", "wife", "name"] + }, + { + "key": "Son's name", + "memory_type": "UserMemory", + "value": "The user’s son is named Li Haoran.", + "tags": ["relationship information", "son", "name"] + }, + { + "key": "Family travel plan", + "memory_type": "LongTermMemory", + "value": "The user plans to travel to Shanghai together with his wife Wang Ting and son Li Haoran during the week following July 1, 2025, and expects the trip to last about three days. The exact departure date is not specified.", + "tags": ["travel", "family", "plan", "Shanghai"] + } + ], + "summary": "Li Ming is planning a family trip to Shanghai in the week after July 1, 2025, and expects the trip to last about three days. The conversation explicitly states that the user’s wife is named Wang Ting and the user’s son is named Li Haoran. This indicates that the user has a near-term travel plan involving close family members." +} + +Please always reply in the same language as the conversation. + +Conversation: +${conversation} + +Your output: +""" + +_IDENTITY_RELATION_ZH = """\ + +您是记忆提取专家,提取各类记忆,包括准确的人物身份和关系信息。 +您的任务是根据用户与助手之间的对话,从用户的角度提取记忆。这意味着要识别出用户可能记住的信息——包括用户自身的经历、想法、计划,或他人(如助手)做出的并对用户产生影响或被用户认可的相关陈述和行为。 + +请执行以下操作: +1. 如果当前对话中包含用户自述的姓名或亲属/社交关系信息。你提取的内容需要精确包括 + - 用户本人的姓名(如"我叫xxx"、"我是xxx") + - 用户提及的所有关系人:关系类型 + 姓名(如"我的儿子叫xxx"、"我老婆是xxx") + - 关系人之间如果存在进一步的关系描述,也要提取; + - 其他内容如果存在照常提取; + 提取要求: + - **绝对不能遗漏任何人名和关系** + - 使用第三人称("用户的儿子叫王明泽"而非"我的儿子叫王明泽") + - 每组身份/关系信息单独作为一条记忆 + +2. 识别反映用户经历、信念、关切、决策、计划或反应的信息——包括用户认可或回应的来自助手的有意义信息。 +如果消息来自用户,请提取与用户相关的记忆;如果来自助手,则仅提取用户认可或回应的事实性记忆。 + +3. 清晰解析所有时间、人物和事件的指代: + - 如果可能,使用消息时间戳将相对时间表达(如“昨天”、“下周五”)转换为绝对日期。 + - 明确区分事件时间和消息时间。 + - 如果存在不确定性,需明确说明(例如,“约2025年6月”,“具体日期不详”)。 + - 若提及具体地点,请包含在内。 + - 将所有代词、别名和模糊指代解析为全名或明确身份。 + - 如有同名人物,需加以区分。 + +4. 始终以第三人称视角撰写,使用“用户”或提及的姓名来指代用户,而不是使用第一人称(“我”、“我们”、“我的”)。 +例如,写“用户感到疲惫……”而不是“我感到疲惫……”。 + +5. 不要遗漏用户可能记住的任何信息。 + - 包括所有关键经历、想法、情绪反应和计划——即使看似微小。 + - 优先考虑完整性和保真度,而非简洁性。 + - 不要泛化或跳过对用户具有个人意义的细节。 + +6. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 + +返回一个有效的JSON对象,结构如下: + +{ + "memory list": [ + { + "key": <字符串,唯一且简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, + "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, + "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])> + }, + ... + ], + "summary": <从用户视角自然总结上述记忆的段落,120–200字,与输入语言一致> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 + +${custom_tags_prompt} + +示例: +对话: +user: [2025年6月26日下午3:00]:嗨Jerry,我叫Tom!昨天下午3点我和团队开了个会,讨论新项目。 +assistant: 你觉得团队能在12月15日前完成吗? +user: [2025年6月26日下午3:00]:我有点担心。后端要到12月10日才能完成,所以测试时间会很紧。 +assistant: [2025年6月26日下午3:00]:也许提议延期? +user: [2025年6月26日下午4:21]:好主意。我明天上午9:30的会上提一下——也许把截止日期推迟到1月5日。 + +输出: +{ + "memory list": [ + { + "key": "用户姓名", + "memory_type": "UserMemory", + "value": "用户名叫Tom。", + "tags": ["身份信息", "姓名"] + }, + { + "key": "项目初期会议", + "memory_type": "LongTermMemory", + "value": "2025年6月25日下午3:00,Tom与团队开会讨论新项目。会议涉及时间表,并提出了对2025年12月15日截止日期可行性的担忧。", + "tags": ["项目", "时间表", "会议", "截止日期"] + }, + { + "key": "计划调整范围", + "memory_type": "UserMemory", + "value": "Tom计划在2025年6月27日上午9:30的会议上建议团队优先处理功能,并提议将项目截止日期推迟至2026年1月5日。", + "tags": ["计划", "截止日期变更", "功能优先级"] + } + ], + "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" +} + +输入: + user: [2025年7月1日上午10:00]:我叫李明,我老婆叫王婷,我儿子叫李浩然。下周我们打算一起去上海旅游。 + assistant: 听起来很不错,你们准备去几天? + user: [2025年7月1日上午10:05]:大概三天。 + +输出: +{ + "memory list": [ + { + "key": "用户姓名", + "memory_type": "UserMemory", + "value": "用户名叫李明。", + "tags": ["身份信息", "姓名"] + }, + { + "key": "配偶姓名", + "memory_type": "UserMemory", + "value": "用户的妻子叫王婷。", + "tags": ["关系信息", "妻子", "姓名"] + }, + { + "key": "儿子姓名", + "memory_type": "UserMemory", + "value": "用户的儿子叫李浩然。", + "tags": ["关系信息", "儿子", "姓名"] + }, + { + "key": "家庭出行计划", + "memory_type": "LongTermMemory", + "value": "用户计划于2025年7月8日所在周与妻子王婷和儿子李浩然一起前往上海旅游,预计行程约三天。具体出发日期未明确。", + "tags": ["旅行", "家庭", "计划", "上海"] + } + ], + "summary": "李明计划在2025年7月1日之后的下一周与家人一起去上海旅游,预计停留约三天。对话中明确提到用户的妻子名叫王婷,儿子名叫李浩然。这表明用户近期有一项与家庭相关的出行安排。" +} + +请始终使用与对话相同的语言进行回复。 + +对话: +${conversation} + +您的输出: +""" + +_DEFAULT_STRATEGIES = [ + PromptStrategy( + name="identity_relation", + template_en=_IDENTITY_RELATION_EN, + template_zh=_IDENTITY_RELATION_ZH, + description="Precise extraction of names and family/social relationships", + ), +] diff --git a/extensions/memos_prompt_strategy_plugin/tests/__init__.py b/extensions/memos_prompt_strategy_plugin/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/extensions/memos_prompt_strategy_plugin/tests/conftest.py b/extensions/memos_prompt_strategy_plugin/tests/conftest.py new file mode 100644 index 000000000..1729d8fdf --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/conftest.py @@ -0,0 +1,9 @@ +"""Ensure hooks used by PromptStrategyPlugin are declared for testing.""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") + +import memos.plugins.hook_defs # noqa: E402, F401 — triggers CE hook declarations diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py b/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py new file mode 100644 index 000000000..930a83fb9 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py @@ -0,0 +1,112 @@ +"""Tests for the single-rule identity/relation classifier.""" + +from memos_prompt_strategy_plugin.classifier import ( + IDENTITY_RELATION, + MessageClassifier, +) + + +def _src(role: str, content: str): + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + +class TestIdentityRelationRule: + def setup_method(self): + self.clf = MessageClassifier() + + # ── Chinese: self-naming ──────────────────────────────────── + + def test_wo_jiao(self): + sources = [_src("user", "你好,我叫王沐辰")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_wo_shi(self): + sources = [_src("user", "我是李明,今年30岁")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_wo_de_mingzi_shi(self): + sources = [_src("user", "我的名字是张三")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── Chinese: relation naming ──────────────────────────────── + + def test_son(self): + sources = [_src("user", "我的儿子叫王明泽")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_daughter(self): + sources = [_src("user", "我女儿叫小红")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_wife(self): + sources = [_src("user", "我老婆是刘芳")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_mother(self): + sources = [_src("user", "我妈妈叫李秀英")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_friend(self): + sources = [_src("user", "我朋友叫赵磊")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_pet(self): + sources = [_src("user", "我的猫叫小花")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── Chinese: combined self + relation ─────────────────────── + + def test_combined(self): + sources = [_src("user", "我叫王沐辰,我的儿子叫王明泽")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── English ───────────────────────────────────────────────── + + def test_my_name_is(self): + sources = [_src("user", "Hi, my name is Alice")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_im(self): + sources = [_src("user", "I'm Bob, nice to meet you")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_call_me(self): + sources = [_src("user", "Just call me Charlie")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_my_son_is(self): + sources = [_src("user", "My son is called David")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_my_wife_is(self): + sources = [_src("user", "My wife's name is Emma")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── No match → None ───────────────────────────────────────── + + def test_no_identity_returns_none(self): + sources = [_src("user", "今天天气不错")] + assert self.clf.classify(sources, "", "chat", {}) is None + + def test_task_text_returns_none(self): + sources = [_src("user", "请帮我安排明天的会议")] + assert self.clf.classify(sources, "", "chat", {}) is None + + def test_code_returns_none(self): + sources = [_src("user", "```python\nprint('hello')\n```")] + assert self.clf.classify(sources, "", "chat", {}) is None + + def test_empty_returns_none(self): + assert self.clf.classify([], "", "chat", {}) is None + + # ── mem_str fallback ──────────────────────────────────────── + + def test_uses_mem_str_when_no_sources(self): + result = self.clf.classify([], "我叫王沐辰,我的儿子叫王明泽", "chat", {}) + assert result == IDENTITY_RELATION diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py new file mode 100644 index 000000000..e68ba8041 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py @@ -0,0 +1,189 @@ +"""Tests for PromptStrategyPlugin lifecycle and hook integration.""" + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def _make_app(): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + app = FastAPI() + plugin = PromptStrategyPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + return app, plugin + + +def _src(role, content): + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + +class TestPluginLifecycle: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_metadata(self): + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + plugin = PromptStrategyPlugin() + assert plugin.name == "prompt_strategy" + assert plugin.version == "0.2.0" + + def test_on_load_initialises_components(self): + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + plugin = PromptStrategyPlugin() + plugin.on_load() + assert plugin.classifier is not None + assert plugin.registry is not None + assert len(plugin.registry.all_strategies()) > 0 + assert dict(plugin.stats) == {} + + def test_full_lifecycle(self): + app, plugin = _make_app() + paths = [r.path for r in app.routes] + assert "/prompt_strategy/health" in paths + assert "/prompt_strategy/strategies" in paths + assert "/prompt_strategy/stats" in paths + plugin.on_shutdown() + + +class TestPluginRoutes: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_health(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/prompt_strategy/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["plugin"] == "prompt_strategy" + + def test_strategies_list(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/prompt_strategy/strategies") + assert resp.status_code == 200 + strategies = resp.json() + assert "identity_relation" in strategies + + def test_stats_empty(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/prompt_strategy/stats") + assert resp.status_code == 200 + assert resp.json()["stats"] == {} + + +class TestHookIntegration: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_pre_extract_swaps_prompt_for_identity(self): + """When identity/relation pattern is detected, the prompt is swapped.""" + from memos.plugins.hooks import trigger_hook + + _, plugin = _make_app() + + result = trigger_hook( + "mem_reader.pre_extract", + prompt="original prompt", + prompt_type="chat", + mem_str="我叫王沐辰,我的儿子叫王明泽", + lang="zh", + sources=[], + ) + assert result != "original prompt" + assert "王沐辰" in result + assert "王明泽" in result + assert plugin.stats["identity_relation"] >= 1 + + def test_pre_extract_preserves_prompt_for_normal_text(self): + """When no classifier rule matches, original prompt passes through.""" + from memos.plugins.hooks import trigger_hook + + _make_app() + + sources = [_src("user", "今天天气不错,出去走走吧")] + result = trigger_hook( + "mem_reader.pre_extract", + prompt="original prompt", + prompt_type="chat", + mem_str="今天天气不错,出去走走吧", + lang="zh", + sources=sources, + ) + assert result == "original prompt" + + def test_pre_extract_english_identity(self): + from memos.plugins.hooks import trigger_hook + + _, plugin = _make_app() + + result = trigger_hook( + "mem_reader.pre_extract", + prompt="original prompt", + prompt_type="chat", + mem_str="My name is Alice and my son is Bob", + lang="en", + sources=[], + ) + assert result != "original prompt" + assert "Alice" in result + assert plugin.stats["identity_relation"] >= 1 + + def test_pre_extract_version_pipeline_appends_supplement(self): + """When prompt_type='version', the plugin appends identity emphasis + instead of replacing the entire prompt.""" + from memos.plugins.hooks import trigger_hook + + _, plugin = _make_app() + + version_prompt = "...existing version prompt with candidates..." + result = trigger_hook( + "mem_reader.pre_extract", + prompt=version_prompt, + prompt_type="version", + mem_str="我叫王沐辰,我的儿子叫王明泽", + lang="zh", + sources=[], + ) + assert version_prompt in result + assert "身份" in result or "关系" in result + assert plugin.stats["identity_relation"] >= 1 + + def test_pre_extract_version_pipeline_no_match(self): + """When prompt_type='version' but no identity pattern, prompt unchanged.""" + from memos.plugins.hooks import trigger_hook + + _make_app() + + version_prompt = "...existing version prompt with candidates..." + result = trigger_hook( + "mem_reader.pre_extract", + prompt=version_prompt, + prompt_type="version", + mem_str="今天天气不错", + lang="zh", + sources=[], + ) + assert result == version_prompt diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py b/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py new file mode 100644 index 000000000..c0aaf5a97 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py @@ -0,0 +1,61 @@ +"""Tests for StrategyRegistry and prompt building.""" + +from memos_prompt_strategy_plugin.strategies import PromptStrategy, StrategyRegistry + + +class TestStrategyRegistry: + def setup_method(self): + self.reg = StrategyRegistry() + self.reg.register_defaults() + + def test_default_strategies_registered(self): + strategies = self.reg.all_strategies() + assert "identity_relation" in strategies + + def test_build_prompt_returns_none_for_unknown(self): + result = self.reg.build_prompt("nonexistent_category", "en", "hello") + assert result is None + + def test_build_prompt_zh(self): + prompt = self.reg.build_prompt("identity_relation", "zh", "我叫王沐辰,我的儿子叫王明泽") + assert prompt is not None + assert "王沐辰" in prompt + assert "王明泽" in prompt + assert "身份" in prompt or "关系" in prompt + + def test_build_prompt_en(self): + prompt = self.reg.build_prompt("identity_relation", "en", "My name is Alice, my son is Bob") + assert prompt is not None + assert "Alice" in prompt + assert "Bob" in prompt + assert "identity" in prompt.lower() or "relationship" in prompt.lower() + + def test_build_prompt_with_custom_tags(self): + prompt = self.reg.build_prompt( + "identity_relation", "zh", "我叫张三", custom_tags=["family", "name"] + ) + assert prompt is not None + assert "family" in prompt + assert "name" in prompt + + def test_custom_strategy_registration(self): + custom = PromptStrategy( + name="custom_test", + template_en="Extract from: ${conversation} ${custom_tags_prompt}", + template_zh="提取:${conversation} ${custom_tags_prompt}", + description="Test strategy", + ) + self.reg.register(custom) + prompt = self.reg.build_prompt("custom_test", "en", "hello world") + assert prompt is not None + assert "hello world" in prompt + + +class TestStrategyRegistryIsolation: + def test_empty_registry_returns_none(self): + reg = StrategyRegistry() + assert reg.build_prompt("identity_relation", "en", "hi") is None + + def test_get_unknown_returns_none(self): + reg = StrategyRegistry() + assert reg.get("nonexistent") is None diff --git a/poetry.lock b/poetry.lock index ba31d1a31..dccd154d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -68,10 +68,10 @@ trio = ["trio (>=0.26.1)"] name = "async-timeout" version = "5.0.1" description = "Timeout context manager for asyncio programs" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "python_full_version < \"3.11.3\" and (extra == \"mem-scheduler\" or extra == \"all\")" +markers = "python_full_version < \"3.11.3\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -394,10 +394,9 @@ files = [ name = "chonkie" version = "1.1.1" description = "🦛 CHONK your texts with Chonkie ✨ - The no-nonsense chunking library" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"mem-reader\" or extra == \"all\"" files = [ {file = "chonkie-1.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c56cff89f38ff5cc06b2e8b4c9a802b85b77ba8ecdda6896f5dba6b0c54d4303"}, {file = "chonkie-1.1.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8bb0d88b4254a9bac7b494349ee3c94f94d3ee2f8cd4970d23e0c0ef3e6392a4"}, @@ -1296,7 +1295,6 @@ files = [ {file = "grpcio-1.73.1-cp39-cp39-win_amd64.whl", hash = "sha256:42f0660bce31b745eb9d23f094a332d31f210dcadd0fc8e5be7e4c62a87ce86b"}, {file = "grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87"}, ] -markers = {main = "extra == \"pref-mem\" or extra == \"all\""} [package.extras] protobuf = ["grpcio-tools (>=1.73.1)"] @@ -1607,10 +1605,9 @@ files = [ name = "jieba" version = "0.42" description = "Chinese Words Segmentation Utilities" -optional = true +optional = false python-versions = "*" groups = ["main"] -markers = "extra == \"all\"" files = [ {file = "jieba-0.42.tar.gz", hash = "sha256:34a3c960cc2943d9da16d6d2565110cf5f305921a67413dddf04f84de69c939b"}, ] @@ -3235,7 +3232,6 @@ files = [ {file = "pandas-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b4b0de34dc8499c2db34000ef8baad684cfa4cbd836ecee05f323ebfba348c7d"}, {file = "pandas-2.3.1.tar.gz", hash = "sha256:0a95b9ac964fe83ce317827f80304d37388ea77616b1425f0ae41c9d2d0d7bb2"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [package.dependencies] numpy = [ @@ -3298,10 +3294,9 @@ image = ["Pillow"] name = "pika" version = "1.3.2" description = "Pika Python AMQP Client Library" -optional = true +optional = false python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"mem-scheduler\" or extra == \"all\"" files = [ {file = "pika-1.3.2-py3-none-any.whl", hash = "sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f"}, {file = "pika-1.3.2.tar.gz", hash = "sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f"}, @@ -3568,7 +3563,83 @@ files = [ {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} + +[[package]] +name = "psycopg2-binary" +version = "2.9.11" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "psycopg2-binary-2.9.11.tar.gz", hash = "sha256:b6aed9e096bf63f9e75edf2581aa9a7e7186d97ab5c177aa6c87797cd591236c"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6fe6b47d0b42ce1c9f1fa3e35bb365011ca22e39db37074458f27921dca40f2"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a6c0e4262e089516603a09474ee13eabf09cb65c332277e39af68f6233911087"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c47676e5b485393f069b4d7a811267d3168ce46f988fa602658b8bb901e9e64d"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a28d8c01a7b27a1e3265b11250ba7557e5f72b5ee9e5f3a2fa8d2949c29bf5d2"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5f3f2732cf504a1aa9e9609d02f79bea1067d99edf844ab92c247bbca143303b"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:865f9945ed1b3950d968ec4690ce68c55019d79e4497366d36e090327ce7db14"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:91537a8df2bde69b1c1db01d6d944c831ca793952e4f57892600e96cee95f2cd"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4dca1f356a67ecb68c81a7bc7809f1569ad9e152ce7fd02c2f2036862ca9f66b"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:0da4de5c1ac69d94ed4364b6cbe7190c1a70d325f112ba783d83f8440285f152"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37d8412565a7267f7d79e29ab66876e55cb5e8e7b3bbf94f8206f6795f8f7e7e"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:c665f01ec8ab273a61c62beeb8cce3014c214429ced8a308ca1fc410ecac3a39"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0e8480afd62362d0a6a27dd09e4ca2def6fa50ed3a4e7c09165266106b2ffa10"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:763c93ef1df3da6d1a90f86ea7f3f806dc06b21c198fa87c3c25504abec9404a"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f090b7ddd13ca842ebfe301cd587a76a4cf0913b1e429eb92c1be5dbeb1a19bc"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:366df99e710a2acd90efed3764bb1e28df6c675d33a7fb40df9b7281694432ee"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:92e3b669236327083a2e33ccfa0d320dd01b9803b3e14dd986a4fc54aa00f4e1"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e0deeb03da539fa3577fcb0b3f2554a97f7e5477c246098dbb18091a4a01c16f"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b52a3f9bb540a3e4ec0f6ba6d31339727b2950c9772850d6545b7eae0b9d7c5"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:db4fd476874ccfdbb630a54426964959e58da4c61c9feba73e6094d51303d7d8"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47f212c1d3be608a12937cc131bd85502954398aaa1320cb4c14421a0ffccf4c"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e35b7abae2b0adab776add56111df1735ccc71406e56203515e228a8dc07089f"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fcf21be3ce5f5659daefd2b3b3b6e4727b028221ddc94e6c1523425579664747"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:9bd81e64e8de111237737b29d68039b9c813bdf520156af36d26819c9a979e5f"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:32770a4d666fbdafab017086655bcddab791d7cb260a16679cc5a7338b64343b"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3cb3a676873d7506825221045bd70e0427c905b9c8ee8d6acd70cfcbd6e576d"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-win_amd64.whl", hash = "sha256:4012c9c954dfaccd28f94e84ab9f94e12df76b4afb22331b1f0d3154893a6316"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:20e7fb94e20b03dcc783f76c0865f9da39559dcc0c28dd1a3fce0d01902a6b9c"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4bdab48575b6f870f465b397c38f1b415520e9879fdf10a53ee4f49dcbdf8a21"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9d3a9edcfbe77a3ed4bc72836d466dfce4174beb79eda79ea155cc77237ed9e8"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:44fc5c2b8fa871ce7f0023f619f1349a0aa03a0857f2c96fbc01c657dcbbdb49"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9c55460033867b4622cda1b6872edf445809535144152e5d14941ef591980edf"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:2d11098a83cca92deaeaed3d58cfd150d49b3b06ee0d0852be466bf87596899e"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:691c807d94aecfbc76a14e1408847d59ff5b5906a04a23e12a89007672b9e819"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b81627b691f29c4c30a8f322546ad039c40c328373b11dff7490a3e1b517855"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:b637d6d941209e8d96a072d7977238eea128046effbf37d1d8b2c0764750017d"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:41360b01c140c2a03d346cec3280cf8a71aa07d94f3b1509fa0161c366af66b4"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"}, +] [[package]] name = "pycparser" @@ -3837,10 +3908,9 @@ windows-terminal = ["colorama (>=0.4.6)"] name = "pymilvus" version = "2.6.2" description = "Python Sdk for Milvus" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"pref-mem\" or extra == \"all\"" files = [ {file = "pymilvus-2.6.2-py3-none-any.whl", hash = "sha256:933e447e09424d490dcf595053b01a7277dadea7ae3235cd704363bd6792509d"}, {file = "pymilvus-2.6.2.tar.gz", hash = "sha256:b4802cc954de8f2d47bf8d6230e92196514dcb8a3726ba6098dc27909d4bc8e3"}, @@ -4033,7 +4103,6 @@ files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\" or extra == \"pref-mem\""} [[package]] name = "pywin32" @@ -4200,10 +4269,9 @@ dev = ["pytest"] name = "redis" version = "6.2.0" description = "Python client for Redis database and key-value store" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"mem-scheduler\" or extra == \"all\"" files = [ {file = "redis-6.2.0-py3-none-any.whl", hash = "sha256:c8ddf316ee0aab65f04a11229e94a64b2618451dab7a67cb2f77eb799d872d5e"}, {file = "redis-6.2.0.tar.gz", hash = "sha256:e821f129b75dde6cb99dd35e5c76e8c49512a5a0d8dfdc560b2fbd44b85ca977"}, @@ -5080,7 +5148,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5638,7 +5706,6 @@ files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "ujson" @@ -6373,4 +6440,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "faff240c05a74263a404e8d9324ffd2f342cb4f0a4c1f5455b87349f6ccc61a5" +content-hash = "0fd4408ce33b59ac489d4d9b0e632bb17538853f049a4edf17425ba83027b74a" diff --git a/pyproject.toml b/pyproject.toml index 9f17c0000..fc10df0a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.8" +version = "2.0.9.post" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" @@ -48,6 +48,12 @@ dependencies = [ "python-dateutil (>=2.9.0.post0,<3.0.0)", "prometheus-client (>=0.23.1,<0.24.0)", "concurrent-log-handler (>=0.9.28,<1.0.0)", # Process-safe rotating file handler + "redis (>=6.2.0,<7.0.0)", # Key-value store + "pika (>=1.3.2,<2.0.0)", # RabbitMQ client + "jieba (>=0.38.1,<0.42.1)", # Chinese text segmentation + "chonkie (>=1.0.7,<2.0.0)", # Sentence chunking + "pymilvus (>=2.5.12,<3.0.0)", # Milvus vector DB + "psycopg2-binary (>=2.9.9,<3.0.0)", # PostgreSQL / PolarDB driver ] [project.urls] @@ -62,6 +68,10 @@ issues = "https://github.com/MemTensor/MemOS/issues" [project.scripts] memos = "memos.cli:main" +[project.entry-points."memos.plugins"] +demo = "memos_demo_plugin:DemoPlugin" +prompt_strategy = "memos_prompt_strategy_plugin:PromptStrategyPlugin" + [project.optional-dependencies] # These are optional dependencies for various features of MemoryOS. # Developers install: `poetry install --extras `. e.g., `poetry install --extras general-mem` @@ -153,7 +163,11 @@ build-backend = "poetry.core.masonry.api" # https://python-poetry.org/docs/dependency-specification#caret-requirements ############################################################################## -packages = [{include = "memos", from = "src"}] +packages = [ + {include = "memos", from = "src"}, + {include = "memos_demo_plugin", from = "extensions"}, + {include = "memos_prompt_strategy_plugin", from = "extensions"}, +] requires-poetry = ">=2.0" dependencies = { "python" = ">=3.10,<4.0" } diff --git a/scripts/check-public-push.sh b/scripts/check-public-push.sh new file mode 100755 index 000000000..0e18cff55 --- /dev/null +++ b/scripts/check-public-push.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash +# Pre-push hook: block private files from being pushed to the public repo. +# Private paths are read from .private-paths (one per line, # comments allowed). +# Installed by `make install` into .git/hooks/pre-push. + +REMOTE_NAME="$1" +REMOTE_URL="$2" + +# Only enforce on the public remote (skip MemOS-Enterprise) +if [[ "${REMOTE_URL}" != *"MemTensor/MemOS.git"* ]] || [[ "${REMOTE_URL}" == *"MemOS-Enterprise"* ]]; then + exit 0 +fi + +PRIVATE_PATHS_FILE=".private-paths" +if [ ! -f "${PRIVATE_PATHS_FILE}" ]; then + echo "⚠️ ${PRIVATE_PATHS_FILE} not found — skipping private-path check." + exit 0 +fi + +# Read private paths into regex patterns +PATTERNS=() +while IFS= read -r line; do + line="$(echo "${line}" | sed 's/#.*//; s/^[[:space:]]*//; s/[[:space:]]*$//')" + [ -z "${line}" ] && continue + # Convert path to regex: strip trailing /, add ^ anchor + pattern="^$(echo "${line}" | sed 's|/$||')" + PATTERNS+=("${pattern}") +done < "${PRIVATE_PATHS_FILE}" + +ERRORS=0 + +while read local_ref local_sha remote_ref remote_sha; do + # Skip delete operations + if [ "${local_sha}" = "0000000000000000000000000000000000000000" ]; then + continue + fi + + # For new remote refs, compare against public/main + if [ "${remote_sha}" = "0000000000000000000000000000000000000000" ]; then + base=$(git merge-base public/main "${local_sha}" 2>/dev/null || echo "public/main") + range="${base}..${local_sha}" + else + range="${remote_sha}..${local_sha}" + fi + + files=$(git diff --name-only "${range}" 2>/dev/null || true) + if [ -z "${files}" ]; then + continue + fi + + for pattern in "${PATTERNS[@]}"; do + matched=$(echo "${files}" | grep -E "${pattern}" || true) + if [ -n "${matched}" ]; then + echo "❌ BLOCKED: Private files detected in push to public repo!" + echo "" + echo " Pattern: ${pattern}" + echo " Files:" + echo "${matched}" | sed 's/^/ /' + echo "" + ERRORS=1 + fi + done +done + +if [ "${ERRORS}" -ne 0 ]; then + echo "💡 Use 'git sync-public \"\"' to safely sync CE code." + exit 1 +fi + +exit 0 diff --git a/scripts/sync-public.sh b/scripts/sync-public.sh new file mode 100755 index 000000000..b71526fd0 --- /dev/null +++ b/scripts/sync-public.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Sync CE-only changes from the enterprise repo to the public repo. +# Private paths are read from .private-paths (one per line, # comments allowed). +# +# Usage: +# git sync-public "" [commit-ref] +# make sync-public msg="" [commit=] + +PUBLIC_REMOTE="public" +PRIVATE_PATHS_FILE=".private-paths" + +CE_MSG="${1:?Usage: git sync-public \"\" [commit-ref]}" +COMMIT="${2:-HEAD}" +EE_BRANCH="$(git branch --show-current)" +PUBLIC_BRANCH="public-$(echo "${EE_BRANCH}" | tr '/' '-')" + +# Read private paths from config file +if [ ! -f "${PRIVATE_PATHS_FILE}" ]; then + echo "❌ ${PRIVATE_PATHS_FILE} not found. Cannot determine private paths." + exit 1 +fi + +EXCLUDE_ARGS="" +while IFS= read -r line; do + line="$(echo "${line}" | sed 's/#.*//; s/^[[:space:]]*//; s/[[:space:]]*$//')" + [ -z "${line}" ] && continue + EXCLUDE_ARGS="${EXCLUDE_ARGS} ':!${line}'" +done < "${PRIVATE_PATHS_FILE}" + +git fetch "${PUBLIC_REMOTE}" main + +# Find CE files changed in the specified commit +CE_FILES=$(eval git diff --name-only "${COMMIT}^..${COMMIT}" -- . ${EXCLUDE_ARGS}) + +if [ -z "${CE_FILES}" ]; then + echo "✅ No CE changes in commit $(git rev-parse --short "${COMMIT}"). Done." + exit 0 +fi + +echo "▶ CE changes from $(git log -1 --format='%h %s' "${COMMIT}"):" +echo "${CE_FILES}" | sed 's/^/ /' + +# Reuse existing public branch or create from public/main +if git show-ref --verify --quiet "refs/heads/${PUBLIC_BRANCH}"; then + git checkout "${PUBLIC_BRANCH}" +else + git checkout -B "${PUBLIC_BRANCH}" "${PUBLIC_REMOTE}/main" +fi + +# Checkout CE files from the enterprise commit +echo "${CE_FILES}" | xargs git checkout "${COMMIT}" -- + +# If there is no staged diff, the CE file content is already present on public branch. +if git diff --cached --quiet; then + echo "✅ All CE changes already synced on ${PUBLIC_BRANCH}. Nothing new to commit." + git checkout "${EE_BRANCH}" + exit 0 +fi + +git commit --no-verify -m "${CE_MSG}" +echo "▶ Pushing ${PUBLIC_BRANCH} to ${PUBLIC_REMOTE}..." +git push "${PUBLIC_REMOTE}" "${PUBLIC_BRANCH}" +git checkout "${EE_BRANCH}" + +echo "" +echo "✅ Done. Create PR:" +echo " https://github.com/MemTensor/MemOS/pull/new/${PUBLIC_BRANCH}" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index 36cc0b5b5..3643c4628 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.8" +__version__ = "2.0.9.post" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 06aa50c65..df3799b5d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -338,6 +338,26 @@ def get_memreader_config() -> dict[str, Any]: }, } + @staticmethod + def get_qwen_llm_config() -> dict[str, Any] | None: + if not os.getenv("QWEN_API_KEY"): + return None + return { + "backend": "qwen", + "config": { + "model_name_or_path": os.getenv("QWEN_MODEL", "qwen-flash"), + "temperature": float(os.getenv("QWEN_TEMPERATURE", "0.8")), + "max_tokens": int(os.getenv("QWEN_MAX_TOKENS", "8000")), + "top_p": float(os.getenv("QWEN_TOP_P", "0.9")), + "top_k": int(os.getenv("QWEN_TOP_K", "50")), + "remove_think_prefix": os.getenv("QWEN_REMOVE_THINK_PREFIX", "true").lower() + == "true", + "api_key": os.getenv("QWEN_API_KEY", ""), + "api_base": os.getenv("QWEN_API_BASE", ""), + "model_schema": os.getenv("QWEN_MODEL_SCHEMA", "memos.configs.llm.QwenLLMConfig"), + }, + } + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. @@ -608,6 +628,7 @@ def get_oss_config() -> dict[str, Any] | None: return config + @staticmethod def get_internet_config() -> dict[str, Any]: """Get embedder configuration.""" reader_config = APIConfig.get_reader_config() @@ -913,6 +934,7 @@ def get_product_default_config() -> dict[str, Any]: "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), + "qwen_llm": APIConfig.get_qwen_llm_config(), # General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.) "general_llm": APIConfig.get_memreader_general_llm_config(), # Image parser LLM (requires vision model) @@ -947,6 +969,7 @@ def get_product_default_config() -> dict[str, Any]: "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/" ), }, + "memory_version_switch": os.getenv("MEM_READER_MEM_VERSION_SWITCH", "off"), }, }, "enable_textual_memory": True, diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index e9ed4f955..4240545f6 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable from memos.types import MessageList @@ -37,6 +38,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index aa2525878..a4cfcc77f 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -35,6 +35,7 @@ from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -171,9 +172,20 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) + memory_history_manager = MemoryHistoryManager( + nli_client=nli_client, + graph_db=graph_db, + llm=llm, + embedder=embedder, + pre_update_retriever=pre_update_retriever, + ) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( @@ -249,6 +261,7 @@ def init_server() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, + history_manager=memory_history_manager, pref_feedback=True, ) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index ee88ae639..c8024baa3 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -141,6 +141,7 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): sources = item.get("metadata", {}).get("sources", []) if ( item["metadata"]["memory_type"] != "RawFileMemory" + and sources and len(sources) > 0 and "type" in sources[0] and sources[0]["type"] == "file" diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 529a709a4..78185a035 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -9,10 +9,13 @@ from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware from memos.api.routers.server_router import router as server_router +from memos.plugins.manager import plugin_manager load_dotenv() +plugin_manager.discover() + # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -38,6 +41,8 @@ # Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) +plugin_manager.init_app(app) + if __name__ == "__main__": import argparse diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index d4844d73f..9ed791fa8 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator, model_validator @@ -76,6 +76,13 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): default=None, description="Skills directory for the MemReader", ) + memory_version_switch: Literal["on", "off"] = Field( + default="off", + description="Turn on memory version or off", + ) + + # Allow passing additional fields without raising validation errors + model_config = ConfigDict(extra="allow", strict=True) class StrategyStructMemReaderConfig(BaseMemReaderConfig): diff --git a/src/memos/extras/nli_model/client.py b/src/memos/extras/nli_model/client.py index a02dae9f6..c97746baf 100644 --- a/src/memos/extras/nli_model/client.py +++ b/src/memos/extras/nli_model/client.py @@ -1,4 +1,5 @@ import logging +import time import requests @@ -13,9 +14,18 @@ class NLIClient: Client for interacting with the deployed NLI model service. """ - def __init__(self, base_url: str = "http://localhost:32532"): + def __init__( + self, + base_url: str = "http://localhost:32532", + timeout: float = 30.0, + max_retries: int = 3, + backoff_seconds: float = 0.5, + ): self.base_url = base_url.rstrip("/") self.session = requests.Session() + self.timeout = timeout + self.max_retries = max_retries + self.backoff_seconds = backoff_seconds def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: """ @@ -35,27 +45,51 @@ def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult # Match schemas.CompareRequest payload = {"source": source, "targets": targets} - try: - response = self.session.post(url, json=payload, timeout=30) - response.raise_for_status() - data = response.json() - - # Match schemas.CompareResponse - results_str = data.get("results", []) - - results = [] - for res_str in results_str: - try: - results.append(NLIResult(res_str)) - except ValueError: + last_error: Exception | None = None + for attempt in range(1, self.max_retries + 1): + try: + response = self.session.post(url, json=payload, timeout=self.timeout) + response.raise_for_status() + data = response.json() + + results_str = data.get("results", []) + + results = [] + for res_str in results_str: + try: + results.append(NLIResult(res_str)) + except ValueError: + logger.warning( + f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" + ) + results.append(NLIResult.UNRELATED) + + return results + except requests.RequestException as e: + last_error = e + if attempt < self.max_retries: logger.warning( - f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" + "[NLIClient] Request failed (attempt %s/%s) url=%s targets=%s error=%s", + attempt, + self.max_retries, + url, + len(targets), + e, + ) + time.sleep(self.backoff_seconds * (2 ** (attempt - 1))) + else: + logger.error( + "[NLIClient] Request failed after %s attempts url=%s targets=%s error=%s", + self.max_retries, + url, + len(targets), + e, ) - results.append(NLIResult.UNRELATED) - - return results - except requests.RequestException as e: - logger.error(f"[NLIClient] Request failed: {e}") - # Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow. - return [NLIResult.UNRELATED] * len(targets) + logger.error( + "[NLIClient] NLI service unavailable or unstable. Please check that it is running at %s", + self.base_url, + ) + if last_error: + logger.error("[NLIClient] Last error: %s", last_error) + return [NLIResult.UNRELATED] * len(targets) diff --git a/src/memos/extras/nli_model/server/config.py b/src/memos/extras/nli_model/server/config.py index d2e12175d..b5744bd26 100644 --- a/src/memos/extras/nli_model/server/config.py +++ b/src/memos/extras/nli_model/server/config.py @@ -13,6 +13,8 @@ NLI_DEVICE = "cuda" NLI_MODEL_HOST = "0.0.0.0" NLI_MODEL_PORT = 32532 +NLI_MAX_CONCURRENCY = 4 +NLI_INFER_TIMEOUT_SECONDS = 30.0 # Configure logging for NLI Server logging.basicConfig( diff --git a/src/memos/extras/nli_model/server/serve.py b/src/memos/extras/nli_model/server/serve.py index 0ed9eae65..f02d25670 100644 --- a/src/memos/extras/nli_model/server/serve.py +++ b/src/memos/extras/nli_model/server/serve.py @@ -1,25 +1,36 @@ +import asyncio + from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI, HTTPException -from memos.extras.nli_model.server.config import NLI_DEVICE, NLI_MODEL_HOST, NLI_MODEL_PORT +from memos.extras.nli_model.server.config import ( + NLI_DEVICE, + NLI_INFER_TIMEOUT_SECONDS, + NLI_MAX_CONCURRENCY, + NLI_MODEL_HOST, + NLI_MODEL_PORT, +) from memos.extras.nli_model.server.handler import NLIHandler from memos.extras.nli_model.types import CompareRequest, CompareResponse # Global handler instance nli_handler: NLIHandler | None = None +nli_semaphore: asyncio.Semaphore | None = None @asynccontextmanager async def lifespan(app: FastAPI): - global nli_handler + global nli_handler, nli_semaphore nli_handler = NLIHandler(device=NLI_DEVICE) + nli_semaphore = asyncio.Semaphore(NLI_MAX_CONCURRENCY) yield # Clean up if needed nli_handler = None + nli_semaphore = None app = FastAPI(lifespan=lifespan) @@ -27,11 +38,17 @@ async def lifespan(app: FastAPI): @app.post("/compare_one_to_many", response_model=CompareResponse) async def compare_one_to_many(request: CompareRequest): - if nli_handler is None: + if nli_handler is None or nli_semaphore is None: raise HTTPException(status_code=503, detail="Model not loaded") try: - results = nli_handler.compare_one_to_many(request.source, request.targets) - return CompareResponse(results=results) + async with nli_semaphore: + results = await asyncio.wait_for( + asyncio.to_thread(nli_handler.compare_one_to_many, request.source, request.targets), + timeout=NLI_INFER_TIMEOUT_SECONDS, + ) + return CompareResponse(results=results) + except asyncio.TimeoutError as e: + raise HTTPException(status_code=504, detail="NLI inference timed out") from e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index b8019004d..864a20e56 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -27,6 +27,7 @@ from memos.mem_reader.factory import MemReaderFactory from memos.mem_reader.read_multi_modal import detect_lang from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import ( MemoryManager, extract_working_binding_ids, @@ -90,6 +91,12 @@ def __init__(self, config: MemFeedbackConfig): }, is_reorganize=self.is_reorganize, ) + # Actually this is initialized through SimpleMemFeedback, so it's fine. + self.history_manager = MemoryHistoryManager( + nli_client=None, + graph_db=self.graph_store, + embedder=self.embedder, + ) self.stopword_manager = StopwordManager self.searcher: Searcher = None self.reranker = None @@ -243,7 +250,7 @@ def _single_add_operation( datetime.now().isoformat() ) to_add_memory.metadata.background = new_memory_item.metadata.background - to_add_memory.metadata.sources = [] + to_add_memory.metadata.sources = new_memory_item.metadata.sources added_ids = self._retry_db_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False) @@ -287,33 +294,43 @@ def _single_update_operation( new_memory_item.memory = operation["text"] new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] - if memory_type == "WorkingMemory": - fields = { - "memory": new_memory_item.memory, - "key": new_memory_item.metadata.key, - "tags": new_memory_item.metadata.tags, - "embedding": new_memory_item.metadata.embedding, - "background": new_memory_item.metadata.background, - "covered_history": old_memory_item.id, - } - self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) - item_id = old_memory_item.id - else: - done = self._single_add_operation( - old_memory_item, new_memory_item, user_id, user_name, async_mode + if getattr(self.mem_reader, "memory_version_switch", "off") != "on": + if memory_type == "WorkingMemory": + fields = { + "memory": new_memory_item.memory, + "key": new_memory_item.metadata.key, + "tags": new_memory_item.metadata.tags, + "embedding": new_memory_item.metadata.embedding, + "background": new_memory_item.metadata.background, + "covered_history": old_memory_item.id, + } + self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) + item_id = old_memory_item.id + else: + done = self._single_add_operation( + old_memory_item, new_memory_item, user_id, user_name, async_mode + ) + item_id = done.get("id") + self.graph_store.update_node( + item_id, {"covered_history": old_memory_item.id}, user_name=user_name + ) + self.graph_store.update_node( + old_memory_item.id, {"status": "archived"}, user_name=user_name + ) + + logger.info( + f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" ) - item_id = done.get("id") - self.graph_store.update_node( - item_id, {"covered_history": old_memory_item.id}, user_name=user_name + else: + item_id = self._single_update_operation_with_versions( + old_memory_item=old_memory_item, + new_memory_item=new_memory_item, + user_name=user_name, ) - self.graph_store.update_node( - old_memory_item.id, {"status": "archived"}, user_name=user_name + logger.info( + f"[Memory Feedback UPDATE] Updated:{item_id} | history appended | memory_type: {old_memory_item.metadata.memory_type}" ) - logger.info( - f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" - ) - return { "id": item_id, "text": new_memory_item.memory, @@ -322,6 +339,79 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } + def _single_update_operation_with_versions( + self, + old_memory_item: TextualMemoryItem, + new_memory_item: TextualMemoryItem, + user_name: str, + ) -> str: + try: + updated_item, archived_item, archived_metadata, updated_fields = ( + self.history_manager.update_from_feedback( + old_item=old_memory_item, + new_item=new_memory_item, + user_name=user_name, + ) + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] history fallback for %s: %s", old_memory_item.id, e + ) + updated_item = old_memory_item.model_copy(deep=True) + updated_item.memory = new_memory_item.memory + updated_item.metadata.key = new_memory_item.metadata.key + updated_item.metadata.tags = new_memory_item.metadata.tags + updated_item.metadata.background = new_memory_item.metadata.background + if getattr(new_memory_item.metadata, "sources", None) is not None: + current_sources = list(updated_item.metadata.sources or []) + updated_item.metadata.sources = ( + list(new_memory_item.metadata.sources or []) + current_sources + ) + if getattr(new_memory_item.metadata, "embedding", None) is not None: + updated_item.metadata.embedding = new_memory_item.metadata.embedding + if updated_item.metadata.memory_type == "PreferenceMemory": + updated_item.metadata.preference = updated_item.memory + updated_fields = { + "memory": updated_item.memory, + "key": updated_item.metadata.key, + "tags": updated_item.metadata.tags, + "embedding": updated_item.metadata.embedding, + "background": updated_item.metadata.background, + "sources": [ + source.model_dump(exclude_none=True) + if hasattr(source, "model_dump") + else source + for source in (updated_item.metadata.sources or []) + ], + "covered_history": old_memory_item.id, + } + archived_item = None + archived_metadata = None + + if archived_item and archived_metadata: + try: + self.graph_store.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=archived_metadata, + user_name=user_name, + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] archive add failed for %s: %s", + old_memory_item.id, + e, + ) + self._retry_db_operation( + lambda: self.graph_store.update_node( + id=updated_item.id, + fields=updated_fields, + user_name=user_name, + ) + ) + self._del_working_binding(user_name, [old_memory_item]) + return updated_item.id + def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -330,9 +420,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> f"[Memory Feedback UPDATE] Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" ) - delete_ids = [] - if bindings_to_delete: - delete_ids = list({bindings_to_delete}) + delete_ids = list(bindings_to_delete) for mid in delete_ids: try: @@ -345,6 +433,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> logger.warning( f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) + return bindings_to_delete def semantics_feedback( self, @@ -470,7 +559,7 @@ def semantics_feedback( f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) - if update_results: + if update_results and getattr(self.mem_reader, "memory_version_switch", "off") != "on": updated_ids = [item["archived_id"] for item in update_results] self._del_working_binding(updated_ids, user_name) @@ -1060,7 +1149,14 @@ def check_validity(item): tags=tags, key=key, embedding=embedding, - sources=[{"type": "chat"}], + sources=[ + { + "type": "feedback", + "role": "user", + "chat_time": feedback_time, + "content": feedback_content, + } + ], background=background, type="fine", info=info, diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index dfc9b9fdf..d28a8e9da 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -4,6 +4,7 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -23,12 +24,14 @@ def __init__( mem_reader: SimpleStructMemReader, searcher: Searcher, reranker: BaseReranker, + history_manager: MemoryHistoryManager, pref_feedback: bool = False, ): self.llm = llm self.embedder = embedder self.graph_store = graph_store self.memory_manager = memory_manager + self.history_manager = history_manager self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 7bd551fb8..8e54873c8 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -10,6 +10,9 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -29,6 +32,7 @@ def from_config( config_factory: MemReaderConfigFactory, graph_db: Optional["BaseGraphDB | None"] = None, searcher: Optional["Searcher | None"] = None, + history_manager: Optional["MemoryHistoryManager | None"] = None, ) -> BaseMemReader: """ Create a MemReader instance from configuration. @@ -55,4 +59,10 @@ def from_config( if searcher is not None: reader.set_searcher(searcher) + if history_manager is not None: + if hasattr(reader, "set_history_manager"): + reader.set_history_manager(history_manager) + else: + reader.history_manager = history_manager + return reader diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 4c0d4dcd0..0848b27f0 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -58,6 +58,9 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.history_manager = None + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Image parser LLM (requires vision model) # Falls back to general_llm if not configured (general_llm itself falls back to main llm) self.image_parser_llm = ( @@ -461,6 +464,21 @@ def _get_llm_response( if self.config.remove_prompt_example and examples: prompt = prompt.replace(examples, "") + + from memos.plugins.hook_defs import H as _H + from memos.plugins.hooks import trigger_hook as _trigger_hook + + _rv = _trigger_hook( + _H.MEM_READER_PRE_EXTRACT, + prompt=prompt, + prompt_type=prompt_type, + mem_str=mem_str, + lang=lang, + sources=sources, + ) + prompt = _rv if _rv is not None else prompt + logger.info(f"[MultiModalParser] Process String Fine After Plugin: {prompt}") + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) @@ -509,6 +527,7 @@ def _get_maybe_merged_memory( sources: list, **kwargs, ) -> dict: + # TODO: delete this function """ Check if extracted memory should be merged with similar existing memories. If merge is needed, return merged memory dict with merged_from field. @@ -523,102 +542,7 @@ def _get_maybe_merged_memory( Returns: Memory dict (possibly merged) with merged_from field if merged """ - # If no graph_db or user_name, return original - if not self.graph_db or "user_name" not in kwargs: - return extracted_memory_dict - user_name = kwargs.get("user_name") - - # Detect language - lang = "en" - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(mem_text) - - # Search for similar memories - merge_threshold = kwargs.get("merge_similarity_threshold", 0.3) - - try: - search_results = self.graph_db.search_by_embedding( - vector=self.embedder.embed(mem_text)[0], - top_k=20, - status="activated", - threshold=merge_threshold, - user_name=user_name, - ) - - if not search_results: - return extracted_memory_dict - - # Get full memory details - similar_memory_ids = [r["id"] for r in search_results if r.get("id")] - similar_memories_list = [ - self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name) - for mem_id in similar_memory_ids - ] - - # Filter out None and mode:fast memories - filtered_similar = [] - for mem in similar_memories_list: - if not mem: - continue - mem_metadata = mem.get("metadata", {}) - tags = mem_metadata.get("tags", []) - if isinstance(tags, list) and "mode:fast" in tags: - continue - filtered_similar.append( - { - "id": mem.get("id"), - "memory": mem.get("memory", ""), - } - ) - logger.info( - f"Valid similar memories for {mem_text} is " - f"{len(filtered_similar)}: {filtered_similar}" - ) - - if not filtered_similar: - return extracted_memory_dict - - # Create a temporary TextualMemoryItem for merge check - temp_memory_item = TextualMemoryItem( - memory=mem_text, - metadata=TreeNodeTextualMemoryMetadata( - user_id="", - session_id="", - memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"), - status="activated", - tags=extracted_memory_dict.get("tags", []), - key=extracted_memory_dict.get("key", ""), - ), - ) - - # Try to merge with LLM - merge_result = self._merge_memories_with_llm( - temp_memory_item, filtered_similar, lang=lang - ) - - if merge_result: - # Return merged memory dict - merged_dict = extracted_memory_dict.copy() - merged_content = merge_result.get("value", mem_text) - merged_dict["value"] = merged_content - merged_from_ids = merge_result.get("merged_from", []) - merged_dict["merged_from"] = merged_from_ids - return merged_dict - else: - return extracted_memory_dict - - except Exception as e: - logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}") - # On error, return original - return extracted_memory_dict + return extracted_memory_dict def _merge_memories_with_llm( self, @@ -720,6 +644,29 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 0: Memory version async extraction/update pipeline ========== + if ( + self.memory_version_switch == "on" + and self.history_manager is not None + and self.history_manager.is_applicable(fast_item) + ): + try: + user_name = kwargs.get("user_name") + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt_template = PROMPT_DICT["custom_tags"][lang] + new_items = self.history_manager.apply_mem_version_update( + fast_item, + user_name, + self.qwen_llm, + custom_tags=custom_tags, + custom_tags_prompt_template=custom_tags_prompt_template, + timeout_sec=30, + ) + return new_items + except Exception as ex: + logger.warning(f"[MultiModalFine] Fine memory version pipeline failed: {ex}") + return [] + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) @@ -730,14 +677,15 @@ def _process_one_item( if resp.get("memory list", []): for m in resp.get("memory list", []): try: - # Check and merge with similar memories if needed - m_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=m, - mem_text=m.get("value", ""), - sources=sources, - original_query=mem_str, - **kwargs, - ) + m_maybe_merged = m + if self.memory_version_switch != "on": + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( m_maybe_merged.get("memory_type", "LongTermMemory") @@ -755,8 +703,7 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in m_maybe_merged: + if self.memory_version_switch != "on" and "merged_from" in m_maybe_merged: node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) @@ -765,13 +712,15 @@ def _process_one_item( elif resp.get("value") and resp.get("key"): try: # Check and merge with similar memories if needed - resp_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=resp, - mem_text=resp.get("value", "").strip(), - sources=sources, - original_query=mem_str, - **kwargs, - ) + resp_maybe_merged = resp + if self.memory_version_switch != "on": + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, @@ -782,8 +731,7 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in resp_maybe_merged: + if self.memory_version_switch != "on" and "merged_from" in resp_maybe_merged: node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) @@ -1011,6 +959,7 @@ def _process_multi_modal_data( scene_data_info, info, mode="fast", need_emb=False, **kwargs ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + if mode == "fast": return fast_memory_items else: diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index f26be360c..fb3bda12b 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -11,6 +11,7 @@ from memos import log from memos.chunkers import ChunkerFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory @@ -20,6 +21,9 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.types.general_types import UserContext from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang @@ -183,6 +187,15 @@ def __init__(self, config: SimpleStructMemReaderConfig): if config.general_llm is not None else self.llm ) + self.qwen_llm = None + qwen_llm_config = getattr(config, "qwen_llm", None) + if qwen_llm_config: + try: + if isinstance(qwen_llm_config, dict): + qwen_llm_config = LLMConfigFactory.model_validate(qwen_llm_config) + self.qwen_llm = LLMFactory.from_config(qwen_llm_config) + except Exception as e: + logger.warning(f"[LLM] Qwen initialization failed: {e}") self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) self.save_rawfile = self.chunker.config.save_rawfile @@ -194,6 +207,7 @@ def __init__(self, config: SimpleStructMemReaderConfig): # Initialize graph_db as None, can be set later via set_graph_db for # recall operations self.graph_db = None + self.history_manager = None def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: self.graph_db = graph_db @@ -201,6 +215,9 @@ def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: def set_searcher(self, searcher: "Searcher | None") -> None: self.searcher = searcher + def set_history_manager(self, history_manager: "MemoryHistoryManager | None") -> None: + self.history_manager = history_manager + def _make_memory_item( self, value: str, diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 8777b9f2e..7a4bdf326 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -12,6 +12,7 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory +from memos.extras.nli_model.client import NLIClient from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger @@ -19,10 +20,12 @@ from memos.mem_feedback.simple_feedback import SimpleMemFeedback from memos.mem_reader.factory import MemReaderFactory from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -245,6 +248,7 @@ def init_components() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() embedder_config = build_embedder_config() + nli_client_config = APIConfig.get_nli_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -256,8 +260,20 @@ def init_components() -> dict[str, Any]: graph_db = GraphStoreFactory.from_config(graph_db_config) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) + nli_client = NLIClient(base_url=nli_client_config["base_url"]) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) + memory_history_manager = MemoryHistoryManager( + nli_client=nli_client, + graph_db=graph_db, + embedder=embedder, + pre_update_retriever=pre_update_retriever, + ) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 20dbb63b2..e0baf63ff 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -211,44 +211,48 @@ def _process_memories_with_reader( ) logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) - # Mark merged_from memories as archived when provided in memory metadata - summary_memories = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type != "RawFileMemory" - ] - if mem_reader.graph_db: - for memory in summary_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - "[Scheduler] Archived merged_from memory: %s", - old_id, - ) - except Exception as e: - logger.warning( - "[Scheduler] Failed to archive merged_from memory %s: %s", - old_id, - e, - ) - else: - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in summary_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + # fallback to simple deduplication logic when mem version switch is off + if getattr(mem_reader, "memory_version_switch", "off") != "on": + # Mark merged_from memories as archived when provided in memory metadata + summary_memories = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type != "RawFileMemory" + ] + if mem_reader.graph_db: + for memory in summary_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), + {"status": "archived"}, + user_name=user_name, + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in summary_memories ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) cloud_env = is_cloud_env() if cloud_env: @@ -386,10 +390,34 @@ def _process_memories_with_reader( delete_ids = list(dict.fromkeys(delete_ids)) if delete_ids: try: - text_mem.delete(delete_ids, user_name=user_name) - logger.info( - "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name - ) + if getattr(mem_reader, "memory_version_switch", "off") != "on": + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + "Delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) + else: + # change to soft-delete for mem versions + flattened_memories = [] + if processed_memories and len(processed_memories) > 0: + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + allowed_types = ["UserMemory", "LongTermMemory"] + text_mem.soft_delete( + delete_ids, + user_name, + [ + mem.id + for mem in flattened_memories + if mem.metadata.memory_type in allowed_types + ], + ) + logger.info( + "Soft delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) except Exception as e: logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) else: diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index a9b2c43a4..b0f90b537 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -69,9 +69,9 @@ class ArchivedTextualMemory(BaseModel): memory: str | None = Field( default_factory=lambda: "", description="The content of the archived version of the memory." ) - update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + update_type: Literal["conflict", "duplicate", "extract", "unrelated", "feedback"] = Field( default="unrelated", - description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`, `feedback`).", ) archived_memory_id: str | None = Field( default=None, @@ -106,15 +106,15 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", ) - evolve_to: list[str] | None = Field( + evolve_to: list[str] = Field( default_factory=list, - description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + description="Recording which new memory nodes it 'evolves' to after llm extraction.", ) - version: int | None = Field( - default=None, + version: int = Field( + default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) - history: list[ArchivedTextualMemory] | None = Field( + history: list[ArchivedTextualMemory] = Field( default_factory=list, description="Storing the archived versions of the memory. Only preserving core information of each version.", ) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 8c896f538..e23df27cd 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -609,3 +609,33 @@ def add_graph_edges( future.result() except Exception as e: logger.exception("Add edge error: ", exc_info=e) + + def soft_delete( + self, + memory_ids: list[str], + user_name: str, + evolve_to_ids: list[str] | None = None, + ) -> None: + # for ruff check... + if not evolve_to_ids: + update_fields = {"status": "deleted"} + else: + update_fields = {"status": "deleted", "evolve_to": evolve_to_ids} + + # Execute the actual marking operation - in db. + with ContextThreadPoolExecutor() as executor: + futures = [] + for mid in memory_ids: + futures.append( + executor.submit( + self.graph_store.update_node, + id=mid, + fields=update_fields, + user_name=user_name, + ) + ) + + # Wait for all tasks to complete and raise any exceptions + for future in futures: + future.result() + return diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 98094877c..21b708862 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,87 +1,234 @@ +import json import logging +import re +import time +import uuid -from typing import Literal +from copy import deepcopy +from datetime import datetime +from typing import Any, Literal from memos.context.context import ContextThreadPoolExecutor +from memos.embedders.base import BaseEmbedder from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem +from memos.llms.base import BaseLLM +from memos.mem_reader.read_multi_modal.utils import detect_lang +from memos.memories.textual.item import ( + ArchivedTextualMemory, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever +from memos.templates.mem_reader_mem_version_prompts import ( + ASYNC_MEMORY_UPDATE_PROMPT_DICT, + MEMORY_MERGE_PROMPT_DICT, +) logger = logging.getLogger(__name__) -CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]" -DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]" - -def _append_related_content( - new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str] +def _rebuild_fast_node_history( + item: TextualMemoryItem, replacements: dict[int, list[ArchivedTextualMemory]] ) -> None: """ - Append duplicate and conflict memory contents to the new item's memory text, - truncated to avoid excessive length. + Reconstruct the history list of a fast node: + 1. Replace resolved items with their evolved versions. + 2. Deduplicate by ID while preserving the newest versions. """ - max_per_item_len = 200 - max_section_len = 1000 - - def _format_section(title: str, items: list[str]) -> str: - if not items: - return "" - - section_content = "" - for mem in items: - # Truncate individual item - snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem - # Check total section length - if len(section_content) + len(snippet) + 5 > max_section_len: - section_content += "\n- ... (more items truncated)" + new_history = {} + + def _add(history_item): + item_id = history_item.archived_memory_id + current = new_history.get(item_id) + + if current is None or history_item.version > current.version: + new_history[item_id] = history_item + + # Apply replacements and filter superseded items + for i, h in enumerate(item.metadata.history): + if i in replacements: + # This item is resolved, insert its replacements + for replacement_item in replacements[i]: + _add(replacement_item) + else: + _add(h) + + item.metadata.history = list(new_history.values()) + + +def _sanitize_metadata_dict(data: dict[str, Any] | None) -> dict[str, Any]: + if not data: + return {} + sanitized = data.copy() + for key in ("id", "memory", "graph_id"): + sanitized.pop(key, None) + return sanitized + + +def _sanitize_metadata_model( + metadata: TreeNodeTextualMemoryMetadata, +) -> TreeNodeTextualMemoryMetadata: + data = _sanitize_metadata_dict(metadata.model_dump(exclude_none=True)) + return metadata.__class__(**data) + + +def _determine_lang(sources: list | None, fallback_text: str) -> str: + lang = None + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang break - section_content += f"\n- {snippet}" + if isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(fallback_text) + return lang - return f"\n\n{title}:{section_content}" - append_text = "" - append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts) - append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates) +def _parse_json_result(response_text: str) -> dict: + s = (response_text or "").strip() - if append_text: - new_item.memory += append_text + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() -def _detach_related_content(new_item: TextualMemoryItem) -> None: - """ - Detach duplicate and conflict memory contents from the new item's memory text. - """ - markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"] + try: + return json.loads(s) + except json.JSONDecodeError: + pass - cut_index = -1 - for marker in markers: - idx = new_item.memory.find(marker) - if idx != -1 and (cut_index == -1 or idx < cut_index): - cut_index = idx + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass - if cut_index != -1: - new_item.memory = new_item.memory[:cut_index] + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t - return + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.warning( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) + return {} class MemoryHistoryManager: - def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: + def __init__( + self, + nli_client: NLIClient, + graph_db: BaseGraphDB, + llm: BaseLLM | None = None, + embedder: BaseEmbedder | None = None, + pre_update_retriever: PreUpdateRetriever | None = None, + ) -> None: """ Initialize the MemoryHistoryManager. Args: nli_client: NLIClient for conflict/duplicate detection. graph_db: GraphDB instance for marking operations during history management. + llm: Optional LLM instance for memory merging during conflicts. """ self.nli_client = nli_client self.graph_db = graph_db + self.llm = llm + self.embedder = embedder + self.pre_update_retriever = pre_update_retriever + + def _compute_embedding(self, text: str) -> list[float] | None: + if not self.embedder: + return None + try: + return self.embedder.embed([text])[0] + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to compute embedding: {e}") + return None + + @staticmethod + def is_applicable(item: TextualMemoryItem) -> bool: + # Only deals with: + # 1. From doc or chat + # 2. LongTermMemory, UserMemory + allowed_sources = ["doc", "chat"] + allowed_memory_types = ["LongTermMemory", "UserMemory"] + return ( + item.metadata.sources[0].type in allowed_sources + and item.metadata.memory_type in allowed_memory_types + ) + + @staticmethod + def update_node_with_history( + item: TextualMemoryItem, + new_memory: str, + update_type: str, + tags: list[str] | None = None, + key: str | None = None, + ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + """ + This method is used to update a given item. + It updates the item.memory to new_memory, and pushes the old item.memory content to its history. + Instead, it also creates an archived_item to store the embeddings and sources of the old memory content, + and stores it to the graph_db. + """ + now = datetime.now().isoformat() + last_update_time = item.metadata.updated_at + + old_id = item.id + archived_id = str(uuid.uuid4()) + # archived memory(need to store this node to the db later) + archived_item = item.model_copy(deep=True) + archived_item.id = archived_id + archived_item.metadata.evolve_to = [old_id] + archived_item.metadata.status = "archived" + archived_item.metadata.created_at = last_update_time + archived_item.metadata.updated_at = now + + # original memory with updated contents and history + history_item = ArchivedTextualMemory( + version=item.metadata.version or 1, + is_fast=item.metadata.is_fast or False, + memory=item.memory, + update_type=update_type, + archived_memory_id=archived_id, + created_at=getattr(item.metadata, "updated_at", None) or last_update_time, + ) + item.memory = new_memory + item.metadata.version = (item.metadata.version or 1) + 1 + item.metadata.status = "activated" + item.metadata.updated_at = now + if tags is not None: + item.metadata.tags = tags + if key is not None: + item.metadata.key = key + if item.metadata.history is None: + item.metadata.history = [] + item.metadata.history.append(history_item) + + return item, archived_item def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[TextualMemoryItem]: + ) -> list[str]: """ Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, and attach them as history to the new fast item. @@ -91,7 +238,7 @@ def resolve_history_via_nli( related_items: Existing memory items that might be related. Returns: - List of duplicate or conflicting memory items judged by the NLI service. + List of duplicate or conflicting memory ids judged by the NLI service. """ if not related_items: return [] @@ -102,15 +249,19 @@ def resolve_history_via_nli( ) # 2. Process results and attach to history + duplicate_memory_ids = [] + conflict_memory_ids = [] duplicate_memories = [] conflict_memories = [] for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" + duplicate_memory_ids.append(r_item.id) duplicate_memories.append(r_item.memory) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" + conflict_memory_ids.append(r_item.id) conflict_memories.append(r_item.memory) else: update_type = "unrelated" @@ -118,45 +269,292 @@ def resolve_history_via_nli( # Safely get created_at, fallback to updated_at created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at + # TODO: change the way of marking fast nodes by directly using is_fast field. archived = ArchivedTextualMemory( version=r_item.metadata.version or 1, - is_fast=r_item.metadata.is_fast or False, + is_fast=( + r_item.metadata.is_fast + or ("mode:fast" in (getattr(r_item.metadata, "tags", None) or [])) + ), memory=r_item.memory, update_type=update_type, archived_memory_id=r_item.id, created_at=created_at, ) new_item.metadata.history.append(archived) - logger.info( - f"[Chunker: MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" + + return duplicate_memory_ids + conflict_memory_ids + + def wait_and_update_fast_history( + self, item: TextualMemoryItem, user_name: str, timeout_sec: int = 30 + ) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + user_name: Required for db query. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices, user_name) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + + def format_prompt(self, item: TextualMemoryItem, custom_tags_prompt: str = "") -> str: + """ + Format the prompt for asynchronous memory update. + + Args: + item: The TextualMemoryItem containing history candidates. + custom_tags_prompt: Optional custom prompt for tags. + + Returns: + Formatted prompt string. + """ + duplicate_candidates = [] + conflict_candidates = [] + unrelated_candidates = [] + + def _fmt_time(ts: str | None) -> str | None: + if not ts or not isinstance(ts, str): + return None + try: + t = datetime.fromisoformat(ts.replace("Z", "")) + return t.strftime("%Y/%m/%d %H:%M:%S") + except Exception: + return ts + + for h in item.metadata.history or []: + created = getattr(h, "created_at", None) + tstr = _fmt_time(created) + time_suffix = f"[Time: {tstr}] " if tstr else "" + candidate_str = f"[ID:{h.archived_memory_id}]{time_suffix}{h.memory}" + + if h.update_type == "duplicate": + duplicate_candidates.append(candidate_str) + elif h.update_type == "conflict": + conflict_candidates.append(candidate_str) + else: + # Includes "unrelated" and any other types + unrelated_candidates.append(candidate_str) + + sources = item.metadata.sources if item.metadata else None + lang = _determine_lang(sources, item.memory) + empty_label = "None" + + def format_list(candidates): + return "\n".join(candidates) if candidates else empty_label + + prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( + lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] + ) + conversation_time_raw = getattr(item.metadata, "created_at", None) + conversation_time = _fmt_time(conversation_time_raw) or conversation_time_raw + + return ( + prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) + .replace("${conflict_candidates}", format_list(conflict_candidates)) + .replace("${unrelated_candidates}", format_list(unrelated_candidates)) + .replace("${custom_tags_prompt}", custom_tags_prompt) + .replace("${conversation_time}", conversation_time) + .replace("${conversation}", item.memory) + ) + + def apply_llm_memory_updates( + self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """ + Apply the updates from the LLM response to the memory graph. + + Args: + llm_response: The parsed JSON response from the LLM. + source_item: The original fast item A whose history contains ArchivedTextualMemory entries. + We derive expected versions and candidate IDs from A.history. + user_name: user_name + + Returns: + List of new or updated memory items. + """ + memory_list = llm_response.get("memory list", []) + preserved_fact_tasks = self._build_preserved_fact_tasks(memory_list) + + expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories + # Recover candidate IDs and their expected versions from the source item's history + if source_item.metadata and source_item.metadata.history: + for h in source_item.metadata.history: + if h.archived_memory_id: + expected_versions[h.archived_memory_id] = h.version + + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + + # Snapshot source nodes before any in-place update. + snapshot_source_ids = {task["source_candidate_id"] for task in preserved_fact_tasks} + snapshot_source_ids.update( + str(candidate_id) + for mem_data in memory_list + for candidate_id in ( + list(mem_data.get("source_candidate_ids", [])) + + list(mem_data.get("conflicted_candidate_ids", [])) + ) + if candidate_id + ) + pre_update_source_item_map: dict[str, TextualMemoryItem] = {} + if snapshot_source_ids: + snapshot_nodes = ( + self.graph_db.get_nodes(sorted(snapshot_source_ids), user_name=user_name) or [] + ) + pre_update_source_item_map = { + item["id"]: TextualMemoryItem(**item) + for item in snapshot_nodes + if item and item.get("id") + } + + # 1. Handle Unrelated Candidates - Do nothing + # 2. Handle Memory List (Update or New) + processed_updates, created_items = self._process_memory_updates( + memory_list, expected_versions, user_name, source_item + ) + updated_items.extend(processed_updates) + new_items.extend(created_items) + + # 3. Handle preserved facts (split still-valid subfacts out of the old node) + new_items.extend( + self._handle_preserved_facts( + preserved_fact_tasks, + source_item, + user_name, + pre_update_source_item_map=pre_update_source_item_map, ) + ) + + return updated_items, new_items + + def _build_preserved_fact_tasks( + self, memory_list: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Flatten per-update preserved facts into executable extraction tasks.""" + tasks: list[dict[str, Any]] = [] + + for mem_data in memory_list: + preserved_facts = mem_data.get("preserved_facts", []) or [] + if not preserved_facts: + continue + + source_ids = list(mem_data.get("source_candidate_ids", []) or []) + conflict_ids = list(mem_data.get("conflicted_candidate_ids", []) or []) + target_ids = source_ids + conflict_ids + if not target_ids: + logger.warning( + "[MemoryHistoryManager] Dropping preserved_facts for create-only memory " + "item key=%s because it has no source/conflict candidate ids.", + mem_data.get("key", ""), + ) + continue + + # A preserved fact must come from the old node referenced by this update item. + # When multiple target ids exist, we bind preserved facts to the primary target + # (the first id), which is also the node actually updated in `_update_existing_memory()`. + source_candidate_id = str(target_ids[0]) + for preserved_fact in preserved_facts: + if not preserved_fact: + continue + tasks.append( + { + "source_candidate_id": source_candidate_id, + "value": preserved_fact.get("value", ""), + "tags": preserved_fact.get("tags", []), + "key": preserved_fact.get("key", ""), + "memory_type": preserved_fact.get("memory_type", "LongTermMemory"), + } + ) + + return tasks + + def build_fallback_new_items( + self, item: TextualMemoryItem, user_name: str | None = None + ) -> list[TextualMemoryItem]: + latest_item = item.model_copy(deep=True) - # 3. Concat duplicate/conflict memories to new_item.memory - # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. - _append_related_content(new_item, duplicate_memories, conflict_memories) + latest_item.id = str(uuid.uuid4()) + latest_item.metadata.is_fast = False + latest_item.metadata.status = "activated" + latest_item.metadata.history = [] + latest_item.metadata.working_binding = None + if hasattr(latest_item.metadata, "background"): + latest_item.metadata.background = "" - return duplicate_memories + conflict_memories + if hasattr(latest_item.metadata, "tags") and latest_item.metadata.tags: + latest_item.metadata.tags = [t for t in latest_item.metadata.tags if t != "mode:fast"] + + latest_item.metadata = _sanitize_metadata_model(latest_item.metadata) + + return [latest_item] def mark_memory_status( self, - memory_items: list[TextualMemoryItem], + memory_ids: list[str], status: Literal["activated", "resolving", "archived", "deleted"], - user_name: str | None = None, + user_name: str, ) -> None: """ Support status marking operations during history management. Common usages are: 1. Mark conflict/duplicate old memories' status as "resolving", to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to restore their visibility. + 2. Mark resolved memories' status as "activated", to recover their visibility. """ # Execute the actual marking operation - in db. with ContextThreadPoolExecutor() as executor: futures = [] - for mem in memory_items: + for mid in memory_ids: futures.append( executor.submit( self.graph_db.update_node, - id=mem.id, + id=mid, fields={"status": status}, user_name=user_name, ) @@ -166,3 +564,617 @@ def mark_memory_status( for future in futures: future.result() return + + def prepare_history_candidates_via_nli(self, item: TextualMemoryItem, user_name: str) -> None: + """ + 1. Recall related memories + 2. Fast conflict/duplication check with NLI model + 3. Attach conflicting/duplicate old memory contents onto fast memory items + """ + if not self.is_applicable(item): + return + + if not self.pre_update_retriever: + logger.warning("[MemoryHistoryManager] PreUpdateRetriever is not initialized.") + return + + try: + # recall related memories + retrieve_start = time.perf_counter() + related = self.pre_update_retriever.retrieve( + item=item, + user_name=user_name, + ) + retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 + logger.info( + "[MemoryHistoryManager] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", + retrieve_ms, + getattr(item, "id", None), + ) + # NLI check & attaching contents + nli_start = time.perf_counter() + conflicting_or_duplicate_ids = self.resolve_history_via_nli(item, related) + nli_ms = (time.perf_counter() - nli_start) * 1000 + logger.info( + "[MemoryHistoryManager] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", + nli_ms, + getattr(item, "id", None), + len(related), + len(conflicting_or_duplicate_ids), + ) + + except Exception as e: + logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + + def apply_mem_version_update( + self, + original_item: TextualMemoryItem, + user_name: str, + llm: BaseLLM | None, + custom_tags: dict[str, str] | None, + custom_tags_prompt_template: str | None, + timeout_sec: int = 30, + ) -> list[TextualMemoryItem]: + """ + 1. Wait for 'fast histories' in the item to resolve, and rebuild its history + 2. Build memory extraction/update prompt (include custom tags and conversation context) + 3. Call LLM and parse JSON response + 4. Apply LLM updates to memory graph and return new items + """ + self.prepare_history_candidates_via_nli(original_item, user_name) + self.wait_and_update_fast_history(original_item, user_name, timeout_sec=timeout_sec) + + custom_tags_prompt = ( + custom_tags_prompt_template.replace("{custom_tags}", str(custom_tags)) + if custom_tags_prompt_template and custom_tags + else "" + ) + prompt = self.format_prompt(original_item, custom_tags_prompt) + + # Give plugins a chance to augment/replace the prompt while + # preserving the version-pipeline context (history candidates, etc.) + try: + from memos.plugins.hook_defs import H as _H + from memos.plugins.hooks import trigger_hook as _trigger_hook + + sources = original_item.metadata.sources if original_item.metadata else None + lang = _determine_lang(sources, original_item.memory) + _rv = _trigger_hook( + _H.MEM_READER_PRE_EXTRACT, + prompt=prompt, + prompt_type="version", + mem_str=original_item.memory, + lang=lang, + sources=sources or [], + ) + prompt = _rv if _rv is not None else prompt + except Exception as hook_err: + logger.debug("[MemoryHistoryManager] Plugin hook skipped: %s", hook_err) + logger.info( + f"[MultiModalParser] Process String Fine After Plugin (In Version Control): {prompt}" + ) + + try: + if llm is None: + raise ValueError("LLM is not initialized") + response_text = llm.generate([{"role": "user", "content": prompt}]) + if not response_text: + raise ValueError("Empty LLM response") + response_json = _parse_json_result(response_text) + if not response_json: + raise ValueError("Empty LLM JSON response") + + _, new_items = self.apply_llm_memory_updates( + response_json, original_item, user_name=user_name + ) + return new_items + + except Exception as e: + logger.warning( + f"[MemoryHistoryManager] Memory extraction/update fallback due to LLM failure: {e}" + ) + return self.build_fallback_new_items(original_item, user_name=user_name) + + def update_from_feedback( + self, + old_item: TextualMemoryItem, + new_item: TextualMemoryItem, + user_name: str, + update_type: Literal[ + "conflict", "duplicate", "extract", "unrelated", "feedback" + ] = "feedback", + ) -> tuple[TextualMemoryItem, TextualMemoryItem, dict[str, Any], dict[str, Any]]: + current_item, archived_item = self.update_node_with_history( + item=old_item.model_copy(deep=True), + new_memory=new_item.memory, + update_type=update_type, + tags=new_item.metadata.tags, + key=new_item.metadata.key, + ) + current_item.metadata.background = new_item.metadata.background + if getattr(new_item.metadata, "sources", None) is not None: + current_sources = list(current_item.metadata.sources or []) + current_item.metadata.sources = list(new_item.metadata.sources or []) + current_sources + if getattr(new_item.metadata, "embedding", None) is not None: + current_item.metadata.embedding = new_item.metadata.embedding + elif self.embedder: + current_item.metadata.embedding = self._compute_embedding(current_item.memory) + if current_item.metadata.memory_type == "PreferenceMemory": + current_item.metadata.preference = current_item.memory + + archived_embedding = getattr(old_item.metadata, "embedding", None) + if archived_embedding is None: + archived_embedding = TextualMemoryItem( + **self.graph_db.get_node(old_item.id, user_name=user_name, include_embedding=True) + ).metadata.embedding + arch_meta = _sanitize_metadata_dict(archived_item.metadata.model_dump(exclude_none=True)) + arch_meta["embedding"] = archived_embedding + metadata_fields = _sanitize_metadata_dict( + current_item.metadata.model_dump(exclude_none=True) + ) + history_dump = [ + h.model_dump(exclude_none=True) for h in (current_item.metadata.history or []) + ] + update_fields = { + **metadata_fields, + "memory": current_item.memory, + "history": history_dump, + "version": current_item.metadata.version, + "covered_history": archived_item.id, + } + return current_item, archived_item, arch_meta, update_fields + + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int], user_name: str + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids, user_name=user_name) or [] + nodes_map = {n["id"]: TextualMemoryItem(**n) for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node = nodes_map.get(h_item.archived_memory_id) + + if not node: + continue + + metadata = _sanitize_metadata_model(node.metadata) # deal with embedded metadata + # Condition: Fast node is processed when it is marked as 'deleted' + if metadata.status == "deleted": + evolve_to_ids = metadata.evolve_to + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type, user_name) + replacements[i] = new_items + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str, user_name: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids, user_name=user_name) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def _process_memory_updates( + self, + memory_list: list[dict[str, Any]], + expected_versions: dict[str, int], + user_name: str, + source_item: TextualMemoryItem, + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """Process Memory List (Update or Create).""" + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + for mem_data in memory_list: + source_ids = mem_data.get("source_candidate_ids", []) + conflict_ids = mem_data.get("conflicted_candidate_ids", []) + + # Determine if this is an update or a creation + target_ids = source_ids + conflict_ids + + if target_ids: + updated_item, new_item = self._update_existing_memory( + mem_data, target_ids, source_ids, expected_versions, user_name, source_item + ) + if updated_item: + updated_items.append(updated_item) + if new_item: + new_items.append(new_item) + else: + item = self._create_new_memory(mem_data, source_item) + new_items.append(item) + return updated_items, new_items + + def _update_existing_memory( + self, + mem_data: dict[str, Any], + target_ids: list[str], + source_ids: list[str], + expected_versions: dict[str, int], + user_name: str, + fast_item: TextualMemoryItem, + ) -> tuple[TextualMemoryItem | None, TextualMemoryItem | None]: + """ + Update existing memory nodes using the LLM result. + + The first ID in target_ids is treated as the primary node. If additional target IDs + are provided, they are treated as secondary candidates and will be merged into the + primary. Merging means: + 1) Mark secondary nodes as archived and append the primary ID to evolve_to + 2) Merge their history entries into the primary history and re-order by created_at + + The method also applies CAS validation via expected_versions, archives the previous + version of the primary node, and persists the updated node back to the graph DB. + + Returns the updated primary TextualMemoryItem and optional new item when fallback is used. + """ + primary_id, secondary_ids = target_ids[0], target_ids[1:] + new_memory_value, tags, key = ( + mem_data.get("value", ""), + mem_data.get("tags", []), + mem_data.get("key", ""), + ) + + # Fetch candidate nodes from the *current* DB state and then select the primary. + # + # This read is intentionally not replaced by the pre-update snapshot captured in + # `apply_llm_memory_updates()`. Unlike restored memories, updates must operate on + # the latest DB state because: + # - CAS/version checking must compare against the newest persisted version + # - earlier updates in the same `memory_list` may already have changed a node + # - secondary merge decisions should reflect the current surviving nodes + # + # So this lookup is still required even though restored-memory handling now reuses + # a pre-update snapshot. + nodes_data = self.graph_db.get_nodes(target_ids, user_name=user_name) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + node_data = nodes_map.get(primary_id) + if not node_data: + logger.warning( + f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." + ) + # Fallback to create new item when the source_id is not valid(hallucination from llm) + new_item = self._create_new_memory(mem_data, fast_item) + return None, new_item + current_item = TextualMemoryItem(**node_data) + + # For concurrency control, need to make sure the primary item has not been modified by others during the run. + # If it has(version changed), then we need to use llm to merge again. + new_memory_value = self._apply_cas_merge( + primary_id, current_item, expected_versions, new_memory_value + ) + + update_type = "duplicate" if primary_id in source_ids else "conflict" + current_item, archived_item = self.update_node_with_history( + current_item, + new_memory_value, + update_type, + tags=tags, + key=key, + ) + + # create archived node for storing older versions of the memory, preserving the embedding + emb = TextualMemoryItem( + **self.graph_db.get_node(primary_id, user_name=user_name, include_embedding=True) + ).metadata.embedding + arch_meta = _sanitize_metadata_dict(archived_item.metadata.model_dump(exclude_none=True)) + arch_meta["embedding"] = emb + self.graph_db.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=arch_meta, + user_name=user_name, + ) + + fields = _sanitize_metadata_dict(current_item.metadata.model_dump(exclude_none=True)) + merged_history = list(current_item.metadata.history or []) + new_primary_version = current_item.metadata.version or 1 + # Multiple related ids indicates existing duplicates/conflicts to be merged + if secondary_ids: + merged_history, new_primary_version = self._merge_secondary_nodes( + secondary_ids, primary_id, nodes_map, user_name, merged_history + ) + current_item.metadata.history = merged_history + current_item.metadata.version = new_primary_version + merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] + embedding = self._compute_embedding(current_item.memory) + sources = [s.model_dump(exclude_none=True) for s in (fast_item.metadata.sources or [])] + # update old memory node with new content and updated history + self.graph_db.update_node( + id=primary_id, + fields={ + **fields, + "memory": current_item.memory, + "history": merged_history_dump, + "version": new_primary_version, + "embedding": embedding, + "sources": sources, + "session_id": fast_item.metadata.session_id, + }, + user_name=user_name, + ) + working_binding = getattr(current_item.metadata, "working_binding", None) + if working_binding and working_binding != current_item.id: + try: + self.mark_memory_status([str(working_binding)], "deleted", user_name=user_name) + except Exception as e: + logger.warning( + f"[MemoryHistoryManager] Failed to mark WorkingMemory {working_binding} as deleted: {e}" + ) + + return current_item, None + + def _apply_cas_merge( + self, + primary_id: str, + current_item: TextualMemoryItem, + expected_versions: dict[str, int], + new_memory_value: str, + ) -> str: + expected_version = expected_versions.get(primary_id) + current_version = current_item.metadata.version or 1 + if expected_version is not None and current_version != expected_version: + logger.warning( + f"[MemoryHistoryManager] Version conflict for node {primary_id}: " + f"Expected v{expected_version}, but found v{current_version} in DB. " + "Triggering merge logic." + ) + merged_content = self._merge_conflicting_memory( + latest_memory=current_item.memory, + proposed_update=new_memory_value, + ) + return merged_content + + return new_memory_value + + def _merge_secondary_nodes( + self, + secondary_ids: list[str], + primary_id: str, + nodes_map: dict, + user_name: str, + base_history: list[ArchivedTextualMemory], + ) -> tuple[list[ArchivedTextualMemory], int]: + merged_history = list(base_history) + + for memory_id in secondary_ids: + node_data = nodes_map.get(memory_id) + if not node_data: + continue + metadata = node_data.get("metadata", {}) + evolve_to = list(metadata.get("evolve_to", []) or []) + if primary_id not in evolve_to: + evolve_to.append(primary_id) + # set secondary nodes to archived and record their evolving destinations + self.graph_db.update_node( + id=memory_id, + fields={"status": "archived", "evolve_to": evolve_to}, + user_name=user_name, + ) + secondary_item = TextualMemoryItem(**node_data) + if secondary_item.metadata.history: + merged_history.extend(secondary_item.metadata.history) + + # Currently we just sort the versions according to their creation time + def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: + created_at = history_item.created_at + if isinstance(created_at, datetime): + return created_at + if created_at: + try: + return datetime.fromisoformat(created_at) + except ValueError: + return datetime.min + return datetime.min + + def _dedupe_history_by_archived_id( + history: list[ArchivedTextualMemory], + ) -> list[ArchivedTextualMemory]: + seen_archived_ids: set[str] = set() + deduped_history: list[ArchivedTextualMemory] = [] + for history_item in history: + archived_id = history_item.archived_memory_id + if archived_id and archived_id in seen_archived_ids: + continue + if archived_id: + seen_archived_ids.add(archived_id) + deduped_history.append(history_item) + return deduped_history + + merged_history.sort(key=_history_sort_key) + merged_history = _dedupe_history_by_archived_id(merged_history) + max_version = 0 + for idx, history_item in enumerate(merged_history, start=1): + history_item.version = idx + max_version = idx + return merged_history, max_version + 1 + + def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: + """ + Call LLM to merge proposed update with latest memory content. + """ + if not self.llm: + return proposed_update + + lang = _determine_lang(None, f"{latest_memory}\n{proposed_update}") + prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) + prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( + "${proposed_update}", proposed_update + ) + + messages = [{"role": "user", "content": prompt}] + try: + response = self.llm.generate(messages) + if not response: + raise ValueError("LLM response is None.") + return response.strip() + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to merge memory via LLM: {e}") + # Fallback: concatenate as a safe fallback. + return f"{latest_memory}\n\n[New Info]: {proposed_update}" + + def _create_new_memory( + self, mem_data: dict[str, Any], fast_item: TextualMemoryItem + ) -> TextualMemoryItem: + """Create New Node.""" + new_value = mem_data.get("value", "") + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + new_value = new_value_item.memory + tags = mem_data.get("tags", []) + key = mem_data.get("key", "") + background = mem_data.get("summary", "") + memory_type = mem_data.get("memory_type", "LongTermMemory") + now = datetime.now().isoformat() + metadata_updates = { + "is_fast": False, + "version": 1, + "memory_type": memory_type, + "status": "activated", + "background": background, + "working_binding": None, + "tags": tags, + "key": key, + "created_at": now, + "updated_at": now, + "history": [], + "embedding": self._compute_embedding(new_value), + } + metadata = fast_item.metadata.model_copy(deep=True) + for field_name, field_value in metadata_updates.items(): + setattr(metadata, field_name, field_value) + metadata = _sanitize_metadata_model(metadata) + + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=new_value, + metadata=metadata, + ) + return new_item + + def _handle_preserved_facts( + self, + preserved_fact_tasks: list[dict[str, Any]], + fast_item: TextualMemoryItem, + user_name: str, + pre_update_source_item_map: dict[str, TextualMemoryItem] | None = None, + ) -> list[TextualMemoryItem]: + """Create standalone nodes for preserved facts split from an updated source node.""" + if not preserved_fact_tasks: + return [] + + # Prefer the pre-update snapshot so preserved facts are extracted from the + # original source nodes referenced by the update item, not from already-updated + # DB state. + source_item_map = dict(pre_update_source_item_map or {}) + missing_source_ids = [ + r.get("source_candidate_id") + for r in preserved_fact_tasks + if r.get("source_candidate_id") and r.get("source_candidate_id") not in source_item_map + ] + if missing_source_ids: + # Fallback only for ids not present in the pre-update snapshot. + source_items = self.graph_db.get_nodes(missing_source_ids, user_name=user_name) or [] + source_item_map.update( + {item["id"]: TextualMemoryItem(**item) for item in source_items if item} + ) + + created_items = [] + for data in preserved_fact_tasks: + source_candidate_id = data.get("source_candidate_id") + source_item = source_item_map.get(source_candidate_id) + if source_item is None: + logger.warning( + "[MemoryHistoryManager] Preserved fact source %s not found. Skipping.", + source_candidate_id, + ) + continue + # deal with history + source_history = deepcopy(source_item.metadata.history) + value = data.get("value", "") + if not value: + logger.warning( + "[MemoryHistoryManager] Preserved fact from source %s has empty value. " + "Skipping.", + source_candidate_id, + ) + continue + value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) + value = value_item.memory + tags = data.get("tags", []) + key = data.get("key", "") + memory_type = data.get("memory_type", "LongTermMemory") + original_sources = deepcopy(source_item.metadata.sources) + version = source_item.metadata.version + new_history_item = ArchivedTextualMemory( + version=version, + is_fast=False, + memory=source_item.memory, + update_type="extract", + archived_memory_id=source_item.id, + created_at=source_item.metadata.created_at, + ) + # Re-use the old node's history and append one more archive entry pointing to + # the pre-update source node itself. This keeps the extracted node anchored to + # the original source memory snapshot(before update). + source_history.append(new_history_item) + # Create new node + metadata_updates = { + "memory_type": memory_type, + "status": "activated", + "is_fast": False, + "version": version + 1, + "sources": original_sources, + "tags": tags, + "key": key, + "created_at": datetime.now().isoformat(), + "history": source_history, + "embedding": self._compute_embedding(value), + } + metadata = fast_item.metadata.model_copy(deep=True) + for field_name, field_value in metadata_updates.items(): + setattr(metadata, field_name, field_value) + metadata = _sanitize_metadata_model(metadata) + + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=value, + metadata=metadata, + ) + + created_items.append(new_item) + + return created_items diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 96453f5a0..4d268fe6a 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -192,10 +192,12 @@ def _add_memories_batch( ) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() + metadata_dict["working_binding"] = working_id # Add working_binding for fast mode tags = metadata_dict.get("tags") or [] if "mode:fast" in tags: + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_id}] direct built from raw inputs" metadata_dict["background"] = ( @@ -235,6 +237,8 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: exc_info=e, ) + # TODO: working id is same with item.id, need to fix, currently stop adding WorkingMemories here. + # here used to be: _submit_batches(working_nodes, "WorkingMemory") _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: @@ -320,6 +324,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non ids: list[str] = [] futures = [] + # TODO: working id is same with item.id, need to fix working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: @@ -396,8 +401,11 @@ def _add_to_graph_memory( node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) + if working_binding: + metadata_dict["working_binding"] = working_binding tags = metadata_dict.get("tags") or [] if working_binding and ("mode:fast" in tags): + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" if prev_bg: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index cb77d2243..b500f6b61 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -1,6 +1,7 @@ import concurrent.futures import re +from datetime import datetime from typing import Any from memos.context.context import ContextThreadPoolExecutor @@ -138,10 +139,11 @@ def vector_search( results = self.graph_db.search_by_embedding( vector=q_embed, top_k=top_k, - status=None, + status="activated", threshold=threshold, user_name=user_name, filter=search_filter, + return_fields=["id", "is_fast", "created_at"], ) return results except Exception as e: @@ -155,6 +157,7 @@ def keyword_search( top_k: int, search_filter: dict[str, Any] | None = None, ) -> list[dict]: + # Currently not used for large latency try: # 1. Tokenize using existing tokenizer keywords = self.tokenizer.tokenize_mixed(query_text) @@ -210,8 +213,11 @@ def retrieve( # 2. Recall futures = [] common_filter = { - "status": {"in": ["activated", "resolving"]}, - "memory_type": {"in": ["LongTermMemory", "UserMemory", "WorkingMemory"]}, + "or": [ + {"memory_type": "LongTermMemory"}, + {"memory_type": "UserMemory"}, + {"memory_type": "WorkingMemory"}, + ] } with ContextThreadPoolExecutor(max_workers=3, thread_name_prefix="fast_recall") as executor: @@ -230,13 +236,7 @@ def retrieve( sim_threshold, ) ) - - # Task B: Keyword Search - futures.append( - executor.submit( - self.keyword_search, switched_query, user_name, top_k, common_filter - ) - ) + # TODO: recovering keyword search or other versions of search for multiple pathways # 3. Collect Results retrieved_ids = set() # for deduplicating ids @@ -247,7 +247,16 @@ def retrieve( continue for r in res: - retrieved_ids.add(r["id"]) + # exclude self and working binding + # also exclude fast nodes that's created after current node to avoid deadlock later + working_binding = item.metadata.working_binding or "" + is_fast = bool(r.get("is_fast", False)) + if (r["id"] != item.id and r["id"] != working_binding) and ( + not is_fast + or datetime.fromisoformat(r["created_at"]) + < datetime.fromisoformat(item.metadata.created_at) + ): + retrieved_ids.add(r["id"]) except Exception as e: logger.error(f"[PreUpdateRetriever] Search future task failed: {e}") diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6df410c19..c4ab4a798 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -751,7 +751,14 @@ def _process_text_mem( ) # Mark merged_from memories as archived when provided in add_req.info - if sync_mode == "sync" and extract_mode == "fine": + if ( + sync_mode == "sync" + and extract_mode == "fine" + and ( + not hasattr(self.mem_reader, "memory_version_switch") + or self.mem_reader.memory_version_switch != "on" + ) + ): for memory in flattened_local: merged_from = (memory.metadata.info or {}).get("merged_from") if merged_from: diff --git a/src/memos/plugins/__init__.py b/src/memos/plugins/__init__.py new file mode 100644 index 000000000..0a0f8cde3 --- /dev/null +++ b/src/memos/plugins/__init__.py @@ -0,0 +1,20 @@ +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H, HookSpec, all_hook_specs, define_hook, get_hook_spec +from memos.plugins.hooks import hookable, register_hook, register_hooks, trigger_hook +from memos.plugins.manager import PluginManager, plugin_manager + + +__all__ = [ + "H", + "HookSpec", + "MemOSPlugin", + "PluginManager", + "all_hook_specs", + "define_hook", + "get_hook_spec", + "hookable", + "plugin_manager", + "register_hook", + "register_hooks", + "trigger_hook", +] diff --git a/src/memos/plugins/base.py b/src/memos/plugins/base.py new file mode 100644 index 000000000..f55d81b75 --- /dev/null +++ b/src/memos/plugins/base.py @@ -0,0 +1,72 @@ +"""MemOS plugin base class — all plugins must inherit from this class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastapi import FastAPI + from starlette.middleware.base import BaseHTTPMiddleware + + +class MemOSPlugin: + """MemOS plugin base class. + + Provides three unified registration methods. Plugin developers need only + inherit from this class and register capabilities via self.register_* + in init_app. + """ + + name: str = "unnamed" + version: str = "0.0.0" + description: str = "" + + _app: FastAPI | None = None + + # ------------------------------------------------------------------ # + # Registration methods — called by plugins in init_app + # ------------------------------------------------------------------ # + + def register_router(self, router, **kwargs) -> None: + """Register a router.""" + self._app.include_router(router, **kwargs) + + def register_middleware(self, middleware_cls: type[BaseHTTPMiddleware], **kwargs) -> None: + """Register middleware.""" + self._app.add_middleware(middleware_cls, **kwargs) + + def register_hook(self, name: str, callback: Callable) -> None: + """Register a single Hook callback.""" + from memos.plugins.hooks import register_hook + + register_hook(name, callback) + + def register_hooks(self, names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple Hook points.""" + from memos.plugins.hooks import register_hooks + + register_hooks(names, callback) + + # ------------------------------------------------------------------ # + # Internal methods — called by PluginManager, plugin developers need not care + # ------------------------------------------------------------------ # + + def _bind_app(self, app: FastAPI) -> None: + """Bind FastAPI instance so that register_* methods are available.""" + self._app = app + + # ------------------------------------------------------------------ # + # Lifecycle methods — override in subclasses + # ------------------------------------------------------------------ # + + def on_load(self) -> None: + """Called after the plugin is discovered. Used for initialization logic, e.g. checking dependencies, reading config.""" + + def init_app(self) -> None: + """Called after FastAPI app is bound. Register routes, middleware, and Hooks via self.register_* here.""" + + def on_shutdown(self) -> None: + """Called when the service shuts down. Used for resource cleanup.""" diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py new file mode 100644 index 000000000..f7fb237ca --- /dev/null +++ b/src/memos/plugins/hook_defs.py @@ -0,0 +1,98 @@ +"""Hook declaration registry — single source of truth for CE repo Hook points. + +The @hookable decorator automatically declares its before/after Hooks; no need to manually define_hook. +Hooks triggered by custom trigger_hook must be explicitly declared in this file. + +Plugin-owned Hooks should be declared within each plugin package, not in this file. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + +_specs: dict[str, HookSpec] = {} + + +@dataclass(frozen=True) +class HookSpec: + """Hook spec definition.""" + + name: str + description: str + params: list[str] + pipe_key: str | None = None + + +def define_hook( + name: str, + *, + description: str, + params: list[str], + pipe_key: str | None = None, +) -> None: + """Declare a Hook point. Skips if already exists (idempotent).""" + if name in _specs: + return + _specs[name] = HookSpec( + name=name, + description=description, + params=params, + pipe_key=pipe_key, + ) + logger.debug("Hook defined: %s (pipe_key=%s)", name, pipe_key) + + +def get_hook_spec(name: str) -> HookSpec | None: + return _specs.get(name) + + +def all_hook_specs() -> dict[str, HookSpec]: + """Return all declared Hooks (including @hookable auto-declared + plugin-declared).""" + return dict(_specs) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE Hook name constants +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +class H: + """CE Hook name constants. Plugin-owned Hook constants should be defined within the plugin package.""" + + # @hookable("add") — AddHandler.handle_add_memories + ADD_BEFORE = "add.before" + ADD_AFTER = "add.after" + + # @hookable("search") — SearchHandler.handle_search_memories + SEARCH_BEFORE = "search.before" + SEARCH_AFTER = "search.after" + + # Custom Hook (manually triggered via trigger_hook) + ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" + + # mem_reader — generic extension point before LLM extraction + MEM_READER_PRE_EXTRACT = "mem_reader.pre_extract" + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE custom Hook declarations (@hookable-generated ones need not be declared here) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +define_hook( + H.ADD_MEMORIES_POST_PROCESS, + description="Post-process result after add_memories returns, before constructing Response", + params=["request", "result"], + pipe_key="result", +) + +define_hook( + H.MEM_READER_PRE_EXTRACT, + description="Customize prompt before mem_reader LLM extraction", + params=["prompt", "prompt_type", "mem_str", "lang", "sources"], + pipe_key="prompt", +) diff --git a/src/memos/plugins/hooks.py b/src/memos/plugins/hooks.py new file mode 100644 index 000000000..eda98f98a --- /dev/null +++ b/src/memos/plugins/hooks.py @@ -0,0 +1,124 @@ +"""Hook runtime — registration, triggering, and @hookable decorator.""" + +from __future__ import annotations + +import asyncio +import logging + +from collections import defaultdict +from functools import wraps +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger(__name__) + +_hooks: dict[str, list[Callable]] = defaultdict(list) + + +def register_hook(name: str, callback: Callable) -> None: + """Register a hook callback. Undeclared hook names will log a warning.""" + from memos.plugins.hook_defs import get_hook_spec + + if get_hook_spec(name) is None: + logger.warning( + "Registering callback for undeclared hook: %s (callback=%s)", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + _hooks[name].append(callback) + logger.debug( + "Hook registered: %s -> %s", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + + +def register_hooks(names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple hook points.""" + for name in names: + register_hook(name, callback) + + +def trigger_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook, invoking all registered callbacks in order. + + - Zero overhead when no callbacks are registered + - Undeclared hook names will log a warning and be skipped + - pipe_key is auto-fetched from HookSpec, supports piped return value passing + """ + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + logger.warning("Undeclared hook triggered: %s — ignored", name) + return None + + pipe_key = spec.pipe_key + + for cb in _hooks.get(name, []): + try: + rv = cb(**kwargs) + if pipe_key is not None and rv is not None: + kwargs[pipe_key] = rv + except Exception: + logger.exception( + "Hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + + return kwargs.get(pipe_key) if pipe_key else None + + +def hookable(name: str): + """Decorator: automatically triggers name.before / name.after hook before and after the method. + + Auto-declares before/after Hooks (idempotent); no need to manually define_hook in hook_defs.py. + Supports piped return values: before can modify request, after can modify result. + Compatible with both sync and async methods. + """ + from memos.plugins.hook_defs import define_hook + + define_hook( + f"{name}.before", + description=f"Before {name} executes; can modify request", + params=["request"], + pipe_key="request", + ) + define_hook( + f"{name}.after", + description=f"After {name} executes; can modify result", + params=["request", "result"], + pipe_key="result", + ) + + def decorator(func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = await func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return async_wrapper + + @wraps(func) + def sync_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return sync_wrapper + + return decorator diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py new file mode 100644 index 000000000..3706855a9 --- /dev/null +++ b/src/memos/plugins/manager.py @@ -0,0 +1,75 @@ +"""Plugin manager — discover, load, and manage MemOS plugins.""" + +from __future__ import annotations + +import importlib.metadata +import logging + +from typing import TYPE_CHECKING + +from memos.plugins.base import MemOSPlugin + + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "memos.plugins" + + +class PluginManager: + """Discover, load, and manage MemOS plugins.""" + + def __init__(self): + self._plugins: dict[str, MemOSPlugin] = {} + + @property + def plugins(self) -> dict[str, MemOSPlugin]: + return dict(self._plugins) + + def discover(self) -> None: + """Discover and load all installed plugins via entry_points.""" + try: + eps = importlib.metadata.entry_points() + if hasattr(eps, "select"): + plugin_eps = eps.select(group=ENTRY_POINT_GROUP) + else: + plugin_eps = eps.get(ENTRY_POINT_GROUP, []) + except Exception: + logger.exception("Failed to query entry_points") + return + + for ep in plugin_eps: + try: + plugin_cls = ep.load() + plugin = plugin_cls() + if not isinstance(plugin, MemOSPlugin): + logger.warning("Plugin %s does not extend MemOSPlugin, skipped", ep.name) + continue + plugin.on_load() + self._plugins[plugin.name] = plugin + logger.info("Plugin discovered: %s v%s", plugin.name, plugin.version) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + + def init_app(self, app: FastAPI) -> None: + """Bind app and initialize all loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin._bind_app(app) + plugin.init_app() + logger.info("Plugin initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin: %s", plugin.name) + + def shutdown(self) -> None: + """Shut down all plugins and release resources.""" + for plugin in self._plugins.values(): + try: + plugin.on_shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) + + +plugin_manager = PluginManager() diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py new file mode 100644 index 000000000..f67fe223f --- /dev/null +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -0,0 +1,426 @@ +# ========================================== +# Memory Update & Maintenance +# ========================================== +ASYNC_MEMORY_UPDATE_PROMPT_ZH = """您是记忆库维护专家。 +您的核心任务是根据最新的用户对话、对话时间,以及系统提供的可能与最新对话相关的“候选记忆”(Candidates),来维护和更新用户的长期记忆图谱。 + +具体而言,“候选记忆”包含以下三种情况: +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)** +2. **潜在事实冲突记忆 (Conflict Candidates)** +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)** + +您需要根据最新对话以及候选记忆,决定是更新现有记忆节点,还是创建全新的记忆节点。 + +**核心原则(STRICT)**: + - 您的目标是**维护**记忆库,而非仅仅提取信息。 + - **优先更新**:如果对话内容涉及现有的“候选记忆”,应优先视为对该记忆节点的**更新**(补充细节或修正状态),而不是创建重复的新节点。 + - **按需新增**:仅当对话内容包含全新的、与现有“候选记忆”完全无关的话题时,才创建新的记忆节点。 + - 提取来源**只能**是【当前的对话内容】。严禁编造未提及的信息。 + +**表达规范(STRICT)**: + - 冲突更新时,`value` 必须是“只含最新事实”的独立陈述,不允许提及旧值或变化过程(如“原名/之前/曾经/从X到Y/改成/不再使用原名/现在自称”)。 + - 若最新状态本身为否定事实,可直接用否定表达,但仍不得包含旧值或对比语。 + - 对于姓名、身份、归属、偏好等字段的更新,始终输出最新值的肯定式表述(例:旧记忆“用户叫王强”,新对话“我叫李白”,输出应为“用户叫李白”)。 + - 涉及第三方人物/实体的客观信息必须使用 `LongTermMemory`,且主体保持为该第三方(如“王强住上海”)。 + +请执行以下操作: +1. 识别反映用户经历、信念、关切、决策、计划或反应的信息。 + - 如果消息来自用户,提取用户相关的记忆。 + - 如果来自助手,仅提取用户认可或回应的事实性记忆。 + +2. 清晰解析所有时间、人物和事件的指代(同原规则): + - 将相对时间(“昨天”)转换为绝对日期。 + - 明确区分事件时间和消息时间。 + - 解析代词和模糊指代。 + - 仅当指代为“我/我们/本人”等用户第一人称时才替换为“用户”。 + - 其他第三方人名/实体必须保留原名,不得替换为“用户”。 + - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 + - 候选记忆可能包含 [Time: ...] 表示该记忆的事件时间,请结合“对话时间”判断是否同一时段。 + +3. 不要遗漏用户可能记住的任何信息。 + - 包括所有关键经历、想法、情绪反应和计划——即使看似微小。 + - 优先考虑完整性和保真度,而非简洁性。 + - 不要泛化或跳过对用户具有个人意义的细节。 + +4. **处理逻辑(更新与新增)**: + 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: + + a) **更新现有记忆节点 (Update via Duplicate/Related)**: + - 检查“潜在重复/关联记忆”。 + - 如果新信息是对某条旧记忆的重复、确认或补充细节: + - 生成一条**更新后的完整记忆**放入 `value`(包含旧信息+新细节)。 + - 将该旧记忆的ID放入 `source_candidate_ids`。 + - 此时 `conflicted_candidate_ids` 应为空。 + - 如果该旧节点中还包含**未被本次更新覆盖、且可以独立存在**的其他子事实,请将它们放入当前这条更新项内部的 `preserved_facts`。 + - `preserved_facts` 中的每一条内容,都必须能在当前这条更新项引用的旧节点原文中直接定位到;它只是“拆分/改写该旧节点内部原本就存在的子事实”,**绝不能**从其他 candidate 挪用、拼接、概括或猜测内容。 + - 如果旧节点只是单一事实,或所有内容都已经被本次更新吸收进 `value`,则 `preserved_facts` 必须为空数组。 + + b) **修正冲突记忆节点 (Update via Conflict)**: + - 检查“潜在事实冲突记忆”。 + - 如果新信息否定了某条旧记忆,或更新了其状态(如“不再喜欢X”“改成Y”“取消计划”“从X转为Y”): + - 生成一条反映**最新状态**的记忆放入 `value`。 + - 将被修正的旧记忆ID放入 `conflicted_candidate_ids`。 + - 如果该旧节点本身是一条混合记忆,而本次只更新其中一部分,则必须把**未被新信息否定、且可独立存在**的剩余事实放入当前这条更新项内部的 `preserved_facts`。 + - `preserved_facts` **绝不能**包含已经被当前更新否定、替换或覆盖的旧事实。例如新信息把“深圳工作”改成“广州工作”,则 `preserved_facts` 里绝不能再出现“深圳工作”。 + - 对于可长期独立存在的属性(如电话号码、出生地、所属组织),优先拆分为独立事实,避免与可变状态混写在同一条记忆中。 + - 如果旧节点没有剩余的独立有效事实,则 `preserved_facts` 必须为空数组。 + + c) **创建新记忆节点 (Create New)**: + - 如果新信息与任何“候选记忆”都无直接关联(既非重复也非冲突): + - 生成一条独立的新记忆放入 `value`。 + - 确保 `source_candidate_ids` 和 `conflicted_candidate_ids` 均为 `[]`。 + - 新建记忆的 `preserved_facts` 必须为空数组。 + +5. 无关的 candidate,只需把它的 ID 放入 `unrelated_candidate_ids`。 + +6. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 + +返回一个有效的JSON对象,结构如下: + +{ + "memory list": [ + { + "key": <字符串,简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory",区分该记忆是客观事实还是和用户相关的内容>, + "value": <字符串,更新后的完整记忆内容(针对更新/冲突情况)或全新记忆内容(针对新增情况)>, + "tags": <相关主题关键词列表>, + "source_candidate_ids": <字符串列表,被此条目更新的“重复/关联记忆”ID。若无则为 []>, + "conflicted_candidate_ids": <字符串列表,被此条目修正的“事实冲突记忆”ID。若无则为 []>, + "preserved_facts": [ + { + "key": <字符串,简洁的记忆标题>, + "value": <字符串,从当前这条更新项所引用的旧节点中拆出的、依然有效的独立事实>, + "tags": <相关主题关键词列表>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory"> + } + ] + }, + ... + ], + "unrelated_candidate_ids": [<字符串列表,被判断为与本次对话无关、应忽略的 candidate ID>], + "summary": <从用户视角自然总结本次记忆更新操作的段落,120–200字> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 +格式规则(STRICT): +- 必须输出**严格 JSON**,不允许出现尾随逗号。 +- 不要输出 Markdown、代码块或任何解释性文字。 + +${custom_tags_prompt} + +示例: +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: +[ID:101][Time: 2025/05/20 09:30:00] 用户喜欢喝拿铁,通常不加糖。 +[ID:102][Time: 2025/06/02 18:00:00] 用户讨厌下雨天。 + +2. **潜在事实冲突记忆 (Conflict Candidates)**: +[ID:201][Time: 2025/02/03 20:15:00] 用户喜欢打羽毛球,但不喜欢滑雪。 + +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: +[ID:301][Time: 2025/06/20 10:00:00] 用户最近在看《星球大战》。 + +**对话时间**: +2025/06/26 09:00:00 + +**对话**: +user: 最近下雨比较频繁。我经常去喝点拿铁,尤其是加燕麦奶的很好喝。另外,我最近膝盖受伤了,以后再也不打羽毛球了。 + +**输出:** +{ + "memory list": [ + { + "key": "咖啡偏好", + "memory_type": "UserMemory", + "value": "用户喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", + "tags": ["饮食", "咖啡", "喜好"], + "source_candidate_ids": ["101"], + "conflicted_candidate_ids": [], + "preserved_facts": [] + }, + { + "key": "运动习惯变更", + "memory_type": "UserMemory", + "value": "用户因膝盖受伤,决定不再打羽毛球。", + "tags": ["运动", "健康", "羽毛球"], + "source_candidate_ids": [], + "conflicted_candidate_ids": ["201"], + "preserved_facts": [ + { + "key": "运动偏好", + "value": "用户不喜欢滑雪。", + "tags": ["运动", "滑雪", "喜好"], + "memory_type": "UserMemory" + } + ] + }, + { + "key": "天气状况", + "memory_type": "LongTermMemory", + "value": "最近(2025年6月)用户所在的地方下雨比较频繁。", + "tags": ["生活", "天气", "降水"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "preserved_facts": [] + } + ], + "unrelated_candidate_ids": ["301"], + "summary": "本次更新中,用户补充了拿铁偏好(加入燕麦奶),并因膝盖受伤将运动习惯更新为不再打羽毛球,同时保留了其不喜欢滑雪这一仍然有效的独立事实。此外,新增了一条关于近期下雨频繁的记忆。" +} + +请始终使用与对话相同的语言进行回复。以下是最新的输入: + +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: +${duplicate_candidates} + +2. **潜在事实冲突记忆 (Conflict Candidates)**: +${conflict_candidates} + +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: +${unrelated_candidates} + +**对话时间**: +${conversation_time} + +**对话**: +${conversation} + +**输出:**""" + +ASYNC_MEMORY_UPDATE_PROMPT_EN = """You are a memory maintenance expert. +Your core task is to maintain and update the user's long-term memory graph based on the latest user conversation, the conversation time, and the system-provided "Candidates" that may be related to the latest conversation. + +Specifically, "Candidates" include three cases: +1. **Duplicate/Related Candidates** +2. **Conflict Candidates** +3. **Possibly unrelated candidates that require further judgment (Unrelated Candidates)** + +You need to decide, based on the latest conversation and the candidates, whether to update existing memory nodes or create brand-new memory nodes. + +**Core Principles (STRICT)**: + - Your goal is to **maintain** the memory base, not merely extract information. + - **Prefer Update**: If the conversation touches any existing "Candidates", treat it as an **update** to that memory node (add details or correct status), rather than creating a duplicate new node. + - **Add As Needed**: Only create a new node when the conversation contains truly new topics that are completely unrelated to existing "Candidates". + - The extraction source must be ONLY the **current conversation**. Do not fabricate information not mentioned. + +**Expression Rules (STRICT)**: + - For conflict updates, `value` must be a standalone statement of the latest fact only, without mentioning old values or change history (e.g., "formerly/previously/used to/changed from X to Y/no longer used the old name/now goes by"). + - If the latest state is inherently negative, express the negation directly but still avoid old values or comparisons. + - For updates to name/identity/affiliation/preference fields, always output a positive statement of the latest value (e.g., old memory "User's name is Wang Qiang", new conversation "My name is Li Bai" → output "The user's name is Li Bai"). + - Objective facts about third-party people/entities must use `LongTermMemory`, and the subject must remain that third party (e.g., "Wang Qiang lives in Shanghai"). + +Please execute the following: +1. Identify information that reflects the user's experiences, beliefs, concerns, decisions, plans, or reactions. + - If the message is from the user, extract user-related memories. + - If it is from the assistant, only extract factual memories that the user explicitly acknowledges or responds to. + +2. Disambiguate all references to time, people, and events (same rules as before): + - Convert relative time ("yesterday") to an absolute date. + - Clearly distinguish event time from message time. + - Resolve pronouns and ambiguous references. + - Replace only first-person references ("I/we/me") with "the user". + - Keep third-party names/entities unchanged; do not replace them with "the user". + - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). + - Candidates may include [Time: ...] to indicate event time; use the conversation time to judge whether they are the same period. + +3. Do not omit any information the user might want to remember. + - Include all key experiences, thoughts, emotional reactions, and plans — even if they seem minor. + - Prioritize completeness and fidelity over brevity. + - Do not generalize or skip details that are personally meaningful to the user. + +4. **Processing Logic (Update and Create)**: + Traverse each piece of information in the conversation that is worth remembering and apply: + + a) **Update existing memory node (Update via Duplicate/Related)**: + - Check Duplicate/Related Candidates. + - If the new information repeats, confirms, or adds details to an old memory: + - Generate an **updated complete memory** into `value` (old info + new details). + - Put the old memory IDs into `source_candidate_ids`. + - `conflicted_candidate_ids` must be []. + - If the old node also contains other sub-facts that remain valid and can stand alone independently, place them inside this same update item as `preserved_facts`. + - Every preserved fact must be directly traceable to the old node referenced by this update item. It is only a split-out/rephrased sub-fact already present inside that same old node, and must NEVER borrow, merge, summarize, or infer content from any other candidate. + - If the old node is effectively a single fact, or all of its content is already absorbed into `value`, then `preserved_facts` must be an empty array. + + b) **Fix conflicting memory node (Update via Conflict)**: + - Check Conflict Candidates. + - If the new information negates an old memory or updates its state (e.g., "no longer likes X", "changed to Y", "canceled plan", "from X to Y"): + - Generate a memory reflecting the **latest state** into `value`. + - Put the corrected old memory IDs into `conflicted_candidate_ids`. + - If the old node itself is a mixed memory and this update changes only one part of it, you must place the unaffected but still valid standalone facts into this same update item as `preserved_facts`. + - `preserved_facts` must NEVER contain any old fact that is contradicted, replaced, or covered by the current update. For example, if "works in Shenzhen" is updated to "works in Guangzhou", then `preserved_facts` must not contain "works in Shenzhen". + - For long-lived independent attributes (e.g., phone number, birthplace, affiliation), prefer splitting them into standalone facts instead of mixing them with mutable states. + - If the old node has no remaining independent valid facts, then `preserved_facts` must be an empty array. + + c) **Create new memory node (Create New)**: + - If the new information is not directly related to any "Candidates" (neither duplicate nor conflict): + - Generate an independent new memory into `value`. + - Ensure `source_candidate_ids` and `conflicted_candidate_ids` are both `[]`. + - Newly created memories must use `preserved_facts: []`. + +5. For any unrelated candidate, simply place its ID into `unrelated_candidate_ids`. + +6. Avoid including any memories that violate laws or involve politically sensitive information. + +Return a valid JSON object with the structure: + +{ + "memory list": [ + { + "key": , + "memory_type": , + "value": , + "tags": , + "source_candidate_ids": , + "conflicted_candidate_ids": , + "preserved_facts": [ + { + "key": , + "value": , + "tags": , + "memory_type": + } + ] + }, + ... + ], + "unrelated_candidate_ids": [], + "summary": +} + +Language rules: +- The `key`, `value`, `tags`, and `summary` fields must match the main language of the input conversation. If the input is English, output English. +- `memory_type` remains in English. +Format rules (STRICT): +- Output **strict JSON** only, no trailing commas. +- Do not include Markdown, code fences, or any explanations. + +${custom_tags_prompt} + +Example: +1. **Duplicate/Related Candidates**: +[ID:101][Time: 2025/05/20 09:30:00] The user likes latte and usually doesn't add sugar. +[ID:102][Time: 2025/05/18 18:00:00] The user hates rainy days. + +2. **Conflict Candidates**: +[ID:201][Time: 2025/02/03 20:15:00] The user likes badminton but dislikes skiing. + +3. **Possibly unrelated candidates that require further judgment (Unrelated Candidates)**: +[ID:301][Time: 2025/06/20 10:00:00] The user recently watched Star Wars. + +**Conversation time**: +2025/06/26 09:00:00 + +**Conversation**: +user: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. + +**Output:** +{ + "memory list": [ + { + "key": "Coffee preference", + "memory_type": "UserMemory", + "value": "The user likes latte most, usually doesn't add sugar, and prefers oat milk.", + "tags": ["diet", "coffee", "preference"], + "source_candidate_ids": ["101"], + "conflicted_candidate_ids": [], + "preserved_facts": [] + }, + { + "key": "Sport habit change", + "memory_type": "UserMemory", + "value": "Due to a knee injury, the user decides to no longer play badminton.", + "tags": ["sport", "health", "badminton"], + "source_candidate_ids": [], + "conflicted_candidate_ids": ["201"], + "preserved_facts": [ + { + "key": "Sport preference", + "value": "The user dislikes skiing.", + "tags": ["sport", "skiing", "preference"], + "memory_type": "UserMemory" + } + ] + }, + { + "key": "Pet status", + "memory_type": "UserMemory", + "value": "The user recently (June 2025) adopted a cat.", + "tags": ["life", "pet", "cat"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "preserved_facts": [] + } + ], + "unrelated_candidate_ids": ["301"], + "summary": "In this update, the user refined their latte preference by adding oat milk and updated their sports habit to no longer playing badminton because of a knee injury, while preserving the still-valid independent fact that the user dislikes skiing. Additionally, a new memory was added that the user recently adopted a cat." +} + +Always reply in the same language as the conversation. The latest input is below: + +1. **Duplicate/Related Candidates**: +${duplicate_candidates} + +2. **Conflict Candidates**: +${conflict_candidates} + +3. **Possibly unrelated candidates that require further judgment (Unrelated Candidates)**: +${unrelated_candidates} + +**Conversation time**: +${conversation_time} + +**Conversation**: +${conversation} + +**Output:**""" + +ASYNC_MEMORY_UPDATE_PROMPT_DICT = { + "zh": ASYNC_MEMORY_UPDATE_PROMPT_ZH, + "en": ASYNC_MEMORY_UPDATE_PROMPT_EN, +} + +MEMORY_MERGE_PROMPT_ZH = """ +您是记忆库维护专家。 +我们尝试更新一个记忆节点,但该节点在数据库中的内容在处理期间发生了变化(版本冲突)。 +我们需要将“本次处理得出的更新内容”合并到“当前数据库中最新的记忆内容”中。 + +**原始记忆(数据库中的最新版本):** +${latest_memory} + +**本次尝试的更新内容(基于旧版本得出的结论):** +${proposed_update} + +**任务:** +将“本次尝试的更新内容”合并到“原始记忆”中。 +- 如果更新内容包含新信息,请将其整合进去。 +- 如果更新内容与原始记忆冲突,请优先采纳更新内容(假设它是基于最新对话的修正),但请尽量保留原始记忆中依然有效的细节。 +- 确保合并后的结果是一个连贯、通顺的完整记忆片段。 + +请只返回合并后的记忆内容字符串,不要包含任何解释。 +""" + +MEMORY_MERGE_PROMPT_EN = """ +You are a memory maintenance expert. +We attempted to update a memory node, but the content of that node changed in the database during processing (version conflict). +We need to merge "the update derived in this attempt" into "the latest memory content currently stored in the database". + +Original memory (latest version in the database): +${latest_memory} + +Proposed update (derived based on an old version): +${proposed_update} + +Task: +Merge "the proposed update" into "the original memory". +- If the update contains new information, integrate it. +- If the update conflicts with the original memory, prefer the update (assuming it is a correction based on the latest conversation), while preserving any details from the original memory that remain valid. +- Ensure the merged result is a coherent, fluent, and complete memory passage. + +Return ONLY the merged memory content string. Do not include any explanation. +""" + +MEMORY_MERGE_PROMPT_DICT = { + "zh": MEMORY_MERGE_PROMPT_ZH, + "en": MEMORY_MERGE_PROMPT_EN, +} diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index a6ac186b7..bca5aa77b 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -1,6 +1,6 @@ import uuid -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest @@ -8,13 +8,13 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.memories.textual.item import ( + ArchivedTextualMemory, TextualMemoryItem, - TextualMemoryMetadata, + TreeNodeTextualMemoryMetadata, ) from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, - _append_related_content, - _detach_related_content, + _rebuild_fast_node_history, ) @@ -34,60 +34,6 @@ def history_manager(mock_nli_client, mock_graph_db): return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db) -def test_detach_related_content(): - original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = ["Duplicate 1", "Duplicate 2"] - conflicts = ["Conflict 1", "Conflict 2"] - - # 1. Append content - _append_related_content(item, duplicates, conflicts) - - # Verify content was appended - assert item.memory != original_memory - assert "[possibly conflicting memories]" in item.memory - assert "[possibly duplicate memories]" in item.memory - assert "Duplicate 1" in item.memory - assert "Conflict 1" in item.memory - - # 2. Detach content - _detach_related_content(item) - - # 3. Verify content is restored - assert item.memory == original_memory - - -def test_detach_only_conflicts(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = [] - conflicts = ["Conflict A"] - - _append_related_content(item, duplicates, conflicts) - assert "Conflict A" in item.memory - assert "Duplicate" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - -def test_detach_only_duplicates(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = ["Duplicate A"] - conflicts = [] - - _append_related_content(item, duplicates, conflicts) - assert "Duplicate A" in item.memory - assert "Conflict" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - def test_truncation(history_manager, mock_nli_client): # Setup new_item = TextualMemoryItem(memory="Test") @@ -97,12 +43,14 @@ def test_truncation(history_manager, mock_nli_client): mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE] # Action - history_manager.resolve_history_via_nli(new_item, [related_item]) + resolved_ids = history_manager.resolve_history_via_nli(new_item, [related_item]) # Assert - assert "possibly duplicate memories" in new_item.memory - assert "..." in new_item.memory # Should be truncated - assert len(new_item.memory) < 1000 # Ensure reasonable length + assert new_item.memory == "Test" + assert resolved_ids == [related_item.id] + assert len(new_item.metadata.history) == 1 + assert new_item.metadata.history[0].memory == long_memory + assert new_item.metadata.history[0].update_type == "duplicate" def test_empty_related_items(history_manager, mock_nli_client): @@ -118,20 +66,822 @@ def test_mark_memory_status(history_manager, mock_graph_db): id1 = uuid.uuid4().hex id2 = uuid.uuid4().hex id3 = uuid.uuid4().hex - items = [ - TextualMemoryItem(memory="M1", id=id1), - TextualMemoryItem(memory="M2", id=id2), - TextualMemoryItem(memory="M3", id=id3), - ] + memory_ids = [id1, id2, id3] status = "resolving" # Action - history_manager.mark_memory_status(items, status) + history_manager.mark_memory_status(memory_ids, status, user_name="u1") # Assert assert mock_graph_db.update_node.call_count == 3 - # Verify we called it correctly (user_name=None is passed by mark_memory_status) - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None) + # Verify we called it correctly + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name="u1") + + +def test_format_async_update_prompt(history_manager): + # Setup + # Create history items + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="101", memory="Duplicate content", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="201", memory="Conflict content", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=1, archived_memory_id="301", memory="Unrelated content", update_type="unrelated" + ) + + item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]), + ) + + # Execute + prompt = history_manager.format_prompt(item) + + # Verify + assert "[ID:101]" in prompt + assert "Duplicate content" in prompt + assert "[ID:201]" in prompt + assert "Conflict content" in prompt + assert "[ID:301]" in prompt + assert "Unrelated content" in prompt + assert "New user input" in prompt + + # Check that placeholders are gone (basic check) + assert "${duplicate_candidates}" not in prompt + assert "${conflict_candidates}" not in prompt + + +def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): + llm_response = { + "memory list": [ + { + "key": "New Memory", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": ["tag1"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "preserved_facts": [], + } + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + new_item = new_items[0] + assert new_item.memory == "New Content" + assert new_item.metadata.tags == ["tag1"] + assert new_item.metadata.history == [] + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db): + # Setup existing node + existing_id = uuid.uuid4().hex + existing_node = { + "id": existing_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + "preserved_facts": [], + } + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=existing_id, + memory="Old Content", + update_type="duplicate", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == existing_id + assert updated_item.memory == "Updated Content" + assert updated_item.metadata.version == 2 + assert updated_item.metadata.tags == ["new"] + assert len(updated_item.metadata.history) == 1 + + history_entry = updated_item.metadata.history[0] + assert history_entry.archived_memory_id != existing_id + assert history_entry.archived_memory_id is not None + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "duplicate" + + mock_graph_db.add_node.assert_called_once() + mock_graph_db.update_node.assert_called_once() + args, kwargs = mock_graph_db.update_node.call_args + assert kwargs["id"] == existing_id + assert kwargs["fields"]["memory"] == "Updated Content" + assert kwargs["fields"]["version"] == 2 + + +def test_apply_llm_memory_updates_preserved_facts(history_manager, mock_graph_db): + source_id = uuid.uuid4().hex + existing_node = { + "id": source_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + restored_item = TextualMemoryItem( + memory="Restored Content", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + history_manager._handle_preserved_facts = MagicMock(return_value=[restored_item]) + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [source_id], + "preserved_facts": [ + { + "key": "Preserved Fact", + "value": "Restored Content", + "tags": ["restored"], + "memory_type": "UserMemory", + } + ], + } + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=source_id, + memory="Old Content", + update_type="conflict", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 1 + assert new_items[0] == restored_item + history_manager._handle_preserved_facts.assert_called_once_with( + [ + { + "source_candidate_id": source_id, + "value": "Restored Content", + "tags": ["restored"], + "key": "Preserved Fact", + } + ], + source_item, + "u1", + pre_update_source_item_map=ANY, + ) + mock_graph_db.add_node.assert_called_once() + + +def test_apply_llm_memory_updates_drops_preserved_facts_for_create_only_items( + history_manager, mock_graph_db +): + conflict_id = uuid.uuid4().hex + history_manager._handle_preserved_facts = MagicMock(return_value=[]) + + existing_node = { + "id": conflict_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [conflict_id], + "preserved_facts": [], + }, + { + "key": "Create Only Memory", + "memory_type": "LongTermMemory", + "value": "Brand New Content", + "tags": ["new"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "preserved_facts": [ + { + "key": "Should Be Dropped", + "value": "Should Be Ignored", + "tags": ["ignore"], + } + ], + }, + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=conflict_id, + memory="Old Content", + update_type="conflict", + ), + ] + ), + ) + + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + history_manager._handle_preserved_facts.assert_called_once_with( + [], + source_item, + "u1", + pre_update_source_item_map=ANY, + ) + + +def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): + id1 = uuid.uuid4().hex + id2 = uuid.uuid4().hex + llm_response = {"memory list": [], "summary": "Summary"} + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=id1, + memory="M1", + update_type="unrelated", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=id2, + memory="M2", + update_type="unrelated", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 0 + mock_graph_db.update_node.assert_not_called() + + +def test_handle_preserved_facts_inherits_memory_type_from_source(history_manager): + source_id = uuid.uuid4().hex + source_item = TextualMemoryItem( + id=source_id, + memory="Wang Lin works in Shenzhen and phone number is 13800138000.", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + memory_type="UserMemory", + created_at="2023-01-01", + sources=[], + history=[], + ), + ) + fast_item = TextualMemoryItem( + memory="new input", + metadata=TreeNodeTextualMemoryMetadata(), + ) + + new_items = history_manager._handle_preserved_facts( + [ + { + "source_candidate_id": source_id, + "key": "Phone number", + "value": "Wang Lin's phone number is 13800138000.", + "tags": ["contact"], + } + ], + fast_item=fast_item, + user_name="u1", + pre_update_source_item_map={source_id: source_item}, + ) + + assert len(new_items) == 1 + assert new_items[0].metadata.memory_type == "UserMemory" + assert new_items[0].metadata.key == "Phone number" + assert new_items[0].metadata.history[-1].archived_memory_id == source_id + + +def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): + # Setup existing node (primary) + primary_id = uuid.uuid4().hex + secondary_id = uuid.uuid4().hex + existing_node = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [ + existing_node, + { + "id": secondary_id, + "memory": "Secondary", + "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, + }, + ] + + llm_response = { + "memory list": [ + { + "key": "Conflict Resolved", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": [], + "source_candidate_ids": [], + "conflicted_candidate_ids": [primary_id, secondary_id], + "preserved_facts": [], + } + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=primary_id, + memory="Old Content", + update_type="conflict", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=secondary_id, + memory="Secondary", + update_type="conflict", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == primary_id + assert updated_item.metadata.history[0].update_type == "conflict" + + # Verify primary update + # The mock_graph_db.update_node is called for primary (update) AND secondary (delete) + + # Find call for primary + primary_update_calls = [ + c + for c in mock_graph_db.update_node.call_args_list + if c.kwargs["id"] == primary_id and "memory" in c.kwargs.get("fields", {}) + ] + assert len(primary_update_calls) >= 1 + assert primary_update_calls[0].kwargs["fields"]["memory"] == "New Content" + + # Find call for secondary + secondary_update_calls = [ + c for c in mock_graph_db.update_node.call_args_list if c.kwargs["id"] == secondary_id + ] + assert len(secondary_update_calls) >= 1 + last_secondary_update = secondary_update_calls[-1] + assert last_secondary_update.kwargs["fields"]["status"] == "archived" + assert last_secondary_update.kwargs["fields"]["evolve_to"] == [primary_id] + + +def test_rebuild_fast_node_history_dedup_and_replace(): + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="a", memory="m1", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="b", memory="m2", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=2, archived_memory_id="a", memory="m3", update_type="duplicate" + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]) + ) + + r1 = ArchivedTextualMemory( + version=2, archived_memory_id="b", memory="m4", update_type="conflict" + ) + r2 = ArchivedTextualMemory( + version=1, archived_memory_id="c", memory="m5", update_type="duplicate" + ) + + _rebuild_fast_node_history(item, {1: [r1, r2]}) + + by_id = {h.archived_memory_id: h for h in item.metadata.history} + assert set(by_id.keys()) == {"a", "b", "c"} + assert by_id["a"].version == 2 + assert by_id["b"].version == 2 + + +def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): + fast_id = uuid.uuid4().hex + history_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="conflict", is_fast=True + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[history_item]) + ) + mock_graph_db.get_nodes.return_value = [ + { + "id": fast_id, + "memory": "fast", + "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}, + } + ] + + replacement_item = ArchivedTextualMemory( + version=1, archived_memory_id="n1", memory="r1", update_type="conflict" + ) + history_manager._fetch_evolved_nodes = MagicMock(return_value=[replacement_item]) + + replacements = history_manager._check_and_fetch_replacements(item, [0], user_name="u1") + + assert 0 in replacements + assert replacements[0][0].archived_memory_id == "n1" + history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict", "u1") + + +def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): + mock_graph_db.get_nodes.return_value = [ + { + "id": "x1", + "memory": "m1", + "metadata": {"version": 2, "is_fast": False, "created_at": "2024-01-01"}, + }, + { + "id": "x2", + "memory": "m2", + "metadata": {"version": 1, "is_fast": True, "created_at": "2024-01-02"}, + }, + ] + + results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate", user_name="u1") + + assert len(results) == 2 + ids = sorted([r.archived_memory_id for r in results]) + assert ids == ["x1", "x2"] + assert all(r.update_type == "duplicate" for r in results) + + +def test_wait_and_update_fast_history_rebuilds(history_manager): + fast_id = uuid.uuid4().hex + fast_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="duplicate", is_fast=True + ) + other_item = ArchivedTextualMemory( + version=1, archived_memory_id="k1", memory="keep", update_type="unrelated", is_fast=False + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[fast_item, other_item]) + ) + + replacement = ArchivedTextualMemory( + version=2, archived_memory_id="n1", memory="new", update_type="duplicate", is_fast=False + ) + history_manager._check_and_fetch_replacements = MagicMock(return_value={0: [replacement]}) + + history_manager.wait_and_update_fast_history(item, user_name="u1", timeout_sec=1) + + ids = [h.archived_memory_id for h in item.metadata.history] + assert "n1" in ids + assert fast_id not in ids + history_manager._check_and_fetch_replacements.assert_called_once_with(item, [0], "u1") + + +def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): + llm = MagicMock() + llm.generate.return_value = "Merged Content" + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + existing_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, + } + mock_graph_db.get_nodes.return_value = [ + { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, + } + ] + + mem_data = { + "key": "k", + "value": "Proposed", + "tags": ["t1"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + } + + updated, new_item = manager._update_existing_memory( + mem_data, + [existing_id], + [existing_id], + {existing_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated.memory == "Merged Content" + assert updated.metadata.version == 3 + assert new_item is None + mock_graph_db.update_node.assert_called_once() + + +def test_update_existing_memory_marks_working_binding_deleted(history_manager, mock_graph_db): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + working_binding = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, + } + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated, new_item = history_manager._update_existing_memory( + mem_data, + [primary_id], + [primary_id], + {primary_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated is not None + assert new_item is None + history_manager.mark_memory_status.assert_called_once_with( + [str(working_binding)], "deleted", user_name="u1" + ) + + +def test_update_existing_memory_no_mark_when_working_binding_matches( + history_manager, mock_graph_db +): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, + } + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated, new_item = history_manager._update_existing_memory( + mem_data, + [primary_id], + [primary_id], + {primary_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated is not None + assert new_item is None + + +def test_update_from_feedback_returns_persistence_payload_without_side_effects( + history_manager, mock_graph_db +): + history_manager.mark_memory_status = MagicMock() + memory_id = str(uuid.uuid4()) + old_item = TextualMemoryItem( + id=memory_id, + memory="Old Content", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + memory_type="LongTermMemory", + embedding=[0.1, 0.2], + sources=[{"type": "chat", "content": "old source"}], + history=[], + ), + ) + new_item = TextualMemoryItem( + memory="Updated Content", + metadata=TreeNodeTextualMemoryMetadata( + tags=["fresh"], + key="topic", + background="new background", + embedding=[0.3, 0.4], + sources=[{"type": "feedback", "content": "new feedback source"}], + memory_type="LongTermMemory", + ), + ) + + current_item, archived_item, archived_metadata, update_fields = ( + history_manager.update_from_feedback( + old_item=old_item, + new_item=new_item, + user_name="u1", + ) + ) + + assert current_item.id == memory_id + assert current_item.memory == "Updated Content" + assert archived_item.memory == "Old Content" + assert current_item.metadata.sources[0].content == "new feedback source" + assert current_item.metadata.sources[0].type == "feedback" + assert current_item.metadata.sources[1].content == "old source" + assert archived_item.metadata.sources[0].content == "old source" + assert archived_metadata["embedding"] == [0.1, 0.2] + assert update_fields["memory"] == "Updated Content" + assert update_fields["covered_history"] == archived_item.id + assert update_fields["embedding"] == [0.3, 0.4] + mock_graph_db.get_node.assert_not_called() + mock_graph_db.add_node.assert_not_called() + mock_graph_db.update_node.assert_not_called() + history_manager.mark_memory_status.assert_not_called() + + +def test_update_existing_memory_node_missing(history_manager, mock_graph_db): + mock_graph_db.get_node.return_value = None + mock_graph_db.get_nodes.return_value = [] + mem_data = {"value": "v", "tags": [], "key": "k"} + + updated, new_item = history_manager._update_existing_memory( + mem_data, + ["missing"], + [], + {}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated is None + assert new_item is not None + assert new_item.memory == "v" + mock_graph_db.update_node.assert_not_called() + + +def test_update_node_with_history(): + item = TextualMemoryItem( + memory="Old Content", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + tags=["old"], + key="k1", + history=[], + ), + ) + + updated, archived = MemoryHistoryManager.update_node_with_history( + item, + "New Content", + "conflict", + ) + + assert updated.memory == "New Content" + assert updated.metadata.version == 3 + assert updated.metadata.tags == ["old"] + assert updated.metadata.key == "k1" + assert len(updated.metadata.history) == 1 + history_entry = updated.metadata.history[0] + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "conflict" + assert history_entry.archived_memory_id == archived.id + assert archived.metadata.status == "archived" + assert archived.metadata.evolve_to == [updated.id] + + +def test_merge_conflicting_memory_llm_error(mock_graph_db): + llm = MagicMock() + llm.generate.side_effect = Exception("fail") + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + merged = manager._merge_conflicting_memory("Latest", "Proposed") + + assert merged == "Latest\n\n[New Info]: Proposed" + assert "Latest" in merged + assert "Proposed" in merged diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/conftest.py b/tests/plugins/conftest.py new file mode 100644 index 000000000..6a1a16b68 --- /dev/null +++ b/tests/plugins/conftest.py @@ -0,0 +1,17 @@ +"""Ensure @hookable-generated hooks are declared for core framework tests. + +In production, @hookable("add") runs at import time of add_handler.py, +declaring add.before / add.after. Core framework tests don't import handler +modules (to avoid heavy dependencies), so we trigger declarations here. + +Plugin-specific hooks are declared in each plugin's own tests/conftest.py. +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("chat") +hookable("feedback") +hookable("memory.get") diff --git a/tests/plugins/run_plugin_server.py b/tests/plugins/run_plugin_server.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py new file mode 100644 index 000000000..77ea8dfce --- /dev/null +++ b/tests/plugins/test_plugin_demo.py @@ -0,0 +1,439 @@ +""" +Plugin system core framework tests. + +Covers generic capabilities of the memos.plugins package (independent of specific plugin implementations): +1. Hook declaration registry (hook_defs) +2. Hook registration and triggering / pipe_key pipeline return value +3. @hookable decorator (sync + async + auto-declaration + pipeline return value) +4. MemOSPlugin base class register_* methods + +Plugin-specific functional tests are located in each plugin package: + extensions/memos_demo_plugin/tests/ +""" + +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +# ========================================================================= # +# 1. Hook declaration registry (hook_defs) +# ========================================================================= # + + +class TestHookDefs: + def test_define_hook_and_get_spec(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook( + "test.custom.hook", + description="test hook", + params=["request", "result"], + pipe_key="result", + ) + + spec = get_hook_spec("test.custom.hook") + assert spec is not None + assert spec.name == "test.custom.hook" + assert spec.params == ["request", "result"] + assert spec.pipe_key == "result" + + def test_define_hook_is_idempotent(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook("test.idempotent", description="first", params=["a"], pipe_key="a") + define_hook("test.idempotent", description="second", params=["b"], pipe_key="b") + + spec = get_hook_spec("test.idempotent") + assert spec.description == "first" + + def test_get_hook_spec_returns_none_for_unknown(self): + from memos.plugins.hook_defs import get_hook_spec + + assert get_hook_spec("definitely.does.not.exist") is None + + def test_all_hook_specs_includes_custom(self): + from memos.plugins.hook_defs import H, all_hook_specs + + specs = all_hook_specs() + assert H.ADD_MEMORIES_POST_PROCESS in specs + + def test_h_constants(self): + from memos.plugins.hook_defs import H + + assert H.ADD_BEFORE == "add.before" + assert H.ADD_AFTER == "add.after" + assert H.SEARCH_BEFORE == "search.before" + assert H.SEARCH_AFTER == "search.after" + assert H.ADD_MEMORIES_POST_PROCESS == "add.memories.post_process" + + +# ========================================================================= # +# 2. Hook registration and triggering / pipe_key pipeline return value +# ========================================================================= # + + +class TestHookMechanism: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_and_trigger(self): + from memos.plugins.hooks import register_hook, trigger_hook + + captured = {} + + def my_callback(*, request, **kwargs): + captured["request"] = request + + register_hook("add.before", my_callback) + trigger_hook("add.before", request="test_request") + + assert captured["request"] == "test_request" + + def test_register_hooks_batch(self): + from memos.plugins.hooks import register_hooks, trigger_hook + + call_count = 0 + + def my_callback(**kwargs): + nonlocal call_count + call_count += 1 + + register_hooks(["add.before", "search.before"], my_callback) + trigger_hook("add.before") + trigger_hook("search.before") + + assert call_count == 2 + + def test_trigger_undeclared_hook_returns_none(self): + from memos.plugins.hooks import trigger_hook + + result = trigger_hook("nonexistent.undeclared.hook", request="anything") + assert result is None + + def test_hook_exception_does_not_propagate(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook("test.exception", description="test", params=["x"]) + + results = [] + + def bad_callback(**kwargs): + raise ValueError("intentional error") + + def good_callback(**kwargs): + results.append("ok") + + register_hook("test.exception", bad_callback) + register_hook("test.exception", good_callback) + trigger_hook("test.exception", x=1) + + assert results == ["ok"] + + def test_trigger_hook_pipe_key_returns_modified_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.pipe", + description="pipe test", + params=["request", "result"], + pipe_key="result", + ) + + def double_result(*, request, result, **kwargs): + return result * 2 + + register_hook("test.pipe", double_result) + rv = trigger_hook("test.pipe", request="req", result=5) + + assert rv == 10 + + def test_trigger_hook_pipe_key_chains_callbacks(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.chain", + description="chain test", + params=["result"], + pipe_key="result", + ) + + def add_one(*, result, **kwargs): + return result + 1 + + def add_ten(*, result, **kwargs): + return result + 10 + + register_hook("test.chain", add_one) + register_hook("test.chain", add_ten) + + rv = trigger_hook("test.chain", result=0) + assert rv == 11 + + def test_trigger_hook_pipe_key_none_callback_no_modify(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.noop", + description="noop test", + params=["result"], + pipe_key="result", + ) + + def noop(*, result, **kwargs): + return None # explicitly return None — should not modify + + register_hook("test.noop", noop) + rv = trigger_hook("test.noop", result="original") + + assert rv == "original" + + def test_trigger_hook_notification_mode(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.notify", + description="notification test", + params=["data"], + pipe_key=None, + ) + + captured = [] + + def observer(*, data, **kwargs): + captured.append(data) + + register_hook("test.notify", observer) + rv = trigger_hook("test.notify", data="hello") + + assert rv is None + assert captured == ["hello"] + + +# ========================================================================= # +# 3. @hookable decorator +# ========================================================================= # + + +class TestHookableDecorator: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_hookable_auto_declares_specs(self): + from memos.plugins.hook_defs import get_hook_spec + from memos.plugins.hooks import hookable + + @hookable("auto_test") + def dummy(self, request): + return request + + before_spec = get_hook_spec("auto_test.before") + after_spec = get_hook_spec("auto_test.after") + + assert before_spec is not None + assert before_spec.pipe_key == "request" + assert after_spec is not None + assert after_spec.pipe_key == "result" + + def test_hookable_sync(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append(("before", request)) + + def on_after(*, request, result, **kwargs): + events.append(("after", result)) + + register_hook("sync_demo.before", on_before) + register_hook("sync_demo.after", on_after) + + class FakeHandler: + @hookable("sync_demo") + def do_work(self, request): + return f"processed:{request}" + + result = FakeHandler().do_work("my_input") + + assert result == "processed:my_input" + assert events == [("before", "my_input"), ("after", "processed:my_input")] + + def test_hookable_async(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append("before") + + def on_after(*, request, result, **kwargs): + events.append("after") + + register_hook("async_demo.before", on_before) + register_hook("async_demo.after", on_after) + + class FakeHandler: + @hookable("async_demo") + async def do_work(self, request): + return "async_result" + + result = asyncio.run(FakeHandler().do_work("req")) + + assert result == "async_result" + assert events == ["before", "after"] + + def test_hookable_before_can_modify_request(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_request(*, request, **kwargs): + return "modified_request" + + register_hook("modify_req.before", rewrite_request) + + class FakeHandler: + @hookable("modify_req") + def do_work(self, request): + return f"got:{request}" + + result = FakeHandler().do_work("original") + assert result == "got:modified_request" + + def test_hookable_after_can_modify_result(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_result(*, request, result, **kwargs): + return f"{result}+modified" + + register_hook("modify_res.after", rewrite_result) + + class FakeHandler: + @hookable("modify_res") + def do_work(self, request): + return "original_result" + + result = FakeHandler().do_work("req") + assert result == "original_result+modified" + + def test_hookable_falsy_return_preserved(self): + """ensure empty list / 0 / empty string are not treated as None""" + from memos.plugins.hooks import hookable, register_hook + + def return_empty_list(*, request, result, **kwargs): + return [] + + register_hook("falsy_test.after", return_empty_list) + + class FakeHandler: + @hookable("falsy_test") + def do_work(self, request): + return [1, 2, 3] + + result = FakeHandler().do_work("req") + assert result == [] + + +# ========================================================================= # +# 4. Base class register_* methods +# ========================================================================= # + + +class TestBaseClassRegisterMethods: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_router(self): + from fastapi import APIRouter + + from memos.plugins.base import MemOSPlugin + + app = FastAPI() + plugin = MemOSPlugin() + plugin._bind_app(app) + + router = APIRouter(prefix="/test") + + @router.get("/ping") + async def ping(): + return {"pong": True} + + plugin.register_router(router) + + paths = [r.path for r in app.routes] + assert "/test/ping" in paths + + def test_register_middleware(self): + from starlette.middleware.base import BaseHTTPMiddleware + + from memos.plugins.base import MemOSPlugin + + class NoopMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = FastAPI() + + @app.get("/x") + async def x(): + return {} + + plugin = MemOSPlugin() + plugin._bind_app(app) + plugin.register_middleware(NoopMiddleware) + + client = TestClient(app) + resp = client.get("/x") + assert resp.status_code == 200 + + def test_register_hook(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("test.reg.event", description="test", params=["x"]) + + called = [] + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hook("test.reg.event", lambda **kw: called.append(True)) + + trigger_hook("test.reg.event", x=1) + assert called == [True] + + def test_register_hooks_batch(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("batch.a", description="a", params=["x"]) + define_hook("batch.b", description="b", params=["x"]) + + count = 0 + + def cb(**kw): + nonlocal count + count += 1 + + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hooks(["batch.a", "batch.b"], cb) + + trigger_hook("batch.a", x=1) + trigger_hook("batch.b", x=2) + assert count == 2