From 6faf1f98a9177d26870ec5cc365c46ebf6745257 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 10:39:13 +0800 Subject: [PATCH 01/92] feat: Building fast-add related functions for memory versions. --- .../organize/history_manager.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 1afdc9281..97bbe7483 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,4 +1,5 @@ import logging +import time from typing import Literal @@ -67,6 +68,36 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: return +def _rebuild_fast_node_history( + item: TextualMemoryItem, + replacements: dict[int, list[ArchivedTextualMemory]], +) -> None: + """ + Reconstruct the history list of a fast node: + 1. Replace resolved items with their evolved versions. + 2. Deduplicate by ID while preserving newest versions. + """ + new_history = {} + + def _add(history_item): + item_id = history_item.archived_memory_id + current = new_history.get(item_id) + + if current is None or history_item.version > current.version: + new_history[item_id] = history_item + + # Apply replacements and filter superseded items + for i, h in enumerate(item.metadata.history): + if i in replacements: + # This item is resolved, insert its replacements + for replacement_item in replacements[i]: + _add(replacement_item) + else: + _add(h) + + item.metadata.history = list(new_history.values()) + + class MemoryHistoryManager: def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: """ @@ -79,6 +110,131 @@ def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: self.nli_client = nli_client self.graph_db = graph_db + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int] + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node_data = nodes_map.get(h_item.archived_memory_id) + + if not node_data: + continue + + metadata = node_data.get("metadata", {}) + status = metadata.get("status") + + # Condition: Fast node is processed when it is marked as 'deleted' + if status == "deleted": + evolve_to_ids = metadata.get("evolve_to", []) + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) + replacements[i] = new_items + + logger.info( + f"[MemoryHistoryManager] Resolved fast history item {h_item.archived_memory_id} -> {evolve_to_ids}" + ) + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: From a1a179cbec3c93227080c53b1fafef6d175d12da Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 10:46:39 +0800 Subject: [PATCH 02/92] feat: Building fast-add related functions for memory versions. --- .../organize/history_manager.py | 112 +++++++++--------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 97bbe7483..4cdfac985 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -179,62 +179,6 @@ def _fetch_evolved_nodes( return results - def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: - """ - Scan the item's history. If any history item is marked as `is_fast`, - wait for it to be resolved (i.e., status becomes 'deleted' in the DB). - When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. - Finally, deduplicate the history. - - Args: - item: The memory item containing the history to check. - timeout_sec: Maximum time to wait for resolution in seconds. - """ - start_time = time.time() - - # 1. Identify pending items (fast nodes) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - while True: - if not pending_indices: - # All fast nodes resolved or none existed - break - - if time.time() - start_time > timeout_sec: - logger.warning( - f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" - ) - # Remove pending fast nodes from history - item.metadata.history = [ - h - for h in item.metadata.history - if not (getattr(h, "is_fast", False) and h.archived_memory_id) - ] - break - - # 2. Check status of the fast nodes and fetch replacements for evolved ones - replacements = self._check_and_fetch_replacements(item, pending_indices) - - # 3. If we have any resolved items, rebuild the history - if replacements: - _rebuild_fast_node_history(item, replacements) - - # Check if we are done (no pending items left) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - if pending_indices: - time.sleep(1) # This avoids visiting the DB too frequently - - return - def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: @@ -293,6 +237,62 @@ def resolve_history_via_nli( return duplicate_memories + conflict_memories + def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + def mark_memory_status( self, memory_items: list[TextualMemoryItem], From fef1d901409b03064d5588d9a2a1230038419e92 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 11:30:26 +0800 Subject: [PATCH 03/92] feat: supporting memory versions in fast-add process --- src/memos/api/handlers/component_init.py | 9 +++- src/memos/configs/mem_reader.py | 6 ++- src/memos/mem_reader/factory.py | 18 ++++++++ src/memos/mem_reader/multi_modal_struct.py | 44 +++++++++++++++++++ src/memos/mem_reader/simple_struct.py | 12 +++++ .../init_components_for_scheduler.py | 14 +++++- 6 files changed, 100 insertions(+), 3 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ba527d602..2e27e9da5 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -192,8 +193,14 @@ def init_server() -> dict[str, Any]: embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + pre_update_retriever=pre_update_retriever, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 4bd7953c0..98aff981f 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator, model_validator @@ -65,6 +65,10 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): default=None, description="Skills directory for the MemReader", ) + memory_version_switch: Literal["on", "off"] = Field( + default="off", + description="Turn on memory version or off", + ) class StrategyStructMemReaderConfig(BaseMemReaderConfig): diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 7bd551fb8..0907168a7 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -5,11 +5,15 @@ from memos.mem_reader.multi_modal_struct import MultiModalStructMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_reader.strategy_struct import StrategyStructMemReader +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memos_tools.singleton import singleton_factory if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -29,6 +33,8 @@ def from_config( config_factory: MemReaderConfigFactory, graph_db: Optional["BaseGraphDB | None"] = None, searcher: Optional["Searcher | None"] = None, + pre_update_retriever: PreUpdateRetriever | None = None, + history_manager: Optional["MemoryHistoryManager | None"] = None, ) -> BaseMemReader: """ Create a MemReader instance from configuration. @@ -55,4 +61,16 @@ def from_config( if searcher is not None: reader.set_searcher(searcher) + if pre_update_retriever is not None: + if hasattr(reader, "set_pre_update_retriever"): + reader.set_pre_update_retriever(pre_update_retriever) + else: + reader.pre_update_retriever = pre_update_retriever + + if history_manager is not None: + if hasattr(reader, "set_history_manager"): + reader.set_history_manager(history_manager) + else: + reader.history_manager = history_manager + return reader diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 8b0968ca1..adda62e26 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -56,6 +56,10 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.pre_update_retriever = None + self.history_manager = None + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Initialize MultiModalParser for routing to different parsers self.multi_modal_parser = MultiModalParser( embedder=self.embedder, @@ -934,6 +938,39 @@ def _process_tool_trajectory_fine( return fine_memory_items + def _fast_resolve_memory_duplicates_and_conflicts( + self, fast_memory_items: list[TextualMemoryItem], user_name: str + ) -> None: + """ + 1. Recall related memories + 2. Fast conflict/duplication check with NLI model + 3. Attach conflicting/duplicate old memory contents onto fast memory items + 4. Mark conflicting/duplicate old memory nodes as "resolving", making them invisible to /search, + but still visible for other conflict/duplication checks' recalls. + """ + if not self.pre_update_retriever or not self.history_manager: + logger.warning( + "[MultiModalStruct] PreUpdateRetriever or HistoryManager is not initialized." + ) + return + + for item in fast_memory_items: + try: + # recall related memories + related = self.pre_update_retriever.retrieve( + item=item, + user_name=user_name, + ) + # NLI check & attaching contents + conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( + item, related + ) + # mark delete + self.history_manager.mark_memory_status(conflicting_or_duplicate_items, "resolving") + + except Exception as e: + logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs @@ -984,6 +1021,13 @@ def _process_multi_modal_data( scene_data_info, info, mode="fast", need_emb=False, **kwargs ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + + # Perform conflict/duplicate check with old memories + # TODO: find a better way to pass in the user_name + user_name = kwargs.get("user_name") + if self.memory_version_switch == "on": + self._fast_resolve_memory_duplicates_and_conflicts(fast_memory_items, user_name) + if mode == "fast": return fast_memory_items else: diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index ceaf28bfa..237653d4a 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -20,6 +20,10 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) + from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang from memos.mem_reader.utils import ( @@ -185,6 +189,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): # Initialize graph_db as None, can be set later via set_graph_db for # recall operations self.graph_db = None + self.pre_update_retriever = None + self.history_manager = None def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: self.graph_db = graph_db @@ -192,6 +198,12 @@ def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: def set_searcher(self, searcher: "Searcher | None") -> None: self.searcher = searcher + def set_pre_update_retriever(self, pre_update_retriever: "PreUpdateRetriever | None") -> None: + self.pre_update_retriever = pre_update_retriever + + def set_history_manager(self, history_manager: "MemoryHistoryManager | None") -> None: + self.history_manager = history_manager + def _make_memory_item( self, value: str, diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index b103acf3a..883c191fe 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -12,6 +12,7 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory +from memos.extras.nli_model.client import NLIClient from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger @@ -30,10 +31,12 @@ ) from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -287,6 +290,7 @@ def init_components() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() embedder_config = build_embedder_config() + nli_client_config = APIConfig.get_nli_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -307,8 +311,16 @@ def init_components() -> dict[str, Any]: ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) + nli_client = NLIClient(base_url=nli_client_config["base_url"]) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + pre_update_retriever=pre_update_retriever, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( From a2b4b02a798437a364a9e698ce3a2ee8eaa82703 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Thu, 5 Feb 2026 17:10:59 +0800 Subject: [PATCH 04/92] feat: Functions for supporting memory versions in fine-add process --- src/memos/api/handlers/component_init.py | 2 +- src/memos/mem_reader/multi_modal_struct.py | 6 +- .../organize/history_manager.py | 583 +++++++++++++++--- .../mem_reader_mem_version_prompts.py | 168 +++++ 4 files changed, 680 insertions(+), 79 deletions(-) create mode 100644 src/memos/templates/mem_reader_mem_version_prompts.py diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 2e27e9da5..e764c7fd4 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -192,7 +192,7 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db, llm=llm) pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config( diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index adda62e26..16098907d 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -965,8 +965,10 @@ def _fast_resolve_memory_duplicates_and_conflicts( conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( item, related ) - # mark delete - self.history_manager.mark_memory_status(conflicting_or_duplicate_items, "resolving") + # mark delete(temporarily) + self.history_manager.mark_memory_status( + [m.id for m in conflicting_or_duplicate_items], "resolving" + ) except Exception as e: logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 4cdfac985..2c2d37654 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,13 +1,24 @@ import logging import time +import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal from memos.context.context import ContextThreadPoolExecutor from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem +from memos.llms.base import BaseLLM +from memos.memories.textual.item import ( + ArchivedTextualMemory, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.templates.mem_reader_mem_version_prompts import ( + ASYNC_MEMORY_UPDATE_PROMPT_ZH, + MEMORY_MERGE_PROMPT_ZH, +) logger = logging.getLogger(__name__) @@ -69,13 +80,12 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: def _rebuild_fast_node_history( - item: TextualMemoryItem, - replacements: dict[int, list[ArchivedTextualMemory]], + item: TextualMemoryItem, replacements: dict[int, list[ArchivedTextualMemory]] ) -> None: """ Reconstruct the history list of a fast node: 1. Replace resolved items with their evolved versions. - 2. Deduplicate by ID while preserving newest versions. + 2. Deduplicate by ID while preserving the newest versions. """ new_history = {} @@ -99,89 +109,109 @@ def _add(history_item): class MemoryHistoryManager: - def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: + def __init__( + self, nli_client: NLIClient, graph_db: BaseGraphDB, llm: BaseLLM | None = None + ) -> None: """ Initialize the MemoryHistoryManager. Args: nli_client: NLIClient for conflict/duplicate detection. graph_db: GraphDB instance for marking operations during history management. + llm: Optional LLM instance for memory merging during conflicts. """ self.nli_client = nli_client self.graph_db = graph_db - - def _check_and_fetch_replacements( - self, item: TextualMemoryItem, pending_indices: list[int] - ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + self.llm = llm + + @staticmethod + def update_node_with_history( + item: TextualMemoryItem, + new_memory: str, + update_type: str, + tags: list[str] | None = None, + key: str | None = None, + ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + old_id = item.id + archived_id = str(uuid.uuid4()) + # archived memory(need to store this node to the db later) + archived_item = item.model_copy(deep=True) + archived_item.id = archived_id + archived_item.metadata.evolve_to = [old_id] + archived_item.metadata.status = "archived" + + # original memory with updated contents and history + archived_history = ArchivedTextualMemory( + version=item.metadata.version or 1, + is_fast=item.metadata.is_fast or False, + memory=item.memory, + update_type=update_type, + archived_memory_id=archived_id, + created_at=getattr(item.metadata, "created_at", None) or datetime.now().isoformat(), + ) + item.memory = new_memory + item.metadata.version = (item.metadata.version or 1) + 1 + item.metadata.status = "activated" + if tags is not None: + item.metadata.tags = tags + if key is not None: + item.metadata.key = key + if item.metadata.history is None: + item.metadata.history = [] + item.metadata.history.append(archived_history) + + return item, archived_item + + @staticmethod + def format_async_update_prompt( + item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + ) -> str: """ - Check DB status for pending items. If 'deleted', fetch evolved nodes. + Format the prompt for asynchronous memory update. + + Args: + item: The TextualMemoryItem containing history candidates. + conversation: The current conversation content. If None, uses item.memory. + custom_tags_prompt: Optional custom prompt for tags. Returns: - replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + Formatted prompt string. """ - pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] - - # Batch fetch pending nodes to check status - nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + # First, detach duplicate and conflict memory contents from the new item's memory text + _detach_related_content(item) - replacements = {} - - for i in pending_indices: - h_item = item.metadata.history[i] - node_data = nodes_map.get(h_item.archived_memory_id) + duplicate_candidates = [] + conflict_candidates = [] + unrelated_candidates = [] - if not node_data: - continue + for h in item.metadata.history or []: + candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" - metadata = node_data.get("metadata", {}) - status = metadata.get("status") - - # Condition: Fast node is processed when it is marked as 'deleted' - if status == "deleted": - evolve_to_ids = metadata.get("evolve_to", []) - - new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) - replacements[i] = new_items - - logger.info( - f"[MemoryHistoryManager] Resolved fast history item {h_item.archived_memory_id} -> {evolve_to_ids}" - ) - - return replacements - - def _fetch_evolved_nodes( - self, evolve_to_ids: list[str], update_type: str - ) -> list[ArchivedTextualMemory]: - """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" - if not evolve_to_ids: - return [] - - evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] - results = [] - - for enode in evolved_nodes: - if not enode or "id" not in enode: - continue + if h.update_type == "duplicate": + duplicate_candidates.append(candidate_str) + elif h.update_type == "conflict": + conflict_candidates.append(candidate_str) + else: + # Includes "unrelated" and any other types + unrelated_candidates.append(candidate_str) - enode_meta = enode.get("metadata", {}) + # Helper to format list as string + def format_list(candidates): + return "\n".join(candidates) if candidates else "无" - # Create new archived memory inheriting the update_type (conflict/duplicate) - new_archived = ArchivedTextualMemory( - version=enode_meta.get("version", 1), - is_fast=enode_meta.get("is_fast", False), - memory=enode.get("memory", ""), - update_type=update_type, - archived_memory_id=enode.get("id"), - created_at=enode_meta.get("created_at"), + return ( + ASYNC_MEMORY_UPDATE_PROMPT_ZH.replace( + "${duplicate_candidates}", format_list(duplicate_candidates) ) - results.append(new_archived) - - return results + .replace("${conflict_candidates}", format_list(conflict_candidates)) + .replace("${unrelated_candidates}", format_list(unrelated_candidates)) + .replace("${custom_tags_prompt}", custom_tags_prompt) + .replace("${conversation}", conversation or item.memory) + ) def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[TextualMemoryItem]: + ) -> list[str]: """ Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, and attach them as history to the new fast item. @@ -191,7 +221,7 @@ def resolve_history_via_nli( related_items: Existing memory items that might be related. Returns: - List of duplicate or conflicting memory items judged by the NLI service. + List of duplicate or conflicting memory ids judged by the NLI service. """ if not related_items: return [] @@ -208,10 +238,10 @@ def resolve_history_via_nli( for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" - duplicate_memories.append(r_item.memory) + duplicate_memories.append(r_item.id) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" - conflict_memories.append(r_item.memory) + conflict_memories.append(r_item.id) else: update_type = "unrelated" @@ -227,9 +257,6 @@ def resolve_history_via_nli( created_at=created_at, ) new_item.metadata.history.append(archived) - logger.info( - f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" - ) # 3. Concat duplicate/conflict memories to new_item.memory # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. @@ -293,26 +320,88 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int return + def apply_llm_memory_updates( + self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """ + Apply the updates from the LLM response to the memory graph. + + Args: + llm_response: The parsed JSON response from the LLM. + source_item: The original fast item A whose history contains ArchivedTextualMemory entries. + We derive expected versions and candidate IDs from A.history. + user_name: user_name + + Returns: + List of new or updated memory items. + """ + memory_list = llm_response.get("memory list", []) + restored_memories = llm_response.get("restored_memories", []) + used_source_ids = set() + used_conflict_ids = set() + # Collect IDs referenced by the LLM response + for mem_data in memory_list: + used_source_ids.update(mem_data.get("source_candidate_ids", [])) + used_conflict_ids.update(mem_data.get("conflicted_candidate_ids", [])) + # Collect IDs referenced by restored memories + restored_source_ids = { + mem.get("source_candidate_id") + for mem in restored_memories + if mem.get("source_candidate_id") + } + # All IDs used by the LLM response + used_ids = used_source_ids | used_conflict_ids | restored_source_ids + expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories + candidate_id_set: set[str] = set() + # Recover candidate IDs and their expected versions from the source item's history + if source_item.metadata and source_item.metadata.history: + for h in source_item.metadata.history: + if h.archived_memory_id: + candidate_id_set.add(h.archived_memory_id) + expected_versions[h.archived_memory_id] = h.version + + # IDs not used by the LLM response will be treated as unrelated + unrelated_ids = sorted(candidate_id_set - used_ids) + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + + # 1. Handle Unrelated Candidates - Restore status to activated + self._handle_unrelated_candidates(unrelated_ids, user_name=user_name) + + # 2. Handle Memory List (Update or New) + processed_updates, created_items = self._process_memory_updates( + memory_list, expected_versions + ) + updated_items.extend(processed_updates) + new_items.extend(created_items) + + # 3. Handle Restored Memories (Extract from conflict) + new_items.extend(self._handle_restored_memories(restored_memories)) + + return updated_items, new_items + def mark_memory_status( self, - memory_items: list[TextualMemoryItem], + memory_ids: list[str], status: Literal["activated", "resolving", "archived", "deleted"], + user_name: str, ) -> None: """ Support status marking operations during history management. Common usages are: 1. Mark conflict/duplicate old memories' status as "resolving", to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to restore their visibility. + 2. Mark resolved memories' status as "activated", to recover their visibility. """ # Execute the actual marking operation - in db. with ContextThreadPoolExecutor() as executor: futures = [] - for mem in memory_items: + for mid in memory_ids: futures.append( executor.submit( self.graph_db.update_node, - id=mem.id, + id=mid, fields={"status": status}, + user_name=user_name, ) ) @@ -320,3 +409,345 @@ def mark_memory_status( for future in futures: future.result() return + + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int] + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node_data = nodes_map.get(h_item.archived_memory_id) + + if not node_data: + continue + + metadata = node_data.get("metadata", {}) + status = metadata.get("status") + + # Condition: Fast node is processed when it is marked as 'deleted' + if status == "deleted": + evolve_to_ids = metadata.get("evolve_to", []) + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) + replacements[i] = new_items + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) -> None: + """Handle Unrelated Candidates - Restore status to `activated`.""" + if unrelated_ids: + self.mark_memory_status(unrelated_ids, "activated", user_name) + + def _process_memory_updates( + self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int] + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """Process Memory List (Update or New).""" + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + for mem_data in memory_list: + source_ids = mem_data.get("source_candidate_ids", []) + conflict_ids = mem_data.get("conflicted_candidate_ids", []) + + # Determine if this is an update or a new node + target_ids = source_ids + conflict_ids + + if target_ids: + item = self._update_existing_memory( + mem_data, + target_ids, + source_ids, + expected_versions, + ) + if item: + updated_items.append(item) + else: + item = self._create_new_memory(mem_data) + new_items.append(item) + return updated_items, new_items + + def _update_existing_memory( + self, + mem_data: dict[str, Any], + target_ids: list[str], + source_ids: list[str], + expected_versions: dict[str, int], + ) -> TextualMemoryItem | None: + """ + Update existing memory nodes using the LLM result. + + The first ID in target_ids is treated as the primary node. If additional target IDs + are provided, they are treated as secondary candidates and will be merged into the + primary. Merging means: + 1) Mark secondary nodes as archived and append the primary ID to evolve_to + 2) Merge their history entries into the primary history and re-order by created_at + + The method also applies CAS validation via expected_versions, archives the previous + version of the primary node, and persists the updated node back to the graph DB. + + Returns the updated primary TextualMemoryItem, or None when the primary node + cannot be found. + """ + original_primary_id, primary_id, secondary_ids = ( + target_ids[0], + target_ids[0], + target_ids[1:], + ) + new_value, tags, key = ( + mem_data.get("value", ""), + mem_data.get("tags", []), + mem_data.get("key", ""), + ) + + # Fetch candidate nodes in batch and then select the primary + nodes_data = self.graph_db.get_nodes(target_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + node_data = nodes_map.get(primary_id) + if not node_data: + node_data = self.graph_db.get_node(primary_id) + if not node_data: + logger.warning( + f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." + ) + return None + + current_item = TextualMemoryItem(**node_data) + new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) + + update_type = "duplicate" if original_primary_id in source_ids else "conflict" + current_item, archived_item = self.update_node_with_history( + current_item, + new_value, + update_type, + tags=tags, + key=key, + ) + now = datetime.now().isoformat() + if hasattr(archived_item.metadata, "created_at") and not getattr( + archived_item.metadata, "created_at", None + ): + archived_item.metadata.created_at = now + if hasattr(archived_item.metadata, "updated_at"): + archived_item.metadata.updated_at = now + if hasattr(current_item.metadata, "updated_at"): + current_item.metadata.updated_at = now + self.graph_db.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=archived_item.metadata.model_dump(exclude_none=True), + ) + fields = current_item.metadata.model_dump(exclude_none=True) + fields.pop("created_at", None) + fields["updated_at"] = now + self.graph_db.update_node( + id=primary_id, + fields={"memory": current_item.memory, **fields}, + ) + self._merge_secondary_nodes(secondary_ids, primary_id) + + return current_item + + def _apply_cas_merge( + self, + primary_id: str, + current_item: TextualMemoryItem, + expected_versions: dict[str, int], + new_value: str, + ) -> str: + expected_version = expected_versions.get(primary_id) + current_version = current_item.metadata.version or 1 + if expected_version is not None and current_version != expected_version: + logger.warning( + f"[MemoryHistoryManager] Version conflict for node {primary_id}: " + f"Expected v{expected_version}, but found v{current_version} in DB. " + "Triggering merge logic." + ) + merged_content = self._merge_conflicting_memory( + latest_memory=current_item.memory, + proposed_update=new_value, + ) + return merged_content + + return new_value + + def _merge_secondary_nodes(self, secondary_ids: list[str], primary_id: str) -> None: + if not secondary_ids: + return + + node_ids = [primary_id, *secondary_ids] + nodes_data = self.graph_db.get_nodes(ids=node_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + primary_data = nodes_map.get(primary_id) or self.graph_db.get_node(primary_id) + if not primary_data: + return + + primary_item = TextualMemoryItem(**primary_data) + merged_history = list(primary_item.metadata.history or []) + + for memory_id in secondary_ids: + node_data = nodes_map.get(memory_id) or self.graph_db.get_node(memory_id) + if not node_data: + continue + metadata = node_data.get("metadata", {}) + evolve_to = list(metadata.get("evolve_to", []) or []) + if primary_id not in evolve_to: + evolve_to.append(primary_id) + self.graph_db.update_node( + id=memory_id, + fields={"status": "archived", "evolve_to": evolve_to}, + ) + secondary_item = TextualMemoryItem(**node_data) + if secondary_item.metadata.history: + merged_history.extend(secondary_item.metadata.history) + + def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: + created_at = history_item.created_at + if isinstance(created_at, datetime): + return created_at + if created_at: + try: + return datetime.fromisoformat(created_at) + except ValueError: + return datetime.min + return datetime.min + + merged_history.sort(key=_history_sort_key) + for idx, history_item in enumerate(merged_history, start=1): + history_item.version = idx + primary_item.metadata.history = merged_history + self.graph_db.update_node( + id=primary_id, + fields={"history": [h.model_dump(exclude_none=True) for h in merged_history]}, + ) + + def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: + """ + Call LLM to merge proposed update with latest memory content. + """ + if not self.llm: + return proposed_update + + prompt = MEMORY_MERGE_PROMPT_ZH.replace("${latest_memory}", latest_memory).replace( + "${proposed_update}", proposed_update + ) + + messages = [{"role": "user", "content": prompt}] + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to merge memory via LLM: {e}") + # Fallback: append proposed update? or just return proposed? + # Returning proposed might overwrite latest changes. + # Returning latest might lose proposed changes. + # Let's concatenate as a safe fallback. + return f"{latest_memory}\n\n[System Merge Fallback] New Info: {proposed_update}" + + def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: + """Create New Node.""" + new_value = mem_data.get("value", "") + tags = mem_data.get("tags", []) + key = mem_data.get("key", "") + memory_type = mem_data.get("memory_type", "LongTermMemory") + + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=new_value, + metadata=TreeNodeTextualMemoryMetadata( + is_fast=False, + version=1, + memory_type=memory_type, + status="activated", + tags=tags, + key=key, + created_at=datetime.now().isoformat(), + history=[], + ), + ) + return new_item + + def _handle_restored_memories( + self, restored_memories: list[dict[str, Any]] + ) -> list[TextualMemoryItem]: + """Handle Restored Memories (Extract from conflict).""" + source_ids = [r.get("source_candidate_id") for r in restored_memories] + source_items = self.graph_db.get_nodes(source_ids) + source_items = [TextualMemoryItem(**i) for i in source_items] + + created_items = [] + for i, data in enumerate(restored_memories): + source_item = source_items[i] + source_history = source_item.history.copy() + value = data.get("value", "") + tags = data.get("tags", []) + keys = data.get("keys", []) + memory_type = data.get("memory_type", "LongTermMemory") + version = source_item.verion + new_history_item = ArchivedTextualMemory( + version=version, + is_fast=False, + memory=source_item.memory, + update_type="extract", + archived_memory_id=source_item.id, + created_at=source_item.metadata.created_at, + ) + source_history.append(new_history_item) # Re-use the history of the old node + # Create new node + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=value, + metadata=TreeNodeTextualMemoryMetadata( + memory_type=memory_type, + status="activated", + is_fast=False, + version=version + 1, + key=keys, + tags=tags, + created_at=datetime.now().isoformat(), + history=source_history, + ), + ) + created_items.append(new_item) + + return created_items diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py new file mode 100644 index 000000000..a5933f192 --- /dev/null +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -0,0 +1,168 @@ +# ========================================== +# 记忆更新与维护链路 (Memory Update & Maintenance) +# ========================================== +ASYNC_MEMORY_UPDATE_PROMPT_ZH = """您是记忆库维护专家。 +您的核心任务是根据最新的用户对话,维护和更新用户的长期记忆图谱。您需要利用系统提供的“候选记忆”(Candidates)来决定是更新现有记忆节点,还是创建全新的记忆节点。 + +系统检测到以下“候选记忆”与用户的最新输入存在潜在关联: + +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: +${duplicate_candidates} + +2. **潜在事实冲突记忆 (Conflict Candidates)**: +${conflict_candidates} + +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: +${unrelated_candidates} + +请执行以下操作: + +1. **核心原则(STRICT)**: + - 您的目标是**维护**记忆库,而非仅仅提取信息。 + - **优先更新**:如果对话内容涉及现有的“候选记忆”,应优先视为对该记忆节点的**更新**(补充细节或修正状态),而不是创建重复的新节点。 + - **按需新增**:仅当对话内容包含全新的、与现有“候选记忆”完全无关的话题时,才创建新的记忆节点。 + - 提取来源**只能**是【当前的对话内容】。严禁编造未提及的信息。 + +2. 识别反映用户经历、信念、关切、决策、计划或反应的信息。 + - 如果消息来自用户,提取用户相关的记忆。 + - 如果来自助手,仅提取用户认可或回应的事实性记忆。 + +3. 清晰解析所有时间、人物和事件的指代(同原规则): + - 将相对时间(“昨天”)转换为绝对日期。 + - 明确区分事件时间和消息时间。 + - 解析代词和模糊指代。 + - 使用第三人称(“用户”)。 + - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 + +4. **处理逻辑(更新与新增)**: + 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: + + a) **更新现有记忆节点 (Update via Duplicate/Related)**: + - 检查“潜在重复/关联记忆”。 + - 如果新信息是对某条旧记忆的重复、确认或补充细节: + - 生成一条**更新后的完整记忆**放入 `value`(包含旧信息+新细节)。 + - 将该旧记忆的ID放入 `source_candidate_ids`。 + - 此时 `conflicted_candidate_ids` 应为空。 + + b) **修正冲突记忆节点 (Update via Conflict)**: + - 检查“潜在事实冲突记忆”。 + - 如果新信息否定了某条旧记忆,或更新了其状态(如“不再喜欢X”“改成Y”“取消计划”“从X转为Y”): + - 生成一条反映**最新状态**的记忆放入 `value`。 + - 将被修正的旧记忆ID放入 `conflicted_candidate_ids`。 + + c) **创建新记忆节点 (Create New)**: + - 如果新信息与任何“候选记忆”都无直接关联(既非重复也非冲突): + - 生成一条独立的新记忆放入 `value`。 + - 确保 `source_candidate_ids` 和 `conflicted_candidate_ids` 均为 `[]`。 + +5. **挽救剩余信息 (Restored Memories)**: + - 当您决定更新一个冲突节点(情况 b)时,请检查该旧节点是否包含**未被新信息否定且依然独立的有效事实**。 + - 如果有,请将这些“剩余有效信息”提取出来,作为 `restored_memories` 返回。这能防止因节点更新导致非冲突信息的丢失。 + +返回一个有效的JSON对象,结构如下: + +{ + "memory list": [ + { + "key": <字符串,简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory",区分该记忆是客观事实还是和用户相关的内容>, + "value": <字符串,更新后的完整记忆内容(针对更新/冲突情况)或全新记忆内容(针对新增情况)>, + "tags": <相关主题关键词列表>, + "source_candidate_ids": <字符串列表,被此条目更新的“重复/关联记忆”ID。若无则为 []>, + "conflicted_candidate_ids": <字符串列表,被此条目修正的“事实冲突记忆”ID。若无则为 []>, + }, + ... + ], + "restored_memories": [ + { + "source_candidate_id": <字符串,来源旧冲突记忆ID>, + "value": <字符串,从旧节点中分离出的、依然有效的剩余事实>, + "tags": <相关主题关键词列表> + }, + ... + ], + "summary": <从用户视角自然总结本次记忆更新操作的段落,120–200字> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 + +${custom_tags_prompt} + +示例: +**潜在重复/关联记忆 (Duplicate/Related Candidates):** +[ID:101] 用户喜欢喝拿铁,通常不加糖。 +[ID:102] 用户讨厌下雨天。 + +**潜在事实冲突记忆 (Conflict Candidates):** +[ID:201] 用户喜欢打羽毛球,但不喜欢滑雪。 + +**对话:** +user: [2025年6月26日]:我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 + +**输出:** +{ + "memory list": [ + { + "key": "咖啡偏好", + "memory_type": "UserMemory", + "value": "用户最喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", + "tags": ["饮食", "咖啡", "喜好"], + "source_candidate_ids": ["101"], + "conflicted_candidate_ids": [], + }, + { + "key": "运动习惯变更", + "memory_type": "UserMemory", + "value": "用户因膝盖受伤,决定不再打羽毛球。", + "tags": ["运动", "健康", "羽毛球"], + "source_candidate_ids": [], + "conflicted_candidate_ids": ["201"], + }, + { + "key": "宠物状况", + "memory_type": "UserMemory", + "value": "用户最近(2025年6月)刚养了一只猫。", + "tags": ["生活", "宠物", "猫"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + } + ], + "restored_memories": [ + { + "key": "运动偏好", + "memory_type": "UserMemory", + "value": "用户不喜欢滑雪。", + "tags": ["运动", "滑雪", "喜好"] + "source_candidate_id": "201", + } + ], + "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了用户养猫的记忆。对天气的喜好未受影响。" +} + +请始终使用与对话相同的语言进行回复。 + +对话: +${conversation} +""" + +MEMORY_MERGE_PROMPT_ZH = """ +您是记忆库维护专家。 +我们尝试更新一个记忆节点,但该节点在数据库中的内容在处理期间发生了变化(版本冲突)。 +我们需要将“本次处理得出的更新内容”合并到“当前数据库中最新的记忆内容”中。 + +**原始记忆(数据库中的最新版本):** +${latest_memory} + +**本次尝试的更新内容(基于旧版本得出的结论):** +${proposed_update} + +**任务:** +将“本次尝试的更新内容”合并到“原始记忆”中。 +- 如果更新内容包含新信息,请将其整合进去。 +- 如果更新内容与原始记忆冲突,请优先采纳更新内容(假设它是基于最新对话的修正),但请尽量保留原始记忆中依然有效的细节。 +- 确保合并后的结果是一个连贯、通顺的完整记忆片段。 + +请只返回合并后的记忆内容字符串,不要包含任何解释。 +""" From 3bb3784f5b0350aa0401a4055b329942b9dfbcc3 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 10:19:13 +0800 Subject: [PATCH 05/92] feat: Supporting English prompts for memory versions. --- .../organize/history_manager.py | 40 +++- .../mem_reader_mem_version_prompts.py | 178 +++++++++++++++++- 2 files changed, 207 insertions(+), 11 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 2c2d37654..65295c769 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -10,14 +10,15 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.llms.base import BaseLLM +from memos.mem_reader.read_multi_modal.utils import detect_lang from memos.memories.textual.item import ( ArchivedTextualMemory, TextualMemoryItem, TreeNodeTextualMemoryMetadata, ) from memos.templates.mem_reader_mem_version_prompts import ( - ASYNC_MEMORY_UPDATE_PROMPT_ZH, - MEMORY_MERGE_PROMPT_ZH, + ASYNC_MEMORY_UPDATE_PROMPT_DICT, + MEMORY_MERGE_PROMPT_DICT, ) @@ -162,9 +163,22 @@ def update_node_with_history( return item, archived_item - @staticmethod + def _determine_lang(self, sources: list | None, fallback_text: str) -> str: + lang = None + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang + break + if isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(fallback_text) + return lang + def format_async_update_prompt( - item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" ) -> str: """ Format the prompt for asynchronous memory update. @@ -195,14 +209,18 @@ def format_async_update_prompt( # Includes "unrelated" and any other types unrelated_candidates.append(candidate_str) - # Helper to format list as string + sources = item.metadata.sources if item.metadata else None + lang = self._determine_lang(sources, conversation or item.memory) + empty_label = "None" + def format_list(candidates): - return "\n".join(candidates) if candidates else "无" + return "\n".join(candidates) if candidates else empty_label + prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( + lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] + ) return ( - ASYNC_MEMORY_UPDATE_PROMPT_ZH.replace( - "${duplicate_candidates}", format_list(duplicate_candidates) - ) + prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) .replace("${conflict_candidates}", format_list(conflict_candidates)) .replace("${unrelated_candidates}", format_list(unrelated_candidates)) .replace("${custom_tags_prompt}", custom_tags_prompt) @@ -668,7 +686,9 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> if not self.llm: return proposed_update - prompt = MEMORY_MERGE_PROMPT_ZH.replace("${latest_memory}", latest_memory).replace( + lang = self._determine_lang(None, f"{latest_memory}\n{proposed_update}") + prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) + prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( "${proposed_update}", proposed_update ) diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index a5933f192..61e56f86c 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -1,5 +1,5 @@ # ========================================== -# 记忆更新与维护链路 (Memory Update & Maintenance) +# Memory Update & Maintenance # ========================================== ASYNC_MEMORY_UPDATE_PROMPT_ZH = """您是记忆库维护专家。 您的核心任务是根据最新的用户对话,维护和更新用户的长期记忆图谱。您需要利用系统提供的“候选记忆”(Candidates)来决定是更新现有记忆节点,还是创建全新的记忆节点。 @@ -147,6 +147,157 @@ ${conversation} """ +ASYNC_MEMORY_UPDATE_PROMPT_EN = """You are a memory maintenance expert. +Your core task is to maintain and update the user's long-term memory graph based on the latest conversation. Use the system-provided "Candidates" to decide whether to update existing memory nodes or create brand-new ones. + +The system detected the following "Candidates" potentially related to the user's latest input: + +1. Duplicate/Related Candidates: +${duplicate_candidates} + +2. Conflict Candidates: +${conflict_candidates} + +3. Possibly unrelated (needs further judgment): +${unrelated_candidates} + +Please execute the following: + +1. Core Principles (STRICT): + - Your goal is maintenance of the memory base, not mere extraction. + - Prefer Update: If the conversation touches any existing "Candidates", treat it as an update to that memory node (add details or correct status) rather than creating a duplicate node. + - Add As Needed: Only create a new node when the conversation contains truly new topics that are completely unrelated to existing "Candidates". + - The extraction source must be ONLY the current conversation. Do not fabricate information not mentioned. + +2. Identify information that reflects the user's experiences, beliefs, concerns, decisions, plans, or reactions. + - If the message is from the user, extract user-related memories. + - If it is from the assistant, only extract factual memories that the user explicitly acknowledges or responds to. + +3. Disambiguate all references to time, people, and events: + - Convert relative time ("yesterday") to an absolute date. + - Clearly distinguish event time from message time. + - Resolve pronouns and ambiguous references. + - Use third person ("the user"). + - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). + +4. Processing Logic (Update and Create): + Traverse each piece of information in the conversation that is worth remembering and apply: + + a) Update existing memory node (Update via Duplicate/Related): + - Check Duplicate/Related Candidates. + - If the new information repeats, confirms, or adds details to an old memory: + - Generate an updated complete memory into `value` (old info + new details). + - Put the old memory IDs into `source_candidate_ids`. + - `conflicted_candidate_ids` must be []. + + b) Fix conflicting memory node (Update via Conflict): + - Check Conflict Candidates. + - If the new information negates an old memory or updates its state (e.g., "no longer likes X", "changed to Y", "canceled plan", "from X to Y"): + - Generate a memory reflecting the latest state into `value`. + - Put the corrected old memory IDs into `conflicted_candidate_ids`. + + c) Create new memory node (Create New): + - If the new information is not directly related to any "Candidates" (neither duplicate nor conflict): + - Generate an independent new memory into `value`. + - Ensure `source_candidate_ids` and `conflicted_candidate_ids` are both `[]`. + +5. Restored Memories: + - When you decide to update a conflict node (case b), check whether the old node contains valid facts that were NOT negated by the new information and remain independently valid. + - If any, extract these "remaining valid facts" and return them as `restored_memories` to prevent loss of non-conflict information due to node updates. + +Return a valid JSON object with the structure: + +{ + "memory list": [ + { + "key": , + "memory_type": , + "value": , + "tags": , + "source_candidate_ids": , + "conflicted_candidate_ids": , + }, + ... + ], + "restored_memories": [ + { + "source_candidate_id": , + "value": , + "tags": + }, + ... + ], + "summary": +} + +Language rules: +- The `key`, `value`, `tags`, and `summary` fields must match the main language of the input conversation. If the input is English, output English. +- `memory_type` remains in English. + +${custom_tags_prompt} + +Example: +Duplicate/Related Candidates: +[ID:101] The user likes latte and usually doesn't add sugar. +[ID:102] The user hates rainy days. + +Conflict Candidates: +[ID:201] The user likes badminton but dislikes skiing. + +Conversation: +user: [June 26, 2025]: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. + +Output: +{ + "memory list": [ + { + "key": "Coffee preference", + "memory_type": "UserMemory", + "value": "The user likes latte most, usually doesn't add sugar, and prefers oat milk.", + "tags": ["diet", "coffee", "preference"], + "source_candidate_ids": ["101"], + "conflicted_candidate_ids": [], + }, + { + "key": "Sport habit change", + "memory_type": "UserMemory", + "value": "Due to a knee injury, the user decides to no longer play badminton.", + "tags": ["sport", "health", "badminton"], + "source_candidate_ids": [], + "conflicted_candidate_ids": ["201"], + }, + { + "key": "Pet status", + "memory_type": "UserMemory", + "value": "The user recently (June 2025) adopted a cat.", + "tags": ["life", "pet", "cat"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + } + ], + "restored_memories": [ + { + "key": "Sports preference", + "memory_type": "UserMemory", + "value": "The user dislikes skiing.", + "tags": ["sport", "skiing", "preference"], + "source_candidate_id": "201" + } + ], + "summary": "In this update, the user refined their latte preference (adds oat milk), and due to a knee injury updated their sports habit (stop playing badminton). The system preserves the record that the user dislikes skiing. Additionally, a new memory is added that the user adopted a cat. Weather preference is unaffected." +} + +Always reply in the same language as the conversation. + +Conversation: +${conversation} +""" + +ASYNC_MEMORY_UPDATE_PROMPT_DICT = { + "zh": ASYNC_MEMORY_UPDATE_PROMPT_ZH, + "en": ASYNC_MEMORY_UPDATE_PROMPT_EN, +} + MEMORY_MERGE_PROMPT_ZH = """ 您是记忆库维护专家。 我们尝试更新一个记忆节点,但该节点在数据库中的内容在处理期间发生了变化(版本冲突)。 @@ -166,3 +317,28 @@ 请只返回合并后的记忆内容字符串,不要包含任何解释。 """ + +MEMORY_MERGE_PROMPT_EN = """ +You are a memory maintenance expert. +We attempted to update a memory node, but the content of that node changed in the database during processing (version conflict). +We need to merge "the update derived in this attempt" into "the latest memory content currently stored in the database". + +Original memory (latest version in the database): +${latest_memory} + +Proposed update (derived based on an old version): +${proposed_update} + +Task: +Merge "the proposed update" into "the original memory". +- If the update contains new information, integrate it. +- If the update conflicts with the original memory, prefer the update (assuming it is a correction based on the latest conversation), while preserving any details from the original memory that remain valid. +- Ensure the merged result is a coherent, fluent, and complete memory passage. + +Return ONLY the merged memory content string. Do not include any explanation. +""" + +MEMORY_MERGE_PROMPT_DICT = { + "zh": MEMORY_MERGE_PROMPT_ZH, + "en": MEMORY_MERGE_PROMPT_EN, +} From 6d43713c5a992b5d1cc947c11724a8f40dad5e67 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 18:49:27 +0800 Subject: [PATCH 06/92] feat: Merge memory version features into the fine/async pathway in add process --- src/memos/api/config.py | 1 + src/memos/mem_reader/multi_modal_struct.py | 89 ++++-- .../handlers/mem_read_handler.py | 99 +++---- src/memos/memories/textual/item.py | 4 +- .../organize/history_manager.py | 265 ++++++++++-------- 5 files changed, 274 insertions(+), 184 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 70d9366e3..c4a126140 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -834,6 +834,7 @@ def get_product_default_config() -> dict[str, Any]: "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/" ), }, + "memory_version_switch": os.getenv("MEM_READER_MEM_VERSION_SWITCH", "off"), }, }, "enable_textual_memory": True, diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 16098907d..bb576fe24 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -701,6 +701,21 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 0: Memory versioning async update pipeline ========== + if ( + self.memory_version_switch == "on" + and self.history_manager is not None + and self.history_manager.is_applicable(fast_item) + ): + try: + new_items = self._process_async_versioning_update( + fast_item, mem_str, custom_tags, **kwargs + ) + return new_items + except Exception as e: + logger.warning(f"[MultiModalFine] Async versioning pipeline failed: {e}") + return [] + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) @@ -711,14 +726,15 @@ def _process_one_item( if resp.get("memory list", []): for m in resp.get("memory list", []): try: - # Check and merge with similar memories if needed - m_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=m, - mem_text=m.get("value", ""), - sources=sources, - original_query=mem_str, - **kwargs, - ) + m_maybe_merged = m + if self.memory_version_switch != "on": + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( m_maybe_merged.get("memory_type", "LongTermMemory") @@ -735,8 +751,7 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in m_maybe_merged: + if self.memory_version_switch != "on" and "merged_from" in m_maybe_merged: node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) @@ -745,13 +760,15 @@ def _process_one_item( elif resp.get("value") and resp.get("key"): try: # Check and merge with similar memories if needed - resp_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=resp, - mem_text=resp.get("value", "").strip(), - sources=sources, - original_query=mem_str, - **kwargs, - ) + resp_maybe_merged = resp + if self.memory_version_switch != "on": + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, @@ -762,8 +779,7 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in resp_maybe_merged: + if self.memory_version_switch != "on" and "merged_from" in resp_maybe_merged: node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) @@ -955,6 +971,8 @@ def _fast_resolve_memory_duplicates_and_conflicts( return for item in fast_memory_items: + if not self.history_manager.is_applicable(item): + continue try: # recall related memories related = self.pre_update_retriever.retrieve( @@ -962,17 +980,46 @@ def _fast_resolve_memory_duplicates_and_conflicts( user_name=user_name, ) # NLI check & attaching contents - conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( + conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( item, related ) # mark delete(temporarily) self.history_manager.mark_memory_status( - [m.id for m in conflicting_or_duplicate_items], "resolving" + conflicting_or_duplicate_ids, "resolving", user_name=user_name ) except Exception as e: logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + def _process_async_versioning_update( + self, item: TextualMemoryItem, mem_str: str, custom_tags: dict[str, str], **kwargs + ) -> list[TextualMemoryItem]: + """ + 1. Wait for fast node resolution and rebuild its history + 2. Build async update prompt (include custom tags and conversation context) + 3. Call LLM and parse JSON response + 4. Apply LLM updates to memory graph and return new items + """ + self.history_manager.wait_and_update_fast_history(item, timeout_sec=30) + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = self.history_manager.format_async_update_prompt( + item, + conversation=kwargs.get("chat_history") or mem_str, + custom_tags_prompt=custom_tags_prompt, + ) + response_text = self.llm.generate([{"role": "user", "content": prompt}]) + response_json = parse_json_result(response_text) + user_name = kwargs.get("user_name") + _, new_items = self.history_manager.apply_llm_memory_updates( + response_json, item, user_name=user_name + ) + return new_items + @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 5d86c5589..bd1d69671 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -196,59 +196,60 @@ def _process_memories_with_reader( enhanced_mem_ids, ) - # add raw file nodes and edges - if mem_reader.save_rawfile: - raw_file_mem_group = [ + if getattr(mem_reader, "memory_version_switch", "off") != "on": + # add raw file nodes and edges + if mem_reader.save_rawfile: + raw_file_mem_group = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type == "RawFileMemory" + ] + text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + enhanced_mem_ids, + user_id=user_id, + user_name=user_name, + ) + logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + + # Mark merged_from memories as archived when provided in memory metadata + summary_memories = [ memory for memory in flattened_memories - if memory.metadata.memory_type == "RawFileMemory" + if memory.metadata.memory_type != "RawFileMemory" ] - text_mem.add_rawfile_nodes_n_edges( - raw_file_mem_group, - enhanced_mem_ids, - user_id=user_id, - user_name=user_name, - ) - logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) - - # Mark merged_from memories as archived when provided in memory metadata - summary_memories = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type != "RawFileMemory" - ] - if mem_reader.graph_db: - for memory in summary_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - "[Scheduler] Archived merged_from memory: %s", - old_id, - ) - except Exception as e: - logger.warning( - "[Scheduler] Failed to archive merged_from memory %s: %s", - old_id, - e, - ) - else: - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in summary_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + if mem_reader.graph_db: + for memory in summary_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), {"status": "archived"}, user_name=user_name + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in summary_memories ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) cloud_env = is_cloud_env() if cloud_env: diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 7e40f1d50..573a83856 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -109,8 +109,8 @@ class TextualMemoryMetadata(BaseModel): default_factory=list, description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", ) - version: int | None = Field( - default=None, + version: int = Field( + default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) history: list[ArchivedTextualMemory] | None = Field( diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 65295c769..dbe31a4a2 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -109,6 +109,21 @@ def _add(history_item): item.metadata.history = list(new_history.values()) +def _determine_lang(sources: list | None, fallback_text: str) -> str: + lang = None + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang + break + if isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(fallback_text) + return lang + + class MemoryHistoryManager: def __init__( self, nli_client: NLIClient, graph_db: BaseGraphDB, llm: BaseLLM | None = None @@ -125,6 +140,18 @@ def __init__( self.graph_db = graph_db self.llm = llm + @staticmethod + def is_applicable(item: TextualMemoryItem) -> bool: + # Only deals with: + # 1. From doc or chat + # 2. LongTermMemory, UserMemory + allowed_sources = ["doc", "chat"] + allowed_memory_types = ["LongTermMemory", "UserMemory"] + return ( + item.metadata.sources[0].type in allowed_sources + and item.metadata.memory_type in allowed_memory_types + ) + @staticmethod def update_node_with_history( item: TextualMemoryItem, @@ -163,70 +190,6 @@ def update_node_with_history( return item, archived_item - def _determine_lang(self, sources: list | None, fallback_text: str) -> str: - lang = None - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - if isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(fallback_text) - return lang - - def format_async_update_prompt( - self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" - ) -> str: - """ - Format the prompt for asynchronous memory update. - - Args: - item: The TextualMemoryItem containing history candidates. - conversation: The current conversation content. If None, uses item.memory. - custom_tags_prompt: Optional custom prompt for tags. - - Returns: - Formatted prompt string. - """ - # First, detach duplicate and conflict memory contents from the new item's memory text - _detach_related_content(item) - - duplicate_candidates = [] - conflict_candidates = [] - unrelated_candidates = [] - - for h in item.metadata.history or []: - candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" - - if h.update_type == "duplicate": - duplicate_candidates.append(candidate_str) - elif h.update_type == "conflict": - conflict_candidates.append(candidate_str) - else: - # Includes "unrelated" and any other types - unrelated_candidates.append(candidate_str) - - sources = item.metadata.sources if item.metadata else None - lang = self._determine_lang(sources, conversation or item.memory) - empty_label = "None" - - def format_list(candidates): - return "\n".join(candidates) if candidates else empty_label - - prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( - lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] - ) - return ( - prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) - .replace("${conflict_candidates}", format_list(conflict_candidates)) - .replace("${unrelated_candidates}", format_list(unrelated_candidates)) - .replace("${custom_tags_prompt}", custom_tags_prompt) - .replace("${conversation}", conversation or item.memory) - ) - def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[str]: @@ -250,16 +213,20 @@ def resolve_history_via_nli( ) # 2. Process results and attach to history + duplicate_memory_ids = [] + conflict_memory_ids = [] duplicate_memories = [] conflict_memories = [] for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" - duplicate_memories.append(r_item.id) + duplicate_memory_ids.append(r_item.id) + duplicate_memories.append(r_item.memory) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" - conflict_memories.append(r_item.id) + conflict_memory_ids.append(r_item.id) + conflict_memories.append(r_item.memory) else: update_type = "unrelated" @@ -280,7 +247,7 @@ def resolve_history_via_nli( # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. _append_related_content(new_item, duplicate_memories, conflict_memories) - return duplicate_memories + conflict_memories + return duplicate_memory_ids + conflict_memory_ids def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: """ @@ -338,6 +305,56 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int return + def format_async_update_prompt( + self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + ) -> str: + """ + Format the prompt for asynchronous memory update. + + Args: + item: The TextualMemoryItem containing history candidates. + conversation: The current conversation content. If None, uses item.memory. + custom_tags_prompt: Optional custom prompt for tags. + + Returns: + Formatted prompt string. + """ + # First, detach duplicate and conflict memory contents from the new item's memory text + _detach_related_content(item) + + duplicate_candidates = [] + conflict_candidates = [] + unrelated_candidates = [] + + for h in item.metadata.history or []: + candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" + + if h.update_type == "duplicate": + duplicate_candidates.append(candidate_str) + elif h.update_type == "conflict": + conflict_candidates.append(candidate_str) + else: + # Includes "unrelated" and any other types + unrelated_candidates.append(candidate_str) + + sources = item.metadata.sources if item.metadata else None + lang = _determine_lang(sources, conversation or item.memory) + empty_label = "None" + + def format_list(candidates): + return "\n".join(candidates) if candidates else empty_label + + prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( + lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] + ) + return ( + prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) + .replace("${conflict_candidates}", format_list(conflict_candidates)) + .replace("${unrelated_candidates}", format_list(unrelated_candidates)) + .replace("${custom_tags_prompt}", custom_tags_prompt) + .replace("${conversation}", conversation or item.memory) + ) + def apply_llm_memory_updates( self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: @@ -388,7 +405,7 @@ def apply_llm_memory_updates( # 2. Handle Memory List (Update or New) processed_updates, created_items = self._process_memory_updates( - memory_list, expected_versions + memory_list, expected_versions, user_name ) updated_items.extend(processed_updates) new_items.extend(created_items) @@ -499,7 +516,7 @@ def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) self.mark_memory_status(unrelated_ids, "activated", user_name) def _process_memory_updates( - self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int] + self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int], user_name: str ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: """Process Memory List (Update or New).""" updated_items: list[TextualMemoryItem] = [] @@ -513,10 +530,7 @@ def _process_memory_updates( if target_ids: item = self._update_existing_memory( - mem_data, - target_ids, - source_ids, - expected_versions, + mem_data, target_ids, source_ids, expected_versions, user_name ) if item: updated_items.append(item) @@ -531,6 +545,7 @@ def _update_existing_memory( target_ids: list[str], source_ids: list[str], expected_versions: dict[str, int], + user_name: str, ) -> TextualMemoryItem | None: """ Update existing memory nodes using the LLM result. @@ -559,18 +574,19 @@ def _update_existing_memory( ) # Fetch candidate nodes in batch and then select the primary + # We update the primary and then merge the secondaries to the primary nodes_data = self.graph_db.get_nodes(target_ids) or [] nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} node_data = nodes_map.get(primary_id) if not node_data: - node_data = self.graph_db.get_node(primary_id) - if not node_data: - logger.warning( - f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." - ) - return None - + logger.warning( + f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." + ) + return None current_item = TextualMemoryItem(**node_data) + + # For concurrency control, need to make sure the primary item has not been modified by others in the meantime + # If it has(version changed), then we need to use llm to merge again. new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) update_type = "duplicate" if original_primary_id in source_ids else "conflict" @@ -581,28 +597,41 @@ def _update_existing_memory( tags=tags, key=key, ) - now = datetime.now().isoformat() - if hasattr(archived_item.metadata, "created_at") and not getattr( - archived_item.metadata, "created_at", None - ): - archived_item.metadata.created_at = now - if hasattr(archived_item.metadata, "updated_at"): - archived_item.metadata.updated_at = now - if hasattr(current_item.metadata, "updated_at"): - current_item.metadata.updated_at = now + + # create archived node for storing older versions of the memory self.graph_db.add_node( id=archived_item.id, memory=archived_item.memory, metadata=archived_item.metadata.model_dump(exclude_none=True), + user_name=user_name, ) + + now = datetime.now().isoformat() + if hasattr(current_item.metadata, "updated_at"): + current_item.metadata.updated_at = now fields = current_item.metadata.model_dump(exclude_none=True) fields.pop("created_at", None) - fields["updated_at"] = now + merged_history = list(current_item.metadata.history or []) + new_primary_version = current_item.metadata.version or 1 + # Multiple related ids indicates existing duplicates/conflicts to be merged + if secondary_ids: + merged_history, new_primary_version = self._merge_secondary_nodes( + secondary_ids, primary_id, nodes_map, user_name, merged_history + ) + current_item.metadata.history = merged_history + current_item.metadata.version = new_primary_version + merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] + # update old memory node with new content and updated history self.graph_db.update_node( id=primary_id, - fields={"memory": current_item.memory, **fields}, + fields={ + "memory": current_item.memory, + **fields, + "history": merged_history_dump, + "version": new_primary_version, + }, + user_name=user_name, ) - self._merge_secondary_nodes(secondary_ids, primary_id) return current_item @@ -629,36 +658,35 @@ def _apply_cas_merge( return new_value - def _merge_secondary_nodes(self, secondary_ids: list[str], primary_id: str) -> None: - if not secondary_ids: - return - - node_ids = [primary_id, *secondary_ids] - nodes_data = self.graph_db.get_nodes(ids=node_ids) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} - primary_data = nodes_map.get(primary_id) or self.graph_db.get_node(primary_id) - if not primary_data: - return - - primary_item = TextualMemoryItem(**primary_data) - merged_history = list(primary_item.metadata.history or []) + def _merge_secondary_nodes( + self, + secondary_ids: list[str], + primary_id: str, + nodes_map: dict, + user_name: str, + base_history: list[ArchivedTextualMemory], + ) -> tuple[list[ArchivedTextualMemory], int]: + merged_history = list(base_history) for memory_id in secondary_ids: - node_data = nodes_map.get(memory_id) or self.graph_db.get_node(memory_id) + node_data = nodes_map.get(memory_id) if not node_data: continue metadata = node_data.get("metadata", {}) evolve_to = list(metadata.get("evolve_to", []) or []) if primary_id not in evolve_to: evolve_to.append(primary_id) + # set secondary nodes to archived and record their evolving destinations self.graph_db.update_node( id=memory_id, fields={"status": "archived", "evolve_to": evolve_to}, + user_name=user_name, ) secondary_item = TextualMemoryItem(**node_data) if secondary_item.metadata.history: merged_history.extend(secondary_item.metadata.history) + # Currently we just sort the versions according to their creation time def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: created_at = history_item.created_at if isinstance(created_at, datetime): @@ -670,14 +698,27 @@ def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: return datetime.min return datetime.min + def _dedupe_history_by_archived_id( + history: list[ArchivedTextualMemory], + ) -> list[ArchivedTextualMemory]: + seen_archived_ids: set[str] = set() + deduped_history: list[ArchivedTextualMemory] = [] + for history_item in history: + archived_id = history_item.archived_memory_id + if archived_id and archived_id in seen_archived_ids: + continue + if archived_id: + seen_archived_ids.add(archived_id) + deduped_history.append(history_item) + return deduped_history + merged_history.sort(key=_history_sort_key) + merged_history = _dedupe_history_by_archived_id(merged_history) + max_version = 0 for idx, history_item in enumerate(merged_history, start=1): history_item.version = idx - primary_item.metadata.history = merged_history - self.graph_db.update_node( - id=primary_id, - fields={"history": [h.model_dump(exclude_none=True) for h in merged_history]}, - ) + max_version = idx + return merged_history, max_version + 1 def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: """ @@ -686,7 +727,7 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> if not self.llm: return proposed_update - lang = self._determine_lang(None, f"{latest_memory}\n{proposed_update}") + lang = _determine_lang(None, f"{latest_memory}\n{proposed_update}") prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( "${proposed_update}", proposed_update From 4b8a379782608f25beced502c1035ae8ef0a78f0 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 21:31:49 +0800 Subject: [PATCH 07/92] fix: Replenish missing embeddings during memory version updates/adds. --- src/memos/api/handlers/component_init.py | 4 +++- .../organize/history_manager.py | 21 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index e764c7fd4..e1b4c5f2f 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -192,7 +192,9 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db, llm=llm) + memory_history_manager = MemoryHistoryManager( + nli_client=nli_client, graph_db=graph_db, llm=llm, embedder=embedder + ) pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config( diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index dbe31a4a2..5b0f411d5 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -6,6 +6,7 @@ from typing import Any, Literal from memos.context.context import ContextThreadPoolExecutor +from memos.embedders.base import BaseEmbedder from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB @@ -126,7 +127,11 @@ def _determine_lang(sources: list | None, fallback_text: str) -> str: class MemoryHistoryManager: def __init__( - self, nli_client: NLIClient, graph_db: BaseGraphDB, llm: BaseLLM | None = None + self, + nli_client: NLIClient, + graph_db: BaseGraphDB, + llm: BaseLLM | None = None, + embedder: BaseEmbedder | None = None, ) -> None: """ Initialize the MemoryHistoryManager. @@ -139,6 +144,16 @@ def __init__( self.nli_client = nli_client self.graph_db = graph_db self.llm = llm + self.embedder = embedder + + def _compute_embedding(self, text: str) -> list[float] | None: + if not self.embedder: + return None + try: + return self.embedder.embed([text])[0] + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to compute embedding: {e}") + return None @staticmethod def is_applicable(item: TextualMemoryItem) -> bool: @@ -621,6 +636,7 @@ def _update_existing_memory( current_item.metadata.history = merged_history current_item.metadata.version = new_primary_version merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] + embedding = self._compute_embedding(current_item.memory) # update old memory node with new content and updated history self.graph_db.update_node( id=primary_id, @@ -629,6 +645,7 @@ def _update_existing_memory( **fields, "history": merged_history_dump, "version": new_primary_version, + "embedding": embedding, }, user_name=user_name, ) @@ -764,6 +781,7 @@ def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: key=key, created_at=datetime.now().isoformat(), history=[], + embedding=self._compute_embedding(new_value), ), ) return new_item @@ -807,6 +825,7 @@ def _handle_restored_memories( tags=tags, created_at=datetime.now().isoformat(), history=source_history, + embedding=self._compute_embedding(value), ), ) created_items.append(new_item) From b50e11ee96e6ebee7f8176d6cb1f8ca46de02ce4 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 22:05:45 +0800 Subject: [PATCH 08/92] fix: Identify fast nodes correctly in memory versions. --- .../textual/tree_text_memory/organize/history_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 5b0f411d5..f8294b82e 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -248,9 +248,13 @@ def resolve_history_via_nli( # Safely get created_at, fallback to updated_at created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at + # TODO: change the way of marking fast nodes by directly using is_fast field. archived = ArchivedTextualMemory( version=r_item.metadata.version or 1, - is_fast=r_item.metadata.is_fast or False, + is_fast=( + r_item.metadata.is_fast + or ("mode:fast" in (getattr(r_item.metadata, "tags", None) or [])) + ), memory=r_item.memory, update_type=update_type, archived_memory_id=r_item.id, From 7317815802a41601188083de9129252d17439d4d Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 10:21:24 +0800 Subject: [PATCH 09/92] fix: Fix branch merge --- .../handlers/mem_read_handler.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index bd1d69671..c2a10b531 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -196,22 +196,23 @@ def _process_memories_with_reader( enhanced_mem_ids, ) - if getattr(mem_reader, "memory_version_switch", "off") != "on": - # add raw file nodes and edges - if mem_reader.save_rawfile: - raw_file_mem_group = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type == "RawFileMemory" - ] - text_mem.add_rawfile_nodes_n_edges( - raw_file_mem_group, - enhanced_mem_ids, - user_id=user_id, - user_name=user_name, - ) - logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + # add raw file nodes and edges + if mem_reader.save_rawfile: + raw_file_mem_group = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type == "RawFileMemory" + ] + text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + enhanced_mem_ids, + user_id=user_id, + user_name=user_name, + ) + logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + # fallback to simple deduplication logic when mem version switch is off + if getattr(mem_reader, "memory_version_switch", "off") != "on": # Mark merged_from memories as archived when provided in memory metadata summary_memories = [ memory @@ -230,7 +231,9 @@ def _process_memories_with_reader( for old_id in old_ids: try: mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name + str(old_id), + {"status": "archived"}, + user_name=user_name, ) logger.info( "[Scheduler] Archived merged_from memory: %s", From 7bc18381cc7a0ff6e2b405f836ce134479b9ca36 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 14:57:03 +0800 Subject: [PATCH 10/92] fix: Fix WorkingMemory behaviors and time settings in memory versions. --- .../organize/history_manager.py | 20 ++++++++++++++----- .../tree_text_memory/organize/manager.py | 9 ++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index f8294b82e..e19c0d593 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -175,6 +175,9 @@ def update_node_with_history( tags: list[str] | None = None, key: str | None = None, ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + now = datetime.now().isoformat() + last_update_time = item.metadata.updated_at + old_id = item.id archived_id = str(uuid.uuid4()) # archived memory(need to store this node to the db later) @@ -182,6 +185,8 @@ def update_node_with_history( archived_item.id = archived_id archived_item.metadata.evolve_to = [old_id] archived_item.metadata.status = "archived" + archived_item.metadata.created_at = last_update_time + archived_item.metadata.updated_at = now # original memory with updated contents and history archived_history = ArchivedTextualMemory( @@ -190,11 +195,12 @@ def update_node_with_history( memory=item.memory, update_type=update_type, archived_memory_id=archived_id, - created_at=getattr(item.metadata, "created_at", None) or datetime.now().isoformat(), + created_at=getattr(item.metadata, "updated_at", None) or last_update_time, ) item.memory = new_memory item.metadata.version = (item.metadata.version or 1) + 1 item.metadata.status = "activated" + item.metadata.updated_at = now if tags is not None: item.metadata.tags = tags if key is not None: @@ -625,11 +631,7 @@ def _update_existing_memory( user_name=user_name, ) - now = datetime.now().isoformat() - if hasattr(current_item.metadata, "updated_at"): - current_item.metadata.updated_at = now fields = current_item.metadata.model_dump(exclude_none=True) - fields.pop("created_at", None) merged_history = list(current_item.metadata.history or []) new_primary_version = current_item.metadata.version or 1 # Multiple related ids indicates existing duplicates/conflicts to be merged @@ -653,6 +655,14 @@ def _update_existing_memory( }, user_name=user_name, ) + working_binding = getattr(current_item.metadata, "working_binding", None) + if working_binding and working_binding != current_item.id: + try: + self.mark_memory_status([str(working_binding)], "deleted", user_name=user_name) + except Exception as e: + logger.warning( + f"[MemoryHistoryManager] Failed to mark WorkingMemory {working_binding} as deleted: {e}" + ) return current_item diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index cbc349d67..9733553c0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -191,10 +191,12 @@ def _add_memories_batch( ) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() + metadata_dict["working_binding"] = working_id # Add working_binding for fast mode tags = metadata_dict.get("tags") or [] if "mode:fast" in tags: + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_id}] direct built from raw inputs" metadata_dict["background"] = ( @@ -234,7 +236,8 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: exc_info=e, ) - _submit_batches(working_nodes, "WorkingMemory") + # TODO: working id is same with item.id, need to fix, currently stop adding WorkingMemories here. + # here used to be: _submit_batches(working_nodes, "WorkingMemory") _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: @@ -318,6 +321,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non ids: list[str] = [] futures = [] + # TODO: working id is same with item.id, need to fix working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: @@ -393,8 +397,11 @@ def _add_to_graph_memory( node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) + if working_binding: + metadata_dict["working_binding"] = working_binding tags = metadata_dict.get("tags") or [] if working_binding and ("mode:fast" in tags): + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" if prev_bg: From 37e7ebb52179bd5ed3dc0be37a4b1c21f9fa33d3 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 15:49:37 +0800 Subject: [PATCH 11/92] feat: Refine prompts for memory versions to take account of time factors. --- src/memos/mem_reader/multi_modal_struct.py | 6 +--- .../organize/history_manager.py | 25 ++++++++++++---- .../mem_reader_mem_version_prompts.py | 30 ++++++++++++++----- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index bb576fe24..8237e22e5 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1007,11 +1007,7 @@ def _process_async_versioning_update( if custom_tags else "" ) - prompt = self.history_manager.format_async_update_prompt( - item, - conversation=kwargs.get("chat_history") or mem_str, - custom_tags_prompt=custom_tags_prompt, - ) + prompt = self.history_manager.format_async_update_prompt(item, custom_tags_prompt) response_text = self.llm.generate([{"role": "user", "content": prompt}]) response_json = parse_json_result(response_text) user_name = kwargs.get("user_name") diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index e19c0d593..3c719b189 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -331,14 +331,13 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int return def format_async_update_prompt( - self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + self, item: TextualMemoryItem, custom_tags_prompt: str = "" ) -> str: """ Format the prompt for asynchronous memory update. Args: item: The TextualMemoryItem containing history candidates. - conversation: The current conversation content. If None, uses item.memory. custom_tags_prompt: Optional custom prompt for tags. Returns: @@ -351,8 +350,20 @@ def format_async_update_prompt( conflict_candidates = [] unrelated_candidates = [] + def _fmt_time(ts: str | None) -> str | None: + if not ts or not isinstance(ts, str): + return None + try: + t = datetime.fromisoformat(ts.replace("Z", "")) + return t.strftime("%Y/%m/%d %H:%M:%S") + except Exception: + return ts + for h in item.metadata.history or []: - candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" + created = getattr(h, "created_at", None) + tstr = _fmt_time(created) + time_suffix = f"[Time: {tstr}] " if tstr else "" + candidate_str = f"[ID:{h.archived_memory_id}]{time_suffix}{h.memory}" if h.update_type == "duplicate": duplicate_candidates.append(candidate_str) @@ -363,7 +374,7 @@ def format_async_update_prompt( unrelated_candidates.append(candidate_str) sources = item.metadata.sources if item.metadata else None - lang = _determine_lang(sources, conversation or item.memory) + lang = _determine_lang(sources, item.memory) empty_label = "None" def format_list(candidates): @@ -372,12 +383,16 @@ def format_list(candidates): prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] ) + conversation_time_raw = getattr(item.metadata, "created_at", None) + conversation_time = _fmt_time(conversation_time_raw) or conversation_time_raw + return ( prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) .replace("${conflict_candidates}", format_list(conflict_candidates)) .replace("${unrelated_candidates}", format_list(unrelated_candidates)) .replace("${custom_tags_prompt}", custom_tags_prompt) - .replace("${conversation}", conversation or item.memory) + .replace("${conversation_time}", conversation_time) + .replace("${conversation}", item.memory) ) def apply_llm_memory_updates( diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index 61e56f86c..d9d57065e 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -33,6 +33,7 @@ - 解析代词和模糊指代。 - 使用第三人称(“用户”)。 - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 + - 候选记忆可能包含 [Time: ...] 表示该记忆的事件时间,请结合“对话时间”判断是否同一时段。 4. **处理逻辑(更新与新增)**: 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: @@ -92,14 +93,17 @@ 示例: **潜在重复/关联记忆 (Duplicate/Related Candidates):** -[ID:101] 用户喜欢喝拿铁,通常不加糖。 -[ID:102] 用户讨厌下雨天。 +[ID:101][Time: 2025/05/20 09:30:00] 用户喜欢喝拿铁,通常不加糖。 +[ID:102][Time: 2025/05/18 18:00:00] 用户讨厌下雨天。 **潜在事实冲突记忆 (Conflict Candidates):** -[ID:201] 用户喜欢打羽毛球,但不喜欢滑雪。 +[ID:201][Time: 2025/02/03 20:15:00] 用户喜欢打羽毛球,但不喜欢滑雪。 + +**对话时间:** +2025/06/26 09:00:00 **对话:** -user: [2025年6月26日]:我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 +user: 我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 **输出:** { @@ -143,6 +147,9 @@ 请始终使用与对话相同的语言进行回复。 +对话时间: +${conversation_time} + 对话: ${conversation} """ @@ -179,6 +186,7 @@ - Resolve pronouns and ambiguous references. - Use third person ("the user"). - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). + - Candidates may include [Time: ...] to indicate event time; use the conversation time to judge whether they are the same period. 4. Processing Logic (Update and Create): Traverse each piece of information in the conversation that is worth remembering and apply: @@ -238,14 +246,17 @@ Example: Duplicate/Related Candidates: -[ID:101] The user likes latte and usually doesn't add sugar. -[ID:102] The user hates rainy days. +[ID:101][Time: 2025/05/20 09:30:00] The user likes latte and usually doesn't add sugar. +[ID:102][Time: 2025/05/18 18:00:00] The user hates rainy days. Conflict Candidates: -[ID:201] The user likes badminton but dislikes skiing. +[ID:201][Time: 2025/02/03 20:15:00] The user likes badminton but dislikes skiing. + +Conversation time: +2025/06/26 09:00:00 Conversation: -user: [June 26, 2025]: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. +user: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. Output: { @@ -289,6 +300,9 @@ Always reply in the same language as the conversation. +Conversation time: +${conversation_time} + Conversation: ${conversation} """ From 06735f45e96ab8f124e3e620b7e3afe3a4b85cff Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 17:30:55 +0800 Subject: [PATCH 12/92] fix: Fix default values for memory items. --- src/memos/memories/textual/item.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 573a83856..654e19354 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -105,15 +105,15 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", ) - evolve_to: list[str] | None = Field( + evolve_to: list[str] = Field( default_factory=list, - description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + description="Recording which new memory nodes it 'evolves' to after llm extraction.", ) version: int = Field( default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) - history: list[ArchivedTextualMemory] | None = Field( + history: list[ArchivedTextualMemory] = Field( default_factory=list, description="Storing the archived versions of the memory. Only preserving core information of each version.", ) From 98cfff414b96f9dd9b848ff873b3b6a3b4718875 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 17:32:29 +0800 Subject: [PATCH 13/92] test: Update test for memory versions --- .../memories/textual/test_history_manager.py | 589 +++++++++++++++++- 1 file changed, 576 insertions(+), 13 deletions(-) diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 46cf3a1f6..f0ea8458f 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -8,13 +8,15 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.memories.textual.item import ( + ArchivedTextualMemory, TextualMemoryItem, - TextualMemoryMetadata, + TreeNodeTextualMemoryMetadata, ) from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, _append_related_content, _detach_related_content, + _rebuild_fast_node_history, ) @@ -36,7 +38,7 @@ def history_manager(mock_nli_client, mock_graph_db): def test_detach_related_content(): original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = ["Duplicate 1", "Duplicate 2"] conflicts = ["Conflict 1", "Conflict 2"] @@ -60,7 +62,7 @@ def test_detach_related_content(): def test_detach_only_conflicts(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = [] conflicts = ["Conflict A"] @@ -75,7 +77,7 @@ def test_detach_only_conflicts(): def test_detach_only_duplicates(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = ["Duplicate A"] conflicts = [] @@ -118,20 +120,581 @@ def test_mark_memory_status(history_manager, mock_graph_db): id1 = uuid.uuid4().hex id2 = uuid.uuid4().hex id3 = uuid.uuid4().hex - items = [ - TextualMemoryItem(memory="M1", id=id1), - TextualMemoryItem(memory="M2", id=id2), - TextualMemoryItem(memory="M3", id=id3), - ] + memory_ids = [id1, id2, id3] status = "resolving" # Action - history_manager.mark_memory_status(items, status) + history_manager.mark_memory_status(memory_ids, status, user_name="u1") # Assert assert mock_graph_db.update_node.call_count == 3 # Verify we called it correctly - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}) + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name="u1") + + +def test_format_async_update_prompt(history_manager): + # Setup + # Create history items + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="101", memory="Duplicate content", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="201", memory="Conflict content", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=1, archived_memory_id="301", memory="Unrelated content", update_type="unrelated" + ) + + item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]), + ) + + # Execute + prompt = history_manager.format_async_update_prompt(item) + + # Verify + assert "[ID:101]" in prompt + assert "Duplicate content" in prompt + assert "[ID:201]" in prompt + assert "Conflict content" in prompt + assert "[ID:301]" in prompt + assert "Unrelated content" in prompt + assert "New user input" in prompt + + # Check that placeholders are gone (basic check) + assert "${duplicate_candidates}" not in prompt + assert "${conflict_candidates}" not in prompt + + +def test_format_async_update_prompt_with_args(history_manager): + item = TextualMemoryItem( + memory="Explicit conversation", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + + prompt = history_manager.format_async_update_prompt(item, custom_tags_prompt="Custom Tags") + + assert "Explicit conversation" in prompt + assert "Custom Tags" in prompt + + +def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): + llm_response = { + "memory list": [ + { + "key": "New Memory", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": ["tag1"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + new_item = new_items[0] + assert new_item.memory == "New Content" + assert new_item.metadata.tags == ["tag1"] + assert new_item.metadata.history == [] + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db): + # Setup existing node + existing_id = uuid.uuid4().hex + existing_node = { + "id": existing_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=existing_id, + memory="Old Content", + update_type="duplicate", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == existing_id + assert updated_item.memory == "Updated Content" + assert updated_item.metadata.version == 2 + assert updated_item.metadata.tags == ["new"] + assert len(updated_item.metadata.history) == 1 + + history_entry = updated_item.metadata.history[0] + assert history_entry.archived_memory_id != existing_id + assert history_entry.archived_memory_id is not None + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "duplicate" + + mock_graph_db.add_node.assert_called_once() + mock_graph_db.update_node.assert_called_once() + args, kwargs = mock_graph_db.update_node.call_args + assert kwargs["id"] == existing_id + assert kwargs["fields"]["memory"] == "Updated Content" + assert kwargs["fields"]["version"] == 2 + + +def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): + source_id = uuid.uuid4().hex + restored_item = TextualMemoryItem( + memory="Restored Content", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + history_manager._handle_restored_memories = MagicMock(return_value=[restored_item]) + llm_response = { + "memory list": [], + "restored_memories": [ + {"source_candidate_id": source_id, "value": "Restored Content", "tags": ["restored"]} + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=source_id, + memory="Old Content", + update_type="conflict", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + assert new_items[0] == restored_item + history_manager._handle_restored_memories.assert_called_once_with( + llm_response["restored_memories"] + ) + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): + id1 = uuid.uuid4().hex + id2 = uuid.uuid4().hex + llm_response = {"memory list": [], "restored_memories": [], "summary": "Summary"} + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=id1, + memory="M1", + update_type="unrelated", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=id2, + memory="M2", + update_type="unrelated", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 0 + + # Check that update_node was called to set status="activated" + # mark_memory_status calls update_node for each item + assert mock_graph_db.update_node.call_count == 2 + + # We can inspect calls + calls = mock_graph_db.update_node.call_args_list + ids = sorted([c.kwargs["id"] for c in calls]) + assert ids == sorted([id1, id2]) + for c in calls: + assert c.kwargs["fields"]["status"] == "activated" + + +def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): + # Setup existing node (primary) + primary_id = uuid.uuid4().hex + secondary_id = uuid.uuid4().hex + existing_node = {"id": primary_id, "memory": "Old Content", "metadata": {"version": 1}} + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [ + existing_node, + { + "id": secondary_id, + "memory": "Secondary", + "metadata": {"version": 1}, + }, + ] + + llm_response = { + "memory list": [ + { + "key": "Conflict Resolved", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": [], + "source_candidate_ids": [], + "conflicted_candidate_ids": [primary_id, secondary_id], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=primary_id, + memory="Old Content", + update_type="conflict", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=secondary_id, + memory="Secondary", + update_type="conflict", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == primary_id + assert updated_item.metadata.history[0].update_type == "conflict" + + # Verify primary update + # The mock_graph_db.update_node is called for primary (update) AND secondary (delete) + + # Find call for primary + primary_update_calls = [ + c + for c in mock_graph_db.update_node.call_args_list + if c.kwargs["id"] == primary_id and "memory" in c.kwargs.get("fields", {}) + ] + assert len(primary_update_calls) >= 1 + assert primary_update_calls[0].kwargs["fields"]["memory"] == "New Content" + + # Find call for secondary + secondary_update_calls = [ + c for c in mock_graph_db.update_node.call_args_list if c.kwargs["id"] == secondary_id + ] + assert len(secondary_update_calls) >= 1 + last_secondary_update = secondary_update_calls[-1] + assert last_secondary_update.kwargs["fields"]["status"] == "archived" + assert last_secondary_update.kwargs["fields"]["evolve_to"] == [primary_id] + + +def test_rebuild_fast_node_history_dedup_and_replace(): + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="a", memory="m1", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="b", memory="m2", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=2, archived_memory_id="a", memory="m3", update_type="duplicate" + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]) + ) + + r1 = ArchivedTextualMemory( + version=2, archived_memory_id="b", memory="m4", update_type="conflict" + ) + r2 = ArchivedTextualMemory( + version=1, archived_memory_id="c", memory="m5", update_type="duplicate" + ) + + _rebuild_fast_node_history(item, {1: [r1, r2]}) + + by_id = {h.archived_memory_id: h for h in item.metadata.history} + assert set(by_id.keys()) == {"a", "b", "c"} + assert by_id["a"].version == 2 + assert by_id["b"].version == 2 + + +def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): + fast_id = uuid.uuid4().hex + history_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="conflict", is_fast=True + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[history_item]) + ) + mock_graph_db.get_nodes.return_value = [ + {"id": fast_id, "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}} + ] + + replacement_item = ArchivedTextualMemory( + version=1, archived_memory_id="n1", memory="r1", update_type="conflict" + ) + history_manager._fetch_evolved_nodes = MagicMock(return_value=[replacement_item]) + + replacements = history_manager._check_and_fetch_replacements(item, [0]) + + assert 0 in replacements + assert replacements[0][0].archived_memory_id == "n1" + history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict") + + +def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): + mock_graph_db.get_nodes.return_value = [ + { + "id": "x1", + "memory": "m1", + "metadata": {"version": 2, "is_fast": False, "created_at": "2024-01-01"}, + }, + { + "id": "x2", + "memory": "m2", + "metadata": {"version": 1, "is_fast": True, "created_at": "2024-01-02"}, + }, + ] + + results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate") + + assert len(results) == 2 + ids = sorted([r.archived_memory_id for r in results]) + assert ids == ["x1", "x2"] + assert all(r.update_type == "duplicate" for r in results) + + +def test_wait_and_update_fast_history_rebuilds(history_manager): + fast_id = uuid.uuid4().hex + fast_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="duplicate", is_fast=True + ) + other_item = ArchivedTextualMemory( + version=1, archived_memory_id="k1", memory="keep", update_type="unrelated", is_fast=False + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[fast_item, other_item]) + ) + + replacement = ArchivedTextualMemory( + version=2, archived_memory_id="n1", memory="new", update_type="duplicate", is_fast=False + ) + history_manager._check_and_fetch_replacements = MagicMock(return_value={0: [replacement]}) + + history_manager.wait_and_update_fast_history(item, timeout_sec=1) + + ids = [h.archived_memory_id for h in item.metadata.history] + assert "n1" in ids + assert fast_id not in ids + history_manager._check_and_fetch_replacements.assert_called_once() + + +def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): + llm = MagicMock() + llm.generate.return_value = "Merged Content" + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + existing_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2}, + } + mock_graph_db.get_nodes.return_value = [ + {"id": existing_id, "memory": "Old Content", "metadata": {"version": 2}} + ] + + mem_data = { + "key": "k", + "value": "Proposed", + "tags": ["t1"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + } + + updated = manager._update_existing_memory( + mem_data, [existing_id], [existing_id], {existing_id: 1}, user_name="u1" + ) + + assert updated.memory == "Merged Content" + assert updated.metadata.version == 3 + mock_graph_db.update_node.assert_called_once() + + +def test_update_existing_memory_marks_working_binding_deleted(history_manager, mock_graph_db): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + working_binding = uuid.uuid4().hex + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated = history_manager._update_existing_memory( + mem_data, [primary_id], [primary_id], {primary_id: 1}, user_name="u1" + ) + + assert updated is not None + history_manager.mark_memory_status.assert_called_once_with( + [str(working_binding)], "deleted", user_name="u1" + ) + + +def test_update_existing_memory_no_mark_when_working_binding_matches( + history_manager, mock_graph_db +): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated = history_manager._update_existing_memory( + mem_data, [primary_id], [primary_id], {primary_id: 1}, user_name="u1" + ) + + assert updated is not None + history_manager.mark_memory_status.assert_not_called() + + +def test_update_existing_memory_node_missing(history_manager, mock_graph_db): + mock_graph_db.get_node.return_value = None + mock_graph_db.get_nodes.return_value = [] + mem_data = {"value": "v", "tags": [], "key": "k"} + + updated = history_manager._update_existing_memory(mem_data, ["missing"], [], {}, user_name="u1") + + assert updated is None + mock_graph_db.update_node.assert_not_called() + + +def test_update_node_with_history(): + item = TextualMemoryItem( + memory="Old Content", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + tags=["old"], + key="k1", + history=[], + ), + ) + + updated, archived = MemoryHistoryManager.update_node_with_history( + item, + "New Content", + "conflict", + ) + + assert updated.memory == "New Content" + assert updated.metadata.version == 3 + assert updated.metadata.tags == ["old"] + assert updated.metadata.key == "k1" + assert len(updated.metadata.history) == 1 + history_entry = updated.metadata.history[0] + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "conflict" + assert history_entry.archived_memory_id == archived.id + assert archived.metadata.status == "archived" + assert archived.metadata.evolve_to == [updated.id] + + +def test_merge_conflicting_memory_llm_error(mock_graph_db): + llm = MagicMock() + llm.generate.side_effect = Exception("fail") + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + merged = manager._merge_conflicting_memory("Latest", "Proposed") + + assert "System Merge Fallback" in merged + assert "Latest" in merged + assert "Proposed" in merged From 8af6e9b28f6634f7c741b26e25e1e9be56e91ddc Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 10:31:57 +0800 Subject: [PATCH 14/92] fix: Fix some behaviors in memory versions. --- src/memos/api/handlers/formatters_handler.py | 1 + .../organize/history_manager.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 06c4fd223..f515bb772 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -158,6 +158,7 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): sources = item.get("metadata", {}).get("sources", []) if ( item["metadata"]["memory_type"] != "RawFileMemory" + and sources and len(sources) > 0 and "type" in sources[0] and sources[0]["type"] == "file" diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 3c719b189..6dd273836 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -612,6 +612,11 @@ def _update_existing_memory( mem_data.get("tags", []), mem_data.get("key", ""), ) + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(new_value_item) + new_value = new_value_item.memory # Fetch candidate nodes in batch and then select the primary # We update the primary and then merge the secondaries to the primary @@ -696,8 +701,12 @@ def _apply_cas_merge( f"Expected v{expected_version}, but found v{current_version} in DB. " "Triggering merge logic." ) + latest_item = TextualMemoryItem( + memory=current_item.memory, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(latest_item) merged_content = self._merge_conflicting_memory( - latest_memory=current_item.memory, + latest_memory=latest_item.memory, proposed_update=new_value, ) return merged_content @@ -794,6 +803,11 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: """Create New Node.""" new_value = mem_data.get("value", "") + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(new_value_item) + new_value = new_value_item.memory tags = mem_data.get("tags", []) key = mem_data.get("key", "") memory_type = mem_data.get("memory_type", "LongTermMemory") @@ -828,6 +842,9 @@ def _handle_restored_memories( source_item = source_items[i] source_history = source_item.history.copy() value = data.get("value", "") + value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) + _detach_related_content(value_item) + value = value_item.memory tags = data.get("tags", []) keys = data.get("keys", []) memory_type = data.get("memory_type", "LongTermMemory") From 6b96298f5cabd21daab4c06a7138a11d8bd17060 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 11:49:25 +0800 Subject: [PATCH 15/92] feat: Add latency information for nli model in the log. --- src/memos/mem_reader/multi_modal_struct.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 8237e22e5..b552e6f37 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1,6 +1,7 @@ import concurrent.futures import json import re +import time import traceback from typing import TYPE_CHECKING, Any @@ -975,14 +976,30 @@ def _fast_resolve_memory_duplicates_and_conflicts( continue try: # recall related memories + retrieve_start = time.perf_counter() related = self.pre_update_retriever.retrieve( item=item, user_name=user_name, ) + retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 + logger.info( + "[MultiModalStruct] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", + retrieve_ms, + getattr(item, "id", None), + ) # NLI check & attaching contents + nli_start = time.perf_counter() conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( item, related ) + nli_ms = (time.perf_counter() - nli_start) * 1000 + logger.info( + "[MultiModalStruct] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", + nli_ms, + getattr(item, "id", None), + len(related), + len(conflicting_or_duplicate_ids), + ) # mark delete(temporarily) self.history_manager.mark_memory_status( conflicting_or_duplicate_ids, "resolving", user_name=user_name From 5b742a56094013147dee21c71ec49766ff434e31 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 10:39:13 +0800 Subject: [PATCH 16/92] feat: Building fast-add related functions for memory versions. --- .../organize/history_manager.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 1afdc9281..97bbe7483 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,4 +1,5 @@ import logging +import time from typing import Literal @@ -67,6 +68,36 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: return +def _rebuild_fast_node_history( + item: TextualMemoryItem, + replacements: dict[int, list[ArchivedTextualMemory]], +) -> None: + """ + Reconstruct the history list of a fast node: + 1. Replace resolved items with their evolved versions. + 2. Deduplicate by ID while preserving newest versions. + """ + new_history = {} + + def _add(history_item): + item_id = history_item.archived_memory_id + current = new_history.get(item_id) + + if current is None or history_item.version > current.version: + new_history[item_id] = history_item + + # Apply replacements and filter superseded items + for i, h in enumerate(item.metadata.history): + if i in replacements: + # This item is resolved, insert its replacements + for replacement_item in replacements[i]: + _add(replacement_item) + else: + _add(h) + + item.metadata.history = list(new_history.values()) + + class MemoryHistoryManager: def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: """ @@ -79,6 +110,131 @@ def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: self.nli_client = nli_client self.graph_db = graph_db + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int] + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node_data = nodes_map.get(h_item.archived_memory_id) + + if not node_data: + continue + + metadata = node_data.get("metadata", {}) + status = metadata.get("status") + + # Condition: Fast node is processed when it is marked as 'deleted' + if status == "deleted": + evolve_to_ids = metadata.get("evolve_to", []) + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) + replacements[i] = new_items + + logger.info( + f"[MemoryHistoryManager] Resolved fast history item {h_item.archived_memory_id} -> {evolve_to_ids}" + ) + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: From 47e0f4a034079b55bbe0cb985d5d5fac7f86715e Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 10:46:39 +0800 Subject: [PATCH 17/92] feat: Building fast-add related functions for memory versions. --- .../organize/history_manager.py | 112 +++++++++--------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 97bbe7483..4cdfac985 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -179,62 +179,6 @@ def _fetch_evolved_nodes( return results - def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: - """ - Scan the item's history. If any history item is marked as `is_fast`, - wait for it to be resolved (i.e., status becomes 'deleted' in the DB). - When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. - Finally, deduplicate the history. - - Args: - item: The memory item containing the history to check. - timeout_sec: Maximum time to wait for resolution in seconds. - """ - start_time = time.time() - - # 1. Identify pending items (fast nodes) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - while True: - if not pending_indices: - # All fast nodes resolved or none existed - break - - if time.time() - start_time > timeout_sec: - logger.warning( - f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" - ) - # Remove pending fast nodes from history - item.metadata.history = [ - h - for h in item.metadata.history - if not (getattr(h, "is_fast", False) and h.archived_memory_id) - ] - break - - # 2. Check status of the fast nodes and fetch replacements for evolved ones - replacements = self._check_and_fetch_replacements(item, pending_indices) - - # 3. If we have any resolved items, rebuild the history - if replacements: - _rebuild_fast_node_history(item, replacements) - - # Check if we are done (no pending items left) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - if pending_indices: - time.sleep(1) # This avoids visiting the DB too frequently - - return - def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: @@ -293,6 +237,62 @@ def resolve_history_via_nli( return duplicate_memories + conflict_memories + def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + def mark_memory_status( self, memory_items: list[TextualMemoryItem], From 7aaddf874268bdc059f9a1830195350d05be51ff Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Feb 2026 11:30:26 +0800 Subject: [PATCH 18/92] feat: supporting memory versions in fast-add process --- src/memos/api/handlers/component_init.py | 9 +++- src/memos/configs/mem_reader.py | 6 ++- src/memos/mem_reader/factory.py | 18 ++++++++ src/memos/mem_reader/multi_modal_struct.py | 44 +++++++++++++++++++ src/memos/mem_reader/simple_struct.py | 12 +++++ .../init_components_for_scheduler.py | 14 +++++- 6 files changed, 100 insertions(+), 3 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ba527d602..2e27e9da5 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -192,8 +193,14 @@ def init_server() -> dict[str, Any]: embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + pre_update_retriever=pre_update_retriever, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 4bd7953c0..98aff981f 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator, model_validator @@ -65,6 +65,10 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): default=None, description="Skills directory for the MemReader", ) + memory_version_switch: Literal["on", "off"] = Field( + default="off", + description="Turn on memory version or off", + ) class StrategyStructMemReaderConfig(BaseMemReaderConfig): diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 7bd551fb8..0907168a7 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -5,11 +5,15 @@ from memos.mem_reader.multi_modal_struct import MultiModalStructMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_reader.strategy_struct import StrategyStructMemReader +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memos_tools.singleton import singleton_factory if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -29,6 +33,8 @@ def from_config( config_factory: MemReaderConfigFactory, graph_db: Optional["BaseGraphDB | None"] = None, searcher: Optional["Searcher | None"] = None, + pre_update_retriever: PreUpdateRetriever | None = None, + history_manager: Optional["MemoryHistoryManager | None"] = None, ) -> BaseMemReader: """ Create a MemReader instance from configuration. @@ -55,4 +61,16 @@ def from_config( if searcher is not None: reader.set_searcher(searcher) + if pre_update_retriever is not None: + if hasattr(reader, "set_pre_update_retriever"): + reader.set_pre_update_retriever(pre_update_retriever) + else: + reader.pre_update_retriever = pre_update_retriever + + if history_manager is not None: + if hasattr(reader, "set_history_manager"): + reader.set_history_manager(history_manager) + else: + reader.history_manager = history_manager + return reader diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 8b0968ca1..adda62e26 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -56,6 +56,10 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.pre_update_retriever = None + self.history_manager = None + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Initialize MultiModalParser for routing to different parsers self.multi_modal_parser = MultiModalParser( embedder=self.embedder, @@ -934,6 +938,39 @@ def _process_tool_trajectory_fine( return fine_memory_items + def _fast_resolve_memory_duplicates_and_conflicts( + self, fast_memory_items: list[TextualMemoryItem], user_name: str + ) -> None: + """ + 1. Recall related memories + 2. Fast conflict/duplication check with NLI model + 3. Attach conflicting/duplicate old memory contents onto fast memory items + 4. Mark conflicting/duplicate old memory nodes as "resolving", making them invisible to /search, + but still visible for other conflict/duplication checks' recalls. + """ + if not self.pre_update_retriever or not self.history_manager: + logger.warning( + "[MultiModalStruct] PreUpdateRetriever or HistoryManager is not initialized." + ) + return + + for item in fast_memory_items: + try: + # recall related memories + related = self.pre_update_retriever.retrieve( + item=item, + user_name=user_name, + ) + # NLI check & attaching contents + conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( + item, related + ) + # mark delete + self.history_manager.mark_memory_status(conflicting_or_duplicate_items, "resolving") + + except Exception as e: + logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs @@ -984,6 +1021,13 @@ def _process_multi_modal_data( scene_data_info, info, mode="fast", need_emb=False, **kwargs ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + + # Perform conflict/duplicate check with old memories + # TODO: find a better way to pass in the user_name + user_name = kwargs.get("user_name") + if self.memory_version_switch == "on": + self._fast_resolve_memory_duplicates_and_conflicts(fast_memory_items, user_name) + if mode == "fast": return fast_memory_items else: diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index ceaf28bfa..237653d4a 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -20,6 +20,10 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.organize.history_manager import ( + MemoryHistoryManager, + ) + from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang from memos.mem_reader.utils import ( @@ -185,6 +189,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): # Initialize graph_db as None, can be set later via set_graph_db for # recall operations self.graph_db = None + self.pre_update_retriever = None + self.history_manager = None def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: self.graph_db = graph_db @@ -192,6 +198,12 @@ def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: def set_searcher(self, searcher: "Searcher | None") -> None: self.searcher = searcher + def set_pre_update_retriever(self, pre_update_retriever: "PreUpdateRetriever | None") -> None: + self.pre_update_retriever = pre_update_retriever + + def set_history_manager(self, history_manager: "MemoryHistoryManager | None") -> None: + self.history_manager = history_manager + def _make_memory_item( self, value: str, diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index b103acf3a..883c191fe 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -12,6 +12,7 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory +from memos.extras.nli_model.client import NLIClient from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger @@ -30,10 +31,12 @@ ) from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer @@ -287,6 +290,7 @@ def init_components() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() embedder_config = build_embedder_config() + nli_client_config = APIConfig.get_nli_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -307,8 +311,16 @@ def init_components() -> dict[str, Any]: ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) + nli_client = NLIClient(base_url=nli_client_config["base_url"]) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + pre_update_retriever=pre_update_retriever, + history_manager=memory_history_manager, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( From 872c4961a49cc2e8fbb00f6fb15f762c45212c06 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Thu, 5 Feb 2026 17:10:59 +0800 Subject: [PATCH 19/92] feat: Functions for supporting memory versions in fine-add process --- src/memos/api/handlers/component_init.py | 2 +- src/memos/mem_reader/multi_modal_struct.py | 6 +- .../organize/history_manager.py | 583 +++++++++++++++--- .../mem_reader_mem_version_prompts.py | 168 +++++ 4 files changed, 680 insertions(+), 79 deletions(-) create mode 100644 src/memos/templates/mem_reader_mem_version_prompts.py diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 2e27e9da5..e764c7fd4 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -192,7 +192,7 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db, llm=llm) pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config( diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index adda62e26..16098907d 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -965,8 +965,10 @@ def _fast_resolve_memory_duplicates_and_conflicts( conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( item, related ) - # mark delete - self.history_manager.mark_memory_status(conflicting_or_duplicate_items, "resolving") + # mark delete(temporarily) + self.history_manager.mark_memory_status( + [m.id for m in conflicting_or_duplicate_items], "resolving" + ) except Exception as e: logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 4cdfac985..2c2d37654 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,13 +1,24 @@ import logging import time +import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal from memos.context.context import ContextThreadPoolExecutor from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem +from memos.llms.base import BaseLLM +from memos.memories.textual.item import ( + ArchivedTextualMemory, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.templates.mem_reader_mem_version_prompts import ( + ASYNC_MEMORY_UPDATE_PROMPT_ZH, + MEMORY_MERGE_PROMPT_ZH, +) logger = logging.getLogger(__name__) @@ -69,13 +80,12 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: def _rebuild_fast_node_history( - item: TextualMemoryItem, - replacements: dict[int, list[ArchivedTextualMemory]], + item: TextualMemoryItem, replacements: dict[int, list[ArchivedTextualMemory]] ) -> None: """ Reconstruct the history list of a fast node: 1. Replace resolved items with their evolved versions. - 2. Deduplicate by ID while preserving newest versions. + 2. Deduplicate by ID while preserving the newest versions. """ new_history = {} @@ -99,89 +109,109 @@ def _add(history_item): class MemoryHistoryManager: - def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: + def __init__( + self, nli_client: NLIClient, graph_db: BaseGraphDB, llm: BaseLLM | None = None + ) -> None: """ Initialize the MemoryHistoryManager. Args: nli_client: NLIClient for conflict/duplicate detection. graph_db: GraphDB instance for marking operations during history management. + llm: Optional LLM instance for memory merging during conflicts. """ self.nli_client = nli_client self.graph_db = graph_db - - def _check_and_fetch_replacements( - self, item: TextualMemoryItem, pending_indices: list[int] - ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + self.llm = llm + + @staticmethod + def update_node_with_history( + item: TextualMemoryItem, + new_memory: str, + update_type: str, + tags: list[str] | None = None, + key: str | None = None, + ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + old_id = item.id + archived_id = str(uuid.uuid4()) + # archived memory(need to store this node to the db later) + archived_item = item.model_copy(deep=True) + archived_item.id = archived_id + archived_item.metadata.evolve_to = [old_id] + archived_item.metadata.status = "archived" + + # original memory with updated contents and history + archived_history = ArchivedTextualMemory( + version=item.metadata.version or 1, + is_fast=item.metadata.is_fast or False, + memory=item.memory, + update_type=update_type, + archived_memory_id=archived_id, + created_at=getattr(item.metadata, "created_at", None) or datetime.now().isoformat(), + ) + item.memory = new_memory + item.metadata.version = (item.metadata.version or 1) + 1 + item.metadata.status = "activated" + if tags is not None: + item.metadata.tags = tags + if key is not None: + item.metadata.key = key + if item.metadata.history is None: + item.metadata.history = [] + item.metadata.history.append(archived_history) + + return item, archived_item + + @staticmethod + def format_async_update_prompt( + item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + ) -> str: """ - Check DB status for pending items. If 'deleted', fetch evolved nodes. + Format the prompt for asynchronous memory update. + + Args: + item: The TextualMemoryItem containing history candidates. + conversation: The current conversation content. If None, uses item.memory. + custom_tags_prompt: Optional custom prompt for tags. Returns: - replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + Formatted prompt string. """ - pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] - - # Batch fetch pending nodes to check status - nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + # First, detach duplicate and conflict memory contents from the new item's memory text + _detach_related_content(item) - replacements = {} - - for i in pending_indices: - h_item = item.metadata.history[i] - node_data = nodes_map.get(h_item.archived_memory_id) + duplicate_candidates = [] + conflict_candidates = [] + unrelated_candidates = [] - if not node_data: - continue + for h in item.metadata.history or []: + candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" - metadata = node_data.get("metadata", {}) - status = metadata.get("status") - - # Condition: Fast node is processed when it is marked as 'deleted' - if status == "deleted": - evolve_to_ids = metadata.get("evolve_to", []) - - new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) - replacements[i] = new_items - - logger.info( - f"[MemoryHistoryManager] Resolved fast history item {h_item.archived_memory_id} -> {evolve_to_ids}" - ) - - return replacements - - def _fetch_evolved_nodes( - self, evolve_to_ids: list[str], update_type: str - ) -> list[ArchivedTextualMemory]: - """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" - if not evolve_to_ids: - return [] - - evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] - results = [] - - for enode in evolved_nodes: - if not enode or "id" not in enode: - continue + if h.update_type == "duplicate": + duplicate_candidates.append(candidate_str) + elif h.update_type == "conflict": + conflict_candidates.append(candidate_str) + else: + # Includes "unrelated" and any other types + unrelated_candidates.append(candidate_str) - enode_meta = enode.get("metadata", {}) + # Helper to format list as string + def format_list(candidates): + return "\n".join(candidates) if candidates else "无" - # Create new archived memory inheriting the update_type (conflict/duplicate) - new_archived = ArchivedTextualMemory( - version=enode_meta.get("version", 1), - is_fast=enode_meta.get("is_fast", False), - memory=enode.get("memory", ""), - update_type=update_type, - archived_memory_id=enode.get("id"), - created_at=enode_meta.get("created_at"), + return ( + ASYNC_MEMORY_UPDATE_PROMPT_ZH.replace( + "${duplicate_candidates}", format_list(duplicate_candidates) ) - results.append(new_archived) - - return results + .replace("${conflict_candidates}", format_list(conflict_candidates)) + .replace("${unrelated_candidates}", format_list(unrelated_candidates)) + .replace("${custom_tags_prompt}", custom_tags_prompt) + .replace("${conversation}", conversation or item.memory) + ) def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[TextualMemoryItem]: + ) -> list[str]: """ Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, and attach them as history to the new fast item. @@ -191,7 +221,7 @@ def resolve_history_via_nli( related_items: Existing memory items that might be related. Returns: - List of duplicate or conflicting memory items judged by the NLI service. + List of duplicate or conflicting memory ids judged by the NLI service. """ if not related_items: return [] @@ -208,10 +238,10 @@ def resolve_history_via_nli( for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" - duplicate_memories.append(r_item.memory) + duplicate_memories.append(r_item.id) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" - conflict_memories.append(r_item.memory) + conflict_memories.append(r_item.id) else: update_type = "unrelated" @@ -227,9 +257,6 @@ def resolve_history_via_nli( created_at=created_at, ) new_item.metadata.history.append(archived) - logger.info( - f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" - ) # 3. Concat duplicate/conflict memories to new_item.memory # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. @@ -293,26 +320,88 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int return + def apply_llm_memory_updates( + self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """ + Apply the updates from the LLM response to the memory graph. + + Args: + llm_response: The parsed JSON response from the LLM. + source_item: The original fast item A whose history contains ArchivedTextualMemory entries. + We derive expected versions and candidate IDs from A.history. + user_name: user_name + + Returns: + List of new or updated memory items. + """ + memory_list = llm_response.get("memory list", []) + restored_memories = llm_response.get("restored_memories", []) + used_source_ids = set() + used_conflict_ids = set() + # Collect IDs referenced by the LLM response + for mem_data in memory_list: + used_source_ids.update(mem_data.get("source_candidate_ids", [])) + used_conflict_ids.update(mem_data.get("conflicted_candidate_ids", [])) + # Collect IDs referenced by restored memories + restored_source_ids = { + mem.get("source_candidate_id") + for mem in restored_memories + if mem.get("source_candidate_id") + } + # All IDs used by the LLM response + used_ids = used_source_ids | used_conflict_ids | restored_source_ids + expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories + candidate_id_set: set[str] = set() + # Recover candidate IDs and their expected versions from the source item's history + if source_item.metadata and source_item.metadata.history: + for h in source_item.metadata.history: + if h.archived_memory_id: + candidate_id_set.add(h.archived_memory_id) + expected_versions[h.archived_memory_id] = h.version + + # IDs not used by the LLM response will be treated as unrelated + unrelated_ids = sorted(candidate_id_set - used_ids) + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + + # 1. Handle Unrelated Candidates - Restore status to activated + self._handle_unrelated_candidates(unrelated_ids, user_name=user_name) + + # 2. Handle Memory List (Update or New) + processed_updates, created_items = self._process_memory_updates( + memory_list, expected_versions + ) + updated_items.extend(processed_updates) + new_items.extend(created_items) + + # 3. Handle Restored Memories (Extract from conflict) + new_items.extend(self._handle_restored_memories(restored_memories)) + + return updated_items, new_items + def mark_memory_status( self, - memory_items: list[TextualMemoryItem], + memory_ids: list[str], status: Literal["activated", "resolving", "archived", "deleted"], + user_name: str, ) -> None: """ Support status marking operations during history management. Common usages are: 1. Mark conflict/duplicate old memories' status as "resolving", to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to restore their visibility. + 2. Mark resolved memories' status as "activated", to recover their visibility. """ # Execute the actual marking operation - in db. with ContextThreadPoolExecutor() as executor: futures = [] - for mem in memory_items: + for mid in memory_ids: futures.append( executor.submit( self.graph_db.update_node, - id=mem.id, + id=mid, fields={"status": status}, + user_name=user_name, ) ) @@ -320,3 +409,345 @@ def mark_memory_status( for future in futures: future.result() return + + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int] + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node_data = nodes_map.get(h_item.archived_memory_id) + + if not node_data: + continue + + metadata = node_data.get("metadata", {}) + status = metadata.get("status") + + # Condition: Fast node is processed when it is marked as 'deleted' + if status == "deleted": + evolve_to_ids = metadata.get("evolve_to", []) + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) + replacements[i] = new_items + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) -> None: + """Handle Unrelated Candidates - Restore status to `activated`.""" + if unrelated_ids: + self.mark_memory_status(unrelated_ids, "activated", user_name) + + def _process_memory_updates( + self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int] + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """Process Memory List (Update or New).""" + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + for mem_data in memory_list: + source_ids = mem_data.get("source_candidate_ids", []) + conflict_ids = mem_data.get("conflicted_candidate_ids", []) + + # Determine if this is an update or a new node + target_ids = source_ids + conflict_ids + + if target_ids: + item = self._update_existing_memory( + mem_data, + target_ids, + source_ids, + expected_versions, + ) + if item: + updated_items.append(item) + else: + item = self._create_new_memory(mem_data) + new_items.append(item) + return updated_items, new_items + + def _update_existing_memory( + self, + mem_data: dict[str, Any], + target_ids: list[str], + source_ids: list[str], + expected_versions: dict[str, int], + ) -> TextualMemoryItem | None: + """ + Update existing memory nodes using the LLM result. + + The first ID in target_ids is treated as the primary node. If additional target IDs + are provided, they are treated as secondary candidates and will be merged into the + primary. Merging means: + 1) Mark secondary nodes as archived and append the primary ID to evolve_to + 2) Merge their history entries into the primary history and re-order by created_at + + The method also applies CAS validation via expected_versions, archives the previous + version of the primary node, and persists the updated node back to the graph DB. + + Returns the updated primary TextualMemoryItem, or None when the primary node + cannot be found. + """ + original_primary_id, primary_id, secondary_ids = ( + target_ids[0], + target_ids[0], + target_ids[1:], + ) + new_value, tags, key = ( + mem_data.get("value", ""), + mem_data.get("tags", []), + mem_data.get("key", ""), + ) + + # Fetch candidate nodes in batch and then select the primary + nodes_data = self.graph_db.get_nodes(target_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + node_data = nodes_map.get(primary_id) + if not node_data: + node_data = self.graph_db.get_node(primary_id) + if not node_data: + logger.warning( + f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." + ) + return None + + current_item = TextualMemoryItem(**node_data) + new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) + + update_type = "duplicate" if original_primary_id in source_ids else "conflict" + current_item, archived_item = self.update_node_with_history( + current_item, + new_value, + update_type, + tags=tags, + key=key, + ) + now = datetime.now().isoformat() + if hasattr(archived_item.metadata, "created_at") and not getattr( + archived_item.metadata, "created_at", None + ): + archived_item.metadata.created_at = now + if hasattr(archived_item.metadata, "updated_at"): + archived_item.metadata.updated_at = now + if hasattr(current_item.metadata, "updated_at"): + current_item.metadata.updated_at = now + self.graph_db.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=archived_item.metadata.model_dump(exclude_none=True), + ) + fields = current_item.metadata.model_dump(exclude_none=True) + fields.pop("created_at", None) + fields["updated_at"] = now + self.graph_db.update_node( + id=primary_id, + fields={"memory": current_item.memory, **fields}, + ) + self._merge_secondary_nodes(secondary_ids, primary_id) + + return current_item + + def _apply_cas_merge( + self, + primary_id: str, + current_item: TextualMemoryItem, + expected_versions: dict[str, int], + new_value: str, + ) -> str: + expected_version = expected_versions.get(primary_id) + current_version = current_item.metadata.version or 1 + if expected_version is not None and current_version != expected_version: + logger.warning( + f"[MemoryHistoryManager] Version conflict for node {primary_id}: " + f"Expected v{expected_version}, but found v{current_version} in DB. " + "Triggering merge logic." + ) + merged_content = self._merge_conflicting_memory( + latest_memory=current_item.memory, + proposed_update=new_value, + ) + return merged_content + + return new_value + + def _merge_secondary_nodes(self, secondary_ids: list[str], primary_id: str) -> None: + if not secondary_ids: + return + + node_ids = [primary_id, *secondary_ids] + nodes_data = self.graph_db.get_nodes(ids=node_ids) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + primary_data = nodes_map.get(primary_id) or self.graph_db.get_node(primary_id) + if not primary_data: + return + + primary_item = TextualMemoryItem(**primary_data) + merged_history = list(primary_item.metadata.history or []) + + for memory_id in secondary_ids: + node_data = nodes_map.get(memory_id) or self.graph_db.get_node(memory_id) + if not node_data: + continue + metadata = node_data.get("metadata", {}) + evolve_to = list(metadata.get("evolve_to", []) or []) + if primary_id not in evolve_to: + evolve_to.append(primary_id) + self.graph_db.update_node( + id=memory_id, + fields={"status": "archived", "evolve_to": evolve_to}, + ) + secondary_item = TextualMemoryItem(**node_data) + if secondary_item.metadata.history: + merged_history.extend(secondary_item.metadata.history) + + def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: + created_at = history_item.created_at + if isinstance(created_at, datetime): + return created_at + if created_at: + try: + return datetime.fromisoformat(created_at) + except ValueError: + return datetime.min + return datetime.min + + merged_history.sort(key=_history_sort_key) + for idx, history_item in enumerate(merged_history, start=1): + history_item.version = idx + primary_item.metadata.history = merged_history + self.graph_db.update_node( + id=primary_id, + fields={"history": [h.model_dump(exclude_none=True) for h in merged_history]}, + ) + + def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: + """ + Call LLM to merge proposed update with latest memory content. + """ + if not self.llm: + return proposed_update + + prompt = MEMORY_MERGE_PROMPT_ZH.replace("${latest_memory}", latest_memory).replace( + "${proposed_update}", proposed_update + ) + + messages = [{"role": "user", "content": prompt}] + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to merge memory via LLM: {e}") + # Fallback: append proposed update? or just return proposed? + # Returning proposed might overwrite latest changes. + # Returning latest might lose proposed changes. + # Let's concatenate as a safe fallback. + return f"{latest_memory}\n\n[System Merge Fallback] New Info: {proposed_update}" + + def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: + """Create New Node.""" + new_value = mem_data.get("value", "") + tags = mem_data.get("tags", []) + key = mem_data.get("key", "") + memory_type = mem_data.get("memory_type", "LongTermMemory") + + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=new_value, + metadata=TreeNodeTextualMemoryMetadata( + is_fast=False, + version=1, + memory_type=memory_type, + status="activated", + tags=tags, + key=key, + created_at=datetime.now().isoformat(), + history=[], + ), + ) + return new_item + + def _handle_restored_memories( + self, restored_memories: list[dict[str, Any]] + ) -> list[TextualMemoryItem]: + """Handle Restored Memories (Extract from conflict).""" + source_ids = [r.get("source_candidate_id") for r in restored_memories] + source_items = self.graph_db.get_nodes(source_ids) + source_items = [TextualMemoryItem(**i) for i in source_items] + + created_items = [] + for i, data in enumerate(restored_memories): + source_item = source_items[i] + source_history = source_item.history.copy() + value = data.get("value", "") + tags = data.get("tags", []) + keys = data.get("keys", []) + memory_type = data.get("memory_type", "LongTermMemory") + version = source_item.verion + new_history_item = ArchivedTextualMemory( + version=version, + is_fast=False, + memory=source_item.memory, + update_type="extract", + archived_memory_id=source_item.id, + created_at=source_item.metadata.created_at, + ) + source_history.append(new_history_item) # Re-use the history of the old node + # Create new node + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=value, + metadata=TreeNodeTextualMemoryMetadata( + memory_type=memory_type, + status="activated", + is_fast=False, + version=version + 1, + key=keys, + tags=tags, + created_at=datetime.now().isoformat(), + history=source_history, + ), + ) + created_items.append(new_item) + + return created_items diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py new file mode 100644 index 000000000..a5933f192 --- /dev/null +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -0,0 +1,168 @@ +# ========================================== +# 记忆更新与维护链路 (Memory Update & Maintenance) +# ========================================== +ASYNC_MEMORY_UPDATE_PROMPT_ZH = """您是记忆库维护专家。 +您的核心任务是根据最新的用户对话,维护和更新用户的长期记忆图谱。您需要利用系统提供的“候选记忆”(Candidates)来决定是更新现有记忆节点,还是创建全新的记忆节点。 + +系统检测到以下“候选记忆”与用户的最新输入存在潜在关联: + +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: +${duplicate_candidates} + +2. **潜在事实冲突记忆 (Conflict Candidates)**: +${conflict_candidates} + +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: +${unrelated_candidates} + +请执行以下操作: + +1. **核心原则(STRICT)**: + - 您的目标是**维护**记忆库,而非仅仅提取信息。 + - **优先更新**:如果对话内容涉及现有的“候选记忆”,应优先视为对该记忆节点的**更新**(补充细节或修正状态),而不是创建重复的新节点。 + - **按需新增**:仅当对话内容包含全新的、与现有“候选记忆”完全无关的话题时,才创建新的记忆节点。 + - 提取来源**只能**是【当前的对话内容】。严禁编造未提及的信息。 + +2. 识别反映用户经历、信念、关切、决策、计划或反应的信息。 + - 如果消息来自用户,提取用户相关的记忆。 + - 如果来自助手,仅提取用户认可或回应的事实性记忆。 + +3. 清晰解析所有时间、人物和事件的指代(同原规则): + - 将相对时间(“昨天”)转换为绝对日期。 + - 明确区分事件时间和消息时间。 + - 解析代词和模糊指代。 + - 使用第三人称(“用户”)。 + - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 + +4. **处理逻辑(更新与新增)**: + 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: + + a) **更新现有记忆节点 (Update via Duplicate/Related)**: + - 检查“潜在重复/关联记忆”。 + - 如果新信息是对某条旧记忆的重复、确认或补充细节: + - 生成一条**更新后的完整记忆**放入 `value`(包含旧信息+新细节)。 + - 将该旧记忆的ID放入 `source_candidate_ids`。 + - 此时 `conflicted_candidate_ids` 应为空。 + + b) **修正冲突记忆节点 (Update via Conflict)**: + - 检查“潜在事实冲突记忆”。 + - 如果新信息否定了某条旧记忆,或更新了其状态(如“不再喜欢X”“改成Y”“取消计划”“从X转为Y”): + - 生成一条反映**最新状态**的记忆放入 `value`。 + - 将被修正的旧记忆ID放入 `conflicted_candidate_ids`。 + + c) **创建新记忆节点 (Create New)**: + - 如果新信息与任何“候选记忆”都无直接关联(既非重复也非冲突): + - 生成一条独立的新记忆放入 `value`。 + - 确保 `source_candidate_ids` 和 `conflicted_candidate_ids` 均为 `[]`。 + +5. **挽救剩余信息 (Restored Memories)**: + - 当您决定更新一个冲突节点(情况 b)时,请检查该旧节点是否包含**未被新信息否定且依然独立的有效事实**。 + - 如果有,请将这些“剩余有效信息”提取出来,作为 `restored_memories` 返回。这能防止因节点更新导致非冲突信息的丢失。 + +返回一个有效的JSON对象,结构如下: + +{ + "memory list": [ + { + "key": <字符串,简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory",区分该记忆是客观事实还是和用户相关的内容>, + "value": <字符串,更新后的完整记忆内容(针对更新/冲突情况)或全新记忆内容(针对新增情况)>, + "tags": <相关主题关键词列表>, + "source_candidate_ids": <字符串列表,被此条目更新的“重复/关联记忆”ID。若无则为 []>, + "conflicted_candidate_ids": <字符串列表,被此条目修正的“事实冲突记忆”ID。若无则为 []>, + }, + ... + ], + "restored_memories": [ + { + "source_candidate_id": <字符串,来源旧冲突记忆ID>, + "value": <字符串,从旧节点中分离出的、依然有效的剩余事实>, + "tags": <相关主题关键词列表> + }, + ... + ], + "summary": <从用户视角自然总结本次记忆更新操作的段落,120–200字> +} + +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 + +${custom_tags_prompt} + +示例: +**潜在重复/关联记忆 (Duplicate/Related Candidates):** +[ID:101] 用户喜欢喝拿铁,通常不加糖。 +[ID:102] 用户讨厌下雨天。 + +**潜在事实冲突记忆 (Conflict Candidates):** +[ID:201] 用户喜欢打羽毛球,但不喜欢滑雪。 + +**对话:** +user: [2025年6月26日]:我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 + +**输出:** +{ + "memory list": [ + { + "key": "咖啡偏好", + "memory_type": "UserMemory", + "value": "用户最喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", + "tags": ["饮食", "咖啡", "喜好"], + "source_candidate_ids": ["101"], + "conflicted_candidate_ids": [], + }, + { + "key": "运动习惯变更", + "memory_type": "UserMemory", + "value": "用户因膝盖受伤,决定不再打羽毛球。", + "tags": ["运动", "健康", "羽毛球"], + "source_candidate_ids": [], + "conflicted_candidate_ids": ["201"], + }, + { + "key": "宠物状况", + "memory_type": "UserMemory", + "value": "用户最近(2025年6月)刚养了一只猫。", + "tags": ["生活", "宠物", "猫"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + } + ], + "restored_memories": [ + { + "key": "运动偏好", + "memory_type": "UserMemory", + "value": "用户不喜欢滑雪。", + "tags": ["运动", "滑雪", "喜好"] + "source_candidate_id": "201", + } + ], + "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了用户养猫的记忆。对天气的喜好未受影响。" +} + +请始终使用与对话相同的语言进行回复。 + +对话: +${conversation} +""" + +MEMORY_MERGE_PROMPT_ZH = """ +您是记忆库维护专家。 +我们尝试更新一个记忆节点,但该节点在数据库中的内容在处理期间发生了变化(版本冲突)。 +我们需要将“本次处理得出的更新内容”合并到“当前数据库中最新的记忆内容”中。 + +**原始记忆(数据库中的最新版本):** +${latest_memory} + +**本次尝试的更新内容(基于旧版本得出的结论):** +${proposed_update} + +**任务:** +将“本次尝试的更新内容”合并到“原始记忆”中。 +- 如果更新内容包含新信息,请将其整合进去。 +- 如果更新内容与原始记忆冲突,请优先采纳更新内容(假设它是基于最新对话的修正),但请尽量保留原始记忆中依然有效的细节。 +- 确保合并后的结果是一个连贯、通顺的完整记忆片段。 + +请只返回合并后的记忆内容字符串,不要包含任何解释。 +""" From 7459f7f922e273d90ef75af3222158ec57af9669 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 10:19:13 +0800 Subject: [PATCH 20/92] feat: Supporting English prompts for memory versions. --- .../organize/history_manager.py | 40 +++- .../mem_reader_mem_version_prompts.py | 178 +++++++++++++++++- 2 files changed, 207 insertions(+), 11 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 2c2d37654..65295c769 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -10,14 +10,15 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.llms.base import BaseLLM +from memos.mem_reader.read_multi_modal.utils import detect_lang from memos.memories.textual.item import ( ArchivedTextualMemory, TextualMemoryItem, TreeNodeTextualMemoryMetadata, ) from memos.templates.mem_reader_mem_version_prompts import ( - ASYNC_MEMORY_UPDATE_PROMPT_ZH, - MEMORY_MERGE_PROMPT_ZH, + ASYNC_MEMORY_UPDATE_PROMPT_DICT, + MEMORY_MERGE_PROMPT_DICT, ) @@ -162,9 +163,22 @@ def update_node_with_history( return item, archived_item - @staticmethod + def _determine_lang(self, sources: list | None, fallback_text: str) -> str: + lang = None + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang + break + if isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(fallback_text) + return lang + def format_async_update_prompt( - item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" ) -> str: """ Format the prompt for asynchronous memory update. @@ -195,14 +209,18 @@ def format_async_update_prompt( # Includes "unrelated" and any other types unrelated_candidates.append(candidate_str) - # Helper to format list as string + sources = item.metadata.sources if item.metadata else None + lang = self._determine_lang(sources, conversation or item.memory) + empty_label = "None" + def format_list(candidates): - return "\n".join(candidates) if candidates else "无" + return "\n".join(candidates) if candidates else empty_label + prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( + lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] + ) return ( - ASYNC_MEMORY_UPDATE_PROMPT_ZH.replace( - "${duplicate_candidates}", format_list(duplicate_candidates) - ) + prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) .replace("${conflict_candidates}", format_list(conflict_candidates)) .replace("${unrelated_candidates}", format_list(unrelated_candidates)) .replace("${custom_tags_prompt}", custom_tags_prompt) @@ -668,7 +686,9 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> if not self.llm: return proposed_update - prompt = MEMORY_MERGE_PROMPT_ZH.replace("${latest_memory}", latest_memory).replace( + lang = self._determine_lang(None, f"{latest_memory}\n{proposed_update}") + prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) + prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( "${proposed_update}", proposed_update ) diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index a5933f192..61e56f86c 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -1,5 +1,5 @@ # ========================================== -# 记忆更新与维护链路 (Memory Update & Maintenance) +# Memory Update & Maintenance # ========================================== ASYNC_MEMORY_UPDATE_PROMPT_ZH = """您是记忆库维护专家。 您的核心任务是根据最新的用户对话,维护和更新用户的长期记忆图谱。您需要利用系统提供的“候选记忆”(Candidates)来决定是更新现有记忆节点,还是创建全新的记忆节点。 @@ -147,6 +147,157 @@ ${conversation} """ +ASYNC_MEMORY_UPDATE_PROMPT_EN = """You are a memory maintenance expert. +Your core task is to maintain and update the user's long-term memory graph based on the latest conversation. Use the system-provided "Candidates" to decide whether to update existing memory nodes or create brand-new ones. + +The system detected the following "Candidates" potentially related to the user's latest input: + +1. Duplicate/Related Candidates: +${duplicate_candidates} + +2. Conflict Candidates: +${conflict_candidates} + +3. Possibly unrelated (needs further judgment): +${unrelated_candidates} + +Please execute the following: + +1. Core Principles (STRICT): + - Your goal is maintenance of the memory base, not mere extraction. + - Prefer Update: If the conversation touches any existing "Candidates", treat it as an update to that memory node (add details or correct status) rather than creating a duplicate node. + - Add As Needed: Only create a new node when the conversation contains truly new topics that are completely unrelated to existing "Candidates". + - The extraction source must be ONLY the current conversation. Do not fabricate information not mentioned. + +2. Identify information that reflects the user's experiences, beliefs, concerns, decisions, plans, or reactions. + - If the message is from the user, extract user-related memories. + - If it is from the assistant, only extract factual memories that the user explicitly acknowledges or responds to. + +3. Disambiguate all references to time, people, and events: + - Convert relative time ("yesterday") to an absolute date. + - Clearly distinguish event time from message time. + - Resolve pronouns and ambiguous references. + - Use third person ("the user"). + - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). + +4. Processing Logic (Update and Create): + Traverse each piece of information in the conversation that is worth remembering and apply: + + a) Update existing memory node (Update via Duplicate/Related): + - Check Duplicate/Related Candidates. + - If the new information repeats, confirms, or adds details to an old memory: + - Generate an updated complete memory into `value` (old info + new details). + - Put the old memory IDs into `source_candidate_ids`. + - `conflicted_candidate_ids` must be []. + + b) Fix conflicting memory node (Update via Conflict): + - Check Conflict Candidates. + - If the new information negates an old memory or updates its state (e.g., "no longer likes X", "changed to Y", "canceled plan", "from X to Y"): + - Generate a memory reflecting the latest state into `value`. + - Put the corrected old memory IDs into `conflicted_candidate_ids`. + + c) Create new memory node (Create New): + - If the new information is not directly related to any "Candidates" (neither duplicate nor conflict): + - Generate an independent new memory into `value`. + - Ensure `source_candidate_ids` and `conflicted_candidate_ids` are both `[]`. + +5. Restored Memories: + - When you decide to update a conflict node (case b), check whether the old node contains valid facts that were NOT negated by the new information and remain independently valid. + - If any, extract these "remaining valid facts" and return them as `restored_memories` to prevent loss of non-conflict information due to node updates. + +Return a valid JSON object with the structure: + +{ + "memory list": [ + { + "key": , + "memory_type": , + "value": , + "tags": , + "source_candidate_ids": , + "conflicted_candidate_ids": , + }, + ... + ], + "restored_memories": [ + { + "source_candidate_id": , + "value": , + "tags": + }, + ... + ], + "summary": +} + +Language rules: +- The `key`, `value`, `tags`, and `summary` fields must match the main language of the input conversation. If the input is English, output English. +- `memory_type` remains in English. + +${custom_tags_prompt} + +Example: +Duplicate/Related Candidates: +[ID:101] The user likes latte and usually doesn't add sugar. +[ID:102] The user hates rainy days. + +Conflict Candidates: +[ID:201] The user likes badminton but dislikes skiing. + +Conversation: +user: [June 26, 2025]: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. + +Output: +{ + "memory list": [ + { + "key": "Coffee preference", + "memory_type": "UserMemory", + "value": "The user likes latte most, usually doesn't add sugar, and prefers oat milk.", + "tags": ["diet", "coffee", "preference"], + "source_candidate_ids": ["101"], + "conflicted_candidate_ids": [], + }, + { + "key": "Sport habit change", + "memory_type": "UserMemory", + "value": "Due to a knee injury, the user decides to no longer play badminton.", + "tags": ["sport", "health", "badminton"], + "source_candidate_ids": [], + "conflicted_candidate_ids": ["201"], + }, + { + "key": "Pet status", + "memory_type": "UserMemory", + "value": "The user recently (June 2025) adopted a cat.", + "tags": ["life", "pet", "cat"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + } + ], + "restored_memories": [ + { + "key": "Sports preference", + "memory_type": "UserMemory", + "value": "The user dislikes skiing.", + "tags": ["sport", "skiing", "preference"], + "source_candidate_id": "201" + } + ], + "summary": "In this update, the user refined their latte preference (adds oat milk), and due to a knee injury updated their sports habit (stop playing badminton). The system preserves the record that the user dislikes skiing. Additionally, a new memory is added that the user adopted a cat. Weather preference is unaffected." +} + +Always reply in the same language as the conversation. + +Conversation: +${conversation} +""" + +ASYNC_MEMORY_UPDATE_PROMPT_DICT = { + "zh": ASYNC_MEMORY_UPDATE_PROMPT_ZH, + "en": ASYNC_MEMORY_UPDATE_PROMPT_EN, +} + MEMORY_MERGE_PROMPT_ZH = """ 您是记忆库维护专家。 我们尝试更新一个记忆节点,但该节点在数据库中的内容在处理期间发生了变化(版本冲突)。 @@ -166,3 +317,28 @@ 请只返回合并后的记忆内容字符串,不要包含任何解释。 """ + +MEMORY_MERGE_PROMPT_EN = """ +You are a memory maintenance expert. +We attempted to update a memory node, but the content of that node changed in the database during processing (version conflict). +We need to merge "the update derived in this attempt" into "the latest memory content currently stored in the database". + +Original memory (latest version in the database): +${latest_memory} + +Proposed update (derived based on an old version): +${proposed_update} + +Task: +Merge "the proposed update" into "the original memory". +- If the update contains new information, integrate it. +- If the update conflicts with the original memory, prefer the update (assuming it is a correction based on the latest conversation), while preserving any details from the original memory that remain valid. +- Ensure the merged result is a coherent, fluent, and complete memory passage. + +Return ONLY the merged memory content string. Do not include any explanation. +""" + +MEMORY_MERGE_PROMPT_DICT = { + "zh": MEMORY_MERGE_PROMPT_ZH, + "en": MEMORY_MERGE_PROMPT_EN, +} From b3e1bdc15128c3544ec026e471acdb303a7b2f13 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 18:49:27 +0800 Subject: [PATCH 21/92] feat: Merge memory version features into the fine/async pathway in add process --- src/memos/api/config.py | 1 + src/memos/mem_reader/multi_modal_struct.py | 89 ++++-- .../handlers/mem_read_handler.py | 99 +++---- src/memos/memories/textual/item.py | 4 +- .../organize/history_manager.py | 265 ++++++++++-------- 5 files changed, 274 insertions(+), 184 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 70d9366e3..c4a126140 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -834,6 +834,7 @@ def get_product_default_config() -> dict[str, Any]: "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/" ), }, + "memory_version_switch": os.getenv("MEM_READER_MEM_VERSION_SWITCH", "off"), }, }, "enable_textual_memory": True, diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 16098907d..bb576fe24 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -701,6 +701,21 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 0: Memory versioning async update pipeline ========== + if ( + self.memory_version_switch == "on" + and self.history_manager is not None + and self.history_manager.is_applicable(fast_item) + ): + try: + new_items = self._process_async_versioning_update( + fast_item, mem_str, custom_tags, **kwargs + ) + return new_items + except Exception as e: + logger.warning(f"[MultiModalFine] Async versioning pipeline failed: {e}") + return [] + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) @@ -711,14 +726,15 @@ def _process_one_item( if resp.get("memory list", []): for m in resp.get("memory list", []): try: - # Check and merge with similar memories if needed - m_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=m, - mem_text=m.get("value", ""), - sources=sources, - original_query=mem_str, - **kwargs, - ) + m_maybe_merged = m + if self.memory_version_switch != "on": + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( m_maybe_merged.get("memory_type", "LongTermMemory") @@ -735,8 +751,7 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in m_maybe_merged: + if self.memory_version_switch != "on" and "merged_from" in m_maybe_merged: node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) @@ -745,13 +760,15 @@ def _process_one_item( elif resp.get("value") and resp.get("key"): try: # Check and merge with similar memories if needed - resp_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=resp, - mem_text=resp.get("value", "").strip(), - sources=sources, - original_query=mem_str, - **kwargs, - ) + resp_maybe_merged = resp + if self.memory_version_switch != "on": + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, @@ -762,8 +779,7 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in resp_maybe_merged: + if self.memory_version_switch != "on" and "merged_from" in resp_maybe_merged: node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) @@ -955,6 +971,8 @@ def _fast_resolve_memory_duplicates_and_conflicts( return for item in fast_memory_items: + if not self.history_manager.is_applicable(item): + continue try: # recall related memories related = self.pre_update_retriever.retrieve( @@ -962,17 +980,46 @@ def _fast_resolve_memory_duplicates_and_conflicts( user_name=user_name, ) # NLI check & attaching contents - conflicting_or_duplicate_items = self.history_manager.resolve_history_via_nli( + conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( item, related ) # mark delete(temporarily) self.history_manager.mark_memory_status( - [m.id for m in conflicting_or_duplicate_items], "resolving" + conflicting_or_duplicate_ids, "resolving", user_name=user_name ) except Exception as e: logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + def _process_async_versioning_update( + self, item: TextualMemoryItem, mem_str: str, custom_tags: dict[str, str], **kwargs + ) -> list[TextualMemoryItem]: + """ + 1. Wait for fast node resolution and rebuild its history + 2. Build async update prompt (include custom tags and conversation context) + 3. Call LLM and parse JSON response + 4. Apply LLM updates to memory graph and return new items + """ + self.history_manager.wait_and_update_fast_history(item, timeout_sec=30) + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = self.history_manager.format_async_update_prompt( + item, + conversation=kwargs.get("chat_history") or mem_str, + custom_tags_prompt=custom_tags_prompt, + ) + response_text = self.llm.generate([{"role": "user", "content": prompt}]) + response_json = parse_json_result(response_text) + user_name = kwargs.get("user_name") + _, new_items = self.history_manager.apply_llm_memory_updates( + response_json, item, user_name=user_name + ) + return new_items + @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 5d86c5589..bd1d69671 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -196,59 +196,60 @@ def _process_memories_with_reader( enhanced_mem_ids, ) - # add raw file nodes and edges - if mem_reader.save_rawfile: - raw_file_mem_group = [ + if getattr(mem_reader, "memory_version_switch", "off") != "on": + # add raw file nodes and edges + if mem_reader.save_rawfile: + raw_file_mem_group = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type == "RawFileMemory" + ] + text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + enhanced_mem_ids, + user_id=user_id, + user_name=user_name, + ) + logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + + # Mark merged_from memories as archived when provided in memory metadata + summary_memories = [ memory for memory in flattened_memories - if memory.metadata.memory_type == "RawFileMemory" + if memory.metadata.memory_type != "RawFileMemory" ] - text_mem.add_rawfile_nodes_n_edges( - raw_file_mem_group, - enhanced_mem_ids, - user_id=user_id, - user_name=user_name, - ) - logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) - - # Mark merged_from memories as archived when provided in memory metadata - summary_memories = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type != "RawFileMemory" - ] - if mem_reader.graph_db: - for memory in summary_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - "[Scheduler] Archived merged_from memory: %s", - old_id, - ) - except Exception as e: - logger.warning( - "[Scheduler] Failed to archive merged_from memory %s: %s", - old_id, - e, - ) - else: - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in summary_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + if mem_reader.graph_db: + for memory in summary_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), {"status": "archived"}, user_name=user_name + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in summary_memories ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) cloud_env = is_cloud_env() if cloud_env: diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 7e40f1d50..573a83856 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -109,8 +109,8 @@ class TextualMemoryMetadata(BaseModel): default_factory=list, description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", ) - version: int | None = Field( - default=None, + version: int = Field( + default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) history: list[ArchivedTextualMemory] | None = Field( diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 65295c769..dbe31a4a2 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -109,6 +109,21 @@ def _add(history_item): item.metadata.history = list(new_history.values()) +def _determine_lang(sources: list | None, fallback_text: str) -> str: + lang = None + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang + break + if isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(fallback_text) + return lang + + class MemoryHistoryManager: def __init__( self, nli_client: NLIClient, graph_db: BaseGraphDB, llm: BaseLLM | None = None @@ -125,6 +140,18 @@ def __init__( self.graph_db = graph_db self.llm = llm + @staticmethod + def is_applicable(item: TextualMemoryItem) -> bool: + # Only deals with: + # 1. From doc or chat + # 2. LongTermMemory, UserMemory + allowed_sources = ["doc", "chat"] + allowed_memory_types = ["LongTermMemory", "UserMemory"] + return ( + item.metadata.sources[0].type in allowed_sources + and item.metadata.memory_type in allowed_memory_types + ) + @staticmethod def update_node_with_history( item: TextualMemoryItem, @@ -163,70 +190,6 @@ def update_node_with_history( return item, archived_item - def _determine_lang(self, sources: list | None, fallback_text: str) -> str: - lang = None - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - if isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(fallback_text) - return lang - - def format_async_update_prompt( - self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" - ) -> str: - """ - Format the prompt for asynchronous memory update. - - Args: - item: The TextualMemoryItem containing history candidates. - conversation: The current conversation content. If None, uses item.memory. - custom_tags_prompt: Optional custom prompt for tags. - - Returns: - Formatted prompt string. - """ - # First, detach duplicate and conflict memory contents from the new item's memory text - _detach_related_content(item) - - duplicate_candidates = [] - conflict_candidates = [] - unrelated_candidates = [] - - for h in item.metadata.history or []: - candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" - - if h.update_type == "duplicate": - duplicate_candidates.append(candidate_str) - elif h.update_type == "conflict": - conflict_candidates.append(candidate_str) - else: - # Includes "unrelated" and any other types - unrelated_candidates.append(candidate_str) - - sources = item.metadata.sources if item.metadata else None - lang = self._determine_lang(sources, conversation or item.memory) - empty_label = "None" - - def format_list(candidates): - return "\n".join(candidates) if candidates else empty_label - - prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( - lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] - ) - return ( - prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) - .replace("${conflict_candidates}", format_list(conflict_candidates)) - .replace("${unrelated_candidates}", format_list(unrelated_candidates)) - .replace("${custom_tags_prompt}", custom_tags_prompt) - .replace("${conversation}", conversation or item.memory) - ) - def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] ) -> list[str]: @@ -250,16 +213,20 @@ def resolve_history_via_nli( ) # 2. Process results and attach to history + duplicate_memory_ids = [] + conflict_memory_ids = [] duplicate_memories = [] conflict_memories = [] for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" - duplicate_memories.append(r_item.id) + duplicate_memory_ids.append(r_item.id) + duplicate_memories.append(r_item.memory) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" - conflict_memories.append(r_item.id) + conflict_memory_ids.append(r_item.id) + conflict_memories.append(r_item.memory) else: update_type = "unrelated" @@ -280,7 +247,7 @@ def resolve_history_via_nli( # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. _append_related_content(new_item, duplicate_memories, conflict_memories) - return duplicate_memories + conflict_memories + return duplicate_memory_ids + conflict_memory_ids def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: """ @@ -338,6 +305,56 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int return + def format_async_update_prompt( + self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + ) -> str: + """ + Format the prompt for asynchronous memory update. + + Args: + item: The TextualMemoryItem containing history candidates. + conversation: The current conversation content. If None, uses item.memory. + custom_tags_prompt: Optional custom prompt for tags. + + Returns: + Formatted prompt string. + """ + # First, detach duplicate and conflict memory contents from the new item's memory text + _detach_related_content(item) + + duplicate_candidates = [] + conflict_candidates = [] + unrelated_candidates = [] + + for h in item.metadata.history or []: + candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" + + if h.update_type == "duplicate": + duplicate_candidates.append(candidate_str) + elif h.update_type == "conflict": + conflict_candidates.append(candidate_str) + else: + # Includes "unrelated" and any other types + unrelated_candidates.append(candidate_str) + + sources = item.metadata.sources if item.metadata else None + lang = _determine_lang(sources, conversation or item.memory) + empty_label = "None" + + def format_list(candidates): + return "\n".join(candidates) if candidates else empty_label + + prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( + lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] + ) + return ( + prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) + .replace("${conflict_candidates}", format_list(conflict_candidates)) + .replace("${unrelated_candidates}", format_list(unrelated_candidates)) + .replace("${custom_tags_prompt}", custom_tags_prompt) + .replace("${conversation}", conversation or item.memory) + ) + def apply_llm_memory_updates( self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: @@ -388,7 +405,7 @@ def apply_llm_memory_updates( # 2. Handle Memory List (Update or New) processed_updates, created_items = self._process_memory_updates( - memory_list, expected_versions + memory_list, expected_versions, user_name ) updated_items.extend(processed_updates) new_items.extend(created_items) @@ -499,7 +516,7 @@ def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) self.mark_memory_status(unrelated_ids, "activated", user_name) def _process_memory_updates( - self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int] + self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int], user_name: str ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: """Process Memory List (Update or New).""" updated_items: list[TextualMemoryItem] = [] @@ -513,10 +530,7 @@ def _process_memory_updates( if target_ids: item = self._update_existing_memory( - mem_data, - target_ids, - source_ids, - expected_versions, + mem_data, target_ids, source_ids, expected_versions, user_name ) if item: updated_items.append(item) @@ -531,6 +545,7 @@ def _update_existing_memory( target_ids: list[str], source_ids: list[str], expected_versions: dict[str, int], + user_name: str, ) -> TextualMemoryItem | None: """ Update existing memory nodes using the LLM result. @@ -559,18 +574,19 @@ def _update_existing_memory( ) # Fetch candidate nodes in batch and then select the primary + # We update the primary and then merge the secondaries to the primary nodes_data = self.graph_db.get_nodes(target_ids) or [] nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} node_data = nodes_map.get(primary_id) if not node_data: - node_data = self.graph_db.get_node(primary_id) - if not node_data: - logger.warning( - f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." - ) - return None - + logger.warning( + f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." + ) + return None current_item = TextualMemoryItem(**node_data) + + # For concurrency control, need to make sure the primary item has not been modified by others in the meantime + # If it has(version changed), then we need to use llm to merge again. new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) update_type = "duplicate" if original_primary_id in source_ids else "conflict" @@ -581,28 +597,41 @@ def _update_existing_memory( tags=tags, key=key, ) - now = datetime.now().isoformat() - if hasattr(archived_item.metadata, "created_at") and not getattr( - archived_item.metadata, "created_at", None - ): - archived_item.metadata.created_at = now - if hasattr(archived_item.metadata, "updated_at"): - archived_item.metadata.updated_at = now - if hasattr(current_item.metadata, "updated_at"): - current_item.metadata.updated_at = now + + # create archived node for storing older versions of the memory self.graph_db.add_node( id=archived_item.id, memory=archived_item.memory, metadata=archived_item.metadata.model_dump(exclude_none=True), + user_name=user_name, ) + + now = datetime.now().isoformat() + if hasattr(current_item.metadata, "updated_at"): + current_item.metadata.updated_at = now fields = current_item.metadata.model_dump(exclude_none=True) fields.pop("created_at", None) - fields["updated_at"] = now + merged_history = list(current_item.metadata.history or []) + new_primary_version = current_item.metadata.version or 1 + # Multiple related ids indicates existing duplicates/conflicts to be merged + if secondary_ids: + merged_history, new_primary_version = self._merge_secondary_nodes( + secondary_ids, primary_id, nodes_map, user_name, merged_history + ) + current_item.metadata.history = merged_history + current_item.metadata.version = new_primary_version + merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] + # update old memory node with new content and updated history self.graph_db.update_node( id=primary_id, - fields={"memory": current_item.memory, **fields}, + fields={ + "memory": current_item.memory, + **fields, + "history": merged_history_dump, + "version": new_primary_version, + }, + user_name=user_name, ) - self._merge_secondary_nodes(secondary_ids, primary_id) return current_item @@ -629,36 +658,35 @@ def _apply_cas_merge( return new_value - def _merge_secondary_nodes(self, secondary_ids: list[str], primary_id: str) -> None: - if not secondary_ids: - return - - node_ids = [primary_id, *secondary_ids] - nodes_data = self.graph_db.get_nodes(ids=node_ids) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} - primary_data = nodes_map.get(primary_id) or self.graph_db.get_node(primary_id) - if not primary_data: - return - - primary_item = TextualMemoryItem(**primary_data) - merged_history = list(primary_item.metadata.history or []) + def _merge_secondary_nodes( + self, + secondary_ids: list[str], + primary_id: str, + nodes_map: dict, + user_name: str, + base_history: list[ArchivedTextualMemory], + ) -> tuple[list[ArchivedTextualMemory], int]: + merged_history = list(base_history) for memory_id in secondary_ids: - node_data = nodes_map.get(memory_id) or self.graph_db.get_node(memory_id) + node_data = nodes_map.get(memory_id) if not node_data: continue metadata = node_data.get("metadata", {}) evolve_to = list(metadata.get("evolve_to", []) or []) if primary_id not in evolve_to: evolve_to.append(primary_id) + # set secondary nodes to archived and record their evolving destinations self.graph_db.update_node( id=memory_id, fields={"status": "archived", "evolve_to": evolve_to}, + user_name=user_name, ) secondary_item = TextualMemoryItem(**node_data) if secondary_item.metadata.history: merged_history.extend(secondary_item.metadata.history) + # Currently we just sort the versions according to their creation time def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: created_at = history_item.created_at if isinstance(created_at, datetime): @@ -670,14 +698,27 @@ def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: return datetime.min return datetime.min + def _dedupe_history_by_archived_id( + history: list[ArchivedTextualMemory], + ) -> list[ArchivedTextualMemory]: + seen_archived_ids: set[str] = set() + deduped_history: list[ArchivedTextualMemory] = [] + for history_item in history: + archived_id = history_item.archived_memory_id + if archived_id and archived_id in seen_archived_ids: + continue + if archived_id: + seen_archived_ids.add(archived_id) + deduped_history.append(history_item) + return deduped_history + merged_history.sort(key=_history_sort_key) + merged_history = _dedupe_history_by_archived_id(merged_history) + max_version = 0 for idx, history_item in enumerate(merged_history, start=1): history_item.version = idx - primary_item.metadata.history = merged_history - self.graph_db.update_node( - id=primary_id, - fields={"history": [h.model_dump(exclude_none=True) for h in merged_history]}, - ) + max_version = idx + return merged_history, max_version + 1 def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: """ @@ -686,7 +727,7 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> if not self.llm: return proposed_update - lang = self._determine_lang(None, f"{latest_memory}\n{proposed_update}") + lang = _determine_lang(None, f"{latest_memory}\n{proposed_update}") prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( "${proposed_update}", proposed_update From cf7ad8d87527685386b2fe3e0cc67f79afe3dac4 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 21:31:49 +0800 Subject: [PATCH 22/92] fix: Replenish missing embeddings during memory version updates/adds. --- src/memos/api/handlers/component_init.py | 4 +++- .../organize/history_manager.py | 21 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index e764c7fd4..e1b4c5f2f 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -192,7 +192,9 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db, llm=llm) + memory_history_manager = MemoryHistoryManager( + nli_client=nli_client, graph_db=graph_db, llm=llm, embedder=embedder + ) pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config( diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index dbe31a4a2..5b0f411d5 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -6,6 +6,7 @@ from typing import Any, Literal from memos.context.context import ContextThreadPoolExecutor +from memos.embedders.base import BaseEmbedder from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB @@ -126,7 +127,11 @@ def _determine_lang(sources: list | None, fallback_text: str) -> str: class MemoryHistoryManager: def __init__( - self, nli_client: NLIClient, graph_db: BaseGraphDB, llm: BaseLLM | None = None + self, + nli_client: NLIClient, + graph_db: BaseGraphDB, + llm: BaseLLM | None = None, + embedder: BaseEmbedder | None = None, ) -> None: """ Initialize the MemoryHistoryManager. @@ -139,6 +144,16 @@ def __init__( self.nli_client = nli_client self.graph_db = graph_db self.llm = llm + self.embedder = embedder + + def _compute_embedding(self, text: str) -> list[float] | None: + if not self.embedder: + return None + try: + return self.embedder.embed([text])[0] + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to compute embedding: {e}") + return None @staticmethod def is_applicable(item: TextualMemoryItem) -> bool: @@ -621,6 +636,7 @@ def _update_existing_memory( current_item.metadata.history = merged_history current_item.metadata.version = new_primary_version merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] + embedding = self._compute_embedding(current_item.memory) # update old memory node with new content and updated history self.graph_db.update_node( id=primary_id, @@ -629,6 +645,7 @@ def _update_existing_memory( **fields, "history": merged_history_dump, "version": new_primary_version, + "embedding": embedding, }, user_name=user_name, ) @@ -764,6 +781,7 @@ def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: key=key, created_at=datetime.now().isoformat(), history=[], + embedding=self._compute_embedding(new_value), ), ) return new_item @@ -807,6 +825,7 @@ def _handle_restored_memories( tags=tags, created_at=datetime.now().isoformat(), history=source_history, + embedding=self._compute_embedding(value), ), ) created_items.append(new_item) From 418602969539ccd11fd46eede0121c26a40a682c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Feb 2026 22:05:45 +0800 Subject: [PATCH 23/92] fix: Identify fast nodes correctly in memory versions. --- .../textual/tree_text_memory/organize/history_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 5b0f411d5..f8294b82e 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -248,9 +248,13 @@ def resolve_history_via_nli( # Safely get created_at, fallback to updated_at created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at + # TODO: change the way of marking fast nodes by directly using is_fast field. archived = ArchivedTextualMemory( version=r_item.metadata.version or 1, - is_fast=r_item.metadata.is_fast or False, + is_fast=( + r_item.metadata.is_fast + or ("mode:fast" in (getattr(r_item.metadata, "tags", None) or [])) + ), memory=r_item.memory, update_type=update_type, archived_memory_id=r_item.id, From 40dae4ebe3e811ea57bd236049f5c56b23f9eff9 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 10:21:24 +0800 Subject: [PATCH 24/92] fix: Fix branch merge --- .../handlers/mem_read_handler.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index bd1d69671..c2a10b531 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -196,22 +196,23 @@ def _process_memories_with_reader( enhanced_mem_ids, ) - if getattr(mem_reader, "memory_version_switch", "off") != "on": - # add raw file nodes and edges - if mem_reader.save_rawfile: - raw_file_mem_group = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type == "RawFileMemory" - ] - text_mem.add_rawfile_nodes_n_edges( - raw_file_mem_group, - enhanced_mem_ids, - user_id=user_id, - user_name=user_name, - ) - logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + # add raw file nodes and edges + if mem_reader.save_rawfile: + raw_file_mem_group = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type == "RawFileMemory" + ] + text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + enhanced_mem_ids, + user_id=user_id, + user_name=user_name, + ) + logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) + # fallback to simple deduplication logic when mem version switch is off + if getattr(mem_reader, "memory_version_switch", "off") != "on": # Mark merged_from memories as archived when provided in memory metadata summary_memories = [ memory @@ -230,7 +231,9 @@ def _process_memories_with_reader( for old_id in old_ids: try: mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name + str(old_id), + {"status": "archived"}, + user_name=user_name, ) logger.info( "[Scheduler] Archived merged_from memory: %s", From 64471396df2266a5b442da945d393d167afe0ce8 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 14:57:03 +0800 Subject: [PATCH 25/92] fix: Fix WorkingMemory behaviors and time settings in memory versions. --- .../organize/history_manager.py | 20 ++++++++++++++----- .../tree_text_memory/organize/manager.py | 9 ++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index f8294b82e..e19c0d593 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -175,6 +175,9 @@ def update_node_with_history( tags: list[str] | None = None, key: str | None = None, ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + now = datetime.now().isoformat() + last_update_time = item.metadata.updated_at + old_id = item.id archived_id = str(uuid.uuid4()) # archived memory(need to store this node to the db later) @@ -182,6 +185,8 @@ def update_node_with_history( archived_item.id = archived_id archived_item.metadata.evolve_to = [old_id] archived_item.metadata.status = "archived" + archived_item.metadata.created_at = last_update_time + archived_item.metadata.updated_at = now # original memory with updated contents and history archived_history = ArchivedTextualMemory( @@ -190,11 +195,12 @@ def update_node_with_history( memory=item.memory, update_type=update_type, archived_memory_id=archived_id, - created_at=getattr(item.metadata, "created_at", None) or datetime.now().isoformat(), + created_at=getattr(item.metadata, "updated_at", None) or last_update_time, ) item.memory = new_memory item.metadata.version = (item.metadata.version or 1) + 1 item.metadata.status = "activated" + item.metadata.updated_at = now if tags is not None: item.metadata.tags = tags if key is not None: @@ -625,11 +631,7 @@ def _update_existing_memory( user_name=user_name, ) - now = datetime.now().isoformat() - if hasattr(current_item.metadata, "updated_at"): - current_item.metadata.updated_at = now fields = current_item.metadata.model_dump(exclude_none=True) - fields.pop("created_at", None) merged_history = list(current_item.metadata.history or []) new_primary_version = current_item.metadata.version or 1 # Multiple related ids indicates existing duplicates/conflicts to be merged @@ -653,6 +655,14 @@ def _update_existing_memory( }, user_name=user_name, ) + working_binding = getattr(current_item.metadata, "working_binding", None) + if working_binding and working_binding != current_item.id: + try: + self.mark_memory_status([str(working_binding)], "deleted", user_name=user_name) + except Exception as e: + logger.warning( + f"[MemoryHistoryManager] Failed to mark WorkingMemory {working_binding} as deleted: {e}" + ) return current_item diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index cbc349d67..9733553c0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -191,10 +191,12 @@ def _add_memories_batch( ) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() + metadata_dict["working_binding"] = working_id # Add working_binding for fast mode tags = metadata_dict.get("tags") or [] if "mode:fast" in tags: + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_id}] direct built from raw inputs" metadata_dict["background"] = ( @@ -234,7 +236,8 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: exc_info=e, ) - _submit_batches(working_nodes, "WorkingMemory") + # TODO: working id is same with item.id, need to fix, currently stop adding WorkingMemories here. + # here used to be: _submit_batches(working_nodes, "WorkingMemory") _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: @@ -318,6 +321,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non ids: list[str] = [] futures = [] + # TODO: working id is same with item.id, need to fix working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: @@ -393,8 +397,11 @@ def _add_to_graph_memory( node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) + if working_binding: + metadata_dict["working_binding"] = working_binding tags = metadata_dict.get("tags") or [] if working_binding and ("mode:fast" in tags): + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" if prev_bg: From 1e45115f79a24aa64a0a00193f005f9e630ef5ce Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 15:49:37 +0800 Subject: [PATCH 26/92] feat: Refine prompts for memory versions to take account of time factors. --- src/memos/mem_reader/multi_modal_struct.py | 6 +--- .../organize/history_manager.py | 25 ++++++++++++---- .../mem_reader_mem_version_prompts.py | 30 ++++++++++++++----- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index bb576fe24..8237e22e5 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1007,11 +1007,7 @@ def _process_async_versioning_update( if custom_tags else "" ) - prompt = self.history_manager.format_async_update_prompt( - item, - conversation=kwargs.get("chat_history") or mem_str, - custom_tags_prompt=custom_tags_prompt, - ) + prompt = self.history_manager.format_async_update_prompt(item, custom_tags_prompt) response_text = self.llm.generate([{"role": "user", "content": prompt}]) response_json = parse_json_result(response_text) user_name = kwargs.get("user_name") diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index e19c0d593..3c719b189 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -331,14 +331,13 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int return def format_async_update_prompt( - self, item: TextualMemoryItem, conversation: str | None = None, custom_tags_prompt: str = "" + self, item: TextualMemoryItem, custom_tags_prompt: str = "" ) -> str: """ Format the prompt for asynchronous memory update. Args: item: The TextualMemoryItem containing history candidates. - conversation: The current conversation content. If None, uses item.memory. custom_tags_prompt: Optional custom prompt for tags. Returns: @@ -351,8 +350,20 @@ def format_async_update_prompt( conflict_candidates = [] unrelated_candidates = [] + def _fmt_time(ts: str | None) -> str | None: + if not ts or not isinstance(ts, str): + return None + try: + t = datetime.fromisoformat(ts.replace("Z", "")) + return t.strftime("%Y/%m/%d %H:%M:%S") + except Exception: + return ts + for h in item.metadata.history or []: - candidate_str = f"[ID:{h.archived_memory_id}] {h.memory}" + created = getattr(h, "created_at", None) + tstr = _fmt_time(created) + time_suffix = f"[Time: {tstr}] " if tstr else "" + candidate_str = f"[ID:{h.archived_memory_id}]{time_suffix}{h.memory}" if h.update_type == "duplicate": duplicate_candidates.append(candidate_str) @@ -363,7 +374,7 @@ def format_async_update_prompt( unrelated_candidates.append(candidate_str) sources = item.metadata.sources if item.metadata else None - lang = _determine_lang(sources, conversation or item.memory) + lang = _determine_lang(sources, item.memory) empty_label = "None" def format_list(candidates): @@ -372,12 +383,16 @@ def format_list(candidates): prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] ) + conversation_time_raw = getattr(item.metadata, "created_at", None) + conversation_time = _fmt_time(conversation_time_raw) or conversation_time_raw + return ( prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) .replace("${conflict_candidates}", format_list(conflict_candidates)) .replace("${unrelated_candidates}", format_list(unrelated_candidates)) .replace("${custom_tags_prompt}", custom_tags_prompt) - .replace("${conversation}", conversation or item.memory) + .replace("${conversation_time}", conversation_time) + .replace("${conversation}", item.memory) ) def apply_llm_memory_updates( diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index 61e56f86c..d9d57065e 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -33,6 +33,7 @@ - 解析代词和模糊指代。 - 使用第三人称(“用户”)。 - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 + - 候选记忆可能包含 [Time: ...] 表示该记忆的事件时间,请结合“对话时间”判断是否同一时段。 4. **处理逻辑(更新与新增)**: 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: @@ -92,14 +93,17 @@ 示例: **潜在重复/关联记忆 (Duplicate/Related Candidates):** -[ID:101] 用户喜欢喝拿铁,通常不加糖。 -[ID:102] 用户讨厌下雨天。 +[ID:101][Time: 2025/05/20 09:30:00] 用户喜欢喝拿铁,通常不加糖。 +[ID:102][Time: 2025/05/18 18:00:00] 用户讨厌下雨天。 **潜在事实冲突记忆 (Conflict Candidates):** -[ID:201] 用户喜欢打羽毛球,但不喜欢滑雪。 +[ID:201][Time: 2025/02/03 20:15:00] 用户喜欢打羽毛球,但不喜欢滑雪。 + +**对话时间:** +2025/06/26 09:00:00 **对话:** -user: [2025年6月26日]:我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 +user: 我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 **输出:** { @@ -143,6 +147,9 @@ 请始终使用与对话相同的语言进行回复。 +对话时间: +${conversation_time} + 对话: ${conversation} """ @@ -179,6 +186,7 @@ - Resolve pronouns and ambiguous references. - Use third person ("the user"). - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). + - Candidates may include [Time: ...] to indicate event time; use the conversation time to judge whether they are the same period. 4. Processing Logic (Update and Create): Traverse each piece of information in the conversation that is worth remembering and apply: @@ -238,14 +246,17 @@ Example: Duplicate/Related Candidates: -[ID:101] The user likes latte and usually doesn't add sugar. -[ID:102] The user hates rainy days. +[ID:101][Time: 2025/05/20 09:30:00] The user likes latte and usually doesn't add sugar. +[ID:102][Time: 2025/05/18 18:00:00] The user hates rainy days. Conflict Candidates: -[ID:201] The user likes badminton but dislikes skiing. +[ID:201][Time: 2025/02/03 20:15:00] The user likes badminton but dislikes skiing. + +Conversation time: +2025/06/26 09:00:00 Conversation: -user: [June 26, 2025]: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. +user: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. Output: { @@ -289,6 +300,9 @@ Always reply in the same language as the conversation. +Conversation time: +${conversation_time} + Conversation: ${conversation} """ From 36553dad3ed9ef7f28d0da59e87d9f30f7b3a51c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 17:30:55 +0800 Subject: [PATCH 27/92] fix: Fix default values for memory items. --- src/memos/memories/textual/item.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 573a83856..654e19354 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -105,15 +105,15 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", ) - evolve_to: list[str] | None = Field( + evolve_to: list[str] = Field( default_factory=list, - description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + description="Recording which new memory nodes it 'evolves' to after llm extraction.", ) version: int = Field( default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) - history: list[ArchivedTextualMemory] | None = Field( + history: list[ArchivedTextualMemory] = Field( default_factory=list, description="Storing the archived versions of the memory. Only preserving core information of each version.", ) From c7f83bdc1dceb2e95766bb63c61da5ac9bd4350c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Feb 2026 17:32:29 +0800 Subject: [PATCH 28/92] test: Update test for memory versions --- .../memories/textual/test_history_manager.py | 589 +++++++++++++++++- 1 file changed, 576 insertions(+), 13 deletions(-) diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 46cf3a1f6..f0ea8458f 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -8,13 +8,15 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.memories.textual.item import ( + ArchivedTextualMemory, TextualMemoryItem, - TextualMemoryMetadata, + TreeNodeTextualMemoryMetadata, ) from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, _append_related_content, _detach_related_content, + _rebuild_fast_node_history, ) @@ -36,7 +38,7 @@ def history_manager(mock_nli_client, mock_graph_db): def test_detach_related_content(): original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = ["Duplicate 1", "Duplicate 2"] conflicts = ["Conflict 1", "Conflict 2"] @@ -60,7 +62,7 @@ def test_detach_related_content(): def test_detach_only_conflicts(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = [] conflicts = ["Conflict A"] @@ -75,7 +77,7 @@ def test_detach_only_conflicts(): def test_detach_only_duplicates(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = ["Duplicate A"] conflicts = [] @@ -118,20 +120,581 @@ def test_mark_memory_status(history_manager, mock_graph_db): id1 = uuid.uuid4().hex id2 = uuid.uuid4().hex id3 = uuid.uuid4().hex - items = [ - TextualMemoryItem(memory="M1", id=id1), - TextualMemoryItem(memory="M2", id=id2), - TextualMemoryItem(memory="M3", id=id3), - ] + memory_ids = [id1, id2, id3] status = "resolving" # Action - history_manager.mark_memory_status(items, status) + history_manager.mark_memory_status(memory_ids, status, user_name="u1") # Assert assert mock_graph_db.update_node.call_count == 3 # Verify we called it correctly - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}) + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name="u1") + + +def test_format_async_update_prompt(history_manager): + # Setup + # Create history items + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="101", memory="Duplicate content", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="201", memory="Conflict content", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=1, archived_memory_id="301", memory="Unrelated content", update_type="unrelated" + ) + + item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]), + ) + + # Execute + prompt = history_manager.format_async_update_prompt(item) + + # Verify + assert "[ID:101]" in prompt + assert "Duplicate content" in prompt + assert "[ID:201]" in prompt + assert "Conflict content" in prompt + assert "[ID:301]" in prompt + assert "Unrelated content" in prompt + assert "New user input" in prompt + + # Check that placeholders are gone (basic check) + assert "${duplicate_candidates}" not in prompt + assert "${conflict_candidates}" not in prompt + + +def test_format_async_update_prompt_with_args(history_manager): + item = TextualMemoryItem( + memory="Explicit conversation", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + + prompt = history_manager.format_async_update_prompt(item, custom_tags_prompt="Custom Tags") + + assert "Explicit conversation" in prompt + assert "Custom Tags" in prompt + + +def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): + llm_response = { + "memory list": [ + { + "key": "New Memory", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": ["tag1"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + new_item = new_items[0] + assert new_item.memory == "New Content" + assert new_item.metadata.tags == ["tag1"] + assert new_item.metadata.history == [] + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db): + # Setup existing node + existing_id = uuid.uuid4().hex + existing_node = { + "id": existing_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=existing_id, + memory="Old Content", + update_type="duplicate", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == existing_id + assert updated_item.memory == "Updated Content" + assert updated_item.metadata.version == 2 + assert updated_item.metadata.tags == ["new"] + assert len(updated_item.metadata.history) == 1 + + history_entry = updated_item.metadata.history[0] + assert history_entry.archived_memory_id != existing_id + assert history_entry.archived_memory_id is not None + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "duplicate" + + mock_graph_db.add_node.assert_called_once() + mock_graph_db.update_node.assert_called_once() + args, kwargs = mock_graph_db.update_node.call_args + assert kwargs["id"] == existing_id + assert kwargs["fields"]["memory"] == "Updated Content" + assert kwargs["fields"]["version"] == 2 + + +def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): + source_id = uuid.uuid4().hex + restored_item = TextualMemoryItem( + memory="Restored Content", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + history_manager._handle_restored_memories = MagicMock(return_value=[restored_item]) + llm_response = { + "memory list": [], + "restored_memories": [ + {"source_candidate_id": source_id, "value": "Restored Content", "tags": ["restored"]} + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=source_id, + memory="Old Content", + update_type="conflict", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + assert new_items[0] == restored_item + history_manager._handle_restored_memories.assert_called_once_with( + llm_response["restored_memories"] + ) + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): + id1 = uuid.uuid4().hex + id2 = uuid.uuid4().hex + llm_response = {"memory list": [], "restored_memories": [], "summary": "Summary"} + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=id1, + memory="M1", + update_type="unrelated", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=id2, + memory="M2", + update_type="unrelated", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 0 + + # Check that update_node was called to set status="activated" + # mark_memory_status calls update_node for each item + assert mock_graph_db.update_node.call_count == 2 + + # We can inspect calls + calls = mock_graph_db.update_node.call_args_list + ids = sorted([c.kwargs["id"] for c in calls]) + assert ids == sorted([id1, id2]) + for c in calls: + assert c.kwargs["fields"]["status"] == "activated" + + +def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): + # Setup existing node (primary) + primary_id = uuid.uuid4().hex + secondary_id = uuid.uuid4().hex + existing_node = {"id": primary_id, "memory": "Old Content", "metadata": {"version": 1}} + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [ + existing_node, + { + "id": secondary_id, + "memory": "Secondary", + "metadata": {"version": 1}, + }, + ] + + llm_response = { + "memory list": [ + { + "key": "Conflict Resolved", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": [], + "source_candidate_ids": [], + "conflicted_candidate_ids": [primary_id, secondary_id], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=primary_id, + memory="Old Content", + update_type="conflict", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=secondary_id, + memory="Secondary", + update_type="conflict", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == primary_id + assert updated_item.metadata.history[0].update_type == "conflict" + + # Verify primary update + # The mock_graph_db.update_node is called for primary (update) AND secondary (delete) + + # Find call for primary + primary_update_calls = [ + c + for c in mock_graph_db.update_node.call_args_list + if c.kwargs["id"] == primary_id and "memory" in c.kwargs.get("fields", {}) + ] + assert len(primary_update_calls) >= 1 + assert primary_update_calls[0].kwargs["fields"]["memory"] == "New Content" + + # Find call for secondary + secondary_update_calls = [ + c for c in mock_graph_db.update_node.call_args_list if c.kwargs["id"] == secondary_id + ] + assert len(secondary_update_calls) >= 1 + last_secondary_update = secondary_update_calls[-1] + assert last_secondary_update.kwargs["fields"]["status"] == "archived" + assert last_secondary_update.kwargs["fields"]["evolve_to"] == [primary_id] + + +def test_rebuild_fast_node_history_dedup_and_replace(): + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="a", memory="m1", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="b", memory="m2", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=2, archived_memory_id="a", memory="m3", update_type="duplicate" + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]) + ) + + r1 = ArchivedTextualMemory( + version=2, archived_memory_id="b", memory="m4", update_type="conflict" + ) + r2 = ArchivedTextualMemory( + version=1, archived_memory_id="c", memory="m5", update_type="duplicate" + ) + + _rebuild_fast_node_history(item, {1: [r1, r2]}) + + by_id = {h.archived_memory_id: h for h in item.metadata.history} + assert set(by_id.keys()) == {"a", "b", "c"} + assert by_id["a"].version == 2 + assert by_id["b"].version == 2 + + +def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): + fast_id = uuid.uuid4().hex + history_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="conflict", is_fast=True + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[history_item]) + ) + mock_graph_db.get_nodes.return_value = [ + {"id": fast_id, "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}} + ] + + replacement_item = ArchivedTextualMemory( + version=1, archived_memory_id="n1", memory="r1", update_type="conflict" + ) + history_manager._fetch_evolved_nodes = MagicMock(return_value=[replacement_item]) + + replacements = history_manager._check_and_fetch_replacements(item, [0]) + + assert 0 in replacements + assert replacements[0][0].archived_memory_id == "n1" + history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict") + + +def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): + mock_graph_db.get_nodes.return_value = [ + { + "id": "x1", + "memory": "m1", + "metadata": {"version": 2, "is_fast": False, "created_at": "2024-01-01"}, + }, + { + "id": "x2", + "memory": "m2", + "metadata": {"version": 1, "is_fast": True, "created_at": "2024-01-02"}, + }, + ] + + results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate") + + assert len(results) == 2 + ids = sorted([r.archived_memory_id for r in results]) + assert ids == ["x1", "x2"] + assert all(r.update_type == "duplicate" for r in results) + + +def test_wait_and_update_fast_history_rebuilds(history_manager): + fast_id = uuid.uuid4().hex + fast_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="duplicate", is_fast=True + ) + other_item = ArchivedTextualMemory( + version=1, archived_memory_id="k1", memory="keep", update_type="unrelated", is_fast=False + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[fast_item, other_item]) + ) + + replacement = ArchivedTextualMemory( + version=2, archived_memory_id="n1", memory="new", update_type="duplicate", is_fast=False + ) + history_manager._check_and_fetch_replacements = MagicMock(return_value={0: [replacement]}) + + history_manager.wait_and_update_fast_history(item, timeout_sec=1) + + ids = [h.archived_memory_id for h in item.metadata.history] + assert "n1" in ids + assert fast_id not in ids + history_manager._check_and_fetch_replacements.assert_called_once() + + +def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): + llm = MagicMock() + llm.generate.return_value = "Merged Content" + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + existing_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2}, + } + mock_graph_db.get_nodes.return_value = [ + {"id": existing_id, "memory": "Old Content", "metadata": {"version": 2}} + ] + + mem_data = { + "key": "k", + "value": "Proposed", + "tags": ["t1"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + } + + updated = manager._update_existing_memory( + mem_data, [existing_id], [existing_id], {existing_id: 1}, user_name="u1" + ) + + assert updated.memory == "Merged Content" + assert updated.metadata.version == 3 + mock_graph_db.update_node.assert_called_once() + + +def test_update_existing_memory_marks_working_binding_deleted(history_manager, mock_graph_db): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + working_binding = uuid.uuid4().hex + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated = history_manager._update_existing_memory( + mem_data, [primary_id], [primary_id], {primary_id: 1}, user_name="u1" + ) + + assert updated is not None + history_manager.mark_memory_status.assert_called_once_with( + [str(working_binding)], "deleted", user_name="u1" + ) + + +def test_update_existing_memory_no_mark_when_working_binding_matches( + history_manager, mock_graph_db +): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated = history_manager._update_existing_memory( + mem_data, [primary_id], [primary_id], {primary_id: 1}, user_name="u1" + ) + + assert updated is not None + history_manager.mark_memory_status.assert_not_called() + + +def test_update_existing_memory_node_missing(history_manager, mock_graph_db): + mock_graph_db.get_node.return_value = None + mock_graph_db.get_nodes.return_value = [] + mem_data = {"value": "v", "tags": [], "key": "k"} + + updated = history_manager._update_existing_memory(mem_data, ["missing"], [], {}, user_name="u1") + + assert updated is None + mock_graph_db.update_node.assert_not_called() + + +def test_update_node_with_history(): + item = TextualMemoryItem( + memory="Old Content", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + tags=["old"], + key="k1", + history=[], + ), + ) + + updated, archived = MemoryHistoryManager.update_node_with_history( + item, + "New Content", + "conflict", + ) + + assert updated.memory == "New Content" + assert updated.metadata.version == 3 + assert updated.metadata.tags == ["old"] + assert updated.metadata.key == "k1" + assert len(updated.metadata.history) == 1 + history_entry = updated.metadata.history[0] + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "conflict" + assert history_entry.archived_memory_id == archived.id + assert archived.metadata.status == "archived" + assert archived.metadata.evolve_to == [updated.id] + + +def test_merge_conflicting_memory_llm_error(mock_graph_db): + llm = MagicMock() + llm.generate.side_effect = Exception("fail") + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + merged = manager._merge_conflicting_memory("Latest", "Proposed") + + assert "System Merge Fallback" in merged + assert "Latest" in merged + assert "Proposed" in merged From 7eb76789ff7dee36022cf36c7fbddd7c111ccaff Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 10:31:57 +0800 Subject: [PATCH 29/92] fix: Fix some behaviors in memory versions. --- src/memos/api/handlers/formatters_handler.py | 1 + .../organize/history_manager.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 06c4fd223..f515bb772 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -158,6 +158,7 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): sources = item.get("metadata", {}).get("sources", []) if ( item["metadata"]["memory_type"] != "RawFileMemory" + and sources and len(sources) > 0 and "type" in sources[0] and sources[0]["type"] == "file" diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 3c719b189..6dd273836 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -612,6 +612,11 @@ def _update_existing_memory( mem_data.get("tags", []), mem_data.get("key", ""), ) + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(new_value_item) + new_value = new_value_item.memory # Fetch candidate nodes in batch and then select the primary # We update the primary and then merge the secondaries to the primary @@ -696,8 +701,12 @@ def _apply_cas_merge( f"Expected v{expected_version}, but found v{current_version} in DB. " "Triggering merge logic." ) + latest_item = TextualMemoryItem( + memory=current_item.memory, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(latest_item) merged_content = self._merge_conflicting_memory( - latest_memory=current_item.memory, + latest_memory=latest_item.memory, proposed_update=new_value, ) return merged_content @@ -794,6 +803,11 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: """Create New Node.""" new_value = mem_data.get("value", "") + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(new_value_item) + new_value = new_value_item.memory tags = mem_data.get("tags", []) key = mem_data.get("key", "") memory_type = mem_data.get("memory_type", "LongTermMemory") @@ -828,6 +842,9 @@ def _handle_restored_memories( source_item = source_items[i] source_history = source_item.history.copy() value = data.get("value", "") + value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) + _detach_related_content(value_item) + value = value_item.memory tags = data.get("tags", []) keys = data.get("keys", []) memory_type = data.get("memory_type", "LongTermMemory") From 27ff12cf1277e7ddbc8cd7d0fa03399eabe2c53c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 11:49:25 +0800 Subject: [PATCH 30/92] feat: Add latency information for nli model in the log. --- src/memos/mem_reader/multi_modal_struct.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 8237e22e5..b552e6f37 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1,6 +1,7 @@ import concurrent.futures import json import re +import time import traceback from typing import TYPE_CHECKING, Any @@ -975,14 +976,30 @@ def _fast_resolve_memory_duplicates_and_conflicts( continue try: # recall related memories + retrieve_start = time.perf_counter() related = self.pre_update_retriever.retrieve( item=item, user_name=user_name, ) + retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 + logger.info( + "[MultiModalStruct] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", + retrieve_ms, + getattr(item, "id", None), + ) # NLI check & attaching contents + nli_start = time.perf_counter() conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( item, related ) + nli_ms = (time.perf_counter() - nli_start) * 1000 + logger.info( + "[MultiModalStruct] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", + nli_ms, + getattr(item, "id", None), + len(related), + len(conflicting_or_duplicate_ids), + ) # mark delete(temporarily) self.history_manager.mark_memory_status( conflicting_or_duplicate_ids, "resolving", user_name=user_name From d974dedeb16e135fe12f220f07eae8be8875e80a Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 16:05:03 +0800 Subject: [PATCH 31/92] Fix: Avoid latency problems caused by database "IN" operations --- .../textual/tree_text_memory/retrieve/pre_update.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index cb77d2243..bde9db770 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -210,8 +210,14 @@ def retrieve( # 2. Recall futures = [] common_filter = { - "status": {"in": ["activated", "resolving"]}, - "memory_type": {"in": ["LongTermMemory", "UserMemory", "WorkingMemory"]}, + "or": [ + {"status": "activated", "memory_type": "LongTermMemory"}, + {"status": "activated", "memory_type": "UserMemory"}, + {"status": "activated", "memory_type": "WorkingMemory"}, + {"status": "resolving", "memory_type": "LongTermMemory"}, + {"status": "resolving", "memory_type": "UserMemory"}, + {"status": "resolving", "memory_type": "WorkingMemory"}, + ] } with ContextThreadPoolExecutor(max_workers=3, thread_name_prefix="fast_recall") as executor: From 71ed69fbb9d74349b5ad9be4dbbb12b310d041c7 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 16:19:09 +0800 Subject: [PATCH 32/92] fix: Simplify the rapid recall process to prevent excessive latency. --- .../textual/tree_text_memory/retrieve/pre_update.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index bde9db770..1d66732d2 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -236,13 +236,7 @@ def retrieve( sim_threshold, ) ) - - # Task B: Keyword Search - futures.append( - executor.submit( - self.keyword_search, switched_query, user_name, top_k, common_filter - ) - ) + # TODO: recovering keyword search or other versions of search for multiple pathways # 3. Collect Results retrieved_ids = set() # for deduplicating ids From 03cd191880197c5d4994a35209b5dd5db290422e Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 18:29:56 +0800 Subject: [PATCH 33/92] fix: Fix metadata passing amongst mem versions. --- .../organize/history_manager.py | 84 ++++++++++++------- src/memos/multi_mem_cube/single_cube.py | 9 +- 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 6dd273836..731d23a8c 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -445,13 +445,13 @@ def apply_llm_memory_updates( # 2. Handle Memory List (Update or New) processed_updates, created_items = self._process_memory_updates( - memory_list, expected_versions, user_name + memory_list, expected_versions, user_name, source_item ) updated_items.extend(processed_updates) new_items.extend(created_items) # 3. Handle Restored Memories (Extract from conflict) - new_items.extend(self._handle_restored_memories(restored_memories)) + new_items.extend(self._handle_restored_memories(restored_memories, source_item)) return updated_items, new_items @@ -556,7 +556,11 @@ def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) self.mark_memory_status(unrelated_ids, "activated", user_name) def _process_memory_updates( - self, memory_list: list[dict[str, Any]], expected_versions: dict[str, int], user_name: str + self, + memory_list: list[dict[str, Any]], + expected_versions: dict[str, int], + user_name: str, + source_item: TextualMemoryItem, ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: """Process Memory List (Update or New).""" updated_items: list[TextualMemoryItem] = [] @@ -575,7 +579,7 @@ def _process_memory_updates( if item: updated_items.append(item) else: - item = self._create_new_memory(mem_data) + item = self._create_new_memory(mem_data, source_item) new_items.append(item) return updated_items, new_items @@ -643,11 +647,16 @@ def _update_existing_memory( key=key, ) - # create archived node for storing older versions of the memory + # create archived node for storing older versions of the memory, preserving the embedding + emb = TextualMemoryItem( + **self.graph_db.get_node(primary_id, include_embedding=True) + ).metadata.embedding + arch_meta = archived_item.metadata.model_dump(exclude_none=True) + arch_meta["embedding"] = emb self.graph_db.add_node( id=archived_item.id, memory=archived_item.memory, - metadata=archived_item.metadata.model_dump(exclude_none=True), + metadata=arch_meta, user_name=user_name, ) @@ -668,10 +677,10 @@ def _update_existing_memory( id=primary_id, fields={ "memory": current_item.memory, - **fields, "history": merged_history_dump, "version": new_primary_version, "embedding": embedding, + **fields, }, user_name=user_name, ) @@ -800,7 +809,9 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> # Let's concatenate as a safe fallback. return f"{latest_memory}\n\n[System Merge Fallback] New Info: {proposed_update}" - def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: + def _create_new_memory( + self, mem_data: dict[str, Any], fast_item: TextualMemoryItem + ) -> TextualMemoryItem: """Create New Node.""" new_value = mem_data.get("value", "") new_value_item = TextualMemoryItem( @@ -811,26 +822,30 @@ def _create_new_memory(self, mem_data: dict[str, Any]) -> TextualMemoryItem: tags = mem_data.get("tags", []) key = mem_data.get("key", "") memory_type = mem_data.get("memory_type", "LongTermMemory") + metadata_updates = { + "is_fast": False, + "version": 1, + "memory_type": memory_type, + "status": "activated", + "tags": tags, + "key": key, + "created_at": datetime.now().isoformat(), + "history": [], + "embedding": self._compute_embedding(new_value), + } + metadata = fast_item.metadata.model_copy(deep=True) + for field_name, value in metadata_updates.items(): + setattr(metadata, field_name, value) new_item = TextualMemoryItem( id=str(uuid.uuid4()), memory=new_value, - metadata=TreeNodeTextualMemoryMetadata( - is_fast=False, - version=1, - memory_type=memory_type, - status="activated", - tags=tags, - key=key, - created_at=datetime.now().isoformat(), - history=[], - embedding=self._compute_embedding(new_value), - ), + metadata=metadata, ) return new_item def _handle_restored_memories( - self, restored_memories: list[dict[str, Any]] + self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem ) -> list[TextualMemoryItem]: """Handle Restored Memories (Extract from conflict).""" source_ids = [r.get("source_candidate_id") for r in restored_memories] @@ -840,6 +855,7 @@ def _handle_restored_memories( created_items = [] for i, data in enumerate(restored_memories): source_item = source_items[i] + # deal with history source_history = source_item.history.copy() value = data.get("value", "") value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) @@ -859,21 +875,27 @@ def _handle_restored_memories( ) source_history.append(new_history_item) # Re-use the history of the old node # Create new node + metadata_updates = { + "memory_type": memory_type, + "status": "activated", + "is_fast": False, + "version": version + 1, + "tags": tags, + "key": keys, + "created_at": datetime.now().isoformat(), + "history": source_history, + "embedding": self._compute_embedding(value), + } + metadata = fast_item.metadata.model_copy(deep=True) + for field_name, value in metadata_updates.items(): + setattr(metadata, field_name, value) + new_item = TextualMemoryItem( id=str(uuid.uuid4()), memory=value, - metadata=TreeNodeTextualMemoryMetadata( - memory_type=memory_type, - status="activated", - is_fast=False, - version=version + 1, - key=keys, - tags=tags, - created_at=datetime.now().isoformat(), - history=source_history, - embedding=self._compute_embedding(value), - ), + metadata=metadata, ) + created_items.append(new_item) return created_items diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 307bf05b2..546d3043f 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -906,7 +906,14 @@ def _process_text_mem( ) # Mark merged_from memories as archived when provided in add_req.info - if sync_mode == "sync" and extract_mode == "fine": + if ( + sync_mode == "sync" + and extract_mode == "fine" + and ( + not hasattr(self.mem_reader, "memory_version_switch") + or self.mem_reader.memory_version_switch != "on" + ) + ): for memory in flattened_local: merged_from = (memory.metadata.info or {}).get("merged_from") if merged_from: From f572f4989ce59f41bab3e0bd88e97c9760491356 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 19:35:56 +0800 Subject: [PATCH 34/92] fix: Fix wrong sources for updated nodes. --- .../textual/tree_text_memory/organize/history_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 731d23a8c..f3032183b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -574,7 +574,7 @@ def _process_memory_updates( if target_ids: item = self._update_existing_memory( - mem_data, target_ids, source_ids, expected_versions, user_name + mem_data, target_ids, source_ids, expected_versions, user_name, source_item ) if item: updated_items.append(item) @@ -590,6 +590,7 @@ def _update_existing_memory( source_ids: list[str], expected_versions: dict[str, int], user_name: str, + fast_item: TextualMemoryItem, ) -> TextualMemoryItem | None: """ Update existing memory nodes using the LLM result. @@ -672,15 +673,17 @@ def _update_existing_memory( current_item.metadata.version = new_primary_version merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] embedding = self._compute_embedding(current_item.memory) + sources = [s.model_dump(exclude_none=True) for s in (fast_item.metadata.sources or [])] # update old memory node with new content and updated history self.graph_db.update_node( id=primary_id, fields={ + **fields, "memory": current_item.memory, "history": merged_history_dump, "version": new_primary_version, "embedding": embedding, - **fields, + "sources": sources, }, user_name=user_name, ) From f11c51fa0363c5c276281aef0e17b1fa7f119bcc Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Feb 2026 19:43:56 +0800 Subject: [PATCH 35/92] fix: Fix session_id. --- .../textual/tree_text_memory/organize/history_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index f3032183b..9148cf55b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -684,6 +684,7 @@ def _update_existing_memory( "version": new_primary_version, "embedding": embedding, "sources": sources, + "session_id": fast_item.metadata.session_id, }, user_name=user_name, ) From 3744d84b73739e9ac6cb4731444ddaa4de5884dc Mon Sep 17 00:00:00 2001 From: bittergreen Date: Thu, 12 Feb 2026 11:59:45 +0800 Subject: [PATCH 36/92] fix: Add fallback in mem versions to deal with llm failure situations --- src/memos/mem_reader/multi_modal_struct.py | 20 +++++--- .../organize/history_manager.py | 46 +++++++++++++++---- .../mem_reader_mem_version_prompts.py | 22 +++++---- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index b552e6f37..a5432cded 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1025,13 +1025,21 @@ def _process_async_versioning_update( else "" ) prompt = self.history_manager.format_async_update_prompt(item, custom_tags_prompt) - response_text = self.llm.generate([{"role": "user", "content": prompt}]) - response_json = parse_json_result(response_text) user_name = kwargs.get("user_name") - _, new_items = self.history_manager.apply_llm_memory_updates( - response_json, item, user_name=user_name - ) - return new_items + try: + response_text = self.llm.generate([{"role": "user", "content": prompt}]) + if not response_text: + raise ValueError("Empty LLM response") + response_json = parse_json_result(response_text) + if not response_json: + raise ValueError("Empty LLM JSON response") + _, new_items = self.history_manager.apply_llm_memory_updates( + response_json, item, user_name=user_name + ) + return new_items + except Exception as e: + logger.warning(f"[MultiModalStruct] Async update fallback due to LLM failure: {e}") + return self.history_manager.build_fallback_new_items(item, user_name=user_name) @timed def _process_multi_modal_data( diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 9148cf55b..250ad44a9 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -455,6 +455,30 @@ def apply_llm_memory_updates( return updated_items, new_items + def build_fallback_new_items( + self, item: TextualMemoryItem, user_name: str | None = None + ) -> list[TextualMemoryItem]: + latest_item = item.model_copy(deep=True) + _detach_related_content(latest_item) + + history = latest_item.metadata.history or [] + archived_ids = [h.archived_memory_id for h in history if h.archived_memory_id] + if archived_ids: + self.mark_memory_status(archived_ids, "activated", user_name or "") + + latest_item.id = str(uuid.uuid4()) + latest_item.metadata.is_fast = False + latest_item.metadata.status = "activated" + latest_item.metadata.history = [] + latest_item.metadata.working_binding = None + if hasattr(latest_item.metadata, "background"): + latest_item.metadata.background = "" + + if hasattr(latest_item.metadata, "tags") and latest_item.metadata.tags: + latest_item.metadata.tags = [t for t in latest_item.metadata.tags if t != "mode:fast"] + + return [latest_item] + def mark_memory_status( self, memory_ids: list[str], @@ -573,11 +597,13 @@ def _process_memory_updates( target_ids = source_ids + conflict_ids if target_ids: - item = self._update_existing_memory( + updated_item, new_item = self._update_existing_memory( mem_data, target_ids, source_ids, expected_versions, user_name, source_item ) - if item: - updated_items.append(item) + if updated_item: + updated_items.append(updated_item) + if new_item: + new_items.append(new_item) else: item = self._create_new_memory(mem_data, source_item) new_items.append(item) @@ -591,7 +617,7 @@ def _update_existing_memory( expected_versions: dict[str, int], user_name: str, fast_item: TextualMemoryItem, - ) -> TextualMemoryItem | None: + ) -> tuple[TextualMemoryItem | None, TextualMemoryItem | None]: """ Update existing memory nodes using the LLM result. @@ -604,8 +630,7 @@ def _update_existing_memory( The method also applies CAS validation via expected_versions, archives the previous version of the primary node, and persists the updated node back to the graph DB. - Returns the updated primary TextualMemoryItem, or None when the primary node - cannot be found. + Returns the updated primary TextualMemoryItem and optional new item when fallback is used. """ original_primary_id, primary_id, secondary_ids = ( target_ids[0], @@ -632,7 +657,9 @@ def _update_existing_memory( logger.warning( f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." ) - return None + # Fallback to create new item when the source_id is hallucinated by llm + new_item = self._create_new_memory(mem_data, fast_item) + return None, new_item current_item = TextualMemoryItem(**node_data) # For concurrency control, need to make sure the primary item has not been modified by others in the meantime @@ -697,7 +724,7 @@ def _update_existing_memory( f"[MemoryHistoryManager] Failed to mark WorkingMemory {working_binding} as deleted: {e}" ) - return current_item + return current_item, None def _apply_cas_merge( self, @@ -825,12 +852,15 @@ def _create_new_memory( new_value = new_value_item.memory tags = mem_data.get("tags", []) key = mem_data.get("key", "") + background = mem_data.get("summary", "") memory_type = mem_data.get("memory_type", "LongTermMemory") metadata_updates = { "is_fast": False, "version": 1, "memory_type": memory_type, "status": "activated", + "background": background, + "working_binding": None, "tags": tags, "key": key, "created_at": datetime.now().isoformat(), diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index d9d57065e..bb7e4acbf 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -88,6 +88,9 @@ 语言规则: - `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** - `memory_type` 保持英文。 +格式规则(STRICT): +- 必须输出**严格 JSON**,不允许出现尾随逗号。 +- 不要输出 Markdown、代码块或任何解释性文字。 ${custom_tags_prompt} @@ -114,7 +117,7 @@ "value": "用户最喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", "tags": ["饮食", "咖啡", "喜好"], "source_candidate_ids": ["101"], - "conflicted_candidate_ids": [], + "conflicted_candidate_ids": [] }, { "key": "运动习惯变更", @@ -122,7 +125,7 @@ "value": "用户因膝盖受伤,决定不再打羽毛球。", "tags": ["运动", "健康", "羽毛球"], "source_candidate_ids": [], - "conflicted_candidate_ids": ["201"], + "conflicted_candidate_ids": ["201"] }, { "key": "宠物状况", @@ -130,7 +133,7 @@ "value": "用户最近(2025年6月)刚养了一只猫。", "tags": ["生活", "宠物", "猫"], "source_candidate_ids": [], - "conflicted_candidate_ids": [], + "conflicted_candidate_ids": [] } ], "restored_memories": [ @@ -138,8 +141,8 @@ "key": "运动偏好", "memory_type": "UserMemory", "value": "用户不喜欢滑雪。", - "tags": ["运动", "滑雪", "喜好"] - "source_candidate_id": "201", + "tags": ["运动", "滑雪", "喜好"], + "source_candidate_id": "201" } ], "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了用户养猫的记忆。对天气的喜好未受影响。" @@ -241,6 +244,9 @@ Language rules: - The `key`, `value`, `tags`, and `summary` fields must match the main language of the input conversation. If the input is English, output English. - `memory_type` remains in English. +Format rules (STRICT): +- Output **strict JSON** only, no trailing commas. +- Do not include Markdown, code fences, or any explanations. ${custom_tags_prompt} @@ -267,7 +273,7 @@ "value": "The user likes latte most, usually doesn't add sugar, and prefers oat milk.", "tags": ["diet", "coffee", "preference"], "source_candidate_ids": ["101"], - "conflicted_candidate_ids": [], + "conflicted_candidate_ids": [] }, { "key": "Sport habit change", @@ -275,7 +281,7 @@ "value": "Due to a knee injury, the user decides to no longer play badminton.", "tags": ["sport", "health", "badminton"], "source_candidate_ids": [], - "conflicted_candidate_ids": ["201"], + "conflicted_candidate_ids": ["201"] }, { "key": "Pet status", @@ -283,7 +289,7 @@ "value": "The user recently (June 2025) adopted a cat.", "tags": ["life", "pet", "cat"], "source_candidate_ids": [], - "conflicted_candidate_ids": [], + "conflicted_candidate_ids": [] } ], "restored_memories": [ From 70afbed3ceac41a8767aafbf142b410f7daa3ef0 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Thu, 12 Feb 2026 12:17:07 +0800 Subject: [PATCH 37/92] fix: Temp fix metadata bug caused by db.get_node returning id and memory in metadata --- .../organize/history_manager.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 250ad44a9..306ce3efe 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -110,6 +110,22 @@ def _add(history_item): item.metadata.history = list(new_history.values()) +def _sanitize_metadata_dict(data: dict[str, Any] | None) -> dict[str, Any]: + if not data: + return {} + sanitized = data.copy() + for key in ("id", "memory", "graph_id"): + sanitized.pop(key, None) + return sanitized + + +def _sanitize_metadata_model( + metadata: TreeNodeTextualMemoryMetadata, +) -> TreeNodeTextualMemoryMetadata: + data = _sanitize_metadata_dict(metadata.model_dump(exclude_none=True)) + return metadata.__class__(**data) + + def _determine_lang(sources: list | None, fallback_text: str) -> str: lang = None if sources: @@ -477,6 +493,8 @@ def build_fallback_new_items( if hasattr(latest_item.metadata, "tags") and latest_item.metadata.tags: latest_item.metadata.tags = [t for t in latest_item.metadata.tags if t != "mode:fast"] + latest_item.metadata = _sanitize_metadata_model(latest_item.metadata) + return [latest_item] def mark_memory_status( @@ -679,7 +697,7 @@ def _update_existing_memory( emb = TextualMemoryItem( **self.graph_db.get_node(primary_id, include_embedding=True) ).metadata.embedding - arch_meta = archived_item.metadata.model_dump(exclude_none=True) + arch_meta = _sanitize_metadata_dict(archived_item.metadata.model_dump(exclude_none=True)) arch_meta["embedding"] = emb self.graph_db.add_node( id=archived_item.id, @@ -688,7 +706,7 @@ def _update_existing_memory( user_name=user_name, ) - fields = current_item.metadata.model_dump(exclude_none=True) + fields = _sanitize_metadata_dict(current_item.metadata.model_dump(exclude_none=True)) merged_history = list(current_item.metadata.history or []) new_primary_version = current_item.metadata.version or 1 # Multiple related ids indicates existing duplicates/conflicts to be merged @@ -870,6 +888,7 @@ def _create_new_memory( metadata = fast_item.metadata.model_copy(deep=True) for field_name, value in metadata_updates.items(): setattr(metadata, field_name, value) + metadata = _sanitize_metadata_model(metadata) new_item = TextualMemoryItem( id=str(uuid.uuid4()), @@ -923,6 +942,7 @@ def _handle_restored_memories( metadata = fast_item.metadata.model_copy(deep=True) for field_name, value in metadata_updates.items(): setattr(metadata, field_name, value) + metadata = _sanitize_metadata_model(metadata) new_item = TextualMemoryItem( id=str(uuid.uuid4()), From 3f17f677518b9541d40969eaec246e6c1529204f Mon Sep 17 00:00:00 2001 From: bittergreen Date: Thu, 12 Feb 2026 12:17:19 +0800 Subject: [PATCH 38/92] test: fix mem versions test --- .../memories/textual/test_history_manager.py | 83 +++++++++++++++---- 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index f0ea8458f..11b5d7f44 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -227,6 +227,8 @@ def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db "created_at": "2023-01-01", "tags": ["old"], "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", }, } mock_graph_db.get_node.return_value = existing_node @@ -324,7 +326,7 @@ def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): assert len(new_items) == 1 assert new_items[0] == restored_item history_manager._handle_restored_memories.assert_called_once_with( - llm_response["restored_memories"] + llm_response["restored_memories"], source_item ) mock_graph_db.add_node.assert_not_called() @@ -376,14 +378,18 @@ def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph # Setup existing node (primary) primary_id = uuid.uuid4().hex secondary_id = uuid.uuid4().hex - existing_node = {"id": primary_id, "memory": "Old Content", "metadata": {"version": 1}} + existing_node = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, + } mock_graph_db.get_node.return_value = existing_node mock_graph_db.get_nodes.return_value = [ existing_node, { "id": secondary_id, "memory": "Secondary", - "metadata": {"version": 1}, + "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, }, ] @@ -565,10 +571,14 @@ def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): mock_graph_db.get_node.return_value = { "id": existing_id, "memory": "Old Content", - "metadata": {"version": 2}, + "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, } mock_graph_db.get_nodes.return_value = [ - {"id": existing_id, "memory": "Old Content", "metadata": {"version": 2}} + { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, + } ] mem_data = { @@ -579,12 +589,20 @@ def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): "conflicted_candidate_ids": [], } - updated = manager._update_existing_memory( - mem_data, [existing_id], [existing_id], {existing_id: 1}, user_name="u1" + updated, new_item = manager._update_existing_memory( + mem_data, + [existing_id], + [existing_id], + {existing_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), ) assert updated.memory == "Merged Content" assert updated.metadata.version == 3 + assert new_item is None mock_graph_db.update_node.assert_called_once() @@ -592,11 +610,16 @@ def test_update_existing_memory_marks_working_binding_deleted(history_manager, m history_manager.mark_memory_status = MagicMock() primary_id = uuid.uuid4().hex working_binding = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, + } mock_graph_db.get_nodes.return_value = [ { "id": primary_id, "memory": "Old Content", - "metadata": {"version": 1, "working_binding": working_binding}, + "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, } ] mem_data = { @@ -607,11 +630,19 @@ def test_update_existing_memory_marks_working_binding_deleted(history_manager, m "conflicted_candidate_ids": [], } - updated = history_manager._update_existing_memory( - mem_data, [primary_id], [primary_id], {primary_id: 1}, user_name="u1" + updated, new_item = history_manager._update_existing_memory( + mem_data, + [primary_id], + [primary_id], + {primary_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), ) assert updated is not None + assert new_item is None history_manager.mark_memory_status.assert_called_once_with( [str(working_binding)], "deleted", user_name="u1" ) @@ -622,11 +653,16 @@ def test_update_existing_memory_no_mark_when_working_binding_matches( ): history_manager.mark_memory_status = MagicMock() primary_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, + } mock_graph_db.get_nodes.return_value = [ { "id": primary_id, "memory": "Old Content", - "metadata": {"version": 1, "working_binding": primary_id}, + "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, } ] mem_data = { @@ -637,11 +673,19 @@ def test_update_existing_memory_no_mark_when_working_binding_matches( "conflicted_candidate_ids": [], } - updated = history_manager._update_existing_memory( - mem_data, [primary_id], [primary_id], {primary_id: 1}, user_name="u1" + updated, new_item = history_manager._update_existing_memory( + mem_data, + [primary_id], + [primary_id], + {primary_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), ) assert updated is not None + assert new_item is None history_manager.mark_memory_status.assert_not_called() @@ -650,9 +694,20 @@ def test_update_existing_memory_node_missing(history_manager, mock_graph_db): mock_graph_db.get_nodes.return_value = [] mem_data = {"value": "v", "tags": [], "key": "k"} - updated = history_manager._update_existing_memory(mem_data, ["missing"], [], {}, user_name="u1") + updated, new_item = history_manager._update_existing_memory( + mem_data, + ["missing"], + [], + {}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) assert updated is None + assert new_item is not None + assert new_item.memory == "v" mock_graph_db.update_node.assert_not_called() From 280373cb80ada3db3039c1277c675e490f3e7dec Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 25 Feb 2026 10:51:51 +0800 Subject: [PATCH 39/92] feat: Better prompt for mem version reader --- .../mem_reader_mem_version_prompts.py | 172 +++++++++++------- 1 file changed, 105 insertions(+), 67 deletions(-) diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index bb7e4acbf..27cb0aba3 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -2,40 +2,42 @@ # Memory Update & Maintenance # ========================================== ASYNC_MEMORY_UPDATE_PROMPT_ZH = """您是记忆库维护专家。 -您的核心任务是根据最新的用户对话,维护和更新用户的长期记忆图谱。您需要利用系统提供的“候选记忆”(Candidates)来决定是更新现有记忆节点,还是创建全新的记忆节点。 +您的核心任务是根据最新的用户对话、对话时间,以及系统提供的可能与最新对话相关的“候选记忆”(Candidates),来维护和更新用户的长期记忆图谱。 -系统检测到以下“候选记忆”与用户的最新输入存在潜在关联: +具体而言,“候选记忆”包含以下三种情况: +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)** +2. **潜在事实冲突记忆 (Conflict Candidates)** +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)** -1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: -${duplicate_candidates} - -2. **潜在事实冲突记忆 (Conflict Candidates)**: -${conflict_candidates} +您需要根据最新对话以及候选记忆,决定是更新现有记忆节点,还是创建全新的记忆节点。 -3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: -${unrelated_candidates} - -请执行以下操作: - -1. **核心原则(STRICT)**: +**核心原则(STRICT)**: - 您的目标是**维护**记忆库,而非仅仅提取信息。 - **优先更新**:如果对话内容涉及现有的“候选记忆”,应优先视为对该记忆节点的**更新**(补充细节或修正状态),而不是创建重复的新节点。 - **按需新增**:仅当对话内容包含全新的、与现有“候选记忆”完全无关的话题时,才创建新的记忆节点。 - 提取来源**只能**是【当前的对话内容】。严禁编造未提及的信息。 -2. 识别反映用户经历、信念、关切、决策、计划或反应的信息。 +**表达规范(STRICT)**: + - 冲突更新时,`value` 必须是“只含最新事实”的独立陈述,不允许提及旧值或变化过程(如“原名/之前/曾经/从X到Y/改成/不再使用原名/现在自称”)。 + - 若最新状态本身为否定事实,可直接用否定表达,但仍不得包含旧值或对比语。 + - 对于姓名、身份、归属、偏好等字段的更新,始终输出最新值的肯定式表述(例:旧记忆“用户叫王强”,新对话“我叫李白”,输出应为“用户叫李白”)。 + - 涉及第三方人物/实体的客观信息必须使用 `LongTermMemory`,且主体保持为该第三方(如“王强住上海”)。 + +请执行以下操作: +1. 识别反映用户经历、信念、关切、决策、计划或反应的信息。 - 如果消息来自用户,提取用户相关的记忆。 - 如果来自助手,仅提取用户认可或回应的事实性记忆。 -3. 清晰解析所有时间、人物和事件的指代(同原规则): +2. 清晰解析所有时间、人物和事件的指代(同原规则): - 将相对时间(“昨天”)转换为绝对日期。 - 明确区分事件时间和消息时间。 - 解析代词和模糊指代。 - - 使用第三人称(“用户”)。 + - 仅当指代为“我/我们/本人”等用户第一人称时才替换为“用户”。 + - 其他第三方人名/实体必须保留原名,不得替换为“用户”。 - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 - 候选记忆可能包含 [Time: ...] 表示该记忆的事件时间,请结合“对话时间”判断是否同一时段。 -4. **处理逻辑(更新与新增)**: +3. **处理逻辑(更新与新增)**: 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: a) **更新现有记忆节点 (Update via Duplicate/Related)**: @@ -59,6 +61,10 @@ 5. **挽救剩余信息 (Restored Memories)**: - 当您决定更新一个冲突节点(情况 b)时,请检查该旧节点是否包含**未被新信息否定且依然独立的有效事实**。 - 如果有,请将这些“剩余有效信息”提取出来,作为 `restored_memories` 返回。这能防止因节点更新导致非冲突信息的丢失。 + - 若旧节点包含多个事实(如“人物状态 + 电话号码”),而新对话仅更新其中一部分,则必须将未被否定的事实拆分为独立条目放入 `restored_memories`,并在冲突更新的 `value` 中仅保留被更新的那一部分。 + - 对于可长期独立存在的属性(如电话号码、出生地、所属组织),优先拆分为独立事实,避免与可变状态混写在同一条记忆中。 + +6. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 返回一个有效的JSON对象,结构如下: @@ -95,18 +101,21 @@ ${custom_tags_prompt} 示例: -**潜在重复/关联记忆 (Duplicate/Related Candidates):** +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: [ID:101][Time: 2025/05/20 09:30:00] 用户喜欢喝拿铁,通常不加糖。 -[ID:102][Time: 2025/05/18 18:00:00] 用户讨厌下雨天。 +[ID:102][Time: 2025/06/02 18:00:00] 用户讨厌下雨天。 -**潜在事实冲突记忆 (Conflict Candidates):** +2. **潜在事实冲突记忆 (Conflict Candidates)**: [ID:201][Time: 2025/02/03 20:15:00] 用户喜欢打羽毛球,但不喜欢滑雪。 -**对话时间:** +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: +[ID:301][Time: 2025/06/20 10:00:00] 用户最近在看《星球大战》。 + +**对话时间**: 2025/06/26 09:00:00 -**对话:** -user: 我还是最喜欢喝拿铁,尤其是加燕麦奶的。另外,我膝盖受伤了,以后再也不打羽毛球了。最近我刚养了一只猫。 +**对话**: +user: 最近下雨比较频繁。我经常去喝点拿铁,尤其是加燕麦奶的很好喝。另外,我最近膝盖受伤了,以后再也不打羽毛球了。 **输出:** { @@ -114,7 +123,7 @@ { "key": "咖啡偏好", "memory_type": "UserMemory", - "value": "用户最喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", + "value": "用户喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", "tags": ["饮食", "咖啡", "喜好"], "source_candidate_ids": ["101"], "conflicted_candidate_ids": [] @@ -128,10 +137,10 @@ "conflicted_candidate_ids": ["201"] }, { - "key": "宠物状况", - "memory_type": "UserMemory", - "value": "用户最近(2025年6月)刚养了一只猫。", - "tags": ["生活", "宠物", "猫"], + "key": "天气状况", + "memory_type": "LongTermMemory", + "value": "最近(2025年6月)用户所在的地方下雨比较频繁。", + "tags": ["生活", "天气", "降水"], "source_candidate_ids": [], "conflicted_candidate_ids": [] } @@ -145,76 +154,92 @@ "source_candidate_id": "201" } ], - "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了用户养猫的记忆。对天气的喜好未受影响。" + "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了近期下雨频繁的记忆。对天气的喜好未受影响。" } -请始终使用与对话相同的语言进行回复。 +请始终使用与对话相同的语言进行回复。以下是最新的输入: + +1. **潜在重复/关联记忆 (Duplicate/Related Candidates)**: +${duplicate_candidates} + +2. **潜在事实冲突记忆 (Conflict Candidates)**: +${conflict_candidates} + +3. **可能无关,但需要进一步判断的记忆 (Unrelated Candidates)**: +${unrelated_candidates} -对话时间: +**对话时间**: ${conversation_time} -对话: +**对话**: ${conversation} -""" + +**输出:**""" ASYNC_MEMORY_UPDATE_PROMPT_EN = """You are a memory maintenance expert. -Your core task is to maintain and update the user's long-term memory graph based on the latest conversation. Use the system-provided "Candidates" to decide whether to update existing memory nodes or create brand-new ones. +Your core task is to maintain and update the user's long-term memory graph based on the latest user conversation, the conversation time, and the system-provided "Candidates" that may be related to the latest conversation. -The system detected the following "Candidates" potentially related to the user's latest input: +Specifically, "Candidates" include three cases: +1. **Duplicate/Related Candidates** +2. **Conflict Candidates** +3. **Possibly unrelated candidates that require further judgment (Unrelated Candidates)** -1. Duplicate/Related Candidates: -${duplicate_candidates} +You need to decide, based on the latest conversation and the candidates, whether to update existing memory nodes or create brand-new memory nodes. -2. Conflict Candidates: -${conflict_candidates} +**Core Principles (STRICT)**: + - Your goal is to **maintain** the memory base, not merely extract information. + - **Prefer Update**: If the conversation touches any existing "Candidates", treat it as an **update** to that memory node (add details or correct status), rather than creating a duplicate new node. + - **Add As Needed**: Only create a new node when the conversation contains truly new topics that are completely unrelated to existing "Candidates". + - The extraction source must be ONLY the **current conversation**. Do not fabricate information not mentioned. -3. Possibly unrelated (needs further judgment): -${unrelated_candidates} +**Expression Rules (STRICT)**: + - For conflict updates, `value` must be a standalone statement of the latest fact only, without mentioning old values or change history (e.g., "formerly/previously/used to/changed from X to Y/no longer used the old name/now goes by"). + - If the latest state is inherently negative, express the negation directly but still avoid old values or comparisons. + - For updates to name/identity/affiliation/preference fields, always output a positive statement of the latest value (e.g., old memory "User's name is Wang Qiang", new conversation "My name is Li Bai" → output "The user's name is Li Bai"). + - Objective facts about third-party people/entities must use `LongTermMemory`, and the subject must remain that third party (e.g., "Wang Qiang lives in Shanghai"). Please execute the following: - -1. Core Principles (STRICT): - - Your goal is maintenance of the memory base, not mere extraction. - - Prefer Update: If the conversation touches any existing "Candidates", treat it as an update to that memory node (add details or correct status) rather than creating a duplicate node. - - Add As Needed: Only create a new node when the conversation contains truly new topics that are completely unrelated to existing "Candidates". - - The extraction source must be ONLY the current conversation. Do not fabricate information not mentioned. - -2. Identify information that reflects the user's experiences, beliefs, concerns, decisions, plans, or reactions. +1. Identify information that reflects the user's experiences, beliefs, concerns, decisions, plans, or reactions. - If the message is from the user, extract user-related memories. - If it is from the assistant, only extract factual memories that the user explicitly acknowledges or responds to. -3. Disambiguate all references to time, people, and events: +2. Disambiguate all references to time, people, and events (same rules as before): - Convert relative time ("yesterday") to an absolute date. - Clearly distinguish event time from message time. - Resolve pronouns and ambiguous references. - - Use third person ("the user"). + - Replace only first-person references ("I/we/me") with "the user". + - Keep third-party names/entities unchanged; do not replace them with "the user". - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). - Candidates may include [Time: ...] to indicate event time; use the conversation time to judge whether they are the same period. -4. Processing Logic (Update and Create): +3. **Processing Logic (Update and Create)**: Traverse each piece of information in the conversation that is worth remembering and apply: - a) Update existing memory node (Update via Duplicate/Related): + a) **Update existing memory node (Update via Duplicate/Related)**: - Check Duplicate/Related Candidates. - If the new information repeats, confirms, or adds details to an old memory: - - Generate an updated complete memory into `value` (old info + new details). + - Generate an **updated complete memory** into `value` (old info + new details). - Put the old memory IDs into `source_candidate_ids`. - `conflicted_candidate_ids` must be []. - b) Fix conflicting memory node (Update via Conflict): + b) **Fix conflicting memory node (Update via Conflict)**: - Check Conflict Candidates. - If the new information negates an old memory or updates its state (e.g., "no longer likes X", "changed to Y", "canceled plan", "from X to Y"): - - Generate a memory reflecting the latest state into `value`. + - Generate a memory reflecting the **latest state** into `value`. - Put the corrected old memory IDs into `conflicted_candidate_ids`. - c) Create new memory node (Create New): + c) **Create new memory node (Create New)**: - If the new information is not directly related to any "Candidates" (neither duplicate nor conflict): - Generate an independent new memory into `value`. - Ensure `source_candidate_ids` and `conflicted_candidate_ids` are both `[]`. -5. Restored Memories: +5. **Restored Memories**: - When you decide to update a conflict node (case b), check whether the old node contains valid facts that were NOT negated by the new information and remain independently valid. - If any, extract these "remaining valid facts" and return them as `restored_memories` to prevent loss of non-conflict information due to node updates. + - If the old node contains multiple facts (e.g., "status + phone number") and the new conversation only updates one part, you must split the unaffected facts into separate entries in `restored_memories`, and keep only the updated part in the conflict `value`. + - For long-lived independent attributes (e.g., phone number, birthplace, affiliation), prefer splitting them into standalone facts instead of mixing them with mutable states. + +6. Avoid including any memories that violate laws or involve politically sensitive information. Return a valid JSON object with the structure: @@ -251,20 +276,23 @@ ${custom_tags_prompt} Example: -Duplicate/Related Candidates: +1. **Duplicate/Related Candidates**: [ID:101][Time: 2025/05/20 09:30:00] The user likes latte and usually doesn't add sugar. [ID:102][Time: 2025/05/18 18:00:00] The user hates rainy days. -Conflict Candidates: +2. **Conflict Candidates**: [ID:201][Time: 2025/02/03 20:15:00] The user likes badminton but dislikes skiing. -Conversation time: +3. **Possibly unrelated candidates that require further judgment (Unrelated Candidates)**: +[ID:301][Time: 2025/06/20 10:00:00] The user recently watched Star Wars. + +**Conversation time**: 2025/06/26 09:00:00 -Conversation: +**Conversation**: user: I still like latte the most, especially with oat milk. Also, my knee is injured, so I'll never play badminton again. Recently I adopted a cat. -Output: +**Output:** { "memory list": [ { @@ -304,14 +332,24 @@ "summary": "In this update, the user refined their latte preference (adds oat milk), and due to a knee injury updated their sports habit (stop playing badminton). The system preserves the record that the user dislikes skiing. Additionally, a new memory is added that the user adopted a cat. Weather preference is unaffected." } -Always reply in the same language as the conversation. +Always reply in the same language as the conversation. The latest input is below: -Conversation time: +1. **Duplicate/Related Candidates**: +${duplicate_candidates} + +2. **Conflict Candidates**: +${conflict_candidates} + +3. **Possibly unrelated candidates that require further judgment (Unrelated Candidates)**: +${unrelated_candidates} + +**Conversation time**: ${conversation_time} -Conversation: +**Conversation**: ${conversation} -""" + +**Output:**""" ASYNC_MEMORY_UPDATE_PROMPT_DICT = { "zh": ASYNC_MEMORY_UPDATE_PROMPT_ZH, From 84a27497c762f3d03171318c768989c2880cee7b Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 25 Feb 2026 10:53:03 +0800 Subject: [PATCH 40/92] feat: Temporarily switch to qwen llm for mem versions --- src/memos/api/config.py | 21 +++++++++++++++++++++ src/memos/configs/mem_reader.py | 3 +++ src/memos/mem_reader/multi_modal_struct.py | 2 +- src/memos/mem_reader/simple_struct.py | 10 ++++++++++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 25496ce54..23e1e97c8 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -338,6 +338,26 @@ def get_memreader_config() -> dict[str, Any]: }, } + @staticmethod + def get_qwen_llm_config() -> dict[str, Any] | None: + if not os.getenv("QWEN_API_KEY"): + return None + return { + "backend": "qwen", + "config": { + "model_name_or_path": os.getenv("QWEN_MODEL", "qwen-flash"), + "temperature": float(os.getenv("QWEN_TEMPERATURE", "0.8")), + "max_tokens": int(os.getenv("QWEN_MAX_TOKENS", "8000")), + "top_p": float(os.getenv("QWEN_TOP_P", "0.9")), + "top_k": int(os.getenv("QWEN_TOP_K", "50")), + "remove_think_prefix": os.getenv("QWEN_REMOVE_THINK_PREFIX", "true").lower() + == "true", + "api_key": os.getenv("QWEN_API_KEY", ""), + "api_base": os.getenv("QWEN_API_BASE", ""), + "model_schema": os.getenv("QWEN_MODEL_SCHEMA", "memos.configs.llm.QwenLLMConfig"), + }, + } + @staticmethod def get_activation_vllm_config() -> dict[str, Any]: """Get Ollama configuration.""" @@ -802,6 +822,7 @@ def get_product_default_config() -> dict[str, Any]: "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), + "qwen_llm": APIConfig.get_qwen_llm_config(), "embedder": APIConfig.get_embedder_config(), "chunker": { "backend": "sentence", diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 98aff981f..63d0982db 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -70,6 +70,9 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): description="Turn on memory version or off", ) + # Allow passing additional fields without raising validation errors + model_config = ConfigDict(extra="allow", strict=True) + class StrategyStructMemReaderConfig(BaseMemReaderConfig): """StrategyStruct MemReader configuration class.""" diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index be543c47f..4a4ab5534 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1029,7 +1029,7 @@ def _process_async_versioning_update( prompt = self.history_manager.format_async_update_prompt(item, custom_tags_prompt) user_name = kwargs.get("user_name") try: - response_text = self.llm.generate([{"role": "user", "content": prompt}]) + response_text = self.qwen_llm.generate([{"role": "user", "content": prompt}]) if not response_text: raise ValueError("Empty LLM response") response_json = parse_json_result(response_text) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 237653d4a..180fd6ee8 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -11,6 +11,7 @@ from memos import log from memos.chunkers import ChunkerFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory @@ -178,6 +179,15 @@ def __init__(self, config: SimpleStructMemReaderConfig): """ self.config = config self.llm = LLMFactory.from_config(config.llm) + self.qwen_llm = None + qwen_llm_config = getattr(config, "qwen_llm", None) + if qwen_llm_config: + try: + if isinstance(qwen_llm_config, dict): + qwen_llm_config = LLMConfigFactory.model_validate(qwen_llm_config) + self.qwen_llm = LLMFactory.from_config(qwen_llm_config) + except Exception as e: + logger.warning(f"[LLM] Qwen initialization failed: {e}") self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) self.save_rawfile = self.chunker.config.save_rawfile From bfb110daed5edc27a939390cedab9b944011617c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 25 Feb 2026 11:20:36 +0800 Subject: [PATCH 41/92] feat: Refine prompts for mem versions --- .../templates/mem_reader_mem_version_prompts.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index 27cb0aba3..f330bac03 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -37,7 +37,12 @@ - 状态变化/否定表达必须被视为冲突更新(如“不再/不喜欢/取消/改为/不打算/否认”)。 - 候选记忆可能包含 [Time: ...] 表示该记忆的事件时间,请结合“对话时间”判断是否同一时段。 -3. **处理逻辑(更新与新增)**: +3. 不要遗漏用户可能记住的任何信息。 + - 包括所有关键经历、想法、情绪反应和计划——即使看似微小。 + - 优先考虑完整性和保真度,而非简洁性。 + - 不要泛化或跳过对用户具有个人意义的细节。 + +4. **处理逻辑(更新与新增)**: 请遍历对话中每一个值得记忆的信息点,并按以下逻辑处理: a) **更新现有记忆节点 (Update via Duplicate/Related)**: @@ -212,7 +217,12 @@ - State changes/negations must be treated as conflict updates (e.g., "no longer/doesn't like/canceled/changed to/doesn't plan/denies"). - Candidates may include [Time: ...] to indicate event time; use the conversation time to judge whether they are the same period. -3. **Processing Logic (Update and Create)**: +3. Do not omit any information the user might want to remember. + - Include all key experiences, thoughts, emotional reactions, and plans — even if they seem minor. + - Prioritize completeness and fidelity over brevity. + - Do not generalize or skip details that are personally meaningful to the user. + +4. **Processing Logic (Update and Create)**: Traverse each piece of information in the conversation that is worth remembering and apply: a) **Update existing memory node (Update via Duplicate/Related)**: From 1e4add38a4536768cc0e963da385201ceeb5ee1c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 25 Feb 2026 17:00:22 +0800 Subject: [PATCH 42/92] fix: Pass user_name to polardb add_node function to make it compatible with the new changes. --- src/memos/mem_reader/multi_modal_struct.py | 4 ++-- .../organize/history_manager.py | 23 +++++++++++-------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 4a4ab5534..73c53dc74 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1019,7 +1019,8 @@ def _process_async_versioning_update( 3. Call LLM and parse JSON response 4. Apply LLM updates to memory graph and return new items """ - self.history_manager.wait_and_update_fast_history(item, timeout_sec=30) + user_name = kwargs.get("user_name") + self.history_manager.wait_and_update_fast_history(item, user_name, timeout_sec=30) lang = detect_lang(kwargs.get("chat_history") or mem_str) custom_tags_prompt = ( PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) @@ -1027,7 +1028,6 @@ def _process_async_versioning_update( else "" ) prompt = self.history_manager.format_async_update_prompt(item, custom_tags_prompt) - user_name = kwargs.get("user_name") try: response_text = self.qwen_llm.generate([{"role": "user", "content": prompt}]) if not response_text: diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 306ce3efe..3e62ce56e 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -290,7 +290,9 @@ def resolve_history_via_nli( return duplicate_memory_ids + conflict_memory_ids - def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int = 30) -> None: + def wait_and_update_fast_history( + self, item: TextualMemoryItem, user_name: str, timeout_sec: int = 30 + ) -> None: """ Scan the item's history. If any history item is marked as `is_fast`, wait for it to be resolved (i.e., status becomes 'deleted' in the DB). @@ -299,6 +301,7 @@ def wait_and_update_fast_history(self, item: TextualMemoryItem, timeout_sec: int Args: item: The memory item containing the history to check. + user_name: Required for db query. timeout_sec: Maximum time to wait for resolution in seconds. """ start_time = time.time() @@ -467,7 +470,7 @@ def apply_llm_memory_updates( new_items.extend(created_items) # 3. Handle Restored Memories (Extract from conflict) - new_items.extend(self._handle_restored_memories(restored_memories, source_item)) + new_items.extend(self._handle_restored_memories(restored_memories, source_item, user_name)) return updated_items, new_items @@ -528,7 +531,7 @@ def mark_memory_status( return def _check_and_fetch_replacements( - self, item: TextualMemoryItem, pending_indices: list[int] + self, item: TextualMemoryItem, pending_indices: list[int], user_name: str ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: """ Check DB status for pending items. If 'deleted', fetch evolved nodes. @@ -539,7 +542,7 @@ def _check_and_fetch_replacements( pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] # Batch fetch pending nodes to check status - nodes_data = self.graph_db.get_nodes(ids=pending_ids) or [] + nodes_data = self.graph_db.get_nodes(ids=pending_ids, user_name=user_name) or [] nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} replacements = {} @@ -558,19 +561,19 @@ def _check_and_fetch_replacements( if status == "deleted": evolve_to_ids = metadata.get("evolve_to", []) - new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type) + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type, user_name) replacements[i] = new_items return replacements def _fetch_evolved_nodes( - self, evolve_to_ids: list[str], update_type: str + self, evolve_to_ids: list[str], update_type: str, user_name: str ) -> list[ArchivedTextualMemory]: """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" if not evolve_to_ids: return [] - evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids) or [] + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids, user_name=user_name) or [] results = [] for enode in evolved_nodes: @@ -668,7 +671,7 @@ def _update_existing_memory( # Fetch candidate nodes in batch and then select the primary # We update the primary and then merge the secondaries to the primary - nodes_data = self.graph_db.get_nodes(target_ids) or [] + nodes_data = self.graph_db.get_nodes(target_ids, user_name=user_name) or [] nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} node_data = nodes_map.get(primary_id) if not node_data: @@ -898,11 +901,11 @@ def _create_new_memory( return new_item def _handle_restored_memories( - self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem + self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem, user_name: str ) -> list[TextualMemoryItem]: """Handle Restored Memories (Extract from conflict).""" source_ids = [r.get("source_candidate_id") for r in restored_memories] - source_items = self.graph_db.get_nodes(source_ids) + source_items = self.graph_db.get_nodes(source_ids, user_name=user_name) source_items = [TextualMemoryItem(**i) for i in source_items] created_items = [] From 7fb018212deb3149814c4197a308cc1d50521bad Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 25 Feb 2026 20:37:23 +0800 Subject: [PATCH 43/92] feat: make nli client more robust --- src/memos/extras/nli_model/client.py | 78 ++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/src/memos/extras/nli_model/client.py b/src/memos/extras/nli_model/client.py index a02dae9f6..c97746baf 100644 --- a/src/memos/extras/nli_model/client.py +++ b/src/memos/extras/nli_model/client.py @@ -1,4 +1,5 @@ import logging +import time import requests @@ -13,9 +14,18 @@ class NLIClient: Client for interacting with the deployed NLI model service. """ - def __init__(self, base_url: str = "http://localhost:32532"): + def __init__( + self, + base_url: str = "http://localhost:32532", + timeout: float = 30.0, + max_retries: int = 3, + backoff_seconds: float = 0.5, + ): self.base_url = base_url.rstrip("/") self.session = requests.Session() + self.timeout = timeout + self.max_retries = max_retries + self.backoff_seconds = backoff_seconds def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: """ @@ -35,27 +45,51 @@ def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult # Match schemas.CompareRequest payload = {"source": source, "targets": targets} - try: - response = self.session.post(url, json=payload, timeout=30) - response.raise_for_status() - data = response.json() - - # Match schemas.CompareResponse - results_str = data.get("results", []) - - results = [] - for res_str in results_str: - try: - results.append(NLIResult(res_str)) - except ValueError: + last_error: Exception | None = None + for attempt in range(1, self.max_retries + 1): + try: + response = self.session.post(url, json=payload, timeout=self.timeout) + response.raise_for_status() + data = response.json() + + results_str = data.get("results", []) + + results = [] + for res_str in results_str: + try: + results.append(NLIResult(res_str)) + except ValueError: + logger.warning( + f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" + ) + results.append(NLIResult.UNRELATED) + + return results + except requests.RequestException as e: + last_error = e + if attempt < self.max_retries: logger.warning( - f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" + "[NLIClient] Request failed (attempt %s/%s) url=%s targets=%s error=%s", + attempt, + self.max_retries, + url, + len(targets), + e, + ) + time.sleep(self.backoff_seconds * (2 ** (attempt - 1))) + else: + logger.error( + "[NLIClient] Request failed after %s attempts url=%s targets=%s error=%s", + self.max_retries, + url, + len(targets), + e, ) - results.append(NLIResult.UNRELATED) - - return results - except requests.RequestException as e: - logger.error(f"[NLIClient] Request failed: {e}") - # Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow. - return [NLIResult.UNRELATED] * len(targets) + logger.error( + "[NLIClient] NLI service unavailable or unstable. Please check that it is running at %s", + self.base_url, + ) + if last_error: + logger.error("[NLIClient] Last error: %s", last_error) + return [NLIResult.UNRELATED] * len(targets) From 6d1a9268c94cd5eeee104c17dd28c6b2f4522505 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Thu, 26 Feb 2026 14:27:46 +0800 Subject: [PATCH 44/92] feat: make nli server more robust. --- src/memos/extras/nli_model/server/config.py | 2 ++ src/memos/extras/nli_model/server/serve.py | 27 +++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/memos/extras/nli_model/server/config.py b/src/memos/extras/nli_model/server/config.py index d2e12175d..b5744bd26 100644 --- a/src/memos/extras/nli_model/server/config.py +++ b/src/memos/extras/nli_model/server/config.py @@ -13,6 +13,8 @@ NLI_DEVICE = "cuda" NLI_MODEL_HOST = "0.0.0.0" NLI_MODEL_PORT = 32532 +NLI_MAX_CONCURRENCY = 4 +NLI_INFER_TIMEOUT_SECONDS = 30.0 # Configure logging for NLI Server logging.basicConfig( diff --git a/src/memos/extras/nli_model/server/serve.py b/src/memos/extras/nli_model/server/serve.py index 0ed9eae65..f02d25670 100644 --- a/src/memos/extras/nli_model/server/serve.py +++ b/src/memos/extras/nli_model/server/serve.py @@ -1,25 +1,36 @@ +import asyncio + from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI, HTTPException -from memos.extras.nli_model.server.config import NLI_DEVICE, NLI_MODEL_HOST, NLI_MODEL_PORT +from memos.extras.nli_model.server.config import ( + NLI_DEVICE, + NLI_INFER_TIMEOUT_SECONDS, + NLI_MAX_CONCURRENCY, + NLI_MODEL_HOST, + NLI_MODEL_PORT, +) from memos.extras.nli_model.server.handler import NLIHandler from memos.extras.nli_model.types import CompareRequest, CompareResponse # Global handler instance nli_handler: NLIHandler | None = None +nli_semaphore: asyncio.Semaphore | None = None @asynccontextmanager async def lifespan(app: FastAPI): - global nli_handler + global nli_handler, nli_semaphore nli_handler = NLIHandler(device=NLI_DEVICE) + nli_semaphore = asyncio.Semaphore(NLI_MAX_CONCURRENCY) yield # Clean up if needed nli_handler = None + nli_semaphore = None app = FastAPI(lifespan=lifespan) @@ -27,11 +38,17 @@ async def lifespan(app: FastAPI): @app.post("/compare_one_to_many", response_model=CompareResponse) async def compare_one_to_many(request: CompareRequest): - if nli_handler is None: + if nli_handler is None or nli_semaphore is None: raise HTTPException(status_code=503, detail="Model not loaded") try: - results = nli_handler.compare_one_to_many(request.source, request.targets) - return CompareResponse(results=results) + async with nli_semaphore: + results = await asyncio.wait_for( + asyncio.to_thread(nli_handler.compare_one_to_many, request.source, request.targets), + timeout=NLI_INFER_TIMEOUT_SECONDS, + ) + return CompareResponse(results=results) + except asyncio.TimeoutError as e: + raise HTTPException(status_code=504, detail="NLI inference timed out") from e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e From d0cc255e31d55e3aa642e0f00a1637220e37dc48 Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Thu, 26 Feb 2026 18:45:08 +0800 Subject: [PATCH 45/92] feat: add profile mock API and multi-view memory support - Add 3 new profile (attribute tree) mock endpoints: CreateProfileTemplate, BindProfile, EditProfile with in-memory mock storage (ProfileMockStore) - Add profile_models.py and profile_handler.py for mock phase - Add profile_detail_list / event_detail_list to SearchMemory response - Add include_memory_view to APISearchRequest with legacy bool field compatibility (search_tool_memory, include_preference, etc.) - Add allow_memory_view dispatch to single_cube add_memories - Add role_id / role_name to MessageDict and all ChatCompletion*Param TypedDicts for multi-speaker (group-chat) support - Revert user_id to backward-compatible single str; multi-user routing handled via writable_cube_ids dict (cube_type + user_or_agent_id) - Fix IdProfileMapping.profile_config_id to default to "default_config" Made-with: Cursor --- src/memos/api/handlers/add_handler.py | 11 +- src/memos/api/handlers/profile_handler.py | 329 ++++++++++++++++++ src/memos/api/handlers/search_handler.py | 18 + src/memos/api/product_models.py | 140 +++++++- src/memos/api/profile_models.py | 162 +++++++++ src/memos/api/routers/server_router.py | 47 +++ src/memos/multi_mem_cube/single_cube.py | 33 +- src/memos/types/general_types.py | 11 +- ...chat_completion_assistant_message_param.py | 9 +- .../chat_completion_system_message_param.py | 9 +- .../chat_completion_user_message_param.py | 10 +- 11 files changed, 762 insertions(+), 17 deletions(-) create mode 100644 src/memos/api/handlers/profile_handler.py create mode 100644 src/memos/api/profile_models.py diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..d8d00aea0 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -118,15 +118,22 @@ def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: Normalize target cube ids from add_req. Priority: 1) writable_cube_ids (deprecated mem_cube_id is converted to this in model validator) - 2) fallback to user_id + 2) user_id (single str after validator normalization) """ if add_req.writable_cube_ids: + if isinstance(add_req.writable_cube_ids, dict): + return list(dict.fromkeys(add_req.writable_cube_ids.keys())) return list(dict.fromkeys(add_req.writable_cube_ids)) - return [add_req.user_id] + if add_req.user_id: + return [add_req.user_id] + + return [] def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: cube_ids = self._resolve_cube_ids(add_req) + if not cube_ids: + raise ValueError("No writable cube id could be resolved from add request.") if len(cube_ids) == 1: cube_id = cube_ids[0] diff --git a/src/memos/api/handlers/profile_handler.py b/src/memos/api/handlers/profile_handler.py new file mode 100644 index 000000000..83e6f3851 --- /dev/null +++ b/src/memos/api/handlers/profile_handler.py @@ -0,0 +1,329 @@ +""" +Profile (Attribute Tree) Mock Handler. + +Fully synchronous, in-memory mock implementation. +No PolarDB, no scheduler, no embedding — just dict storage. + +This handler is standalone and does NOT depend on HandlerDependencies. +""" + +from __future__ import annotations + +import copy +import uuid +from typing import Any + +from fastapi import HTTPException + +from memos.api.profile_models import ( + DEFAULT_PROFILE_CONFIG_ID, + BindProfileRequest, + BindProfileResponse, + CreateProfileTemplateRequest, + CreateProfileTemplateResponse, + EditProfileRequest, + EditProfileResponse, +) +from memos.log import get_logger + +logger = get_logger(__name__) + +DEFAULT_PROFILE_TEMPLATE: dict[str, dict[str, dict[str, Any]]] = { + "客观档案": { + "姓名": {"内容": "", "证据": ""}, + "性别": {"内容": "", "证据": ""}, + "年龄": {"内容": "", "证据": ""}, + "当前关系状态": {"内容": "单身", "证据": ""}, + }, + "人格状态": { + "温柔程度": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, + "自尊表现": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, + "依附感": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, + "强势感": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, + "退让程度": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, + }, +} + + +class ProfileMockStore: + """ + In-memory storage for profile templates, bindings, and instances. + + Data is lost on server restart — this is intentional for a mock. + + Storage layout: + templates: {profile_config_id: {metadata_dict}} + bindings: {entity_id: profile_config_id} + instances: {profile_instance_id: {merged_profile_data}} + """ + + def __init__(self) -> None: + self.templates: dict[str, dict[str, Any]] = { + DEFAULT_PROFILE_CONFIG_ID: copy.deepcopy(DEFAULT_PROFILE_TEMPLATE) + } + self.bindings: dict[str, str] = {} + self.instances: dict[str, dict[str, Any]] = {} + + def make_instance_id(self, entity_id: str, profile_config_id: str) -> str: + """Generate profile_instance_id: entity_id + '_' + profile_config_id.""" + return f"{entity_id}_{profile_config_id}" + + +# Singleton store shared across all requests +_store = ProfileMockStore() + + +def get_store() -> ProfileMockStore: + """Get the global mock store instance.""" + return _store + + +class ProfileHandler: + """Mock handler for profile (attribute tree) endpoints.""" + + def __init__(self, store: ProfileMockStore | None = None) -> None: + self.store = store or get_store() + + # ----------------------------------------------------------------- + # CreateProfileTemplate + # ----------------------------------------------------------------- + async def create_template( + self, req: CreateProfileTemplateRequest + ) -> CreateProfileTemplateResponse: + """ + Create a new profile template. + + Generates a UUID as profile_config_id, stores the template metadata. + """ + profile_config_id = f"tpl_{uuid.uuid4().hex[:12]}" + self.store.templates[profile_config_id] = copy.deepcopy(req.metadata) + + logger.info( + "Created profile template %s with %d categories", + profile_config_id, + len(req.metadata), + ) + + return CreateProfileTemplateResponse( + code=200, + message="success", + data={"profile_config_id": profile_config_id}, + ) + + # ----------------------------------------------------------------- + # BindProfile + # ----------------------------------------------------------------- + async def bind_profile(self, req: BindProfileRequest) -> BindProfileResponse: + """ + Bind user/agent IDs to profile templates. + + Rules: + - If not yet bound: create binding + instantiate template defaults + - If already bound to SAME template: idempotent, return existing instance_id + - If already bound to DIFFERENT template: return 400 error + """ + instance_ids: list[str] = [] + + for mapping in req.id_profile_map: + entity_id = mapping.id + profile_config_id = mapping.profile_config_id + + # Check if template exists + if profile_config_id not in self.store.templates: + raise HTTPException( + status_code=400, + detail=f"Template '{profile_config_id}' not found. " + f"Please create it first via CreateProfileTemplate.", + ) + + # Check existing binding + existing = self.store.bindings.get(entity_id) + + if existing is not None and existing != profile_config_id: + raise HTTPException( + status_code=400, + detail=f"Entity '{entity_id}' is already bound to template " + f"'{existing}'. Cannot rebind to '{profile_config_id}'. " + f"One ID can only bind to one template.", + ) + + instance_id = self.store.make_instance_id(entity_id, profile_config_id) + + if existing == profile_config_id: + # Idempotent — already bound to same template + logger.info( + "Entity %s already bound to %s, idempotent", + entity_id, + profile_config_id, + ) + else: + # New binding — instantiate template with defaults filled in + self.store.bindings[entity_id] = profile_config_id + template_data = self.store.templates[profile_config_id] + self.store.instances[instance_id] = copy.deepcopy(template_data) + self._fill_defaults(instance_id) + + logger.info( + "Bound entity %s to template %s → instance %s", + entity_id, + profile_config_id, + instance_id, + ) + + instance_ids.append(instance_id) + + return BindProfileResponse( + code=200, + message="success", + data={"profile_instance_id": instance_ids}, + ) + + # ----------------------------------------------------------------- + # EditProfile + # ----------------------------------------------------------------- + async def edit_profile(self, req: EditProfileRequest) -> EditProfileResponse: + """ + Edit profile values for a bound user/agent (editProfileConfig §3.3). + + Execution logic (per final spec): + 1. Update the passed-in fields' values, regardless of prior values. + 2. Check fields with default values; if empty, fill with defaults. + 3. If a field is not in the tree, add it to this user's instance only. + 4. Mark fields with algorithm_updatable from the metadata dict. + """ + entity_id = req.id + profile_config_id = req.profile_config_id or DEFAULT_PROFILE_CONFIG_ID + instance_id = self.store.make_instance_id(entity_id, profile_config_id) + + # Ensure template exists (bootstrap default if needed) + if profile_config_id not in self.store.templates: + if profile_config_id == DEFAULT_PROFILE_CONFIG_ID: + self.store.templates[profile_config_id] = copy.deepcopy(DEFAULT_PROFILE_TEMPLATE) + else: + raise HTTPException( + status_code=400, + detail=f"Template '{profile_config_id}' not found. " + "Please create it first via create_profile_template.", + ) + + # Auto-bind if not yet bound; reject cross-template edits + existing_binding = self.store.bindings.get(entity_id) + if existing_binding is None: + self.store.bindings[entity_id] = profile_config_id + template_data = self.store.templates[profile_config_id] + self.store.instances[instance_id] = copy.deepcopy(template_data) + self._fill_defaults(instance_id) + logger.info("Auto-bound entity %s to template %s", entity_id, profile_config_id) + elif existing_binding != profile_config_id: + raise HTTPException( + status_code=400, + detail=f"Entity '{entity_id}' is already bound to template " + f"'{existing_binding}'. Cannot edit with '{profile_config_id}'.", + ) + + # Apply field updates from metadata dict + if req.metadata: + instance_data = self.store.instances[instance_id] + for field_name, field_value in req.metadata.items(): + updated = False + for _category_name, category_fields in instance_data.items(): + if isinstance(category_fields, dict) and field_name in category_fields: + if isinstance(field_value, dict): + category_fields[field_name].update(field_value) + else: + category_fields[field_name] = field_value + updated = True + break + + if not updated: + instance_data[field_name] = field_value + + logger.info( + "Updated %d field(s) for instance %s", len(req.metadata), instance_id + ) + + return EditProfileResponse( + code=200, + message="success", + data={"profile_instance_id": instance_id}, + ) + + # ----------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------- + def _fill_defaults(self, instance_id: str) -> None: + """ + Ensure every field dict in a freshly instantiated instance has + algorithm_updatable set to True (the default for new bindings). + + Template defaults (内容, 当前描述, 变化趋势, etc.) are already copied + in via deepcopy; this pass only guarantees the control flag is present. + """ + instance_data = self.store.instances.get(instance_id) + if instance_data is None: + return + + for category_name, category_fields in instance_data.items(): + if not isinstance(category_fields, dict): + continue + for field_name, field_value in category_fields.items(): + if not isinstance(field_value, dict): + continue + # Ensure algorithm_updatable has a default + if "algorithm_updatable" not in field_value: + field_value["algorithm_updatable"] = True + + logger.debug("Filled defaults for instance %s", instance_id) + + # ----------------------------------------------------------------- + # Get profile data (for SearchMemory mock) + # ----------------------------------------------------------------- + def get_profile_for_entity(self, entity_id: str) -> list[dict[str, Any]]: + """ + Retrieve stored profile data as profile_detail_list items. + + Used by SearchMemory mock when include_memory_view contains "profile". + Returns a flat list of profile fields formatted for the API response. + """ + binding = self.store.bindings.get(entity_id) + if binding is None: + return [] + + instance_id = self.store.make_instance_id(entity_id, binding) + instance_data = self.store.instances.get(instance_id) + if instance_data is None: + return [] + + results: list[dict[str, Any]] = [] + for category_name, category_fields in instance_data.items(): + if not isinstance(category_fields, dict): + continue + for field_name, field_value in category_fields.items(): + if not isinstance(field_value, dict): + continue + + # Build content string from field value + content_parts = [] + for k, v in field_value.items(): + if k != "algorithm_updatable": + content_parts.append(f"{k}: {v}") + content_str = "; ".join(content_parts) if content_parts else str(field_value) + + results.append( + { + "type": "ProfileMemory", + "content": f"{category_name}.{field_name}: {content_str}", + "score": 0.85, # Mock score + "metadata": { + "profile_field": f"{category_name}.{field_name}", + "profile_category": category_name, + "algorithm_updatable": field_value.get( + "algorithm_updatable", True + ), + "template_id": binding, + "profile_instance_id": instance_id, + }, + } + ) + + return results diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 267d1bb28..fb5ec1ec7 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -12,6 +12,7 @@ from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.handlers.formatters_handler import rerank_knowledge_mem +from memos.api.handlers.profile_handler import ProfileHandler from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( @@ -43,6 +44,7 @@ def __init__(self, dependencies: HandlerDependencies): self._validate_dependencies( "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) + self.profile_handler = ProfileHandler() def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -90,6 +92,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse top_k=search_req_local.top_k, file_mem_proportion=0.5, ) + self._attach_profile_and_event_mock_results(results, search_req_local) self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -100,6 +103,21 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) + def _attach_profile_and_event_mock_results( + self, results: dict[str, Any], search_req: APISearchRequest + ) -> None: + include_views = set(search_req.include_memory_view or []) + if "profile" in include_views: + profile_detail_list = self.profile_handler.get_profile_for_entity(search_req.user_id) + else: + profile_detail_list = [] + + # Event mock is wired for interface compatibility in this phase. + event_detail_list: list[dict[str, Any]] = [] + + results["profile_detail_list"] = profile_detail_list + results["event_detail_list"] = event_detail_list + @staticmethod def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]: if relativity <= 0: diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 5bf27e985..93b094abc 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -12,6 +12,22 @@ logger = get_logger(__name__) T = TypeVar("T") +VALID_MEMORY_VIEWS: set[str] = { + "detail_factual", + "preference", + "skill", + "profile", + "event", + "tool_memory", +} +DEFAULT_SEARCH_MEMORY_VIEWS: list[str] = [ + "detail_factual", + "preference", + "skill", + "profile", + "event", +] + class BaseRequest(BaseModel): """Base model for all requests.""" @@ -399,6 +415,16 @@ class APISearchRequest(BaseRequest): description="Number of skill memories to retrieve (top-K). Default: 3.", ) + include_memory_view: list[str] | None = Field( + None, + description=( + "Controls which memory views to search. " + "Values: 'detail_factual', 'preference', 'skill', 'profile', 'event', 'tool_memory'. " + "When provided, overrides individual bool flags (include_preference, etc). " + "Default (None): ['detail_factual', 'preference', 'skill', 'profile', 'event']." + ), + ) + # ==== Filter conditions ==== # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( @@ -488,6 +514,8 @@ def _convert_deprecated_fields(self) -> "APISearchRequest": - moscube is ignored with warning - operation ignored """ + fields_set = self.model_fields_set + # Convert mem_cube_id to readable_cube_ids (new field takes priority) if self.mem_cube_id is not None: if not self.readable_cube_ids: @@ -512,6 +540,49 @@ def _convert_deprecated_fields(self) -> "APISearchRequest": "This field is deprecated and ignored." ) + # include_memory_view takes highest priority. Legacy bool fields are ignored. + if self.include_memory_view is not None: + normalized_views: list[str] = [] + seen: set[str] = set() + for view in self.include_memory_view: + if view not in VALID_MEMORY_VIEWS: + logger.warning( + "Unknown memory view '%s' in include_memory_view; it will be ignored.", + view, + ) + continue + if view not in seen: + normalized_views.append(view) + seen.add(view) + + self.include_memory_view = normalized_views + self.include_preference = "preference" in normalized_views + self.include_skill_memory = "skill" in normalized_views + self.search_tool_memory = "tool_memory" in normalized_views + return self + + # Compatibility strategy: + # 1) legacy bool fields explicitly provided -> derive include_memory_view from defaults + # 2) none provided -> use new default include_memory_view + legacy_fields = {"include_preference", "include_skill_memory", "search_tool_memory"} + legacy_given = any(field in fields_set for field in legacy_fields) + + views = list(DEFAULT_SEARCH_MEMORY_VIEWS) + if legacy_given: + if not self.include_preference and "preference" in views: + views.remove("preference") + if not self.include_skill_memory and "skill" in views: + views.remove("skill") + if self.search_tool_memory and "tool_memory" not in views: + views.append("tool_memory") + + self.include_memory_view = views + return self + + self.include_memory_view = list(DEFAULT_SEARCH_MEMORY_VIEWS) + self.include_preference = "preference" in self.include_memory_view + self.include_skill_memory = "skill" in self.include_memory_view + self.search_tool_memory = "tool_memory" in self.include_memory_view return self @@ -519,7 +590,7 @@ class APIADDRequest(BaseRequest): """Request model for creating memories.""" # ==== Basic identifiers ==== - user_id: str = Field(None, description="User ID") + user_id: str | list[str] = Field(None, description="User ID or user ID list") session_id: str | None = Field( None, description="Session ID. If not provided, a default session will be used.", @@ -529,8 +600,12 @@ class APIADDRequest(BaseRequest): project_id: str | None = Field(None, description="Project ID") # ==== Multi-cube writing ==== - writable_cube_ids: list[str] | None = Field( - None, description="List of cube IDs user can write for multi-cube add" + writable_cube_ids: list[str] | dict[str, Any] | None = Field( + None, + description=( + "Writable cube IDs. Supports legacy list[str] and dict format " + "{cube_id: {cube_type, user_or_agent_id, ...}}." + ), ) # ==== Async control ==== @@ -608,6 +683,16 @@ class APIADDRequest(BaseRequest): description=("Whether this request represents user feedback. Default: False."), ) + # ==== Memory view control ==== + allow_memory_view: list[str] | None = Field( + None, + description=( + "Controls which memory views are allowed for this add request. " + "Example: ['detail_factual', 'preference', 'skill', 'profile', 'event', 'tool_memory']. " + "Default (None) means all views are allowed." + ), + ) + # ==== Backward compatibility fields (will delete later) ==== mem_cube_id: str | None = Field( None, @@ -656,6 +741,39 @@ def _convert_deprecated_fields(self) -> "APIADDRequest": ) self.mode = None + # user_id is a backward-compatible field (str only in practice). + # list[str] is accepted at the interface level but normalized to the first + # element here. Multi-user / multi-agent routing is handled entirely by + # writable_cube_ids (dict with cube_type & user_or_agent_id); profile + # target users are derived from that dict, NOT from user_id. + if isinstance(self.user_id, list): + if not self.user_id: + raise ValueError("APIADDRequest.user_id list cannot be empty.") + logger.info( + "APIADDRequest received user_id as list; normalizing to first element '%s'.", + self.user_id[0], + ) + self.user_id = self.user_id[0] + + # Accept writable_cube_ids dict and normalize to list of cube IDs. + if isinstance(self.writable_cube_ids, dict): + self.writable_cube_ids = list(self.writable_cube_ids.keys()) + + if self.allow_memory_view is not None: + normalized_views: list[str] = [] + seen: set[str] = set() + for view in self.allow_memory_view: + if view not in VALID_MEMORY_VIEWS: + logger.warning( + "Unknown memory view '%s' in allow_memory_view; it will be ignored.", + view, + ) + continue + if view not in seen: + normalized_views.append(view) + seen.add(view) + self.allow_memory_view = normalized_views + # Convert mem_cube_id to writable_cube_ids (new field takes priority) if self.mem_cube_id: logger.warning( @@ -921,11 +1039,25 @@ class SearchMemoryData(BaseModel): tool_memory_detail_list: list[MessageDetail] | None = Field( None, alias="tool_memory_detail_list", - description="List of tool_memor details (usually None)", + description="List of tool_memory details (usually None)", ) preference_note: str = Field( None, alias="preference_note", description="String of preference_note" ) + profile_detail_list: list[dict[str, Any]] | None = Field( + None, + alias="profile_detail_list", + description=( + "Attribute-tree fields that matched the query (mixed-rank results). " + "Each item: {type, content, score, metadata{profile_field, profile_category, " + "algorithm_updatable, template_id, profile_instance_id}}." + ), + ) + event_detail_list: list[dict[str, Any]] | None = Field( + None, + alias="event_detail_list", + description="Event memory items that matched the query (reserved, empty in mock phase).", + ) class GetKnowledgebaseFileData(BaseModel): diff --git a/src/memos/api/profile_models.py b/src/memos/api/profile_models.py new file mode 100644 index 000000000..6d62db332 --- /dev/null +++ b/src/memos/api/profile_models.py @@ -0,0 +1,162 @@ +""" +Profile (Attribute Tree) Mock Models. + +Pydantic request/response models for the 3 new profile endpoints: +- CreateProfileTemplate +- BindProfile +- EditProfile + +These are used for the mock phase only. No PolarDB or scheduler dependency. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +from memos.api.product_models import BaseResponse + +DEFAULT_PROFILE_CONFIG_ID = "default_config" + + +# ============================================================================= +# Sub-models +# ============================================================================= + + +class IdProfileMapping(BaseModel): + """Single mapping entry for BindProfile: one id → one profile_config_id.""" + + id: str = Field(..., description="User ID or Agent ID to bind") + profile_config_id: str = Field( + DEFAULT_PROFILE_CONFIG_ID, + description="Profile template ID to bind to this id. Defaults to default_config.", + ) + + +# ============================================================================= +# CreateProfileTemplate +# ============================================================================= + + +class CreateProfileTemplateRequest(BaseModel): + """ + Request for creating a new profile template. + + metadata example: + { + "客观档案": { + "姓名": { + "内容": "张三", + "依据": "...", + "algorithm_updatable": true + } + }, + "人格状态": { + "温柔程度": { + "当前描述": "语气柔和", + "变化趋势": "上升", + "依据": "...", + "algorithm_updatable": true + } + } + } + """ + + metadata: dict[str, Any] = Field( + ..., + description="Profile template structure with categories and fields", + ) + + +class CreateProfileTemplateResponse(BaseResponse): + """Response for CreateProfileTemplate. data contains profile_config_id.""" + + data: dict[str, Any] | None = Field( + None, description="Response data containing profile_config_id" + ) + + +# ============================================================================= +# BindProfile +# ============================================================================= + + +class BindProfileRequest(BaseModel): + """ + Request for binding user/agent IDs to profile templates. + + Currently one ID can only bind to one template. + If the same ID is already bound to the same template, it's idempotent. + If bound to a different template, return an error. + """ + + id_profile_map: list[IdProfileMapping] = Field( + ..., + description="List of {id, profile_config_id} pairs to bind", + min_length=1, + ) + + +class BindProfileResponse(BaseResponse): + """Response for BindProfile. data contains profile_instance_id list.""" + + data: dict[str, Any] | None = Field( + None, + description="Response data containing profile_instance_id list", + ) + + +# ============================================================================= +# EditProfile +# ============================================================================= + + +class EditProfileRequest(BaseModel): + """ + Request for editProfileConfig (§3.3). + + Rules (from final spec): + - If already bound: update the specified fields. + - If not yet bound: auto-bind to the given template, then apply updates. + - If already bound to a *different* template: return 400. + - Update field values regardless of whether they had values before. + - Check fields with default values: if empty, fill with defaults. + - If a field is not in the template tree: add it to this user's instance only + (does not affect the template or other users). + - Fields marked with algorithm_updatable=false are locked from future + algorithm extraction overwrites. + + metadata example: + { + "姓名": {"内容": "张三", "algorithm_updatable": false}, + "性别": {"内容": "女"}, + "Smauel称呼我": {"内容": "小张", "algorithm_updatable": false} + } + """ + + id: str = Field( + ..., description="User ID or Agent ID whose profile to edit" + ) + profile_config_id: str = Field( + DEFAULT_PROFILE_CONFIG_ID, + description="Profile template ID. Defaults to default_config.", + ) + metadata: dict[str, Any] | None = Field( + None, + description=( + "Fields to update. Keys are field names, values are field data dicts " + "(e.g. {'内容': '张三', 'algorithm_updatable': false}). " + "Fields with algorithm_updatable=false are locked from algorithm overwrites." + ), + ) + + +class EditProfileResponse(BaseResponse): + """Response for editProfileConfig. data contains profile_instance_id.""" + + data: dict[str, Any] | None = Field( + None, + description="Response data containing profile_instance_id", + ) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index af6ae4fe5..748262097 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -22,6 +22,7 @@ from memos.api.handlers.base_handler import HandlerDependencies from memos.api.handlers.chat_handler import ChatHandler from memos.api.handlers.feedback_handler import FeedbackHandler +from memos.api.handlers.profile_handler import ProfileHandler from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( AllStatusResponse, @@ -53,6 +54,14 @@ SuggestionResponse, TaskQueueResponse, ) +from memos.api.profile_models import ( + BindProfileRequest, + BindProfileResponse, + CreateProfileTemplateRequest, + CreateProfileTemplateResponse, + EditProfileRequest, + EditProfileResponse, +) from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -96,6 +105,9 @@ graph_db = components["graph_db"] vector_db = components["vector_db"] +# Initialize profile handler (standalone, no HandlerDependencies needed) +profile_handler = ProfileHandler() + # ============================================================================= # Search API Endpoints @@ -358,6 +370,41 @@ def feedback_memories(feedback_req: APIFeedbackRequest): return feedback_handler.handle_feedback_memories(feedback_req) +# ============================================================================= +# Profile (Attribute Tree) API Endpoints — Mock +# ============================================================================= + + +@router.post( + "/create_profile_template", + summary="Create a profile template", + response_model=CreateProfileTemplateResponse, +) +async def create_profile_template(req: CreateProfileTemplateRequest): + """Create a new profile (attribute tree) template.""" + return await profile_handler.create_template(req) + + +@router.post( + "/bind_profile_config", + summary="Bind user/agent to a profile template", + response_model=BindProfileResponse, +) +async def bind_profile(req: BindProfileRequest): + """Bind user/agent IDs to profile templates (one ID → one template currently).""" + return await profile_handler.bind_profile(req) + + +@router.post( + "/edit_profile_config", + summary="Edit profile values", + response_model=EditProfileResponse, +) +async def edit_profile(req: EditProfileRequest): + """Edit profile values for a bound user/agent. Auto-binds if not yet bound.""" + return await profile_handler.edit_profile(req) + + # ============================================================================= # Other API Endpoints (for internal use) # ============================================================================= diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2a6991e8c..78204a410 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -85,12 +85,35 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: f"Processing add with mode={sync_mode}, session={target_session_id}" ) - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) - pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) + allowed_views = set(add_req.allow_memory_view or []) + allow_all_views = add_req.allow_memory_view is None + allow_text_mem = allow_all_views or "detail_factual" in allowed_views + allow_pref_mem = allow_all_views or "preference" in allowed_views + + if not allow_text_mem and not allow_pref_mem: + self.logger.info( + "[SingleCubeView] cube=%s Skip add: allow_memory_view excludes " + "both detail_factual and preference.", + self.cube_id, + ) + return [] - text_results = text_future.result() - pref_results = pref_future.result() + if allow_text_mem and allow_pref_mem: + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit( + self._process_text_mem, add_req, user_context, sync_mode + ) + pref_future = executor.submit( + self._process_pref_mem, add_req, user_context, sync_mode + ) + text_results = text_future.result() + pref_results = pref_future.result() + elif allow_text_mem: + text_results = self._process_text_mem(add_req, user_context, sync_mode) + pref_results = [] + else: + text_results = [] + pref_results = self._process_pref_mem(add_req, user_context, sync_mode) self.logger.info( f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 8234caf8b..580d447ff 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -53,9 +53,14 @@ class MessageDict(TypedDict, total=False): role: MessageRole content: str - chat_time: str | None # Optional timestamp for the message, format is not - # restricted, it can be any vague or precise time string. - message_id: str | None # Optional unique identifier for the message + chat_time: str | None # Optional timestamp; any vague or precise time string. + message_id: str | None # Unique message ID. Used for cascade-deletion of profile + # memory fields via source_message_ids. Corresponds to messages[].id in the + # cloud-service (product) API. + role_id: str | None # Speaker identity in multi-user / multi-agent (group-chat) + # scenarios. Must match an entry in the request-level user_id list. + role_name: str | None # Display name of the speaker. Optional; when provided, + # must be accompanied by role_id. RawMessageDict: TypeAlias = ChatCompletionContentPartTextParam | File diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py index f28796c2d..d975164e4 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -53,4 +53,11 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False): restricted, it can be any vague or precise time string.""" message_id: str | None - """Optional unique identifier for the message""" + """Optional unique identifier for the message. Corresponds to messages[].id in the + cloud-service API; used for cascade-deletion of profile memory fields.""" + + role_id: str | None + """Speaker's user_id or agent_id in multi-user / multi-agent (group-chat) scenarios.""" + + role_name: str | None + """Display name of the speaker. Optional; when provided must be accompanied by role_id.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py index 13a9a89af..8d9d83de7 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -33,4 +33,11 @@ class ChatCompletionSystemMessageParam(TypedDict, total=False): restricted, it can be any vague or precise time string.""" message_id: str | None - """Optional unique identifier for the message""" + """Optional unique identifier for the message. Corresponds to messages[].id in the + cloud-service API; used for cascade-deletion of profile memory fields.""" + + role_id: str | None + """Speaker's user_id or agent_id in multi-user / multi-agent (group-chat) scenarios.""" + + role_name: str | None + """Display name of the speaker. Optional; when provided must be accompanied by role_id.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py index b5bee9842..482684af1 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -31,4 +31,12 @@ class ChatCompletionUserMessageParam(TypedDict, total=False): restricted, it can be any vague or precise time string.""" message_id: str | None - """Optional unique identifier for the message""" + """Optional unique identifier for the message. Corresponds to messages[].id in the + cloud-service API; used for cascade-deletion of profile memory fields.""" + + role_id: str | None + """Speaker's user_id or agent_id in multi-user / multi-agent (group-chat) scenarios. + Must match an entry in the request-level user_id list.""" + + role_name: str | None + """Display name of the speaker. Optional; when provided must be accompanied by role_id.""" From b2cdabc2012769fc33ac36fe409aa6411853a158 Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Thu, 26 Feb 2026 19:12:03 +0800 Subject: [PATCH 46/92] fix: preserve writable_cube_ids dict, fix feedback path + ruff format - Stop flattening writable_cube_ids dict in validator; preserve cube-level metadata (cube_type, user_or_agent_id) for future use - Fix feedback path: pass _resolve_cube_ids() result (list[str]) into APIFeedbackRequest instead of raw dict to avoid ValidationError - Apply ruff check fixes (unused vars, blank lines, line length) Made-with: Cursor --- src/memos/api/handlers/add_handler.py | 2 +- src/memos/api/handlers/profile_handler.py | 14 ++++++-------- src/memos/api/product_models.py | 6 +++--- src/memos/api/profile_models.py | 5 ++--- .../textual/tree_text_memory/retrieve/searcher.py | 2 +- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index d8d00aea0..5470a2d71 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -87,7 +87,7 @@ def _check_messages(messages: MessageList) -> None: task_id=add_req.task_id, history=feedback_history, feedback_content=feedback_content, - writable_cube_ids=add_req.writable_cube_ids, + writable_cube_ids=self._resolve_cube_ids(add_req), async_mode=add_req.async_mode, info=add_req.info, ) diff --git a/src/memos/api/handlers/profile_handler.py b/src/memos/api/handlers/profile_handler.py index 83e6f3851..0b1c516d8 100644 --- a/src/memos/api/handlers/profile_handler.py +++ b/src/memos/api/handlers/profile_handler.py @@ -11,6 +11,7 @@ import copy import uuid + from typing import Any from fastapi import HTTPException @@ -26,6 +27,7 @@ ) from memos.log import get_logger + logger = get_logger(__name__) DEFAULT_PROFILE_TEMPLATE: dict[str, dict[str, dict[str, Any]]] = { @@ -238,9 +240,7 @@ async def edit_profile(self, req: EditProfileRequest) -> EditProfileResponse: if not updated: instance_data[field_name] = field_value - logger.info( - "Updated %d field(s) for instance %s", len(req.metadata), instance_id - ) + logger.info("Updated %d field(s) for instance %s", len(req.metadata), instance_id) return EditProfileResponse( code=200, @@ -263,10 +263,10 @@ def _fill_defaults(self, instance_id: str) -> None: if instance_data is None: return - for category_name, category_fields in instance_data.items(): + for _category_name, category_fields in instance_data.items(): if not isinstance(category_fields, dict): continue - for field_name, field_value in category_fields.items(): + for _field_name, field_value in category_fields.items(): if not isinstance(field_value, dict): continue # Ensure algorithm_updatable has a default @@ -317,9 +317,7 @@ def get_profile_for_entity(self, entity_id: str) -> list[dict[str, Any]]: "metadata": { "profile_field": f"{category_name}.{field_name}", "profile_category": category_name, - "algorithm_updatable": field_value.get( - "algorithm_updatable", True - ), + "algorithm_updatable": field_value.get("algorithm_updatable", True), "template_id": binding, "profile_instance_id": instance_id, }, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 93b094abc..98c51a1b2 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -755,9 +755,9 @@ def _convert_deprecated_fields(self) -> "APIADDRequest": ) self.user_id = self.user_id[0] - # Accept writable_cube_ids dict and normalize to list of cube IDs. - if isinstance(self.writable_cube_ids, dict): - self.writable_cube_ids = list(self.writable_cube_ids.keys()) + # writable_cube_ids dict is preserved as-is so downstream code can read + # cube-level metadata (cube_type, user_or_agent_id, etc.). + # Cube ID extraction (keys) is done lazily in _resolve_cube_ids. if self.allow_memory_view is not None: normalized_views: list[str] = [] diff --git a/src/memos/api/profile_models.py b/src/memos/api/profile_models.py index 6d62db332..6e70f70b4 100644 --- a/src/memos/api/profile_models.py +++ b/src/memos/api/profile_models.py @@ -17,6 +17,7 @@ from memos.api.product_models import BaseResponse + DEFAULT_PROFILE_CONFIG_ID = "default_config" @@ -136,9 +137,7 @@ class EditProfileRequest(BaseModel): } """ - id: str = Field( - ..., description="User ID or Agent ID whose profile to edit" - ) + id: str = Field(..., description="User ID or Agent ID whose profile to edit") profile_config_id: str = Field( DEFAULT_PROFILE_CONFIG_ID, description="Profile template ID. Defaults to default_config.", diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9dcbe8c56..cc269e8c4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -524,7 +524,7 @@ def _retrieve_from_keyword( user_name=user_name, tsquery_config="jiebaqry", ) - except Exception as e: + except Exception: logger.warning( f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}" ) From d4ad974098f52c698d66d484138c0de795c64a6e Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Thu, 26 Feb 2026 20:11:14 +0800 Subject: [PATCH 47/92] fix: revert user_id to str only, remove list[str] support user_id was originally str; list[str] was added in this branch but never released. Multi-user routing is handled by writable_cube_ids dict, so user_id should not carry multi-user semantics. Made-with: Cursor --- src/memos/api/product_models.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 98c51a1b2..6e35d7841 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -590,7 +590,7 @@ class APIADDRequest(BaseRequest): """Request model for creating memories.""" # ==== Basic identifiers ==== - user_id: str | list[str] = Field(None, description="User ID or user ID list") + user_id: str = Field(None, description="User ID") session_id: str | None = Field( None, description="Session ID. If not provided, a default session will be used.", @@ -741,20 +741,6 @@ def _convert_deprecated_fields(self) -> "APIADDRequest": ) self.mode = None - # user_id is a backward-compatible field (str only in practice). - # list[str] is accepted at the interface level but normalized to the first - # element here. Multi-user / multi-agent routing is handled entirely by - # writable_cube_ids (dict with cube_type & user_or_agent_id); profile - # target users are derived from that dict, NOT from user_id. - if isinstance(self.user_id, list): - if not self.user_id: - raise ValueError("APIADDRequest.user_id list cannot be empty.") - logger.info( - "APIADDRequest received user_id as list; normalizing to first element '%s'.", - self.user_id[0], - ) - self.user_id = self.user_id[0] - # writable_cube_ids dict is preserved as-is so downstream code can read # cube-level metadata (cube_type, user_or_agent_id, etc.). # Cube ID extraction (keys) is done lazily in _resolve_cube_ids. From 2eaa30a84f81740dcaa4d1ab3f598cfa0b329131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E6=99=A8=E9=98=B3?= Date: Thu, 26 Feb 2026 21:36:58 +0800 Subject: [PATCH 48/92] Revert from "feat: add profile mock API and multi-view memory support" to "Merge #2 into private from feature/profile-mock-interfaces" * Reverted: d0cc255e feat: add profile mock API and multi-view memory support b2cdabc2 fix: preserve writable_cube_ids dict, fix feedback path + ruff format d4ad9740 fix: revert user_id to str only, remove list[str] support f70a0374 Merge #2 into private from feature/profile-mock-interfaces This reverts change request#2 --- src/memos/api/handlers/add_handler.py | 13 +- src/memos/api/handlers/profile_handler.py | 327 ------------------ src/memos/api/handlers/search_handler.py | 18 - src/memos/api/product_models.py | 124 +------ src/memos/api/profile_models.py | 161 --------- src/memos/api/routers/server_router.py | 47 --- .../tree_text_memory/retrieve/searcher.py | 2 +- src/memos/multi_mem_cube/single_cube.py | 33 +- src/memos/types/general_types.py | 11 +- ...chat_completion_assistant_message_param.py | 9 +- .../chat_completion_system_message_param.py | 9 +- .../chat_completion_user_message_param.py | 10 +- 12 files changed, 18 insertions(+), 746 deletions(-) delete mode 100644 src/memos/api/handlers/profile_handler.py delete mode 100644 src/memos/api/profile_models.py diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 5470a2d71..3cdbedabf 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -87,7 +87,7 @@ def _check_messages(messages: MessageList) -> None: task_id=add_req.task_id, history=feedback_history, feedback_content=feedback_content, - writable_cube_ids=self._resolve_cube_ids(add_req), + writable_cube_ids=add_req.writable_cube_ids, async_mode=add_req.async_mode, info=add_req.info, ) @@ -118,22 +118,15 @@ def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: Normalize target cube ids from add_req. Priority: 1) writable_cube_ids (deprecated mem_cube_id is converted to this in model validator) - 2) user_id (single str after validator normalization) + 2) fallback to user_id """ if add_req.writable_cube_ids: - if isinstance(add_req.writable_cube_ids, dict): - return list(dict.fromkeys(add_req.writable_cube_ids.keys())) return list(dict.fromkeys(add_req.writable_cube_ids)) - if add_req.user_id: - return [add_req.user_id] - - return [] + return [add_req.user_id] def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: cube_ids = self._resolve_cube_ids(add_req) - if not cube_ids: - raise ValueError("No writable cube id could be resolved from add request.") if len(cube_ids) == 1: cube_id = cube_ids[0] diff --git a/src/memos/api/handlers/profile_handler.py b/src/memos/api/handlers/profile_handler.py deleted file mode 100644 index 0b1c516d8..000000000 --- a/src/memos/api/handlers/profile_handler.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -Profile (Attribute Tree) Mock Handler. - -Fully synchronous, in-memory mock implementation. -No PolarDB, no scheduler, no embedding — just dict storage. - -This handler is standalone and does NOT depend on HandlerDependencies. -""" - -from __future__ import annotations - -import copy -import uuid - -from typing import Any - -from fastapi import HTTPException - -from memos.api.profile_models import ( - DEFAULT_PROFILE_CONFIG_ID, - BindProfileRequest, - BindProfileResponse, - CreateProfileTemplateRequest, - CreateProfileTemplateResponse, - EditProfileRequest, - EditProfileResponse, -) -from memos.log import get_logger - - -logger = get_logger(__name__) - -DEFAULT_PROFILE_TEMPLATE: dict[str, dict[str, dict[str, Any]]] = { - "客观档案": { - "姓名": {"内容": "", "证据": ""}, - "性别": {"内容": "", "证据": ""}, - "年龄": {"内容": "", "证据": ""}, - "当前关系状态": {"内容": "单身", "证据": ""}, - }, - "人格状态": { - "温柔程度": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, - "自尊表现": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, - "依附感": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, - "强势感": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, - "退让程度": {"当前描述": "中性", "变化趋势": "稳定", "依据": ""}, - }, -} - - -class ProfileMockStore: - """ - In-memory storage for profile templates, bindings, and instances. - - Data is lost on server restart — this is intentional for a mock. - - Storage layout: - templates: {profile_config_id: {metadata_dict}} - bindings: {entity_id: profile_config_id} - instances: {profile_instance_id: {merged_profile_data}} - """ - - def __init__(self) -> None: - self.templates: dict[str, dict[str, Any]] = { - DEFAULT_PROFILE_CONFIG_ID: copy.deepcopy(DEFAULT_PROFILE_TEMPLATE) - } - self.bindings: dict[str, str] = {} - self.instances: dict[str, dict[str, Any]] = {} - - def make_instance_id(self, entity_id: str, profile_config_id: str) -> str: - """Generate profile_instance_id: entity_id + '_' + profile_config_id.""" - return f"{entity_id}_{profile_config_id}" - - -# Singleton store shared across all requests -_store = ProfileMockStore() - - -def get_store() -> ProfileMockStore: - """Get the global mock store instance.""" - return _store - - -class ProfileHandler: - """Mock handler for profile (attribute tree) endpoints.""" - - def __init__(self, store: ProfileMockStore | None = None) -> None: - self.store = store or get_store() - - # ----------------------------------------------------------------- - # CreateProfileTemplate - # ----------------------------------------------------------------- - async def create_template( - self, req: CreateProfileTemplateRequest - ) -> CreateProfileTemplateResponse: - """ - Create a new profile template. - - Generates a UUID as profile_config_id, stores the template metadata. - """ - profile_config_id = f"tpl_{uuid.uuid4().hex[:12]}" - self.store.templates[profile_config_id] = copy.deepcopy(req.metadata) - - logger.info( - "Created profile template %s with %d categories", - profile_config_id, - len(req.metadata), - ) - - return CreateProfileTemplateResponse( - code=200, - message="success", - data={"profile_config_id": profile_config_id}, - ) - - # ----------------------------------------------------------------- - # BindProfile - # ----------------------------------------------------------------- - async def bind_profile(self, req: BindProfileRequest) -> BindProfileResponse: - """ - Bind user/agent IDs to profile templates. - - Rules: - - If not yet bound: create binding + instantiate template defaults - - If already bound to SAME template: idempotent, return existing instance_id - - If already bound to DIFFERENT template: return 400 error - """ - instance_ids: list[str] = [] - - for mapping in req.id_profile_map: - entity_id = mapping.id - profile_config_id = mapping.profile_config_id - - # Check if template exists - if profile_config_id not in self.store.templates: - raise HTTPException( - status_code=400, - detail=f"Template '{profile_config_id}' not found. " - f"Please create it first via CreateProfileTemplate.", - ) - - # Check existing binding - existing = self.store.bindings.get(entity_id) - - if existing is not None and existing != profile_config_id: - raise HTTPException( - status_code=400, - detail=f"Entity '{entity_id}' is already bound to template " - f"'{existing}'. Cannot rebind to '{profile_config_id}'. " - f"One ID can only bind to one template.", - ) - - instance_id = self.store.make_instance_id(entity_id, profile_config_id) - - if existing == profile_config_id: - # Idempotent — already bound to same template - logger.info( - "Entity %s already bound to %s, idempotent", - entity_id, - profile_config_id, - ) - else: - # New binding — instantiate template with defaults filled in - self.store.bindings[entity_id] = profile_config_id - template_data = self.store.templates[profile_config_id] - self.store.instances[instance_id] = copy.deepcopy(template_data) - self._fill_defaults(instance_id) - - logger.info( - "Bound entity %s to template %s → instance %s", - entity_id, - profile_config_id, - instance_id, - ) - - instance_ids.append(instance_id) - - return BindProfileResponse( - code=200, - message="success", - data={"profile_instance_id": instance_ids}, - ) - - # ----------------------------------------------------------------- - # EditProfile - # ----------------------------------------------------------------- - async def edit_profile(self, req: EditProfileRequest) -> EditProfileResponse: - """ - Edit profile values for a bound user/agent (editProfileConfig §3.3). - - Execution logic (per final spec): - 1. Update the passed-in fields' values, regardless of prior values. - 2. Check fields with default values; if empty, fill with defaults. - 3. If a field is not in the tree, add it to this user's instance only. - 4. Mark fields with algorithm_updatable from the metadata dict. - """ - entity_id = req.id - profile_config_id = req.profile_config_id or DEFAULT_PROFILE_CONFIG_ID - instance_id = self.store.make_instance_id(entity_id, profile_config_id) - - # Ensure template exists (bootstrap default if needed) - if profile_config_id not in self.store.templates: - if profile_config_id == DEFAULT_PROFILE_CONFIG_ID: - self.store.templates[profile_config_id] = copy.deepcopy(DEFAULT_PROFILE_TEMPLATE) - else: - raise HTTPException( - status_code=400, - detail=f"Template '{profile_config_id}' not found. " - "Please create it first via create_profile_template.", - ) - - # Auto-bind if not yet bound; reject cross-template edits - existing_binding = self.store.bindings.get(entity_id) - if existing_binding is None: - self.store.bindings[entity_id] = profile_config_id - template_data = self.store.templates[profile_config_id] - self.store.instances[instance_id] = copy.deepcopy(template_data) - self._fill_defaults(instance_id) - logger.info("Auto-bound entity %s to template %s", entity_id, profile_config_id) - elif existing_binding != profile_config_id: - raise HTTPException( - status_code=400, - detail=f"Entity '{entity_id}' is already bound to template " - f"'{existing_binding}'. Cannot edit with '{profile_config_id}'.", - ) - - # Apply field updates from metadata dict - if req.metadata: - instance_data = self.store.instances[instance_id] - for field_name, field_value in req.metadata.items(): - updated = False - for _category_name, category_fields in instance_data.items(): - if isinstance(category_fields, dict) and field_name in category_fields: - if isinstance(field_value, dict): - category_fields[field_name].update(field_value) - else: - category_fields[field_name] = field_value - updated = True - break - - if not updated: - instance_data[field_name] = field_value - - logger.info("Updated %d field(s) for instance %s", len(req.metadata), instance_id) - - return EditProfileResponse( - code=200, - message="success", - data={"profile_instance_id": instance_id}, - ) - - # ----------------------------------------------------------------- - # Internal helpers - # ----------------------------------------------------------------- - def _fill_defaults(self, instance_id: str) -> None: - """ - Ensure every field dict in a freshly instantiated instance has - algorithm_updatable set to True (the default for new bindings). - - Template defaults (内容, 当前描述, 变化趋势, etc.) are already copied - in via deepcopy; this pass only guarantees the control flag is present. - """ - instance_data = self.store.instances.get(instance_id) - if instance_data is None: - return - - for _category_name, category_fields in instance_data.items(): - if not isinstance(category_fields, dict): - continue - for _field_name, field_value in category_fields.items(): - if not isinstance(field_value, dict): - continue - # Ensure algorithm_updatable has a default - if "algorithm_updatable" not in field_value: - field_value["algorithm_updatable"] = True - - logger.debug("Filled defaults for instance %s", instance_id) - - # ----------------------------------------------------------------- - # Get profile data (for SearchMemory mock) - # ----------------------------------------------------------------- - def get_profile_for_entity(self, entity_id: str) -> list[dict[str, Any]]: - """ - Retrieve stored profile data as profile_detail_list items. - - Used by SearchMemory mock when include_memory_view contains "profile". - Returns a flat list of profile fields formatted for the API response. - """ - binding = self.store.bindings.get(entity_id) - if binding is None: - return [] - - instance_id = self.store.make_instance_id(entity_id, binding) - instance_data = self.store.instances.get(instance_id) - if instance_data is None: - return [] - - results: list[dict[str, Any]] = [] - for category_name, category_fields in instance_data.items(): - if not isinstance(category_fields, dict): - continue - for field_name, field_value in category_fields.items(): - if not isinstance(field_value, dict): - continue - - # Build content string from field value - content_parts = [] - for k, v in field_value.items(): - if k != "algorithm_updatable": - content_parts.append(f"{k}: {v}") - content_str = "; ".join(content_parts) if content_parts else str(field_value) - - results.append( - { - "type": "ProfileMemory", - "content": f"{category_name}.{field_name}: {content_str}", - "score": 0.85, # Mock score - "metadata": { - "profile_field": f"{category_name}.{field_name}", - "profile_category": category_name, - "algorithm_updatable": field_value.get("algorithm_updatable", True), - "template_id": binding, - "profile_instance_id": instance_id, - }, - } - ) - - return results diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index fb5ec1ec7..267d1bb28 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -12,7 +12,6 @@ from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.handlers.formatters_handler import rerank_knowledge_mem -from memos.api.handlers.profile_handler import ProfileHandler from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( @@ -44,7 +43,6 @@ def __init__(self, dependencies: HandlerDependencies): self._validate_dependencies( "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) - self.profile_handler = ProfileHandler() def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -92,7 +90,6 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse top_k=search_req_local.top_k, file_mem_proportion=0.5, ) - self._attach_profile_and_event_mock_results(results, search_req_local) self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -103,21 +100,6 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) - def _attach_profile_and_event_mock_results( - self, results: dict[str, Any], search_req: APISearchRequest - ) -> None: - include_views = set(search_req.include_memory_view or []) - if "profile" in include_views: - profile_detail_list = self.profile_handler.get_profile_for_entity(search_req.user_id) - else: - profile_detail_list = [] - - # Event mock is wired for interface compatibility in this phase. - event_detail_list: list[dict[str, Any]] = [] - - results["profile_detail_list"] = profile_detail_list - results["event_detail_list"] = event_detail_list - @staticmethod def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]: if relativity <= 0: diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 6e35d7841..5bf27e985 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -12,22 +12,6 @@ logger = get_logger(__name__) T = TypeVar("T") -VALID_MEMORY_VIEWS: set[str] = { - "detail_factual", - "preference", - "skill", - "profile", - "event", - "tool_memory", -} -DEFAULT_SEARCH_MEMORY_VIEWS: list[str] = [ - "detail_factual", - "preference", - "skill", - "profile", - "event", -] - class BaseRequest(BaseModel): """Base model for all requests.""" @@ -415,16 +399,6 @@ class APISearchRequest(BaseRequest): description="Number of skill memories to retrieve (top-K). Default: 3.", ) - include_memory_view: list[str] | None = Field( - None, - description=( - "Controls which memory views to search. " - "Values: 'detail_factual', 'preference', 'skill', 'profile', 'event', 'tool_memory'. " - "When provided, overrides individual bool flags (include_preference, etc). " - "Default (None): ['detail_factual', 'preference', 'skill', 'profile', 'event']." - ), - ) - # ==== Filter conditions ==== # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( @@ -514,8 +488,6 @@ def _convert_deprecated_fields(self) -> "APISearchRequest": - moscube is ignored with warning - operation ignored """ - fields_set = self.model_fields_set - # Convert mem_cube_id to readable_cube_ids (new field takes priority) if self.mem_cube_id is not None: if not self.readable_cube_ids: @@ -540,49 +512,6 @@ def _convert_deprecated_fields(self) -> "APISearchRequest": "This field is deprecated and ignored." ) - # include_memory_view takes highest priority. Legacy bool fields are ignored. - if self.include_memory_view is not None: - normalized_views: list[str] = [] - seen: set[str] = set() - for view in self.include_memory_view: - if view not in VALID_MEMORY_VIEWS: - logger.warning( - "Unknown memory view '%s' in include_memory_view; it will be ignored.", - view, - ) - continue - if view not in seen: - normalized_views.append(view) - seen.add(view) - - self.include_memory_view = normalized_views - self.include_preference = "preference" in normalized_views - self.include_skill_memory = "skill" in normalized_views - self.search_tool_memory = "tool_memory" in normalized_views - return self - - # Compatibility strategy: - # 1) legacy bool fields explicitly provided -> derive include_memory_view from defaults - # 2) none provided -> use new default include_memory_view - legacy_fields = {"include_preference", "include_skill_memory", "search_tool_memory"} - legacy_given = any(field in fields_set for field in legacy_fields) - - views = list(DEFAULT_SEARCH_MEMORY_VIEWS) - if legacy_given: - if not self.include_preference and "preference" in views: - views.remove("preference") - if not self.include_skill_memory and "skill" in views: - views.remove("skill") - if self.search_tool_memory and "tool_memory" not in views: - views.append("tool_memory") - - self.include_memory_view = views - return self - - self.include_memory_view = list(DEFAULT_SEARCH_MEMORY_VIEWS) - self.include_preference = "preference" in self.include_memory_view - self.include_skill_memory = "skill" in self.include_memory_view - self.search_tool_memory = "tool_memory" in self.include_memory_view return self @@ -600,12 +529,8 @@ class APIADDRequest(BaseRequest): project_id: str | None = Field(None, description="Project ID") # ==== Multi-cube writing ==== - writable_cube_ids: list[str] | dict[str, Any] | None = Field( - None, - description=( - "Writable cube IDs. Supports legacy list[str] and dict format " - "{cube_id: {cube_type, user_or_agent_id, ...}}." - ), + writable_cube_ids: list[str] | None = Field( + None, description="List of cube IDs user can write for multi-cube add" ) # ==== Async control ==== @@ -683,16 +608,6 @@ class APIADDRequest(BaseRequest): description=("Whether this request represents user feedback. Default: False."), ) - # ==== Memory view control ==== - allow_memory_view: list[str] | None = Field( - None, - description=( - "Controls which memory views are allowed for this add request. " - "Example: ['detail_factual', 'preference', 'skill', 'profile', 'event', 'tool_memory']. " - "Default (None) means all views are allowed." - ), - ) - # ==== Backward compatibility fields (will delete later) ==== mem_cube_id: str | None = Field( None, @@ -741,25 +656,6 @@ def _convert_deprecated_fields(self) -> "APIADDRequest": ) self.mode = None - # writable_cube_ids dict is preserved as-is so downstream code can read - # cube-level metadata (cube_type, user_or_agent_id, etc.). - # Cube ID extraction (keys) is done lazily in _resolve_cube_ids. - - if self.allow_memory_view is not None: - normalized_views: list[str] = [] - seen: set[str] = set() - for view in self.allow_memory_view: - if view not in VALID_MEMORY_VIEWS: - logger.warning( - "Unknown memory view '%s' in allow_memory_view; it will be ignored.", - view, - ) - continue - if view not in seen: - normalized_views.append(view) - seen.add(view) - self.allow_memory_view = normalized_views - # Convert mem_cube_id to writable_cube_ids (new field takes priority) if self.mem_cube_id: logger.warning( @@ -1025,25 +921,11 @@ class SearchMemoryData(BaseModel): tool_memory_detail_list: list[MessageDetail] | None = Field( None, alias="tool_memory_detail_list", - description="List of tool_memory details (usually None)", + description="List of tool_memor details (usually None)", ) preference_note: str = Field( None, alias="preference_note", description="String of preference_note" ) - profile_detail_list: list[dict[str, Any]] | None = Field( - None, - alias="profile_detail_list", - description=( - "Attribute-tree fields that matched the query (mixed-rank results). " - "Each item: {type, content, score, metadata{profile_field, profile_category, " - "algorithm_updatable, template_id, profile_instance_id}}." - ), - ) - event_detail_list: list[dict[str, Any]] | None = Field( - None, - alias="event_detail_list", - description="Event memory items that matched the query (reserved, empty in mock phase).", - ) class GetKnowledgebaseFileData(BaseModel): diff --git a/src/memos/api/profile_models.py b/src/memos/api/profile_models.py deleted file mode 100644 index 6e70f70b4..000000000 --- a/src/memos/api/profile_models.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Profile (Attribute Tree) Mock Models. - -Pydantic request/response models for the 3 new profile endpoints: -- CreateProfileTemplate -- BindProfile -- EditProfile - -These are used for the mock phase only. No PolarDB or scheduler dependency. -""" - -from __future__ import annotations - -from typing import Any - -from pydantic import BaseModel, Field - -from memos.api.product_models import BaseResponse - - -DEFAULT_PROFILE_CONFIG_ID = "default_config" - - -# ============================================================================= -# Sub-models -# ============================================================================= - - -class IdProfileMapping(BaseModel): - """Single mapping entry for BindProfile: one id → one profile_config_id.""" - - id: str = Field(..., description="User ID or Agent ID to bind") - profile_config_id: str = Field( - DEFAULT_PROFILE_CONFIG_ID, - description="Profile template ID to bind to this id. Defaults to default_config.", - ) - - -# ============================================================================= -# CreateProfileTemplate -# ============================================================================= - - -class CreateProfileTemplateRequest(BaseModel): - """ - Request for creating a new profile template. - - metadata example: - { - "客观档案": { - "姓名": { - "内容": "张三", - "依据": "...", - "algorithm_updatable": true - } - }, - "人格状态": { - "温柔程度": { - "当前描述": "语气柔和", - "变化趋势": "上升", - "依据": "...", - "algorithm_updatable": true - } - } - } - """ - - metadata: dict[str, Any] = Field( - ..., - description="Profile template structure with categories and fields", - ) - - -class CreateProfileTemplateResponse(BaseResponse): - """Response for CreateProfileTemplate. data contains profile_config_id.""" - - data: dict[str, Any] | None = Field( - None, description="Response data containing profile_config_id" - ) - - -# ============================================================================= -# BindProfile -# ============================================================================= - - -class BindProfileRequest(BaseModel): - """ - Request for binding user/agent IDs to profile templates. - - Currently one ID can only bind to one template. - If the same ID is already bound to the same template, it's idempotent. - If bound to a different template, return an error. - """ - - id_profile_map: list[IdProfileMapping] = Field( - ..., - description="List of {id, profile_config_id} pairs to bind", - min_length=1, - ) - - -class BindProfileResponse(BaseResponse): - """Response for BindProfile. data contains profile_instance_id list.""" - - data: dict[str, Any] | None = Field( - None, - description="Response data containing profile_instance_id list", - ) - - -# ============================================================================= -# EditProfile -# ============================================================================= - - -class EditProfileRequest(BaseModel): - """ - Request for editProfileConfig (§3.3). - - Rules (from final spec): - - If already bound: update the specified fields. - - If not yet bound: auto-bind to the given template, then apply updates. - - If already bound to a *different* template: return 400. - - Update field values regardless of whether they had values before. - - Check fields with default values: if empty, fill with defaults. - - If a field is not in the template tree: add it to this user's instance only - (does not affect the template or other users). - - Fields marked with algorithm_updatable=false are locked from future - algorithm extraction overwrites. - - metadata example: - { - "姓名": {"内容": "张三", "algorithm_updatable": false}, - "性别": {"内容": "女"}, - "Smauel称呼我": {"内容": "小张", "algorithm_updatable": false} - } - """ - - id: str = Field(..., description="User ID or Agent ID whose profile to edit") - profile_config_id: str = Field( - DEFAULT_PROFILE_CONFIG_ID, - description="Profile template ID. Defaults to default_config.", - ) - metadata: dict[str, Any] | None = Field( - None, - description=( - "Fields to update. Keys are field names, values are field data dicts " - "(e.g. {'内容': '张三', 'algorithm_updatable': false}). " - "Fields with algorithm_updatable=false are locked from algorithm overwrites." - ), - ) - - -class EditProfileResponse(BaseResponse): - """Response for editProfileConfig. data contains profile_instance_id.""" - - data: dict[str, Any] | None = Field( - None, - description="Response data containing profile_instance_id", - ) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 748262097..af6ae4fe5 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -22,7 +22,6 @@ from memos.api.handlers.base_handler import HandlerDependencies from memos.api.handlers.chat_handler import ChatHandler from memos.api.handlers.feedback_handler import FeedbackHandler -from memos.api.handlers.profile_handler import ProfileHandler from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( AllStatusResponse, @@ -54,14 +53,6 @@ SuggestionResponse, TaskQueueResponse, ) -from memos.api.profile_models import ( - BindProfileRequest, - BindProfileResponse, - CreateProfileTemplateRequest, - CreateProfileTemplateResponse, - EditProfileRequest, - EditProfileResponse, -) from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -105,9 +96,6 @@ graph_db = components["graph_db"] vector_db = components["vector_db"] -# Initialize profile handler (standalone, no HandlerDependencies needed) -profile_handler = ProfileHandler() - # ============================================================================= # Search API Endpoints @@ -370,41 +358,6 @@ def feedback_memories(feedback_req: APIFeedbackRequest): return feedback_handler.handle_feedback_memories(feedback_req) -# ============================================================================= -# Profile (Attribute Tree) API Endpoints — Mock -# ============================================================================= - - -@router.post( - "/create_profile_template", - summary="Create a profile template", - response_model=CreateProfileTemplateResponse, -) -async def create_profile_template(req: CreateProfileTemplateRequest): - """Create a new profile (attribute tree) template.""" - return await profile_handler.create_template(req) - - -@router.post( - "/bind_profile_config", - summary="Bind user/agent to a profile template", - response_model=BindProfileResponse, -) -async def bind_profile(req: BindProfileRequest): - """Bind user/agent IDs to profile templates (one ID → one template currently).""" - return await profile_handler.bind_profile(req) - - -@router.post( - "/edit_profile_config", - summary="Edit profile values", - response_model=EditProfileResponse, -) -async def edit_profile(req: EditProfileRequest): - """Edit profile values for a bound user/agent. Auto-binds if not yet bound.""" - return await profile_handler.edit_profile(req) - - # ============================================================================= # Other API Endpoints (for internal use) # ============================================================================= diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index cc269e8c4..9dcbe8c56 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -524,7 +524,7 @@ def _retrieve_from_keyword( user_name=user_name, tsquery_config="jiebaqry", ) - except Exception: + except Exception as e: logger.warning( f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}" ) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 78204a410..2a6991e8c 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -85,35 +85,12 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: f"Processing add with mode={sync_mode}, session={target_session_id}" ) - allowed_views = set(add_req.allow_memory_view or []) - allow_all_views = add_req.allow_memory_view is None - allow_text_mem = allow_all_views or "detail_factual" in allowed_views - allow_pref_mem = allow_all_views or "preference" in allowed_views - - if not allow_text_mem and not allow_pref_mem: - self.logger.info( - "[SingleCubeView] cube=%s Skip add: allow_memory_view excludes " - "both detail_factual and preference.", - self.cube_id, - ) - return [] + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) + pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) - if allow_text_mem and allow_pref_mem: - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit( - self._process_text_mem, add_req, user_context, sync_mode - ) - pref_future = executor.submit( - self._process_pref_mem, add_req, user_context, sync_mode - ) - text_results = text_future.result() - pref_results = pref_future.result() - elif allow_text_mem: - text_results = self._process_text_mem(add_req, user_context, sync_mode) - pref_results = [] - else: - text_results = [] - pref_results = self._process_pref_mem(add_req, user_context, sync_mode) + text_results = text_future.result() + pref_results = pref_future.result() self.logger.info( f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 580d447ff..8234caf8b 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -53,14 +53,9 @@ class MessageDict(TypedDict, total=False): role: MessageRole content: str - chat_time: str | None # Optional timestamp; any vague or precise time string. - message_id: str | None # Unique message ID. Used for cascade-deletion of profile - # memory fields via source_message_ids. Corresponds to messages[].id in the - # cloud-service (product) API. - role_id: str | None # Speaker identity in multi-user / multi-agent (group-chat) - # scenarios. Must match an entry in the request-level user_id list. - role_name: str | None # Display name of the speaker. Optional; when provided, - # must be accompanied by role_id. + chat_time: str | None # Optional timestamp for the message, format is not + # restricted, it can be any vague or precise time string. + message_id: str | None # Optional unique identifier for the message RawMessageDict: TypeAlias = ChatCompletionContentPartTextParam | File diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py index d975164e4..f28796c2d 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -53,11 +53,4 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False): restricted, it can be any vague or precise time string.""" message_id: str | None - """Optional unique identifier for the message. Corresponds to messages[].id in the - cloud-service API; used for cascade-deletion of profile memory fields.""" - - role_id: str | None - """Speaker's user_id or agent_id in multi-user / multi-agent (group-chat) scenarios.""" - - role_name: str | None - """Display name of the speaker. Optional; when provided must be accompanied by role_id.""" + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py index 8d9d83de7..13a9a89af 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -33,11 +33,4 @@ class ChatCompletionSystemMessageParam(TypedDict, total=False): restricted, it can be any vague or precise time string.""" message_id: str | None - """Optional unique identifier for the message. Corresponds to messages[].id in the - cloud-service API; used for cascade-deletion of profile memory fields.""" - - role_id: str | None - """Speaker's user_id or agent_id in multi-user / multi-agent (group-chat) scenarios.""" - - role_name: str | None - """Display name of the speaker. Optional; when provided must be accompanied by role_id.""" + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py index 482684af1..b5bee9842 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -31,12 +31,4 @@ class ChatCompletionUserMessageParam(TypedDict, total=False): restricted, it can be any vague or precise time string.""" message_id: str | None - """Optional unique identifier for the message. Corresponds to messages[].id in the - cloud-service API; used for cascade-deletion of profile memory fields.""" - - role_id: str | None - """Speaker's user_id or agent_id in multi-user / multi-agent (group-chat) scenarios. - Must match an entry in the request-level user_id list.""" - - role_name: str | None - """Display name of the speaker. Optional; when provided must be accompanied by role_id.""" + """Optional unique identifier for the message""" From 0a1ec108055ea5c29fa25bd1213a49b4d7d0d69c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Fri, 27 Feb 2026 10:18:35 +0800 Subject: [PATCH 49/92] fix: user_name (#1131) * fix: user_name * fix: user_name * feat: delete some logs * fix: test_history_manager # Conflicts: # src/memos/memories/textual/tree_text_memory/organize/history_manager.py # tests/memories/textual/test_history_manager.py --- src/memos/embedders/universal_api.py | 1 - src/memos/memories/textual/tree.py | 14 +- .../tree_text_memory/organize/handler.py | 64 +- .../organize/history_manager.py | 820 +----------------- .../tree_text_memory/organize/manager.py | 41 +- .../tree_text_memory/organize/reorganizer.py | 77 +- .../memories/textual/test_history_manager.py | 646 +------------- 7 files changed, 169 insertions(+), 1494 deletions(-) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 538d913ea..2b3bd0967 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -73,7 +73,6 @@ async def _create_embeddings(): ) ) logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds") - logger.info(f"Embeddings request response: {response}") return [r.embedding for r in response.data] except Exception as e: if self.use_backup_client: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5faf8aa09..5b210ba61 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -404,10 +404,10 @@ def delete_by_memory_ids(self, memory_ids: list[str]) -> None: except Exception as e: logger.error(f"An error occurred while deleting memories by memory_ids: {e}") - def delete_all(self) -> None: + def delete_all(self, user_name: str | None = None) -> None: """Delete all memories and their relationships from the graph store.""" try: - self.graph_store.clear() + self.graph_store.clear(user_name=user_name) logger.info("All memories and edges have been deleted from the graph.") except Exception as e: logger.error(f"An error occurred while deleting all memories: {e}") @@ -424,7 +424,7 @@ def delete_by_filter( writable_cube_ids=writable_cube_ids, file_ids=file_ids, filter=filter ) - def load(self, dir: str) -> None: + def load(self, dir: str, user_name: str | None = None) -> None: try: memory_file = os.path.join(dir, self.config.memory_filename) @@ -435,7 +435,7 @@ def load(self, dir: str) -> None: with open(memory_file, encoding="utf-8") as f: memories = json.load(f) - self.graph_store.import_graph(memories) + self.graph_store.import_graph(memories, user_name=user_name) logger.info(f"Loaded {len(memories)} memories from {memory_file}") except FileNotFoundError: @@ -445,10 +445,12 @@ def load(self, dir: str) -> None: except Exception as e: logger.error(f"An error occurred while loading memories: {e}") - def dump(self, dir: str, include_embedding: bool = False) -> None: + def dump(self, dir: str, include_embedding: bool = False, user_name: str | None = None) -> None: """Dump memories to os.path.join(dir, self.config.memory_filename)""" try: - json_memories = self.graph_store.export_graph(include_embedding=include_embedding) + json_memories = self.graph_store.export_graph( + include_embedding=include_embedding, user_name=user_name + ) os.makedirs(dir, exist_ok=True) memory_file = os.path.join(dir, self.config.memory_filename) diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index 595cf099c..2d776912b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -27,18 +27,24 @@ def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedd self.llm = llm self.embedder = embedder - def detect(self, memory, top_k: int = 5, scope=None): + def detect(self, memory, top_k: int = 5, scope=None, user_name: str | None = None): # 1. Search for similar memories based on embedding embedding = memory.metadata.embedding embedding_candidates_info = self.graph_store.search_by_embedding( - embedding, top_k=top_k, scope=scope, threshold=self.EMBEDDING_THRESHOLD + embedding, + top_k=top_k, + scope=scope, + threshold=self.EMBEDDING_THRESHOLD, + user_name=user_name, ) # 2. Filter based on similarity threshold embedding_candidates_ids = [ info["id"] for info in embedding_candidates_info if info["id"] != memory.id ] # 3. Judge conflicts using LLM - embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids) + embedding_candidates = self.graph_store.get_nodes( + embedding_candidates_ids, user_name=user_name + ) detected_relationships = [] for embedding_candidate in embedding_candidates: embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate) @@ -67,13 +73,20 @@ def detect(self, memory, top_k: int = 5, scope=None): pass return detected_relationships - def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, relation) -> None: + def resolve( + self, + memory_a: TextualMemoryItem, + memory_b: TextualMemoryItem, + relation, + user_name: str | None = None, + ) -> None: """ Resolve detected conflicts between two memory items using LLM fusion. Args: memory_a: The first conflicting memory item. memory_b: The second conflicting memory item. relation: relation + user_name: Optional user name for multi-tenant isolation. Returns: A fused TextualMemoryItem representing the resolved memory. """ @@ -105,17 +118,22 @@ def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, rela logger.warning( f"{relation} between {memory_a.id} and {memory_b.id} could not be resolved. " ) - self._hard_update(memory_a, memory_b) + self._hard_update(memory_a, memory_b, user_name=user_name) # —————— 2.2 Conflict resolved, update metadata and memory ———— else: fixed_metadata = self._merge_metadata(answer, memory_a.metadata, memory_b.metadata) merged_memory = TextualMemoryItem(memory=answer, metadata=fixed_metadata) logger.info(f"Resolved result: {merged_memory}") - self._resolve_in_graph(memory_a, memory_b, merged_memory) + self._resolve_in_graph(memory_a, memory_b, merged_memory, user_name=user_name) except json.decoder.JSONDecodeError: logger.error(f"Failed to parse LLM response: {response}") - def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem): + def _hard_update( + self, + memory_a: TextualMemoryItem, + memory_b: TextualMemoryItem, + user_name: str | None = None, + ): """ Hard update: compare updated_at, keep the newer one, overwrite the older one's metadata. """ @@ -125,7 +143,7 @@ def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) newer_mem = memory_a if time_a >= time_b else memory_b older_mem = memory_b if time_a >= time_b else memory_a - self.graph_store.delete_node(older_mem.id) + self.graph_store.delete_node(older_mem.id, user_name=user_name) logger.warning( f"Delete older memory {older_mem.id}: <{older_mem.memory}> due to conflict with {newer_mem.id}: <{newer_mem.memory}>" ) @@ -135,13 +153,21 @@ def _resolve_in_graph( conflict_a: TextualMemoryItem, conflict_b: TextualMemoryItem, merged: TextualMemoryItem, + user_name: str | None = None, ): - edges_a = self.graph_store.get_edges(conflict_a.id, type="ANY", direction="ANY") - edges_b = self.graph_store.get_edges(conflict_b.id, type="ANY", direction="ANY") + edges_a = self.graph_store.get_edges( + conflict_a.id, type="ANY", direction="ANY", user_name=user_name + ) + edges_b = self.graph_store.get_edges( + conflict_b.id, type="ANY", direction="ANY", user_name=user_name + ) all_edges = edges_a + edges_b self.graph_store.add_node( - merged.id, merged.memory, merged.metadata.model_dump(exclude_none=True) + merged.id, + merged.memory, + merged.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for edge in all_edges: @@ -150,13 +176,15 @@ def _resolve_in_graph( if new_from == new_to: continue # Check if the edge already exists before adding - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) - - self.graph_store.update_node(conflict_a.id, {"status": "archived"}) - self.graph_store.update_node(conflict_b.id, {"status": "archived"}) - self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO") - self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO") + if not self.graph_store.edge_exists( + new_from, new_to, edge["type"], direction="ANY", user_name=user_name + ): + self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name) + + self.graph_store.update_node(conflict_a.id, {"status": "archived"}, user_name=user_name) + self.graph_store.update_node(conflict_b.id, {"status": "archived"}, user_name=user_name) + self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO", user_name=user_name) + self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO", user_name=user_name) logger.debug( f"Archive {conflict_a.id} and {conflict_b.id}, and inherit their edges to {merged.id}." ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 3e62ce56e..132582a0d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,26 +1,12 @@ import logging -import time -import uuid -from datetime import datetime -from typing import Any, Literal +from typing import Literal from memos.context.context import ContextThreadPoolExecutor -from memos.embedders.base import BaseEmbedder from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB -from memos.llms.base import BaseLLM -from memos.mem_reader.read_multi_modal.utils import detect_lang -from memos.memories.textual.item import ( - ArchivedTextualMemory, - TextualMemoryItem, - TreeNodeTextualMemoryMetadata, -) -from memos.templates.mem_reader_mem_version_prompts import ( - ASYNC_MEMORY_UPDATE_PROMPT_DICT, - MEMORY_MERGE_PROMPT_DICT, -) +from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem logger = logging.getLogger(__name__) @@ -81,155 +67,21 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: return -def _rebuild_fast_node_history( - item: TextualMemoryItem, replacements: dict[int, list[ArchivedTextualMemory]] -) -> None: - """ - Reconstruct the history list of a fast node: - 1. Replace resolved items with their evolved versions. - 2. Deduplicate by ID while preserving the newest versions. - """ - new_history = {} - - def _add(history_item): - item_id = history_item.archived_memory_id - current = new_history.get(item_id) - - if current is None or history_item.version > current.version: - new_history[item_id] = history_item - - # Apply replacements and filter superseded items - for i, h in enumerate(item.metadata.history): - if i in replacements: - # This item is resolved, insert its replacements - for replacement_item in replacements[i]: - _add(replacement_item) - else: - _add(h) - - item.metadata.history = list(new_history.values()) - - -def _sanitize_metadata_dict(data: dict[str, Any] | None) -> dict[str, Any]: - if not data: - return {} - sanitized = data.copy() - for key in ("id", "memory", "graph_id"): - sanitized.pop(key, None) - return sanitized - - -def _sanitize_metadata_model( - metadata: TreeNodeTextualMemoryMetadata, -) -> TreeNodeTextualMemoryMetadata: - data = _sanitize_metadata_dict(metadata.model_dump(exclude_none=True)) - return metadata.__class__(**data) - - -def _determine_lang(sources: list | None, fallback_text: str) -> str: - lang = None - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - if isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(fallback_text) - return lang - - class MemoryHistoryManager: - def __init__( - self, - nli_client: NLIClient, - graph_db: BaseGraphDB, - llm: BaseLLM | None = None, - embedder: BaseEmbedder | None = None, - ) -> None: + def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: """ Initialize the MemoryHistoryManager. Args: nli_client: NLIClient for conflict/duplicate detection. graph_db: GraphDB instance for marking operations during history management. - llm: Optional LLM instance for memory merging during conflicts. """ self.nli_client = nli_client self.graph_db = graph_db - self.llm = llm - self.embedder = embedder - - def _compute_embedding(self, text: str) -> list[float] | None: - if not self.embedder: - return None - try: - return self.embedder.embed([text])[0] - except Exception as e: - logger.error(f"[MemoryHistoryManager] Failed to compute embedding: {e}") - return None - - @staticmethod - def is_applicable(item: TextualMemoryItem) -> bool: - # Only deals with: - # 1. From doc or chat - # 2. LongTermMemory, UserMemory - allowed_sources = ["doc", "chat"] - allowed_memory_types = ["LongTermMemory", "UserMemory"] - return ( - item.metadata.sources[0].type in allowed_sources - and item.metadata.memory_type in allowed_memory_types - ) - - @staticmethod - def update_node_with_history( - item: TextualMemoryItem, - new_memory: str, - update_type: str, - tags: list[str] | None = None, - key: str | None = None, - ) -> tuple[TextualMemoryItem, TextualMemoryItem]: - now = datetime.now().isoformat() - last_update_time = item.metadata.updated_at - - old_id = item.id - archived_id = str(uuid.uuid4()) - # archived memory(need to store this node to the db later) - archived_item = item.model_copy(deep=True) - archived_item.id = archived_id - archived_item.metadata.evolve_to = [old_id] - archived_item.metadata.status = "archived" - archived_item.metadata.created_at = last_update_time - archived_item.metadata.updated_at = now - - # original memory with updated contents and history - archived_history = ArchivedTextualMemory( - version=item.metadata.version or 1, - is_fast=item.metadata.is_fast or False, - memory=item.memory, - update_type=update_type, - archived_memory_id=archived_id, - created_at=getattr(item.metadata, "updated_at", None) or last_update_time, - ) - item.memory = new_memory - item.metadata.version = (item.metadata.version or 1) + 1 - item.metadata.status = "activated" - item.metadata.updated_at = now - if tags is not None: - item.metadata.tags = tags - if key is not None: - item.metadata.key = key - if item.metadata.history is None: - item.metadata.history = [] - item.metadata.history.append(archived_history) - - return item, archived_item def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[str]: + ) -> list[TextualMemoryItem]: """ Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, and attach them as history to the new fast item. @@ -239,7 +91,7 @@ def resolve_history_via_nli( related_items: Existing memory items that might be related. Returns: - List of duplicate or conflicting memory ids judged by the NLI service. + List of duplicate or conflicting memory items judged by the NLI service. """ if not related_items: return [] @@ -250,19 +102,15 @@ def resolve_history_via_nli( ) # 2. Process results and attach to history - duplicate_memory_ids = [] - conflict_memory_ids = [] duplicate_memories = [] conflict_memories = [] for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" - duplicate_memory_ids.append(r_item.id) duplicate_memories.append(r_item.memory) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" - conflict_memory_ids.append(r_item.id) conflict_memories.append(r_item.memory) else: update_type = "unrelated" @@ -270,256 +118,45 @@ def resolve_history_via_nli( # Safely get created_at, fallback to updated_at created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at - # TODO: change the way of marking fast nodes by directly using is_fast field. archived = ArchivedTextualMemory( version=r_item.metadata.version or 1, - is_fast=( - r_item.metadata.is_fast - or ("mode:fast" in (getattr(r_item.metadata, "tags", None) or [])) - ), + is_fast=r_item.metadata.is_fast or False, memory=r_item.memory, update_type=update_type, archived_memory_id=r_item.id, created_at=created_at, ) new_item.metadata.history.append(archived) + logger.info( + f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" + ) # 3. Concat duplicate/conflict memories to new_item.memory # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. _append_related_content(new_item, duplicate_memories, conflict_memories) - return duplicate_memory_ids + conflict_memory_ids - - def wait_and_update_fast_history( - self, item: TextualMemoryItem, user_name: str, timeout_sec: int = 30 - ) -> None: - """ - Scan the item's history. If any history item is marked as `is_fast`, - wait for it to be resolved (i.e., status becomes 'deleted' in the DB). - When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. - Finally, deduplicate the history. - - Args: - item: The memory item containing the history to check. - user_name: Required for db query. - timeout_sec: Maximum time to wait for resolution in seconds. - """ - start_time = time.time() - - # 1. Identify pending items (fast nodes) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - while True: - if not pending_indices: - # All fast nodes resolved or none existed - break - - if time.time() - start_time > timeout_sec: - logger.warning( - f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" - ) - # Remove pending fast nodes from history - item.metadata.history = [ - h - for h in item.metadata.history - if not (getattr(h, "is_fast", False) and h.archived_memory_id) - ] - break - - # 2. Check status of the fast nodes and fetch replacements for evolved ones - replacements = self._check_and_fetch_replacements(item, pending_indices) - - # 3. If we have any resolved items, rebuild the history - if replacements: - _rebuild_fast_node_history(item, replacements) - - # Check if we are done (no pending items left) - pending_indices = [ - i - for i, h in enumerate(item.metadata.history) - if getattr(h, "is_fast", False) and h.archived_memory_id - ] - - if pending_indices: - time.sleep(1) # This avoids visiting the DB too frequently - - return - - def format_async_update_prompt( - self, item: TextualMemoryItem, custom_tags_prompt: str = "" - ) -> str: - """ - Format the prompt for asynchronous memory update. - - Args: - item: The TextualMemoryItem containing history candidates. - custom_tags_prompt: Optional custom prompt for tags. - - Returns: - Formatted prompt string. - """ - # First, detach duplicate and conflict memory contents from the new item's memory text - _detach_related_content(item) - - duplicate_candidates = [] - conflict_candidates = [] - unrelated_candidates = [] - - def _fmt_time(ts: str | None) -> str | None: - if not ts or not isinstance(ts, str): - return None - try: - t = datetime.fromisoformat(ts.replace("Z", "")) - return t.strftime("%Y/%m/%d %H:%M:%S") - except Exception: - return ts - - for h in item.metadata.history or []: - created = getattr(h, "created_at", None) - tstr = _fmt_time(created) - time_suffix = f"[Time: {tstr}] " if tstr else "" - candidate_str = f"[ID:{h.archived_memory_id}]{time_suffix}{h.memory}" - - if h.update_type == "duplicate": - duplicate_candidates.append(candidate_str) - elif h.update_type == "conflict": - conflict_candidates.append(candidate_str) - else: - # Includes "unrelated" and any other types - unrelated_candidates.append(candidate_str) - - sources = item.metadata.sources if item.metadata else None - lang = _determine_lang(sources, item.memory) - empty_label = "None" - - def format_list(candidates): - return "\n".join(candidates) if candidates else empty_label - - prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( - lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] - ) - conversation_time_raw = getattr(item.metadata, "created_at", None) - conversation_time = _fmt_time(conversation_time_raw) or conversation_time_raw - - return ( - prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) - .replace("${conflict_candidates}", format_list(conflict_candidates)) - .replace("${unrelated_candidates}", format_list(unrelated_candidates)) - .replace("${custom_tags_prompt}", custom_tags_prompt) - .replace("${conversation_time}", conversation_time) - .replace("${conversation}", item.memory) - ) - - def apply_llm_memory_updates( - self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str - ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: - """ - Apply the updates from the LLM response to the memory graph. - - Args: - llm_response: The parsed JSON response from the LLM. - source_item: The original fast item A whose history contains ArchivedTextualMemory entries. - We derive expected versions and candidate IDs from A.history. - user_name: user_name - - Returns: - List of new or updated memory items. - """ - memory_list = llm_response.get("memory list", []) - restored_memories = llm_response.get("restored_memories", []) - used_source_ids = set() - used_conflict_ids = set() - # Collect IDs referenced by the LLM response - for mem_data in memory_list: - used_source_ids.update(mem_data.get("source_candidate_ids", [])) - used_conflict_ids.update(mem_data.get("conflicted_candidate_ids", [])) - # Collect IDs referenced by restored memories - restored_source_ids = { - mem.get("source_candidate_id") - for mem in restored_memories - if mem.get("source_candidate_id") - } - # All IDs used by the LLM response - used_ids = used_source_ids | used_conflict_ids | restored_source_ids - expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories - candidate_id_set: set[str] = set() - # Recover candidate IDs and their expected versions from the source item's history - if source_item.metadata and source_item.metadata.history: - for h in source_item.metadata.history: - if h.archived_memory_id: - candidate_id_set.add(h.archived_memory_id) - expected_versions[h.archived_memory_id] = h.version - - # IDs not used by the LLM response will be treated as unrelated - unrelated_ids = sorted(candidate_id_set - used_ids) - updated_items: list[TextualMemoryItem] = [] - new_items: list[TextualMemoryItem] = [] - - # 1. Handle Unrelated Candidates - Restore status to activated - self._handle_unrelated_candidates(unrelated_ids, user_name=user_name) - - # 2. Handle Memory List (Update or New) - processed_updates, created_items = self._process_memory_updates( - memory_list, expected_versions, user_name, source_item - ) - updated_items.extend(processed_updates) - new_items.extend(created_items) - - # 3. Handle Restored Memories (Extract from conflict) - new_items.extend(self._handle_restored_memories(restored_memories, source_item, user_name)) - - return updated_items, new_items - - def build_fallback_new_items( - self, item: TextualMemoryItem, user_name: str | None = None - ) -> list[TextualMemoryItem]: - latest_item = item.model_copy(deep=True) - _detach_related_content(latest_item) - - history = latest_item.metadata.history or [] - archived_ids = [h.archived_memory_id for h in history if h.archived_memory_id] - if archived_ids: - self.mark_memory_status(archived_ids, "activated", user_name or "") - - latest_item.id = str(uuid.uuid4()) - latest_item.metadata.is_fast = False - latest_item.metadata.status = "activated" - latest_item.metadata.history = [] - latest_item.metadata.working_binding = None - if hasattr(latest_item.metadata, "background"): - latest_item.metadata.background = "" - - if hasattr(latest_item.metadata, "tags") and latest_item.metadata.tags: - latest_item.metadata.tags = [t for t in latest_item.metadata.tags if t != "mode:fast"] - - latest_item.metadata = _sanitize_metadata_model(latest_item.metadata) - - return [latest_item] + return duplicate_memories + conflict_memories def mark_memory_status( self, - memory_ids: list[str], + memory_items: list[TextualMemoryItem], status: Literal["activated", "resolving", "archived", "deleted"], - user_name: str, + user_name: str | None = None, ) -> None: """ Support status marking operations during history management. Common usages are: 1. Mark conflict/duplicate old memories' status as "resolving", to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to recover their visibility. + 2. Mark resolved memories' status as "activated", to restore their visibility. """ # Execute the actual marking operation - in db. with ContextThreadPoolExecutor() as executor: futures = [] - for mid in memory_ids: + for mem in memory_items: futures.append( executor.submit( self.graph_db.update_node, - id=mid, + id=mem.id, fields={"status": status}, user_name=user_name, ) @@ -529,430 +166,3 @@ def mark_memory_status( for future in futures: future.result() return - - def _check_and_fetch_replacements( - self, item: TextualMemoryItem, pending_indices: list[int], user_name: str - ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: - """ - Check DB status for pending items. If 'deleted', fetch evolved nodes. - - Returns: - replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. - """ - pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] - - # Batch fetch pending nodes to check status - nodes_data = self.graph_db.get_nodes(ids=pending_ids, user_name=user_name) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} - - replacements = {} - - for i in pending_indices: - h_item = item.metadata.history[i] - node_data = nodes_map.get(h_item.archived_memory_id) - - if not node_data: - continue - - metadata = node_data.get("metadata", {}) - status = metadata.get("status") - - # Condition: Fast node is processed when it is marked as 'deleted' - if status == "deleted": - evolve_to_ids = metadata.get("evolve_to", []) - - new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type, user_name) - replacements[i] = new_items - - return replacements - - def _fetch_evolved_nodes( - self, evolve_to_ids: list[str], update_type: str, user_name: str - ) -> list[ArchivedTextualMemory]: - """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" - if not evolve_to_ids: - return [] - - evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids, user_name=user_name) or [] - results = [] - - for enode in evolved_nodes: - if not enode or "id" not in enode: - continue - - enode_meta = enode.get("metadata", {}) - - # Create new archived memory inheriting the update_type (conflict/duplicate) - new_archived = ArchivedTextualMemory( - version=enode_meta.get("version", 1), - is_fast=enode_meta.get("is_fast", False), - memory=enode.get("memory", ""), - update_type=update_type, - archived_memory_id=enode.get("id"), - created_at=enode_meta.get("created_at"), - ) - results.append(new_archived) - - return results - - def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) -> None: - """Handle Unrelated Candidates - Restore status to `activated`.""" - if unrelated_ids: - self.mark_memory_status(unrelated_ids, "activated", user_name) - - def _process_memory_updates( - self, - memory_list: list[dict[str, Any]], - expected_versions: dict[str, int], - user_name: str, - source_item: TextualMemoryItem, - ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: - """Process Memory List (Update or New).""" - updated_items: list[TextualMemoryItem] = [] - new_items: list[TextualMemoryItem] = [] - for mem_data in memory_list: - source_ids = mem_data.get("source_candidate_ids", []) - conflict_ids = mem_data.get("conflicted_candidate_ids", []) - - # Determine if this is an update or a new node - target_ids = source_ids + conflict_ids - - if target_ids: - updated_item, new_item = self._update_existing_memory( - mem_data, target_ids, source_ids, expected_versions, user_name, source_item - ) - if updated_item: - updated_items.append(updated_item) - if new_item: - new_items.append(new_item) - else: - item = self._create_new_memory(mem_data, source_item) - new_items.append(item) - return updated_items, new_items - - def _update_existing_memory( - self, - mem_data: dict[str, Any], - target_ids: list[str], - source_ids: list[str], - expected_versions: dict[str, int], - user_name: str, - fast_item: TextualMemoryItem, - ) -> tuple[TextualMemoryItem | None, TextualMemoryItem | None]: - """ - Update existing memory nodes using the LLM result. - - The first ID in target_ids is treated as the primary node. If additional target IDs - are provided, they are treated as secondary candidates and will be merged into the - primary. Merging means: - 1) Mark secondary nodes as archived and append the primary ID to evolve_to - 2) Merge their history entries into the primary history and re-order by created_at - - The method also applies CAS validation via expected_versions, archives the previous - version of the primary node, and persists the updated node back to the graph DB. - - Returns the updated primary TextualMemoryItem and optional new item when fallback is used. - """ - original_primary_id, primary_id, secondary_ids = ( - target_ids[0], - target_ids[0], - target_ids[1:], - ) - new_value, tags, key = ( - mem_data.get("value", ""), - mem_data.get("tags", []), - mem_data.get("key", ""), - ) - new_value_item = TextualMemoryItem( - memory=new_value, metadata=TreeNodeTextualMemoryMetadata() - ) - _detach_related_content(new_value_item) - new_value = new_value_item.memory - - # Fetch candidate nodes in batch and then select the primary - # We update the primary and then merge the secondaries to the primary - nodes_data = self.graph_db.get_nodes(target_ids, user_name=user_name) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} - node_data = nodes_map.get(primary_id) - if not node_data: - logger.warning( - f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." - ) - # Fallback to create new item when the source_id is hallucinated by llm - new_item = self._create_new_memory(mem_data, fast_item) - return None, new_item - current_item = TextualMemoryItem(**node_data) - - # For concurrency control, need to make sure the primary item has not been modified by others in the meantime - # If it has(version changed), then we need to use llm to merge again. - new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) - - update_type = "duplicate" if original_primary_id in source_ids else "conflict" - current_item, archived_item = self.update_node_with_history( - current_item, - new_value, - update_type, - tags=tags, - key=key, - ) - - # create archived node for storing older versions of the memory, preserving the embedding - emb = TextualMemoryItem( - **self.graph_db.get_node(primary_id, include_embedding=True) - ).metadata.embedding - arch_meta = _sanitize_metadata_dict(archived_item.metadata.model_dump(exclude_none=True)) - arch_meta["embedding"] = emb - self.graph_db.add_node( - id=archived_item.id, - memory=archived_item.memory, - metadata=arch_meta, - user_name=user_name, - ) - - fields = _sanitize_metadata_dict(current_item.metadata.model_dump(exclude_none=True)) - merged_history = list(current_item.metadata.history or []) - new_primary_version = current_item.metadata.version or 1 - # Multiple related ids indicates existing duplicates/conflicts to be merged - if secondary_ids: - merged_history, new_primary_version = self._merge_secondary_nodes( - secondary_ids, primary_id, nodes_map, user_name, merged_history - ) - current_item.metadata.history = merged_history - current_item.metadata.version = new_primary_version - merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] - embedding = self._compute_embedding(current_item.memory) - sources = [s.model_dump(exclude_none=True) for s in (fast_item.metadata.sources or [])] - # update old memory node with new content and updated history - self.graph_db.update_node( - id=primary_id, - fields={ - **fields, - "memory": current_item.memory, - "history": merged_history_dump, - "version": new_primary_version, - "embedding": embedding, - "sources": sources, - "session_id": fast_item.metadata.session_id, - }, - user_name=user_name, - ) - working_binding = getattr(current_item.metadata, "working_binding", None) - if working_binding and working_binding != current_item.id: - try: - self.mark_memory_status([str(working_binding)], "deleted", user_name=user_name) - except Exception as e: - logger.warning( - f"[MemoryHistoryManager] Failed to mark WorkingMemory {working_binding} as deleted: {e}" - ) - - return current_item, None - - def _apply_cas_merge( - self, - primary_id: str, - current_item: TextualMemoryItem, - expected_versions: dict[str, int], - new_value: str, - ) -> str: - expected_version = expected_versions.get(primary_id) - current_version = current_item.metadata.version or 1 - if expected_version is not None and current_version != expected_version: - logger.warning( - f"[MemoryHistoryManager] Version conflict for node {primary_id}: " - f"Expected v{expected_version}, but found v{current_version} in DB. " - "Triggering merge logic." - ) - latest_item = TextualMemoryItem( - memory=current_item.memory, metadata=TreeNodeTextualMemoryMetadata() - ) - _detach_related_content(latest_item) - merged_content = self._merge_conflicting_memory( - latest_memory=latest_item.memory, - proposed_update=new_value, - ) - return merged_content - - return new_value - - def _merge_secondary_nodes( - self, - secondary_ids: list[str], - primary_id: str, - nodes_map: dict, - user_name: str, - base_history: list[ArchivedTextualMemory], - ) -> tuple[list[ArchivedTextualMemory], int]: - merged_history = list(base_history) - - for memory_id in secondary_ids: - node_data = nodes_map.get(memory_id) - if not node_data: - continue - metadata = node_data.get("metadata", {}) - evolve_to = list(metadata.get("evolve_to", []) or []) - if primary_id not in evolve_to: - evolve_to.append(primary_id) - # set secondary nodes to archived and record their evolving destinations - self.graph_db.update_node( - id=memory_id, - fields={"status": "archived", "evolve_to": evolve_to}, - user_name=user_name, - ) - secondary_item = TextualMemoryItem(**node_data) - if secondary_item.metadata.history: - merged_history.extend(secondary_item.metadata.history) - - # Currently we just sort the versions according to their creation time - def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: - created_at = history_item.created_at - if isinstance(created_at, datetime): - return created_at - if created_at: - try: - return datetime.fromisoformat(created_at) - except ValueError: - return datetime.min - return datetime.min - - def _dedupe_history_by_archived_id( - history: list[ArchivedTextualMemory], - ) -> list[ArchivedTextualMemory]: - seen_archived_ids: set[str] = set() - deduped_history: list[ArchivedTextualMemory] = [] - for history_item in history: - archived_id = history_item.archived_memory_id - if archived_id and archived_id in seen_archived_ids: - continue - if archived_id: - seen_archived_ids.add(archived_id) - deduped_history.append(history_item) - return deduped_history - - merged_history.sort(key=_history_sort_key) - merged_history = _dedupe_history_by_archived_id(merged_history) - max_version = 0 - for idx, history_item in enumerate(merged_history, start=1): - history_item.version = idx - max_version = idx - return merged_history, max_version + 1 - - def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: - """ - Call LLM to merge proposed update with latest memory content. - """ - if not self.llm: - return proposed_update - - lang = _determine_lang(None, f"{latest_memory}\n{proposed_update}") - prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) - prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( - "${proposed_update}", proposed_update - ) - - messages = [{"role": "user", "content": prompt}] - try: - response = self.llm.generate(messages) - return response.strip() - except Exception as e: - logger.error(f"[MemoryHistoryManager] Failed to merge memory via LLM: {e}") - # Fallback: append proposed update? or just return proposed? - # Returning proposed might overwrite latest changes. - # Returning latest might lose proposed changes. - # Let's concatenate as a safe fallback. - return f"{latest_memory}\n\n[System Merge Fallback] New Info: {proposed_update}" - - def _create_new_memory( - self, mem_data: dict[str, Any], fast_item: TextualMemoryItem - ) -> TextualMemoryItem: - """Create New Node.""" - new_value = mem_data.get("value", "") - new_value_item = TextualMemoryItem( - memory=new_value, metadata=TreeNodeTextualMemoryMetadata() - ) - _detach_related_content(new_value_item) - new_value = new_value_item.memory - tags = mem_data.get("tags", []) - key = mem_data.get("key", "") - background = mem_data.get("summary", "") - memory_type = mem_data.get("memory_type", "LongTermMemory") - metadata_updates = { - "is_fast": False, - "version": 1, - "memory_type": memory_type, - "status": "activated", - "background": background, - "working_binding": None, - "tags": tags, - "key": key, - "created_at": datetime.now().isoformat(), - "history": [], - "embedding": self._compute_embedding(new_value), - } - metadata = fast_item.metadata.model_copy(deep=True) - for field_name, value in metadata_updates.items(): - setattr(metadata, field_name, value) - metadata = _sanitize_metadata_model(metadata) - - new_item = TextualMemoryItem( - id=str(uuid.uuid4()), - memory=new_value, - metadata=metadata, - ) - return new_item - - def _handle_restored_memories( - self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem, user_name: str - ) -> list[TextualMemoryItem]: - """Handle Restored Memories (Extract from conflict).""" - source_ids = [r.get("source_candidate_id") for r in restored_memories] - source_items = self.graph_db.get_nodes(source_ids, user_name=user_name) - source_items = [TextualMemoryItem(**i) for i in source_items] - - created_items = [] - for i, data in enumerate(restored_memories): - source_item = source_items[i] - # deal with history - source_history = source_item.history.copy() - value = data.get("value", "") - value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) - _detach_related_content(value_item) - value = value_item.memory - tags = data.get("tags", []) - keys = data.get("keys", []) - memory_type = data.get("memory_type", "LongTermMemory") - version = source_item.verion - new_history_item = ArchivedTextualMemory( - version=version, - is_fast=False, - memory=source_item.memory, - update_type="extract", - archived_memory_id=source_item.id, - created_at=source_item.metadata.created_at, - ) - source_history.append(new_history_item) # Re-use the history of the old node - # Create new node - metadata_updates = { - "memory_type": memory_type, - "status": "activated", - "is_fast": False, - "version": version + 1, - "tags": tags, - "key": keys, - "created_at": datetime.now().isoformat(), - "history": source_history, - "embedding": self._compute_embedding(value), - } - metadata = fast_item.metadata.model_copy(deep=True) - for field_name, value in metadata_updates.items(): - setattr(metadata, field_name, value) - metadata = _sanitize_metadata_model(metadata) - - new_item = TextualMemoryItem( - id=str(uuid.uuid4()), - memory=value, - metadata=metadata, - ) - - created_items.append(new_item) - - return created_items diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 9733553c0..6d6f2389b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -241,7 +241,9 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: - self.reorganizer.add_message(QueueMessage(op="add", after_node=graph_node_ids)) + self.reorganizer.add_message( + QueueMessage(op="add", after_node=graph_node_ids, user_name=user_name) + ) return added_ids @@ -418,16 +420,19 @@ def _add_to_graph_memory( QueueMessage( op="add", after_node=[node_id], + user_name=user_name, ) ) return node_id - def _inherit_edges(self, from_id: str, to_id: str) -> None: + def _inherit_edges(self, from_id: str, to_id: str, user_name: str | None = None) -> None: """ Migrate all non-lineage edges from `from_id` to `to_id`, and remove them from `from_id` after copying. """ - edges = self.graph_store.get_edges(from_id, type="ANY", direction="ANY") + edges = self.graph_store.get_edges( + from_id, type="ANY", direction="ANY", user_name=user_name + ) for edge in edges: if edge["type"] == "MERGED_TO": @@ -440,20 +445,29 @@ def _inherit_edges(self, from_id: str, to_id: str) -> None: continue # Add edge to merged node if it doesn't already exist - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) + if not self.graph_store.edge_exists( + new_from, new_to, edge["type"], direction="ANY", user_name=user_name + ): + self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name) # Remove original edge if it involved the archived node - self.graph_store.delete_edge(edge["from"], edge["to"], edge["type"]) + self.graph_store.delete_edge( + edge["from"], edge["to"], edge["type"], user_name=user_name + ) def _ensure_structure_path( - self, memory_type: str, metadata: TreeNodeTextualMemoryMetadata + self, + memory_type: str, + metadata: TreeNodeTextualMemoryMetadata, + user_name: str | None = None, ) -> str: """ Ensure structural path exists (ROOT → ... → final node), return last node ID. Args: - path: like ["hobby", "photography"] + memory_type: Memory type for the structure node. + metadata: Metadata containing key and other fields. + user_name: Optional user name for multi-tenant isolation. Returns: Final node ID of the structure path. @@ -463,7 +477,8 @@ def _ensure_structure_path( [ {"field": "memory", "op": "=", "value": metadata.key}, {"field": "memory_type", "op": "=", "value": memory_type}, - ] + ], + user_name=user_name, ) if existing: node_id = existing[0] # Use the first match @@ -486,14 +501,16 @@ def _ensure_structure_path( ), ) self.graph_store.add_node( - id=new_node.id, - memory=new_node.memory, - metadata=new_node.metadata.model_dump(exclude_none=True), + new_node.id, + new_node.memory, + new_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( op="add", after_node=[new_node.id], + user_name=user_name, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index ea06a7c60..b7fb6b1a0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -52,12 +52,14 @@ def __init__( before_edge: list[str] | list[GraphDBEdge] | None = None, after_node: list[str] | list[GraphDBNode] | None = None, after_edge: list[str] | list[GraphDBEdge] | None = None, + user_name: str | None = None, ): self.op = op self.before_node = before_node self.before_edge = before_edge self.after_node = after_node self.after_edge = after_edge + self.user_name = user_name def __str__(self) -> str: return f"QueueMessage(op={self.op}, before_node={self.before_node if self.before_node is None else len(self.before_node)}, after_node={self.after_node if self.after_node is None else len(self.after_node)})" @@ -191,11 +193,15 @@ def handle_add(self, message: QueueMessage): logger.debug(f"Handling add operation: {str(message)[:500]}") added_node = message.after_node[0] detected_relationships = self.resolver.detect( - added_node, scope=added_node.metadata.memory_type + added_node, + scope=added_node.metadata.memory_type, + user_name=message.user_name, ) if detected_relationships: for added_node, existing_node, relation in detected_relationships: - self.resolver.resolve(added_node, existing_node, relation) + self.resolver.resolve( + added_node, existing_node, relation, user_name=message.user_name + ) self._reorganize_needed = True @@ -209,6 +215,7 @@ def optimize_structure( min_cluster_size: int = 4, min_group_size: int = 20, max_duration_sec: int = 600, + user_name: str | None = None, ): """ Periodically reorganize the graph: @@ -232,7 +239,7 @@ def _check_deadline(where: str): logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.") return - if self.graph_store.node_not_exist(scope): + if self.graph_store.node_not_exist(scope, user_name=user_name): logger.debug(f"[GraphStructureReorganize] No nodes for scope={scope}. Skip.") return @@ -244,12 +251,14 @@ def _check_deadline(where: str): logger.debug( f"[GraphStructureReorganize] Num of scope in self.graph_store is" - f" {self.graph_store.get_memory_count(scope)}" + f" {self.graph_store.get_memory_count(scope, user_name=user_name)}" ) # Load candidate nodes if _check_deadline("[GraphStructureReorganize] Before loading candidates"): return - raw_nodes = self.graph_store.get_structure_optimization_candidates(scope) + raw_nodes = self.graph_store.get_structure_optimization_candidates( + scope, user_name=user_name + ) nodes = [GraphDBNode(**n) for n in raw_nodes] if not nodes: @@ -281,6 +290,7 @@ def _check_deadline(where: str): scope, local_tree_threshold, min_cluster_size, + user_name, ) ) @@ -307,6 +317,7 @@ def _process_cluster_and_write( scope: str, local_tree_threshold: int, min_cluster_size: int, + user_name: str | None = None, ): if len(cluster_nodes) <= min_cluster_size: return @@ -319,15 +330,17 @@ def _process_cluster_and_write( if len(sub_nodes) < min_cluster_size: continue # Skip tiny noise sub_parent_node = self._summarize_cluster(sub_nodes, scope) - self._create_parent_node(sub_parent_node) - self._link_cluster_nodes(sub_parent_node, sub_nodes) + self._create_parent_node(sub_parent_node, user_name=user_name) + self._link_cluster_nodes(sub_parent_node, sub_nodes, user_name=user_name) sub_parents.append(sub_parent_node) if sub_parents and len(sub_parents) >= min_cluster_size: cluster_parent_node = self._summarize_cluster(cluster_nodes, scope) - self._create_parent_node(cluster_parent_node) + self._create_parent_node(cluster_parent_node, user_name=user_name) for sub_parent in sub_parents: - self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT") + self.graph_store.add_edge( + cluster_parent_node.id, sub_parent.id, "PARENT", user_name=user_name + ) logger.info("Adding relations/reasons") nodes_to_check = cluster_nodes @@ -351,10 +364,16 @@ def _process_cluster_and_write( # 1) Add pairwise relations for rel in results["relations"]: if not self.graph_store.edge_exists( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ): self.graph_store.add_edge( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ) # 2) Add inferred nodes and link to sources @@ -363,14 +382,21 @@ def _process_cluster_and_write( inf_node.id, inf_node.memory, inf_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for src_id in inf_node.metadata.sources: - self.graph_store.add_edge(src_id, inf_node.id, "INFERS") + self.graph_store.add_edge( + src_id, inf_node.id, "INFERS", user_name=user_name + ) # 3) Add sequence links for seq in results["sequence_links"]: - if not self.graph_store.edge_exists(seq["from_id"], seq["to_id"], "FOLLOWS"): - self.graph_store.add_edge(seq["from_id"], seq["to_id"], "FOLLOWS") + if not self.graph_store.edge_exists( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ): + self.graph_store.add_edge( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ) # 4) Add aggregate concept nodes for agg_node in results["aggregate_nodes"]: @@ -378,9 +404,12 @@ def _process_cluster_and_write( agg_node.id, agg_node.memory, agg_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for child_id in agg_node.metadata.sources: - self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO") + self.graph_store.add_edge( + agg_node.id, child_id, "AGGREGATE_TO", user_name=user_name + ) logger.info("[Reorganizer] Cluster relation/reasoning done.") @@ -577,7 +606,7 @@ def _parse_json_result(self, response_text): ) return {} - def _create_parent_node(self, parent_node: GraphDBNode) -> None: + def _create_parent_node(self, parent_node: GraphDBNode, user_name: str | None = None) -> None: """ Create a new parent node for the cluster. """ @@ -585,17 +614,23 @@ def _create_parent_node(self, parent_node: GraphDBNode) -> None: parent_node.id, parent_node.memory, parent_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) - def _link_cluster_nodes(self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode]): + def _link_cluster_nodes( + self, + parent_node: GraphDBNode, + child_nodes: list[GraphDBNode], + user_name: str | None = None, + ): """ Add PARENT edges from the parent node to all nodes in the cluster. """ for child in child_nodes: if not self.graph_store.edge_exists( - parent_node.id, child.id, "PARENT", direction="OUTGOING" + parent_node.id, child.id, "PARENT", direction="OUTGOING", user_name=user_name ): - self.graph_store.add_edge(parent_node.id, child.id, "PARENT") + self.graph_store.add_edge(parent_node.id, child.id, "PARENT", user_name=user_name) def _preprocess_message(self, message: QueueMessage) -> bool: message = self._convert_id_to_node(message) @@ -613,7 +648,9 @@ def _convert_id_to_node(self, message: QueueMessage) -> QueueMessage: for i, node in enumerate(message.after_node or []): if not isinstance(node, str): continue - raw_node = self.graph_store.get_node(node, include_embedding=True) + raw_node = self.graph_store.get_node( + node, include_embedding=True, user_name=message.user_name + ) if raw_node is None: logger.debug(f"Node with ID {node} not found in the graph store.") message.after_node[i] = None diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 11b5d7f44..a6ac186b7 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -8,15 +8,13 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.memories.textual.item import ( - ArchivedTextualMemory, TextualMemoryItem, - TreeNodeTextualMemoryMetadata, + TextualMemoryMetadata, ) from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, _append_related_content, _detach_related_content, - _rebuild_fast_node_history, ) @@ -38,7 +36,7 @@ def history_manager(mock_nli_client, mock_graph_db): def test_detach_related_content(): original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) duplicates = ["Duplicate 1", "Duplicate 2"] conflicts = ["Conflict 1", "Conflict 2"] @@ -62,7 +60,7 @@ def test_detach_related_content(): def test_detach_only_conflicts(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) duplicates = [] conflicts = ["Conflict A"] @@ -77,7 +75,7 @@ def test_detach_only_conflicts(): def test_detach_only_duplicates(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) duplicates = ["Duplicate A"] conflicts = [] @@ -120,636 +118,20 @@ def test_mark_memory_status(history_manager, mock_graph_db): id1 = uuid.uuid4().hex id2 = uuid.uuid4().hex id3 = uuid.uuid4().hex - memory_ids = [id1, id2, id3] + items = [ + TextualMemoryItem(memory="M1", id=id1), + TextualMemoryItem(memory="M2", id=id2), + TextualMemoryItem(memory="M3", id=id3), + ] status = "resolving" # Action - history_manager.mark_memory_status(memory_ids, status, user_name="u1") + history_manager.mark_memory_status(items, status) # Assert assert mock_graph_db.update_node.call_count == 3 - # Verify we called it correctly - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name="u1") - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name="u1") - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name="u1") - - -def test_format_async_update_prompt(history_manager): - # Setup - # Create history items - h1 = ArchivedTextualMemory( - version=1, archived_memory_id="101", memory="Duplicate content", update_type="duplicate" - ) - h2 = ArchivedTextualMemory( - version=1, archived_memory_id="201", memory="Conflict content", update_type="conflict" - ) - h3 = ArchivedTextualMemory( - version=1, archived_memory_id="301", memory="Unrelated content", update_type="unrelated" - ) - - item = TextualMemoryItem( - memory="New user input", - metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]), - ) - - # Execute - prompt = history_manager.format_async_update_prompt(item) - - # Verify - assert "[ID:101]" in prompt - assert "Duplicate content" in prompt - assert "[ID:201]" in prompt - assert "Conflict content" in prompt - assert "[ID:301]" in prompt - assert "Unrelated content" in prompt - assert "New user input" in prompt - - # Check that placeholders are gone (basic check) - assert "${duplicate_candidates}" not in prompt - assert "${conflict_candidates}" not in prompt - - -def test_format_async_update_prompt_with_args(history_manager): - item = TextualMemoryItem( - memory="Explicit conversation", - metadata=TreeNodeTextualMemoryMetadata(history=[]), - ) - - prompt = history_manager.format_async_update_prompt(item, custom_tags_prompt="Custom Tags") - - assert "Explicit conversation" in prompt - assert "Custom Tags" in prompt - - -def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): - llm_response = { - "memory list": [ - { - "key": "New Memory", - "memory_type": "LongTermMemory", - "value": "New Content", - "tags": ["tag1"], - "source_candidate_ids": [], - "conflicted_candidate_ids": [], - "history_segments": [], - } - ], - "restored_memories": [], - "summary": "Summary", - } - - source_item = TextualMemoryItem( - memory="New user input", - metadata=TreeNodeTextualMemoryMetadata(history=[]), - ) - updated, new_items = history_manager.apply_llm_memory_updates( - llm_response, source_item=source_item, user_name="u1" - ) - - assert len(updated) == 0 - assert len(new_items) == 1 - new_item = new_items[0] - assert new_item.memory == "New Content" - assert new_item.metadata.tags == ["tag1"] - assert new_item.metadata.history == [] - mock_graph_db.add_node.assert_not_called() - - -def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db): - # Setup existing node - existing_id = uuid.uuid4().hex - existing_node = { - "id": existing_id, - "memory": "Old Content", - "metadata": { - "version": 1, - "created_at": "2023-01-01", - "tags": ["old"], - "status": "resolving", - "embedding": [], - "memory_type": "LongTermMemory", - }, - } - mock_graph_db.get_node.return_value = existing_node - mock_graph_db.get_nodes.return_value = [existing_node] - - llm_response = { - "memory list": [ - { - "key": "Updated Memory", - "memory_type": "LongTermMemory", - "value": "Updated Content", - "tags": ["new"], - "source_candidate_ids": [existing_id], - "conflicted_candidate_ids": [], - "history_segments": [], - } - ], - "restored_memories": [], - "summary": "Summary", - } - - source_item = TextualMemoryItem( - memory="New user input", - metadata=TreeNodeTextualMemoryMetadata( - history=[ - ArchivedTextualMemory( - version=1, - archived_memory_id=existing_id, - memory="Old Content", - update_type="duplicate", - ) - ] - ), - ) - updated, new_items = history_manager.apply_llm_memory_updates( - llm_response, source_item=source_item, user_name="u1" - ) - - assert len(updated) == 1 - assert len(new_items) == 0 - updated_item = updated[0] - assert updated_item.id == existing_id - assert updated_item.memory == "Updated Content" - assert updated_item.metadata.version == 2 - assert updated_item.metadata.tags == ["new"] - assert len(updated_item.metadata.history) == 1 - - history_entry = updated_item.metadata.history[0] - assert history_entry.archived_memory_id != existing_id - assert history_entry.archived_memory_id is not None - assert history_entry.memory == "Old Content" - assert history_entry.update_type == "duplicate" - - mock_graph_db.add_node.assert_called_once() - mock_graph_db.update_node.assert_called_once() - args, kwargs = mock_graph_db.update_node.call_args - assert kwargs["id"] == existing_id - assert kwargs["fields"]["memory"] == "Updated Content" - assert kwargs["fields"]["version"] == 2 - - -def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): - source_id = uuid.uuid4().hex - restored_item = TextualMemoryItem( - memory="Restored Content", - metadata=TreeNodeTextualMemoryMetadata(history=[]), - ) - history_manager._handle_restored_memories = MagicMock(return_value=[restored_item]) - llm_response = { - "memory list": [], - "restored_memories": [ - {"source_candidate_id": source_id, "value": "Restored Content", "tags": ["restored"]} - ], - "summary": "Summary", - } - - source_item = TextualMemoryItem( - memory="New user input", - metadata=TreeNodeTextualMemoryMetadata( - history=[ - ArchivedTextualMemory( - version=1, - archived_memory_id=source_id, - memory="Old Content", - update_type="conflict", - ) - ] - ), - ) - updated, new_items = history_manager.apply_llm_memory_updates( - llm_response, source_item=source_item, user_name="u1" - ) - - assert len(updated) == 0 - assert len(new_items) == 1 - assert new_items[0] == restored_item - history_manager._handle_restored_memories.assert_called_once_with( - llm_response["restored_memories"], source_item - ) - mock_graph_db.add_node.assert_not_called() - - -def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): - id1 = uuid.uuid4().hex - id2 = uuid.uuid4().hex - llm_response = {"memory list": [], "restored_memories": [], "summary": "Summary"} - - source_item = TextualMemoryItem( - memory="New user input", - metadata=TreeNodeTextualMemoryMetadata( - history=[ - ArchivedTextualMemory( - version=1, - archived_memory_id=id1, - memory="M1", - update_type="unrelated", - ), - ArchivedTextualMemory( - version=1, - archived_memory_id=id2, - memory="M2", - update_type="unrelated", - ), - ] - ), - ) - updated, new_items = history_manager.apply_llm_memory_updates( - llm_response, source_item=source_item, user_name="u1" - ) - - assert len(updated) == 0 - assert len(new_items) == 0 - - # Check that update_node was called to set status="activated" - # mark_memory_status calls update_node for each item - assert mock_graph_db.update_node.call_count == 2 - - # We can inspect calls - calls = mock_graph_db.update_node.call_args_list - ids = sorted([c.kwargs["id"] for c in calls]) - assert ids == sorted([id1, id2]) - for c in calls: - assert c.kwargs["fields"]["status"] == "activated" - - -def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): - # Setup existing node (primary) - primary_id = uuid.uuid4().hex - secondary_id = uuid.uuid4().hex - existing_node = { - "id": primary_id, - "memory": "Old Content", - "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, - } - mock_graph_db.get_node.return_value = existing_node - mock_graph_db.get_nodes.return_value = [ - existing_node, - { - "id": secondary_id, - "memory": "Secondary", - "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, - }, - ] - - llm_response = { - "memory list": [ - { - "key": "Conflict Resolved", - "memory_type": "LongTermMemory", - "value": "New Content", - "tags": [], - "source_candidate_ids": [], - "conflicted_candidate_ids": [primary_id, secondary_id], - "history_segments": [], - } - ], - "restored_memories": [], - "summary": "Summary", - } - - source_item = TextualMemoryItem( - memory="New user input", - metadata=TreeNodeTextualMemoryMetadata( - history=[ - ArchivedTextualMemory( - version=1, - archived_memory_id=primary_id, - memory="Old Content", - update_type="conflict", - ), - ArchivedTextualMemory( - version=1, - archived_memory_id=secondary_id, - memory="Secondary", - update_type="conflict", - ), - ] - ), - ) - updated, new_items = history_manager.apply_llm_memory_updates( - llm_response, source_item=source_item, user_name="u1" - ) - - assert len(updated) == 1 - assert len(new_items) == 0 - updated_item = updated[0] - assert updated_item.id == primary_id - assert updated_item.metadata.history[0].update_type == "conflict" - - # Verify primary update - # The mock_graph_db.update_node is called for primary (update) AND secondary (delete) - - # Find call for primary - primary_update_calls = [ - c - for c in mock_graph_db.update_node.call_args_list - if c.kwargs["id"] == primary_id and "memory" in c.kwargs.get("fields", {}) - ] - assert len(primary_update_calls) >= 1 - assert primary_update_calls[0].kwargs["fields"]["memory"] == "New Content" - - # Find call for secondary - secondary_update_calls = [ - c for c in mock_graph_db.update_node.call_args_list if c.kwargs["id"] == secondary_id - ] - assert len(secondary_update_calls) >= 1 - last_secondary_update = secondary_update_calls[-1] - assert last_secondary_update.kwargs["fields"]["status"] == "archived" - assert last_secondary_update.kwargs["fields"]["evolve_to"] == [primary_id] - - -def test_rebuild_fast_node_history_dedup_and_replace(): - h1 = ArchivedTextualMemory( - version=1, archived_memory_id="a", memory="m1", update_type="duplicate" - ) - h2 = ArchivedTextualMemory( - version=1, archived_memory_id="b", memory="m2", update_type="conflict" - ) - h3 = ArchivedTextualMemory( - version=2, archived_memory_id="a", memory="m3", update_type="duplicate" - ) - item = TextualMemoryItem( - memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]) - ) - - r1 = ArchivedTextualMemory( - version=2, archived_memory_id="b", memory="m4", update_type="conflict" - ) - r2 = ArchivedTextualMemory( - version=1, archived_memory_id="c", memory="m5", update_type="duplicate" - ) - - _rebuild_fast_node_history(item, {1: [r1, r2]}) - - by_id = {h.archived_memory_id: h for h in item.metadata.history} - assert set(by_id.keys()) == {"a", "b", "c"} - assert by_id["a"].version == 2 - assert by_id["b"].version == 2 - - -def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): - fast_id = uuid.uuid4().hex - history_item = ArchivedTextualMemory( - version=1, archived_memory_id=fast_id, memory="fast", update_type="conflict", is_fast=True - ) - item = TextualMemoryItem( - memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[history_item]) - ) - mock_graph_db.get_nodes.return_value = [ - {"id": fast_id, "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}} - ] - - replacement_item = ArchivedTextualMemory( - version=1, archived_memory_id="n1", memory="r1", update_type="conflict" - ) - history_manager._fetch_evolved_nodes = MagicMock(return_value=[replacement_item]) - - replacements = history_manager._check_and_fetch_replacements(item, [0]) - - assert 0 in replacements - assert replacements[0][0].archived_memory_id == "n1" - history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict") - - -def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): - mock_graph_db.get_nodes.return_value = [ - { - "id": "x1", - "memory": "m1", - "metadata": {"version": 2, "is_fast": False, "created_at": "2024-01-01"}, - }, - { - "id": "x2", - "memory": "m2", - "metadata": {"version": 1, "is_fast": True, "created_at": "2024-01-02"}, - }, - ] - - results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate") - - assert len(results) == 2 - ids = sorted([r.archived_memory_id for r in results]) - assert ids == ["x1", "x2"] - assert all(r.update_type == "duplicate" for r in results) - - -def test_wait_and_update_fast_history_rebuilds(history_manager): - fast_id = uuid.uuid4().hex - fast_item = ArchivedTextualMemory( - version=1, archived_memory_id=fast_id, memory="fast", update_type="duplicate", is_fast=True - ) - other_item = ArchivedTextualMemory( - version=1, archived_memory_id="k1", memory="keep", update_type="unrelated", is_fast=False - ) - item = TextualMemoryItem( - memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[fast_item, other_item]) - ) - - replacement = ArchivedTextualMemory( - version=2, archived_memory_id="n1", memory="new", update_type="duplicate", is_fast=False - ) - history_manager._check_and_fetch_replacements = MagicMock(return_value={0: [replacement]}) - - history_manager.wait_and_update_fast_history(item, timeout_sec=1) - - ids = [h.archived_memory_id for h in item.metadata.history] - assert "n1" in ids - assert fast_id not in ids - history_manager._check_and_fetch_replacements.assert_called_once() - - -def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): - llm = MagicMock() - llm.generate.return_value = "Merged Content" - manager = MemoryHistoryManager( - nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm - ) - - existing_id = uuid.uuid4().hex - mock_graph_db.get_node.return_value = { - "id": existing_id, - "memory": "Old Content", - "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, - } - mock_graph_db.get_nodes.return_value = [ - { - "id": existing_id, - "memory": "Old Content", - "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, - } - ] - - mem_data = { - "key": "k", - "value": "Proposed", - "tags": ["t1"], - "source_candidate_ids": [existing_id], - "conflicted_candidate_ids": [], - } - - updated, new_item = manager._update_existing_memory( - mem_data, - [existing_id], - [existing_id], - {existing_id: 1}, - user_name="u1", - fast_item=TextualMemoryItem( - memory="New user input", metadata=TreeNodeTextualMemoryMetadata() - ), - ) - - assert updated.memory == "Merged Content" - assert updated.metadata.version == 3 - assert new_item is None - mock_graph_db.update_node.assert_called_once() - - -def test_update_existing_memory_marks_working_binding_deleted(history_manager, mock_graph_db): - history_manager.mark_memory_status = MagicMock() - primary_id = uuid.uuid4().hex - working_binding = uuid.uuid4().hex - mock_graph_db.get_node.return_value = { - "id": primary_id, - "memory": "Old Content", - "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, - } - mock_graph_db.get_nodes.return_value = [ - { - "id": primary_id, - "memory": "Old Content", - "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, - } - ] - mem_data = { - "key": "k", - "value": "Updated", - "tags": [], - "source_candidate_ids": [primary_id], - "conflicted_candidate_ids": [], - } - - updated, new_item = history_manager._update_existing_memory( - mem_data, - [primary_id], - [primary_id], - {primary_id: 1}, - user_name="u1", - fast_item=TextualMemoryItem( - memory="New user input", metadata=TreeNodeTextualMemoryMetadata() - ), - ) - - assert updated is not None - assert new_item is None - history_manager.mark_memory_status.assert_called_once_with( - [str(working_binding)], "deleted", user_name="u1" - ) - - -def test_update_existing_memory_no_mark_when_working_binding_matches( - history_manager, mock_graph_db -): - history_manager.mark_memory_status = MagicMock() - primary_id = uuid.uuid4().hex - mock_graph_db.get_node.return_value = { - "id": primary_id, - "memory": "Old Content", - "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, - } - mock_graph_db.get_nodes.return_value = [ - { - "id": primary_id, - "memory": "Old Content", - "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, - } - ] - mem_data = { - "key": "k", - "value": "Updated", - "tags": [], - "source_candidate_ids": [primary_id], - "conflicted_candidate_ids": [], - } - - updated, new_item = history_manager._update_existing_memory( - mem_data, - [primary_id], - [primary_id], - {primary_id: 1}, - user_name="u1", - fast_item=TextualMemoryItem( - memory="New user input", metadata=TreeNodeTextualMemoryMetadata() - ), - ) - - assert updated is not None - assert new_item is None - history_manager.mark_memory_status.assert_not_called() - - -def test_update_existing_memory_node_missing(history_manager, mock_graph_db): - mock_graph_db.get_node.return_value = None - mock_graph_db.get_nodes.return_value = [] - mem_data = {"value": "v", "tags": [], "key": "k"} - - updated, new_item = history_manager._update_existing_memory( - mem_data, - ["missing"], - [], - {}, - user_name="u1", - fast_item=TextualMemoryItem( - memory="New user input", metadata=TreeNodeTextualMemoryMetadata() - ), - ) - - assert updated is None - assert new_item is not None - assert new_item.memory == "v" - mock_graph_db.update_node.assert_not_called() - - -def test_update_node_with_history(): - item = TextualMemoryItem( - memory="Old Content", - metadata=TreeNodeTextualMemoryMetadata( - version=2, - tags=["old"], - key="k1", - history=[], - ), - ) - - updated, archived = MemoryHistoryManager.update_node_with_history( - item, - "New Content", - "conflict", - ) - - assert updated.memory == "New Content" - assert updated.metadata.version == 3 - assert updated.metadata.tags == ["old"] - assert updated.metadata.key == "k1" - assert len(updated.metadata.history) == 1 - history_entry = updated.metadata.history[0] - assert history_entry.memory == "Old Content" - assert history_entry.update_type == "conflict" - assert history_entry.archived_memory_id == archived.id - assert archived.metadata.status == "archived" - assert archived.metadata.evolve_to == [updated.id] - - -def test_merge_conflicting_memory_llm_error(mock_graph_db): - llm = MagicMock() - llm.generate.side_effect = Exception("fail") - manager = MemoryHistoryManager( - nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm - ) - - merged = manager._merge_conflicting_memory("Latest", "Proposed") - - assert "System Merge Fallback" in merged - assert "Latest" in merged - assert "Proposed" in merged + # Verify we called it correctly (user_name=None is passed by mark_memory_status) + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None) From 511a8802954d4fe93da7e8cf5846a5c75f8353d5 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Fri, 27 Feb 2026 10:14:34 +0800 Subject: [PATCH 50/92] fix: fix user_name passing --- .../organize/history_manager.py | 820 +++++++++++++++++- 1 file changed, 805 insertions(+), 15 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 132582a0d..7ebc25694 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,12 +1,26 @@ import logging +import time +import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal from memos.context.context import ContextThreadPoolExecutor +from memos.embedders.base import BaseEmbedder from memos.extras.nli_model.client import NLIClient from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem +from memos.llms.base import BaseLLM +from memos.mem_reader.read_multi_modal.utils import detect_lang +from memos.memories.textual.item import ( + ArchivedTextualMemory, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.templates.mem_reader_mem_version_prompts import ( + ASYNC_MEMORY_UPDATE_PROMPT_DICT, + MEMORY_MERGE_PROMPT_DICT, +) logger = logging.getLogger(__name__) @@ -67,21 +81,155 @@ def _detach_related_content(new_item: TextualMemoryItem) -> None: return +def _rebuild_fast_node_history( + item: TextualMemoryItem, replacements: dict[int, list[ArchivedTextualMemory]] +) -> None: + """ + Reconstruct the history list of a fast node: + 1. Replace resolved items with their evolved versions. + 2. Deduplicate by ID while preserving the newest versions. + """ + new_history = {} + + def _add(history_item): + item_id = history_item.archived_memory_id + current = new_history.get(item_id) + + if current is None or history_item.version > current.version: + new_history[item_id] = history_item + + # Apply replacements and filter superseded items + for i, h in enumerate(item.metadata.history): + if i in replacements: + # This item is resolved, insert its replacements + for replacement_item in replacements[i]: + _add(replacement_item) + else: + _add(h) + + item.metadata.history = list(new_history.values()) + + +def _sanitize_metadata_dict(data: dict[str, Any] | None) -> dict[str, Any]: + if not data: + return {} + sanitized = data.copy() + for key in ("id", "memory", "graph_id"): + sanitized.pop(key, None) + return sanitized + + +def _sanitize_metadata_model( + metadata: TreeNodeTextualMemoryMetadata, +) -> TreeNodeTextualMemoryMetadata: + data = _sanitize_metadata_dict(metadata.model_dump(exclude_none=True)) + return metadata.__class__(**data) + + +def _determine_lang(sources: list | None, fallback_text: str) -> str: + lang = None + if sources: + for source in sources: + if hasattr(source, "lang") and source.lang: + lang = source.lang + break + if isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + break + if lang is None: + lang = detect_lang(fallback_text) + return lang + + class MemoryHistoryManager: - def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: + def __init__( + self, + nli_client: NLIClient, + graph_db: BaseGraphDB, + llm: BaseLLM | None = None, + embedder: BaseEmbedder | None = None, + ) -> None: """ Initialize the MemoryHistoryManager. Args: nli_client: NLIClient for conflict/duplicate detection. graph_db: GraphDB instance for marking operations during history management. + llm: Optional LLM instance for memory merging during conflicts. """ self.nli_client = nli_client self.graph_db = graph_db + self.llm = llm + self.embedder = embedder + + def _compute_embedding(self, text: str) -> list[float] | None: + if not self.embedder: + return None + try: + return self.embedder.embed([text])[0] + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to compute embedding: {e}") + return None + + @staticmethod + def is_applicable(item: TextualMemoryItem) -> bool: + # Only deals with: + # 1. From doc or chat + # 2. LongTermMemory, UserMemory + allowed_sources = ["doc", "chat"] + allowed_memory_types = ["LongTermMemory", "UserMemory"] + return ( + item.metadata.sources[0].type in allowed_sources + and item.metadata.memory_type in allowed_memory_types + ) + + @staticmethod + def update_node_with_history( + item: TextualMemoryItem, + new_memory: str, + update_type: str, + tags: list[str] | None = None, + key: str | None = None, + ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + now = datetime.now().isoformat() + last_update_time = item.metadata.updated_at + + old_id = item.id + archived_id = str(uuid.uuid4()) + # archived memory(need to store this node to the db later) + archived_item = item.model_copy(deep=True) + archived_item.id = archived_id + archived_item.metadata.evolve_to = [old_id] + archived_item.metadata.status = "archived" + archived_item.metadata.created_at = last_update_time + archived_item.metadata.updated_at = now + + # original memory with updated contents and history + archived_history = ArchivedTextualMemory( + version=item.metadata.version or 1, + is_fast=item.metadata.is_fast or False, + memory=item.memory, + update_type=update_type, + archived_memory_id=archived_id, + created_at=getattr(item.metadata, "updated_at", None) or last_update_time, + ) + item.memory = new_memory + item.metadata.version = (item.metadata.version or 1) + 1 + item.metadata.status = "activated" + item.metadata.updated_at = now + if tags is not None: + item.metadata.tags = tags + if key is not None: + item.metadata.key = key + if item.metadata.history is None: + item.metadata.history = [] + item.metadata.history.append(archived_history) + + return item, archived_item def resolve_history_via_nli( self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[TextualMemoryItem]: + ) -> list[str]: """ Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, and attach them as history to the new fast item. @@ -91,7 +239,7 @@ def resolve_history_via_nli( related_items: Existing memory items that might be related. Returns: - List of duplicate or conflicting memory items judged by the NLI service. + List of duplicate or conflicting memory ids judged by the NLI service. """ if not related_items: return [] @@ -102,15 +250,19 @@ def resolve_history_via_nli( ) # 2. Process results and attach to history + duplicate_memory_ids = [] + conflict_memory_ids = [] duplicate_memories = [] conflict_memories = [] for r_item, nli_res in zip(related_items, nli_results, strict=False): if nli_res == NLIResult.DUPLICATE: update_type = "duplicate" + duplicate_memory_ids.append(r_item.id) duplicate_memories.append(r_item.memory) elif nli_res == NLIResult.CONTRADICTION: update_type = "conflict" + conflict_memory_ids.append(r_item.id) conflict_memories.append(r_item.memory) else: update_type = "unrelated" @@ -118,45 +270,256 @@ def resolve_history_via_nli( # Safely get created_at, fallback to updated_at created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at + # TODO: change the way of marking fast nodes by directly using is_fast field. archived = ArchivedTextualMemory( version=r_item.metadata.version or 1, - is_fast=r_item.metadata.is_fast or False, + is_fast=( + r_item.metadata.is_fast + or ("mode:fast" in (getattr(r_item.metadata, "tags", None) or [])) + ), memory=r_item.memory, update_type=update_type, archived_memory_id=r_item.id, created_at=created_at, ) new_item.metadata.history.append(archived) - logger.info( - f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" - ) # 3. Concat duplicate/conflict memories to new_item.memory # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. _append_related_content(new_item, duplicate_memories, conflict_memories) - return duplicate_memories + conflict_memories + return duplicate_memory_ids + conflict_memory_ids + + def wait_and_update_fast_history( + self, item: TextualMemoryItem, user_name: str, timeout_sec: int = 30 + ) -> None: + """ + Scan the item's history. If any history item is marked as `is_fast`, + wait for it to be resolved (i.e., status becomes 'deleted' in the DB). + When resolved, replace the fast item with the nodes referenced in its `evolve_to` field. + Finally, deduplicate the history. + + Args: + item: The memory item containing the history to check. + user_name: Required for db query. + timeout_sec: Maximum time to wait for resolution in seconds. + """ + start_time = time.time() + + # 1. Identify pending items (fast nodes) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + while True: + if not pending_indices: + # All fast nodes resolved or none existed + break + + if time.time() - start_time > timeout_sec: + logger.warning( + f"[MemoryHistoryManager] Timeout waiting for fast history resolution for item {item.id}" + ) + # Remove pending fast nodes from history + item.metadata.history = [ + h + for h in item.metadata.history + if not (getattr(h, "is_fast", False) and h.archived_memory_id) + ] + break + + # 2. Check status of the fast nodes and fetch replacements for evolved ones + replacements = self._check_and_fetch_replacements(item, pending_indices) + + # 3. If we have any resolved items, rebuild the history + if replacements: + _rebuild_fast_node_history(item, replacements) + + # Check if we are done (no pending items left) + pending_indices = [ + i + for i, h in enumerate(item.metadata.history) + if getattr(h, "is_fast", False) and h.archived_memory_id + ] + + if pending_indices: + time.sleep(1) # This avoids visiting the DB too frequently + + return + + def format_async_update_prompt( + self, item: TextualMemoryItem, custom_tags_prompt: str = "" + ) -> str: + """ + Format the prompt for asynchronous memory update. + + Args: + item: The TextualMemoryItem containing history candidates. + custom_tags_prompt: Optional custom prompt for tags. + + Returns: + Formatted prompt string. + """ + # First, detach duplicate and conflict memory contents from the new item's memory text + _detach_related_content(item) + + duplicate_candidates = [] + conflict_candidates = [] + unrelated_candidates = [] + + def _fmt_time(ts: str | None) -> str | None: + if not ts or not isinstance(ts, str): + return None + try: + t = datetime.fromisoformat(ts.replace("Z", "")) + return t.strftime("%Y/%m/%d %H:%M:%S") + except Exception: + return ts + + for h in item.metadata.history or []: + created = getattr(h, "created_at", None) + tstr = _fmt_time(created) + time_suffix = f"[Time: {tstr}] " if tstr else "" + candidate_str = f"[ID:{h.archived_memory_id}]{time_suffix}{h.memory}" + + if h.update_type == "duplicate": + duplicate_candidates.append(candidate_str) + elif h.update_type == "conflict": + conflict_candidates.append(candidate_str) + else: + # Includes "unrelated" and any other types + unrelated_candidates.append(candidate_str) + + sources = item.metadata.sources if item.metadata else None + lang = _determine_lang(sources, item.memory) + empty_label = "None" + + def format_list(candidates): + return "\n".join(candidates) if candidates else empty_label + + prompt_template = ASYNC_MEMORY_UPDATE_PROMPT_DICT.get( + lang, ASYNC_MEMORY_UPDATE_PROMPT_DICT["en"] + ) + conversation_time_raw = getattr(item.metadata, "created_at", None) + conversation_time = _fmt_time(conversation_time_raw) or conversation_time_raw + + return ( + prompt_template.replace("${duplicate_candidates}", format_list(duplicate_candidates)) + .replace("${conflict_candidates}", format_list(conflict_candidates)) + .replace("${unrelated_candidates}", format_list(unrelated_candidates)) + .replace("${custom_tags_prompt}", custom_tags_prompt) + .replace("${conversation_time}", conversation_time) + .replace("${conversation}", item.memory) + ) + + def apply_llm_memory_updates( + self, llm_response: dict[str, Any], source_item: TextualMemoryItem, user_name: str + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """ + Apply the updates from the LLM response to the memory graph. + + Args: + llm_response: The parsed JSON response from the LLM. + source_item: The original fast item A whose history contains ArchivedTextualMemory entries. + We derive expected versions and candidate IDs from A.history. + user_name: user_name + + Returns: + List of new or updated memory items. + """ + memory_list = llm_response.get("memory list", []) + restored_memories = llm_response.get("restored_memories", []) + used_source_ids = set() + used_conflict_ids = set() + # Collect IDs referenced by the LLM response + for mem_data in memory_list: + used_source_ids.update(mem_data.get("source_candidate_ids", [])) + used_conflict_ids.update(mem_data.get("conflicted_candidate_ids", [])) + # Collect IDs referenced by restored memories + restored_source_ids = { + mem.get("source_candidate_id") + for mem in restored_memories + if mem.get("source_candidate_id") + } + # All IDs used by the LLM response + used_ids = used_source_ids | used_conflict_ids | restored_source_ids + expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories + candidate_id_set: set[str] = set() + # Recover candidate IDs and their expected versions from the source item's history + if source_item.metadata and source_item.metadata.history: + for h in source_item.metadata.history: + if h.archived_memory_id: + candidate_id_set.add(h.archived_memory_id) + expected_versions[h.archived_memory_id] = h.version + + # IDs not used by the LLM response will be treated as unrelated + unrelated_ids = sorted(candidate_id_set - used_ids) + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + + # 1. Handle Unrelated Candidates - Restore status to activated + self._handle_unrelated_candidates(unrelated_ids, user_name=user_name) + + # 2. Handle Memory List (Update or New) + processed_updates, created_items = self._process_memory_updates( + memory_list, expected_versions, user_name, source_item + ) + updated_items.extend(processed_updates) + new_items.extend(created_items) + + # 3. Handle Restored Memories (Extract from conflict) + new_items.extend(self._handle_restored_memories(restored_memories, source_item, user_name)) + + return updated_items, new_items + + def build_fallback_new_items( + self, item: TextualMemoryItem, user_name: str | None = None + ) -> list[TextualMemoryItem]: + latest_item = item.model_copy(deep=True) + _detach_related_content(latest_item) + + history = latest_item.metadata.history or [] + archived_ids = [h.archived_memory_id for h in history if h.archived_memory_id] + if archived_ids: + self.mark_memory_status(archived_ids, "activated", user_name or "") + + latest_item.id = str(uuid.uuid4()) + latest_item.metadata.is_fast = False + latest_item.metadata.status = "activated" + latest_item.metadata.history = [] + latest_item.metadata.working_binding = None + if hasattr(latest_item.metadata, "background"): + latest_item.metadata.background = "" + + if hasattr(latest_item.metadata, "tags") and latest_item.metadata.tags: + latest_item.metadata.tags = [t for t in latest_item.metadata.tags if t != "mode:fast"] + + latest_item.metadata = _sanitize_metadata_model(latest_item.metadata) + + return [latest_item] def mark_memory_status( self, - memory_items: list[TextualMemoryItem], + memory_ids: list[str], status: Literal["activated", "resolving", "archived", "deleted"], - user_name: str | None = None, + user_name: str, ) -> None: """ Support status marking operations during history management. Common usages are: 1. Mark conflict/duplicate old memories' status as "resolving", to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to restore their visibility. + 2. Mark resolved memories' status as "activated", to recover their visibility. """ # Execute the actual marking operation - in db. with ContextThreadPoolExecutor() as executor: futures = [] - for mem in memory_items: + for mid in memory_ids: futures.append( executor.submit( self.graph_db.update_node, - id=mem.id, + id=mid, fields={"status": status}, user_name=user_name, ) @@ -166,3 +529,430 @@ def mark_memory_status( for future in futures: future.result() return + + def _check_and_fetch_replacements( + self, item: TextualMemoryItem, pending_indices: list[int], user_name: str + ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: + """ + Check DB status for pending items. If 'deleted', fetch evolved nodes. + + Returns: + replacements: Dict mapping original history index to list of new ArchivedTextualMemory items. + """ + pending_ids = [item.metadata.history[i].archived_memory_id for i in pending_indices] + + # Batch fetch pending nodes to check status + nodes_data = self.graph_db.get_nodes(ids=pending_ids, user_name=user_name) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + + replacements = {} + + for i in pending_indices: + h_item = item.metadata.history[i] + node_data = nodes_map.get(h_item.archived_memory_id) + + if not node_data: + continue + + metadata = node_data.get("metadata", {}) + status = metadata.get("status") + + # Condition: Fast node is processed when it is marked as 'deleted' + if status == "deleted": + evolve_to_ids = metadata.get("evolve_to", []) + + new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type, user_name) + replacements[i] = new_items + + return replacements + + def _fetch_evolved_nodes( + self, evolve_to_ids: list[str], update_type: str, user_name: str + ) -> list[ArchivedTextualMemory]: + """Fetch the actual nodes that the fast node evolved into and convert to archive format.""" + if not evolve_to_ids: + return [] + + evolved_nodes = self.graph_db.get_nodes(ids=evolve_to_ids, user_name=user_name) or [] + results = [] + + for enode in evolved_nodes: + if not enode or "id" not in enode: + continue + + enode_meta = enode.get("metadata", {}) + + # Create new archived memory inheriting the update_type (conflict/duplicate) + new_archived = ArchivedTextualMemory( + version=enode_meta.get("version", 1), + is_fast=enode_meta.get("is_fast", False), + memory=enode.get("memory", ""), + update_type=update_type, + archived_memory_id=enode.get("id"), + created_at=enode_meta.get("created_at"), + ) + results.append(new_archived) + + return results + + def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) -> None: + """Handle Unrelated Candidates - Restore status to `activated`.""" + if unrelated_ids: + self.mark_memory_status(unrelated_ids, "activated", user_name) + + def _process_memory_updates( + self, + memory_list: list[dict[str, Any]], + expected_versions: dict[str, int], + user_name: str, + source_item: TextualMemoryItem, + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: + """Process Memory List (Update or New).""" + updated_items: list[TextualMemoryItem] = [] + new_items: list[TextualMemoryItem] = [] + for mem_data in memory_list: + source_ids = mem_data.get("source_candidate_ids", []) + conflict_ids = mem_data.get("conflicted_candidate_ids", []) + + # Determine if this is an update or a new node + target_ids = source_ids + conflict_ids + + if target_ids: + updated_item, new_item = self._update_existing_memory( + mem_data, target_ids, source_ids, expected_versions, user_name, source_item + ) + if updated_item: + updated_items.append(updated_item) + if new_item: + new_items.append(new_item) + else: + item = self._create_new_memory(mem_data, source_item) + new_items.append(item) + return updated_items, new_items + + def _update_existing_memory( + self, + mem_data: dict[str, Any], + target_ids: list[str], + source_ids: list[str], + expected_versions: dict[str, int], + user_name: str, + fast_item: TextualMemoryItem, + ) -> tuple[TextualMemoryItem | None, TextualMemoryItem | None]: + """ + Update existing memory nodes using the LLM result. + + The first ID in target_ids is treated as the primary node. If additional target IDs + are provided, they are treated as secondary candidates and will be merged into the + primary. Merging means: + 1) Mark secondary nodes as archived and append the primary ID to evolve_to + 2) Merge their history entries into the primary history and re-order by created_at + + The method also applies CAS validation via expected_versions, archives the previous + version of the primary node, and persists the updated node back to the graph DB. + + Returns the updated primary TextualMemoryItem and optional new item when fallback is used. + """ + original_primary_id, primary_id, secondary_ids = ( + target_ids[0], + target_ids[0], + target_ids[1:], + ) + new_value, tags, key = ( + mem_data.get("value", ""), + mem_data.get("tags", []), + mem_data.get("key", ""), + ) + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(new_value_item) + new_value = new_value_item.memory + + # Fetch candidate nodes in batch and then select the primary + # We update the primary and then merge the secondaries to the primary + nodes_data = self.graph_db.get_nodes(target_ids, user_name=user_name) or [] + nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + node_data = nodes_map.get(primary_id) + if not node_data: + logger.warning( + f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." + ) + # Fallback to create new item when the source_id is hallucinated by llm + new_item = self._create_new_memory(mem_data, fast_item) + return None, new_item + current_item = TextualMemoryItem(**node_data) + + # For concurrency control, need to make sure the primary item has not been modified by others in the meantime + # If it has(version changed), then we need to use llm to merge again. + new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) + + update_type = "duplicate" if original_primary_id in source_ids else "conflict" + current_item, archived_item = self.update_node_with_history( + current_item, + new_value, + update_type, + tags=tags, + key=key, + ) + + # create archived node for storing older versions of the memory, preserving the embedding + emb = TextualMemoryItem( + **self.graph_db.get_node(primary_id, user_name=user_name, include_embedding=True) + ).metadata.embedding + arch_meta = _sanitize_metadata_dict(archived_item.metadata.model_dump(exclude_none=True)) + arch_meta["embedding"] = emb + self.graph_db.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=arch_meta, + user_name=user_name, + ) + + fields = _sanitize_metadata_dict(current_item.metadata.model_dump(exclude_none=True)) + merged_history = list(current_item.metadata.history or []) + new_primary_version = current_item.metadata.version or 1 + # Multiple related ids indicates existing duplicates/conflicts to be merged + if secondary_ids: + merged_history, new_primary_version = self._merge_secondary_nodes( + secondary_ids, primary_id, nodes_map, user_name, merged_history + ) + current_item.metadata.history = merged_history + current_item.metadata.version = new_primary_version + merged_history_dump = [h.model_dump(exclude_none=True) for h in merged_history] + embedding = self._compute_embedding(current_item.memory) + sources = [s.model_dump(exclude_none=True) for s in (fast_item.metadata.sources or [])] + # update old memory node with new content and updated history + self.graph_db.update_node( + id=primary_id, + fields={ + **fields, + "memory": current_item.memory, + "history": merged_history_dump, + "version": new_primary_version, + "embedding": embedding, + "sources": sources, + "session_id": fast_item.metadata.session_id, + }, + user_name=user_name, + ) + working_binding = getattr(current_item.metadata, "working_binding", None) + if working_binding and working_binding != current_item.id: + try: + self.mark_memory_status([str(working_binding)], "deleted", user_name=user_name) + except Exception as e: + logger.warning( + f"[MemoryHistoryManager] Failed to mark WorkingMemory {working_binding} as deleted: {e}" + ) + + return current_item, None + + def _apply_cas_merge( + self, + primary_id: str, + current_item: TextualMemoryItem, + expected_versions: dict[str, int], + new_value: str, + ) -> str: + expected_version = expected_versions.get(primary_id) + current_version = current_item.metadata.version or 1 + if expected_version is not None and current_version != expected_version: + logger.warning( + f"[MemoryHistoryManager] Version conflict for node {primary_id}: " + f"Expected v{expected_version}, but found v{current_version} in DB. " + "Triggering merge logic." + ) + latest_item = TextualMemoryItem( + memory=current_item.memory, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(latest_item) + merged_content = self._merge_conflicting_memory( + latest_memory=latest_item.memory, + proposed_update=new_value, + ) + return merged_content + + return new_value + + def _merge_secondary_nodes( + self, + secondary_ids: list[str], + primary_id: str, + nodes_map: dict, + user_name: str, + base_history: list[ArchivedTextualMemory], + ) -> tuple[list[ArchivedTextualMemory], int]: + merged_history = list(base_history) + + for memory_id in secondary_ids: + node_data = nodes_map.get(memory_id) + if not node_data: + continue + metadata = node_data.get("metadata", {}) + evolve_to = list(metadata.get("evolve_to", []) or []) + if primary_id not in evolve_to: + evolve_to.append(primary_id) + # set secondary nodes to archived and record their evolving destinations + self.graph_db.update_node( + id=memory_id, + fields={"status": "archived", "evolve_to": evolve_to}, + user_name=user_name, + ) + secondary_item = TextualMemoryItem(**node_data) + if secondary_item.metadata.history: + merged_history.extend(secondary_item.metadata.history) + + # Currently we just sort the versions according to their creation time + def _history_sort_key(history_item: ArchivedTextualMemory) -> datetime: + created_at = history_item.created_at + if isinstance(created_at, datetime): + return created_at + if created_at: + try: + return datetime.fromisoformat(created_at) + except ValueError: + return datetime.min + return datetime.min + + def _dedupe_history_by_archived_id( + history: list[ArchivedTextualMemory], + ) -> list[ArchivedTextualMemory]: + seen_archived_ids: set[str] = set() + deduped_history: list[ArchivedTextualMemory] = [] + for history_item in history: + archived_id = history_item.archived_memory_id + if archived_id and archived_id in seen_archived_ids: + continue + if archived_id: + seen_archived_ids.add(archived_id) + deduped_history.append(history_item) + return deduped_history + + merged_history.sort(key=_history_sort_key) + merged_history = _dedupe_history_by_archived_id(merged_history) + max_version = 0 + for idx, history_item in enumerate(merged_history, start=1): + history_item.version = idx + max_version = idx + return merged_history, max_version + 1 + + def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> str: + """ + Call LLM to merge proposed update with latest memory content. + """ + if not self.llm: + return proposed_update + + lang = _determine_lang(None, f"{latest_memory}\n{proposed_update}") + prompt_template = MEMORY_MERGE_PROMPT_DICT.get(lang, MEMORY_MERGE_PROMPT_DICT["en"]) + prompt = prompt_template.replace("${latest_memory}", latest_memory).replace( + "${proposed_update}", proposed_update + ) + + messages = [{"role": "user", "content": prompt}] + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"[MemoryHistoryManager] Failed to merge memory via LLM: {e}") + # Fallback: append proposed update? or just return proposed? + # Returning proposed might overwrite latest changes. + # Returning latest might lose proposed changes. + # Let's concatenate as a safe fallback. + return f"{latest_memory}\n\n[System Merge Fallback] New Info: {proposed_update}" + + def _create_new_memory( + self, mem_data: dict[str, Any], fast_item: TextualMemoryItem + ) -> TextualMemoryItem: + """Create New Node.""" + new_value = mem_data.get("value", "") + new_value_item = TextualMemoryItem( + memory=new_value, metadata=TreeNodeTextualMemoryMetadata() + ) + _detach_related_content(new_value_item) + new_value = new_value_item.memory + tags = mem_data.get("tags", []) + key = mem_data.get("key", "") + background = mem_data.get("summary", "") + memory_type = mem_data.get("memory_type", "LongTermMemory") + metadata_updates = { + "is_fast": False, + "version": 1, + "memory_type": memory_type, + "status": "activated", + "background": background, + "working_binding": None, + "tags": tags, + "key": key, + "created_at": datetime.now().isoformat(), + "history": [], + "embedding": self._compute_embedding(new_value), + } + metadata = fast_item.metadata.model_copy(deep=True) + for field_name, value in metadata_updates.items(): + setattr(metadata, field_name, value) + metadata = _sanitize_metadata_model(metadata) + + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=new_value, + metadata=metadata, + ) + return new_item + + def _handle_restored_memories( + self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem, user_name: str + ) -> list[TextualMemoryItem]: + """Handle Restored Memories (Extract from conflict).""" + source_ids = [r.get("source_candidate_id") for r in restored_memories] + source_items = self.graph_db.get_nodes(source_ids, user_name=user_name) + source_items = [TextualMemoryItem(**i) for i in source_items] + + created_items = [] + for i, data in enumerate(restored_memories): + source_item = source_items[i] + # deal with history + source_history = source_item.history.copy() + value = data.get("value", "") + value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) + _detach_related_content(value_item) + value = value_item.memory + tags = data.get("tags", []) + keys = data.get("keys", []) + memory_type = data.get("memory_type", "LongTermMemory") + version = source_item.verion + new_history_item = ArchivedTextualMemory( + version=version, + is_fast=False, + memory=source_item.memory, + update_type="extract", + archived_memory_id=source_item.id, + created_at=source_item.metadata.created_at, + ) + source_history.append(new_history_item) # Re-use the history of the old node + # Create new node + metadata_updates = { + "memory_type": memory_type, + "status": "activated", + "is_fast": False, + "version": version + 1, + "tags": tags, + "key": keys, + "created_at": datetime.now().isoformat(), + "history": source_history, + "embedding": self._compute_embedding(value), + } + metadata = fast_item.metadata.model_copy(deep=True) + for field_name, value in metadata_updates.items(): + setattr(metadata, field_name, value) + metadata = _sanitize_metadata_model(metadata) + + new_item = TextualMemoryItem( + id=str(uuid.uuid4()), + memory=value, + metadata=metadata, + ) + + created_items.append(new_item) + + return created_items From 54b52004423524d127adaf08c4bc88b017400dae Mon Sep 17 00:00:00 2001 From: bittergreen Date: Fri, 27 Feb 2026 11:30:03 +0800 Subject: [PATCH 51/92] fix: fix history manager test --- .../memories/textual/test_history_manager.py | 646 +++++++++++++++++- 1 file changed, 632 insertions(+), 14 deletions(-) diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index a6ac186b7..11b5d7f44 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -8,13 +8,15 @@ from memos.extras.nli_model.types import NLIResult from memos.graph_dbs.base import BaseGraphDB from memos.memories.textual.item import ( + ArchivedTextualMemory, TextualMemoryItem, - TextualMemoryMetadata, + TreeNodeTextualMemoryMetadata, ) from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, _append_related_content, _detach_related_content, + _rebuild_fast_node_history, ) @@ -36,7 +38,7 @@ def history_manager(mock_nli_client, mock_graph_db): def test_detach_related_content(): original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = ["Duplicate 1", "Duplicate 2"] conflicts = ["Conflict 1", "Conflict 2"] @@ -60,7 +62,7 @@ def test_detach_related_content(): def test_detach_only_conflicts(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = [] conflicts = ["Conflict A"] @@ -75,7 +77,7 @@ def test_detach_only_conflicts(): def test_detach_only_duplicates(): original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) + item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) duplicates = ["Duplicate A"] conflicts = [] @@ -118,20 +120,636 @@ def test_mark_memory_status(history_manager, mock_graph_db): id1 = uuid.uuid4().hex id2 = uuid.uuid4().hex id3 = uuid.uuid4().hex - items = [ - TextualMemoryItem(memory="M1", id=id1), - TextualMemoryItem(memory="M2", id=id2), - TextualMemoryItem(memory="M3", id=id3), - ] + memory_ids = [id1, id2, id3] status = "resolving" # Action - history_manager.mark_memory_status(items, status) + history_manager.mark_memory_status(memory_ids, status, user_name="u1") # Assert assert mock_graph_db.update_node.call_count == 3 - # Verify we called it correctly (user_name=None is passed by mark_memory_status) - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None) + # Verify we called it correctly + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name="u1") + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name="u1") + + +def test_format_async_update_prompt(history_manager): + # Setup + # Create history items + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="101", memory="Duplicate content", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="201", memory="Conflict content", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=1, archived_memory_id="301", memory="Unrelated content", update_type="unrelated" + ) + + item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]), + ) + + # Execute + prompt = history_manager.format_async_update_prompt(item) + + # Verify + assert "[ID:101]" in prompt + assert "Duplicate content" in prompt + assert "[ID:201]" in prompt + assert "Conflict content" in prompt + assert "[ID:301]" in prompt + assert "Unrelated content" in prompt + assert "New user input" in prompt + + # Check that placeholders are gone (basic check) + assert "${duplicate_candidates}" not in prompt + assert "${conflict_candidates}" not in prompt + + +def test_format_async_update_prompt_with_args(history_manager): + item = TextualMemoryItem( + memory="Explicit conversation", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + + prompt = history_manager.format_async_update_prompt(item, custom_tags_prompt="Custom Tags") + + assert "Explicit conversation" in prompt + assert "Custom Tags" in prompt + + +def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): + llm_response = { + "memory list": [ + { + "key": "New Memory", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": ["tag1"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + new_item = new_items[0] + assert new_item.memory == "New Content" + assert new_item.metadata.tags == ["tag1"] + assert new_item.metadata.history == [] + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db): + # Setup existing node + existing_id = uuid.uuid4().hex + existing_node = { + "id": existing_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=existing_id, + memory="Old Content", + update_type="duplicate", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == existing_id + assert updated_item.memory == "Updated Content" + assert updated_item.metadata.version == 2 + assert updated_item.metadata.tags == ["new"] + assert len(updated_item.metadata.history) == 1 + + history_entry = updated_item.metadata.history[0] + assert history_entry.archived_memory_id != existing_id + assert history_entry.archived_memory_id is not None + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "duplicate" + + mock_graph_db.add_node.assert_called_once() + mock_graph_db.update_node.assert_called_once() + args, kwargs = mock_graph_db.update_node.call_args + assert kwargs["id"] == existing_id + assert kwargs["fields"]["memory"] == "Updated Content" + assert kwargs["fields"]["version"] == 2 + + +def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): + source_id = uuid.uuid4().hex + restored_item = TextualMemoryItem( + memory="Restored Content", + metadata=TreeNodeTextualMemoryMetadata(history=[]), + ) + history_manager._handle_restored_memories = MagicMock(return_value=[restored_item]) + llm_response = { + "memory list": [], + "restored_memories": [ + {"source_candidate_id": source_id, "value": "Restored Content", "tags": ["restored"]} + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=source_id, + memory="Old Content", + update_type="conflict", + ) + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 1 + assert new_items[0] == restored_item + history_manager._handle_restored_memories.assert_called_once_with( + llm_response["restored_memories"], source_item + ) + mock_graph_db.add_node.assert_not_called() + + +def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): + id1 = uuid.uuid4().hex + id2 = uuid.uuid4().hex + llm_response = {"memory list": [], "restored_memories": [], "summary": "Summary"} + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=id1, + memory="M1", + update_type="unrelated", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=id2, + memory="M2", + update_type="unrelated", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 0 + assert len(new_items) == 0 + + # Check that update_node was called to set status="activated" + # mark_memory_status calls update_node for each item + assert mock_graph_db.update_node.call_count == 2 + + # We can inspect calls + calls = mock_graph_db.update_node.call_args_list + ids = sorted([c.kwargs["id"] for c in calls]) + assert ids == sorted([id1, id2]) + for c in calls: + assert c.kwargs["fields"]["status"] == "activated" + + +def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): + # Setup existing node (primary) + primary_id = uuid.uuid4().hex + secondary_id = uuid.uuid4().hex + existing_node = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [ + existing_node, + { + "id": secondary_id, + "memory": "Secondary", + "metadata": {"version": 1, "embedding": [], "memory_type": "LongTermMemory"}, + }, + ] + + llm_response = { + "memory list": [ + { + "key": "Conflict Resolved", + "memory_type": "LongTermMemory", + "value": "New Content", + "tags": [], + "source_candidate_ids": [], + "conflicted_candidate_ids": [primary_id, secondary_id], + "history_segments": [], + } + ], + "restored_memories": [], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=primary_id, + memory="Old Content", + update_type="conflict", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=secondary_id, + memory="Secondary", + update_type="conflict", + ), + ] + ), + ) + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + updated_item = updated[0] + assert updated_item.id == primary_id + assert updated_item.metadata.history[0].update_type == "conflict" + + # Verify primary update + # The mock_graph_db.update_node is called for primary (update) AND secondary (delete) + + # Find call for primary + primary_update_calls = [ + c + for c in mock_graph_db.update_node.call_args_list + if c.kwargs["id"] == primary_id and "memory" in c.kwargs.get("fields", {}) + ] + assert len(primary_update_calls) >= 1 + assert primary_update_calls[0].kwargs["fields"]["memory"] == "New Content" + + # Find call for secondary + secondary_update_calls = [ + c for c in mock_graph_db.update_node.call_args_list if c.kwargs["id"] == secondary_id + ] + assert len(secondary_update_calls) >= 1 + last_secondary_update = secondary_update_calls[-1] + assert last_secondary_update.kwargs["fields"]["status"] == "archived" + assert last_secondary_update.kwargs["fields"]["evolve_to"] == [primary_id] + + +def test_rebuild_fast_node_history_dedup_and_replace(): + h1 = ArchivedTextualMemory( + version=1, archived_memory_id="a", memory="m1", update_type="duplicate" + ) + h2 = ArchivedTextualMemory( + version=1, archived_memory_id="b", memory="m2", update_type="conflict" + ) + h3 = ArchivedTextualMemory( + version=2, archived_memory_id="a", memory="m3", update_type="duplicate" + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[h1, h2, h3]) + ) + + r1 = ArchivedTextualMemory( + version=2, archived_memory_id="b", memory="m4", update_type="conflict" + ) + r2 = ArchivedTextualMemory( + version=1, archived_memory_id="c", memory="m5", update_type="duplicate" + ) + + _rebuild_fast_node_history(item, {1: [r1, r2]}) + + by_id = {h.archived_memory_id: h for h in item.metadata.history} + assert set(by_id.keys()) == {"a", "b", "c"} + assert by_id["a"].version == 2 + assert by_id["b"].version == 2 + + +def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): + fast_id = uuid.uuid4().hex + history_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="conflict", is_fast=True + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[history_item]) + ) + mock_graph_db.get_nodes.return_value = [ + {"id": fast_id, "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}} + ] + + replacement_item = ArchivedTextualMemory( + version=1, archived_memory_id="n1", memory="r1", update_type="conflict" + ) + history_manager._fetch_evolved_nodes = MagicMock(return_value=[replacement_item]) + + replacements = history_manager._check_and_fetch_replacements(item, [0]) + + assert 0 in replacements + assert replacements[0][0].archived_memory_id == "n1" + history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict") + + +def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): + mock_graph_db.get_nodes.return_value = [ + { + "id": "x1", + "memory": "m1", + "metadata": {"version": 2, "is_fast": False, "created_at": "2024-01-01"}, + }, + { + "id": "x2", + "memory": "m2", + "metadata": {"version": 1, "is_fast": True, "created_at": "2024-01-02"}, + }, + ] + + results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate") + + assert len(results) == 2 + ids = sorted([r.archived_memory_id for r in results]) + assert ids == ["x1", "x2"] + assert all(r.update_type == "duplicate" for r in results) + + +def test_wait_and_update_fast_history_rebuilds(history_manager): + fast_id = uuid.uuid4().hex + fast_item = ArchivedTextualMemory( + version=1, archived_memory_id=fast_id, memory="fast", update_type="duplicate", is_fast=True + ) + other_item = ArchivedTextualMemory( + version=1, archived_memory_id="k1", memory="keep", update_type="unrelated", is_fast=False + ) + item = TextualMemoryItem( + memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[fast_item, other_item]) + ) + + replacement = ArchivedTextualMemory( + version=2, archived_memory_id="n1", memory="new", update_type="duplicate", is_fast=False + ) + history_manager._check_and_fetch_replacements = MagicMock(return_value={0: [replacement]}) + + history_manager.wait_and_update_fast_history(item, timeout_sec=1) + + ids = [h.archived_memory_id for h in item.metadata.history] + assert "n1" in ids + assert fast_id not in ids + history_manager._check_and_fetch_replacements.assert_called_once() + + +def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): + llm = MagicMock() + llm.generate.return_value = "Merged Content" + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + existing_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, + } + mock_graph_db.get_nodes.return_value = [ + { + "id": existing_id, + "memory": "Old Content", + "metadata": {"version": 2, "embedding": [], "memory_type": "LongTermMemory"}, + } + ] + + mem_data = { + "key": "k", + "value": "Proposed", + "tags": ["t1"], + "source_candidate_ids": [existing_id], + "conflicted_candidate_ids": [], + } + + updated, new_item = manager._update_existing_memory( + mem_data, + [existing_id], + [existing_id], + {existing_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated.memory == "Merged Content" + assert updated.metadata.version == 3 + assert new_item is None + mock_graph_db.update_node.assert_called_once() + + +def test_update_existing_memory_marks_working_binding_deleted(history_manager, mock_graph_db): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + working_binding = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, + } + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": working_binding, "embedding": []}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated, new_item = history_manager._update_existing_memory( + mem_data, + [primary_id], + [primary_id], + {primary_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated is not None + assert new_item is None + history_manager.mark_memory_status.assert_called_once_with( + [str(working_binding)], "deleted", user_name="u1" + ) + + +def test_update_existing_memory_no_mark_when_working_binding_matches( + history_manager, mock_graph_db +): + history_manager.mark_memory_status = MagicMock() + primary_id = uuid.uuid4().hex + mock_graph_db.get_node.return_value = { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, + } + mock_graph_db.get_nodes.return_value = [ + { + "id": primary_id, + "memory": "Old Content", + "metadata": {"version": 1, "working_binding": primary_id, "embedding": []}, + } + ] + mem_data = { + "key": "k", + "value": "Updated", + "tags": [], + "source_candidate_ids": [primary_id], + "conflicted_candidate_ids": [], + } + + updated, new_item = history_manager._update_existing_memory( + mem_data, + [primary_id], + [primary_id], + {primary_id: 1}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated is not None + assert new_item is None + history_manager.mark_memory_status.assert_not_called() + + +def test_update_existing_memory_node_missing(history_manager, mock_graph_db): + mock_graph_db.get_node.return_value = None + mock_graph_db.get_nodes.return_value = [] + mem_data = {"value": "v", "tags": [], "key": "k"} + + updated, new_item = history_manager._update_existing_memory( + mem_data, + ["missing"], + [], + {}, + user_name="u1", + fast_item=TextualMemoryItem( + memory="New user input", metadata=TreeNodeTextualMemoryMetadata() + ), + ) + + assert updated is None + assert new_item is not None + assert new_item.memory == "v" + mock_graph_db.update_node.assert_not_called() + + +def test_update_node_with_history(): + item = TextualMemoryItem( + memory="Old Content", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + tags=["old"], + key="k1", + history=[], + ), + ) + + updated, archived = MemoryHistoryManager.update_node_with_history( + item, + "New Content", + "conflict", + ) + + assert updated.memory == "New Content" + assert updated.metadata.version == 3 + assert updated.metadata.tags == ["old"] + assert updated.metadata.key == "k1" + assert len(updated.metadata.history) == 1 + history_entry = updated.metadata.history[0] + assert history_entry.memory == "Old Content" + assert history_entry.update_type == "conflict" + assert history_entry.archived_memory_id == archived.id + assert archived.metadata.status == "archived" + assert archived.metadata.evolve_to == [updated.id] + + +def test_merge_conflicting_memory_llm_error(mock_graph_db): + llm = MagicMock() + llm.generate.side_effect = Exception("fail") + manager = MemoryHistoryManager( + nli_client=MagicMock(spec=NLIClient), graph_db=mock_graph_db, llm=llm + ) + + merged = manager._merge_conflicting_memory("Latest", "Proposed") + + assert "System Merge Fallback" in merged + assert "Latest" in merged + assert "Proposed" in merged From a47fd071c1de7f8f4ee97f3f0123c1173a959c97 Mon Sep 17 00:00:00 2001 From: Jiang <33757498+hijzy@users.noreply.github.com> Date: Fri, 27 Feb 2026 14:28:12 +0800 Subject: [PATCH 52/92] fix: get embedding of pref mem from db instead of recompute (#1133) * fix: add embedding for pref memory * fix: recompute embedding for missing memory instead of all memory * reformat * fix: get embedding from pref instead of pref.payload * fix: add include embedding params to pref mem * fix: decrease top k --------- Co-authored-by: jiang --- src/memos/api/handlers/search_handler.py | 72 +++++++++++++++---- .../textual/prefer_text_memory/retrievers.py | 52 +++++++++----- .../tree_text_memory/retrieve/searcher.py | 2 +- src/memos/multi_mem_cube/single_cube.py | 5 +- 4 files changed, 98 insertions(+), 33 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 267d1bb28..8e7785ad5 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -64,7 +64,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Expand top_k for deduplication (5x to ensure enough candidates) if search_req_local.dedup in ("sim", "mmr"): - search_req_local.top_k = search_req_local.top_k * 5 + search_req_local.top_k = search_req_local.top_k * 3 # Search and deduplicate cube_view = self._build_cube_view(search_req_local) @@ -152,9 +152,6 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di return results embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) - if embeddings is None: - documents = [mem.get("memory", "") for _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) similarity_matrix = cosine_similarity_matrix(embeddings) @@ -235,12 +232,39 @@ def _mmr_dedup_text_memories( if len(flat) <= 1: return results + total_by_type: dict[str, int] = {"text": 0, "preference": 0} + existing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_indices: list[int] = [] + for idx, (mem_type, _, mem, _) in enumerate(flat): + if mem_type not in total_by_type: + total_by_type[mem_type] = 0 + existing_by_type[mem_type] = 0 + missing_by_type[mem_type] = 0 + total_by_type[mem_type] += 1 + + embedding = mem.get("metadata", {}).get("embedding") + if embedding: + existing_by_type[mem_type] += 1 + else: + missing_by_type[mem_type] += 1 + missing_indices.append(idx) + + self.logger.info( + "[SearchHandler] MMR embedding metadata scan: total=%s total_by_type=%s existing_by_type=%s missing_by_type=%s", + len(flat), + total_by_type, + existing_by_type, + missing_by_type, + ) + if missing_indices: + self.logger.warning( + "[SearchHandler] MMR embedding metadata missing; will compute missing embeddings: missing_total=%s", + len(missing_indices), + ) + # Get or compute embeddings embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat]) - if embeddings is None: - self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings") - documents = [mem.get("memory", "") for _, _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) # Compute similarity matrix using NumPy-optimized method # Returns numpy array but compatible with list[i][j] indexing @@ -404,14 +428,32 @@ def _max_similarity( return 0.0 return max(similarity_matrix[index][j] for j in selected_indices) - @staticmethod - def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + def _extract_embeddings(self, memories: list[dict[str, Any]]) -> list[list[float]]: embeddings: list[list[float]] = [] - for mem in memories: - embedding = mem.get("metadata", {}).get("embedding") - if not embedding: - return None - embeddings.append(embedding) + missing_indices: list[int] = [] + missing_documents: list[str] = [] + + for idx, mem in enumerate(memories): + metadata = mem.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + mem["metadata"] = metadata + + embedding = metadata.get("embedding") + if embedding: + embeddings.append(embedding) + continue + + embeddings.append([]) + missing_indices.append(idx) + missing_documents.append(mem.get("memory", "")) + + if missing_indices: + computed = self.searcher.embedder.embed(missing_documents) + for idx, embedding in zip(missing_indices, computed, strict=False): + embeddings[idx] = embedding + memories[idx]["metadata"]["embedding"] = embedding + return embeddings @staticmethod diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 6352d5840..8483a5151 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -124,25 +124,45 @@ def retrieve( explicit_prefs.sort(key=lambda x: x.score, reverse=True) implicit_prefs.sort(key=lambda x: x.score, reverse=True) - explicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + explicit_prefs_mem = [] + for pref in explicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + explicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in explicit_prefs - if pref.payload.get("preference", None) - ] - implicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + implicit_prefs_mem = [] + for pref in implicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + implicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in implicit_prefs - if pref.payload.get("preference", None) - ] reranker_map = { "naive": self._naive_reranker, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9dcbe8c56..cc269e8c4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -524,7 +524,7 @@ def _retrieve_from_keyword( user_name=user_name, tsquery_config="jiebaqry", ) - except Exception as e: + except Exception: logger.warning( f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}" ) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2a6991e8c..02723e28c 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -443,7 +443,10 @@ def _search_pref( }, search_filter=search_req.filter, ) - formatted_results = self._postformat_memories(results, user_context.mem_cube_id) + include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true" + formatted_results = self._postformat_memories( + results, user_context.mem_cube_id, include_embedding=include_embedding + ) # For each returned item, tackle with metadata.info project_id / # operation / manager_user_id From 984fbcc1e3533070b8eec86949485e611590bb3a Mon Sep 17 00:00:00 2001 From: bittergreen Date: Fri, 27 Feb 2026 17:52:11 +0800 Subject: [PATCH 53/92] fix: Cancel marking in mem version fast process for performance considerations. --- src/memos/mem_reader/multi_modal_struct.py | 4 - .../organize/history_manager.py | 94 +------------------ 2 files changed, 3 insertions(+), 95 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 94b7d3de6..b07ce53d6 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1003,10 +1003,6 @@ def _fast_resolve_memory_duplicates_and_conflicts( len(related), len(conflicting_or_duplicate_ids), ) - # mark delete(temporarily) - self.history_manager.mark_memory_status( - conflicting_or_duplicate_ids, "resolving", user_name=user_name - ) except Exception as e: logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 7ebc25694..4e33af59f 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -25,61 +25,6 @@ logger = logging.getLogger(__name__) -CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]" -DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]" - - -def _append_related_content( - new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str] -) -> None: - """ - Append duplicate and conflict memory contents to the new item's memory text, - truncated to avoid excessive length. - """ - max_per_item_len = 200 - max_section_len = 1000 - - def _format_section(title: str, items: list[str]) -> str: - if not items: - return "" - - section_content = "" - for mem in items: - # Truncate individual item - snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem - # Check total section length - if len(section_content) + len(snippet) + 5 > max_section_len: - section_content += "\n- ... (more items truncated)" - break - section_content += f"\n- {snippet}" - - return f"\n\n{title}:{section_content}" - - append_text = "" - append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts) - append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates) - - if append_text: - new_item.memory += append_text - - -def _detach_related_content(new_item: TextualMemoryItem) -> None: - """ - Detach duplicate and conflict memory contents from the new item's memory text. - """ - markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"] - - cut_index = -1 - for marker in markers: - idx = new_item.memory.find(marker) - if idx != -1 and (cut_index == -1 or idx < cut_index): - cut_index = idx - - if cut_index != -1: - new_item.memory = new_item.memory[:cut_index] - - return - def _rebuild_fast_node_history( item: TextualMemoryItem, replacements: dict[int, list[ArchivedTextualMemory]] @@ -284,10 +229,6 @@ def resolve_history_via_nli( ) new_item.metadata.history.append(archived) - # 3. Concat duplicate/conflict memories to new_item.memory - # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. - _append_related_content(new_item, duplicate_memories, conflict_memories) - return duplicate_memory_ids + conflict_memory_ids def wait_and_update_fast_history( @@ -331,7 +272,7 @@ def wait_and_update_fast_history( break # 2. Check status of the fast nodes and fetch replacements for evolved ones - replacements = self._check_and_fetch_replacements(item, pending_indices) + replacements = self._check_and_fetch_replacements(item, pending_indices, user_name) # 3. If we have any resolved items, rebuild the history if replacements: @@ -362,9 +303,6 @@ def format_async_update_prompt( Returns: Formatted prompt string. """ - # First, detach duplicate and conflict memory contents from the new item's memory text - _detach_related_content(item) - duplicate_candidates = [] conflict_candidates = [] unrelated_candidates = [] @@ -437,14 +375,7 @@ def apply_llm_memory_updates( for mem_data in memory_list: used_source_ids.update(mem_data.get("source_candidate_ids", [])) used_conflict_ids.update(mem_data.get("conflicted_candidate_ids", [])) - # Collect IDs referenced by restored memories - restored_source_ids = { - mem.get("source_candidate_id") - for mem in restored_memories - if mem.get("source_candidate_id") - } - # All IDs used by the LLM response - used_ids = used_source_ids | used_conflict_ids | restored_source_ids + expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories candidate_id_set: set[str] = set() # Recover candidate IDs and their expected versions from the source item's history @@ -454,14 +385,10 @@ def apply_llm_memory_updates( candidate_id_set.add(h.archived_memory_id) expected_versions[h.archived_memory_id] = h.version - # IDs not used by the LLM response will be treated as unrelated - unrelated_ids = sorted(candidate_id_set - used_ids) updated_items: list[TextualMemoryItem] = [] new_items: list[TextualMemoryItem] = [] - # 1. Handle Unrelated Candidates - Restore status to activated - self._handle_unrelated_candidates(unrelated_ids, user_name=user_name) - + # 1. Handle Unrelated Candidates - Do nothing # 2. Handle Memory List (Update or New) processed_updates, created_items = self._process_memory_updates( memory_list, expected_versions, user_name, source_item @@ -478,12 +405,6 @@ def build_fallback_new_items( self, item: TextualMemoryItem, user_name: str | None = None ) -> list[TextualMemoryItem]: latest_item = item.model_copy(deep=True) - _detach_related_content(latest_item) - - history = latest_item.metadata.history or [] - archived_ids = [h.archived_memory_id for h in history if h.archived_memory_id] - if archived_ids: - self.mark_memory_status(archived_ids, "activated", user_name or "") latest_item.id = str(uuid.uuid4()) latest_item.metadata.is_fast = False @@ -595,11 +516,6 @@ def _fetch_evolved_nodes( return results - def _handle_unrelated_candidates(self, unrelated_ids: list[str], user_name: str) -> None: - """Handle Unrelated Candidates - Restore status to `activated`.""" - if unrelated_ids: - self.mark_memory_status(unrelated_ids, "activated", user_name) - def _process_memory_updates( self, memory_list: list[dict[str, Any]], @@ -666,7 +582,6 @@ def _update_existing_memory( new_value_item = TextualMemoryItem( memory=new_value, metadata=TreeNodeTextualMemoryMetadata() ) - _detach_related_content(new_value_item) new_value = new_value_item.memory # Fetch candidate nodes in batch and then select the primary @@ -765,7 +680,6 @@ def _apply_cas_merge( latest_item = TextualMemoryItem( memory=current_item.memory, metadata=TreeNodeTextualMemoryMetadata() ) - _detach_related_content(latest_item) merged_content = self._merge_conflicting_memory( latest_memory=latest_item.memory, proposed_update=new_value, @@ -869,7 +783,6 @@ def _create_new_memory( new_value_item = TextualMemoryItem( memory=new_value, metadata=TreeNodeTextualMemoryMetadata() ) - _detach_related_content(new_value_item) new_value = new_value_item.memory tags = mem_data.get("tags", []) key = mem_data.get("key", "") @@ -915,7 +828,6 @@ def _handle_restored_memories( source_history = source_item.history.copy() value = data.get("value", "") value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) - _detach_related_content(value_item) value = value_item.memory tags = data.get("tags", []) keys = data.get("keys", []) From 3cfa17b08953cc2246ad04c2fd77d8d74da3c6c6 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Fri, 27 Feb 2026 20:18:42 +0800 Subject: [PATCH 54/92] fix: Fix test --- .../memories/textual/test_history_manager.py | 56 ------------------- 1 file changed, 56 deletions(-) diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 11b5d7f44..74962e6b7 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -14,8 +14,6 @@ ) from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, - _append_related_content, - _detach_related_content, _rebuild_fast_node_history, ) @@ -36,60 +34,6 @@ def history_manager(mock_nli_client, mock_graph_db): return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db) -def test_detach_related_content(): - original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) - - duplicates = ["Duplicate 1", "Duplicate 2"] - conflicts = ["Conflict 1", "Conflict 2"] - - # 1. Append content - _append_related_content(item, duplicates, conflicts) - - # Verify content was appended - assert item.memory != original_memory - assert "[possibly conflicting memories]" in item.memory - assert "[possibly duplicate memories]" in item.memory - assert "Duplicate 1" in item.memory - assert "Conflict 1" in item.memory - - # 2. Detach content - _detach_related_content(item) - - # 3. Verify content is restored - assert item.memory == original_memory - - -def test_detach_only_conflicts(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) - - duplicates = [] - conflicts = ["Conflict A"] - - _append_related_content(item, duplicates, conflicts) - assert "Conflict A" in item.memory - assert "Duplicate" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - -def test_detach_only_duplicates(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TreeNodeTextualMemoryMetadata()) - - duplicates = ["Duplicate A"] - conflicts = [] - - _append_related_content(item, duplicates, conflicts) - assert "Duplicate A" in item.memory - assert "Conflict" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - def test_truncation(history_manager, mock_nli_client): # Setup new_item = TextualMemoryItem(memory="Test") From 5c6a712d4418aa1bb187f8a755ef8b8de77e575c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Sat, 28 Feb 2026 14:46:05 +0800 Subject: [PATCH 55/92] feat: move recall & nli judging to async mode --- src/memos/mem_reader/multi_modal_struct.py | 73 +++++++++++----------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index b07ce53d6..3a9a859cc 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -959,7 +959,7 @@ def _process_tool_trajectory_fine( return fine_memory_items def _fast_resolve_memory_duplicates_and_conflicts( - self, fast_memory_items: list[TextualMemoryItem], user_name: str + self, item: list[TextualMemoryItem], user_name: str ) -> None: """ 1. Recall related memories @@ -968,44 +968,44 @@ def _fast_resolve_memory_duplicates_and_conflicts( 4. Mark conflicting/duplicate old memory nodes as "resolving", making them invisible to /search, but still visible for other conflict/duplication checks' recalls. """ + if not self.history_manager.is_applicable(item): + return + if not self.pre_update_retriever or not self.history_manager: logger.warning( "[MultiModalStruct] PreUpdateRetriever or HistoryManager is not initialized." ) return - for item in fast_memory_items: - if not self.history_manager.is_applicable(item): - continue - try: - # recall related memories - retrieve_start = time.perf_counter() - related = self.pre_update_retriever.retrieve( - item=item, - user_name=user_name, - ) - retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 - logger.info( - "[MultiModalStruct] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", - retrieve_ms, - getattr(item, "id", None), - ) - # NLI check & attaching contents - nli_start = time.perf_counter() - conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( - item, related - ) - nli_ms = (time.perf_counter() - nli_start) * 1000 - logger.info( - "[MultiModalStruct] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", - nli_ms, - getattr(item, "id", None), - len(related), - len(conflicting_or_duplicate_ids), - ) + try: + # recall related memories + retrieve_start = time.perf_counter() + related = self.pre_update_retriever.retrieve( + item=item, + user_name=user_name, + ) + retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 + logger.info( + "[MultiModalStruct] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", + retrieve_ms, + getattr(item, "id", None), + ) + # NLI check & attaching contents + nli_start = time.perf_counter() + conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( + item, related + ) + nli_ms = (time.perf_counter() - nli_start) * 1000 + logger.info( + "[MultiModalStruct] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", + nli_ms, + getattr(item, "id", None), + len(related), + len(conflicting_or_duplicate_ids), + ) - except Exception as e: - logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + except Exception as e: + logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") def _process_async_versioning_update( self, item: TextualMemoryItem, mem_str: str, custom_tags: dict[str, str], **kwargs @@ -1016,7 +1016,10 @@ def _process_async_versioning_update( 3. Call LLM and parse JSON response 4. Apply LLM updates to memory graph and return new items """ + # resolve via nli first user_name = kwargs.get("user_name") + self._fast_resolve_memory_duplicates_and_conflicts(item, user_name) + self.history_manager.wait_and_update_fast_history(item, user_name, timeout_sec=30) lang = detect_lang(kwargs.get("chat_history") or mem_str) custom_tags_prompt = ( @@ -1091,12 +1094,6 @@ def _process_multi_modal_data( ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) - # Perform conflict/duplicate check with old memories - # TODO: find a better way to pass in the user_name - user_name = kwargs.get("user_name") - if self.memory_version_switch == "on": - self._fast_resolve_memory_duplicates_and_conflicts(fast_memory_items, user_name) - if mode == "fast": return fast_memory_items else: From a329e8fe24ed6abcbf3333f1d4e8d5498587f0f5 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 3 Mar 2026 10:55:04 +0800 Subject: [PATCH 56/92] fix: exclude self-search in PreUpdateRetriever --- .../tree_text_memory/retrieve/pre_update.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index 1d66732d2..3c69ee31a 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -138,7 +138,7 @@ def vector_search( results = self.graph_db.search_by_embedding( vector=q_embed, top_k=top_k, - status=None, + status="activated", threshold=threshold, user_name=user_name, filter=search_filter, @@ -211,12 +211,9 @@ def retrieve( futures = [] common_filter = { "or": [ - {"status": "activated", "memory_type": "LongTermMemory"}, - {"status": "activated", "memory_type": "UserMemory"}, - {"status": "activated", "memory_type": "WorkingMemory"}, - {"status": "resolving", "memory_type": "LongTermMemory"}, - {"status": "resolving", "memory_type": "UserMemory"}, - {"status": "resolving", "memory_type": "WorkingMemory"}, + {"memory_type": "LongTermMemory"}, + {"memory_type": "UserMemory"}, + {"memory_type": "WorkingMemory"}, ] } @@ -247,7 +244,10 @@ def retrieve( continue for r in res: - retrieved_ids.add(r["id"]) + # exclude self and working binding + working_binding = item.metadata.working_binding or "" + if r["id"] != item.id and r["id"] != working_binding: + retrieved_ids.add(r["id"]) except Exception as e: logger.error(f"[PreUpdateRetriever] Search future task failed: {e}") From edecdda730f556aa35c30a030a17bf842ee0fa5b Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 3 Mar 2026 15:21:31 +0800 Subject: [PATCH 57/92] fix: fix mem-version behaviors under concurrent situations --- .../handlers/mem_read_handler.py | 32 ++++++++++++++--- src/memos/memories/textual/tree.py | 30 ++++++++++++++++ .../organize/history_manager.py | 34 ++++++------------- .../tree_text_memory/retrieve/pre_update.py | 10 +++++- 4 files changed, 77 insertions(+), 29 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 3f44915fc..e0baf63ff 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -390,10 +390,34 @@ def _process_memories_with_reader( delete_ids = list(dict.fromkeys(delete_ids)) if delete_ids: try: - text_mem.delete(delete_ids, user_name=user_name) - logger.info( - "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name - ) + if getattr(mem_reader, "memory_version_switch", "off") != "on": + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + "Delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) + else: + # change to soft-delete for mem versions + flattened_memories = [] + if processed_memories and len(processed_memories) > 0: + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + allowed_types = ["UserMemory", "LongTermMemory"] + text_mem.soft_delete( + delete_ids, + user_name, + [ + mem.id + for mem in flattened_memories + if mem.metadata.memory_type in allowed_types + ], + ) + logger.info( + "Soft delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) except Exception as e: logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) else: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5b210ba61..6e2284f55 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -605,3 +605,33 @@ def add_graph_edges( future.result() except Exception as e: logger.exception("Add edge error: ", exc_info=e) + + def soft_delete( + self, + memory_ids: list[str], + user_name: str, + evolve_to_ids: list[str] | None = None, + ) -> None: + # for ruff check... + if not evolve_to_ids: + update_fields = {"status": "deleted"} + else: + update_fields = {"status": "deleted", "evolve_to": evolve_to_ids} + + # Execute the actual marking operation - in db. + with ContextThreadPoolExecutor() as executor: + futures = [] + for mid in memory_ids: + futures.append( + executor.submit( + self.graph_store.update_node, + id=mid, + fields=update_fields, + user_name=user_name, + ) + ) + + # Wait for all tasks to complete and raise any exceptions + for future in futures: + future.result() + return diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 4e33af59f..2af1290dc 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -369,20 +369,12 @@ def apply_llm_memory_updates( """ memory_list = llm_response.get("memory list", []) restored_memories = llm_response.get("restored_memories", []) - used_source_ids = set() - used_conflict_ids = set() - # Collect IDs referenced by the LLM response - for mem_data in memory_list: - used_source_ids.update(mem_data.get("source_candidate_ids", [])) - used_conflict_ids.update(mem_data.get("conflicted_candidate_ids", [])) expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories - candidate_id_set: set[str] = set() # Recover candidate IDs and their expected versions from the source item's history if source_item.metadata and source_item.metadata.history: for h in source_item.metadata.history: if h.archived_memory_id: - candidate_id_set.add(h.archived_memory_id) expected_versions[h.archived_memory_id] = h.version updated_items: list[TextualMemoryItem] = [] @@ -464,23 +456,21 @@ def _check_and_fetch_replacements( # Batch fetch pending nodes to check status nodes_data = self.graph_db.get_nodes(ids=pending_ids, user_name=user_name) or [] - nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} + nodes_map = {n["id"]: TextualMemoryItem(**n) for n in nodes_data if n and "id" in n} replacements = {} for i in pending_indices: h_item = item.metadata.history[i] - node_data = nodes_map.get(h_item.archived_memory_id) + node = nodes_map.get(h_item.archived_memory_id) - if not node_data: + if not node: continue - metadata = node_data.get("metadata", {}) - status = metadata.get("status") - + metadata = _sanitize_metadata_model(node.metadata) # deal with embedded metadata # Condition: Fast node is processed when it is marked as 'deleted' - if status == "deleted": - evolve_to_ids = metadata.get("evolve_to", []) + if metadata.status == "deleted": + evolve_to_ids = metadata.evolve_to new_items = self._fetch_evolved_nodes(evolve_to_ids, h_item.update_type, user_name) replacements[i] = new_items @@ -523,14 +513,14 @@ def _process_memory_updates( user_name: str, source_item: TextualMemoryItem, ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]]: - """Process Memory List (Update or New).""" + """Process Memory List (Update or Create).""" updated_items: list[TextualMemoryItem] = [] new_items: list[TextualMemoryItem] = [] for mem_data in memory_list: source_ids = mem_data.get("source_candidate_ids", []) conflict_ids = mem_data.get("conflicted_candidate_ids", []) - # Determine if this is an update or a new node + # Determine if this is an update or a creation target_ids = source_ids + conflict_ids if target_ids: @@ -569,11 +559,7 @@ def _update_existing_memory( Returns the updated primary TextualMemoryItem and optional new item when fallback is used. """ - original_primary_id, primary_id, secondary_ids = ( - target_ids[0], - target_ids[0], - target_ids[1:], - ) + primary_id, secondary_ids = target_ids[0], target_ids[1:] new_value, tags, key = ( mem_data.get("value", ""), mem_data.get("tags", []), @@ -602,7 +588,7 @@ def _update_existing_memory( # If it has(version changed), then we need to use llm to merge again. new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) - update_type = "duplicate" if original_primary_id in source_ids else "conflict" + update_type = "duplicate" if primary_id in source_ids else "conflict" current_item, archived_item = self.update_node_with_history( current_item, new_value, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index 3c69ee31a..db3c5d294 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -1,6 +1,7 @@ import concurrent.futures import re +from datetime import datetime from typing import Any from memos.context.context import ContextThreadPoolExecutor @@ -142,6 +143,7 @@ def vector_search( threshold=threshold, user_name=user_name, filter=search_filter, + return_fields=["id", "is_fast", "created_at"], ) return results except Exception as e: @@ -155,6 +157,7 @@ def keyword_search( top_k: int, search_filter: dict[str, Any] | None = None, ) -> list[dict]: + # Currently not used for large latency try: # 1. Tokenize using existing tokenizer keywords = self.tokenizer.tokenize_mixed(query_text) @@ -245,8 +248,13 @@ def retrieve( for r in res: # exclude self and working binding + # also exclude fast nodes that's created after current node to avoid deadlock later working_binding = item.metadata.working_binding or "" - if r["id"] != item.id and r["id"] != working_binding: + if (r["id"] != item.id and r["id"] != working_binding) and ( + not r["is_fast"] + or datetime.fromisoformat(r["created_at"]) + < datetime.fromisoformat(item.metadata.created_at) + ): retrieved_ids.add(r["id"]) except Exception as e: From 9591638bc8455494123a44ad11d34466ea686869 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 3 Mar 2026 19:46:32 +0800 Subject: [PATCH 58/92] fix: deal with absence of is_fast --- .../memories/textual/tree_text_memory/retrieve/pre_update.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py index db3c5d294..b500f6b61 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -250,8 +250,9 @@ def retrieve( # exclude self and working binding # also exclude fast nodes that's created after current node to avoid deadlock later working_binding = item.metadata.working_binding or "" + is_fast = bool(r.get("is_fast", False)) if (r["id"] != item.id and r["id"] != working_binding) and ( - not r["is_fast"] + not is_fast or datetime.fromisoformat(r["created_at"]) < datetime.fromisoformat(item.metadata.created_at) ): From 445d8326287186dbee6fc3d365cae25229c72ca0 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Mar 2026 14:30:03 +0800 Subject: [PATCH 59/92] refactor: simplify memory version call to minimize influence to opensource code. --- src/memos/api/handlers/component_init.py | 9 +- src/memos/mem_reader/factory.py | 8 - src/memos/mem_reader/multi_modal_struct.py | 105 ++------------ src/memos/mem_reader/simple_struct.py | 5 - .../init_components_for_scheduler.py | 8 +- .../organize/history_manager.py | 137 +++++++++++++++++- .../memories/textual/test_history_manager.py | 14 +- 7 files changed, 160 insertions(+), 126 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index b3492e7bb..53e7bfe16 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -172,15 +172,18 @@ def init_server() -> dict[str, Any]: ) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) + pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) memory_history_manager = MemoryHistoryManager( - nli_client=nli_client, graph_db=graph_db, llm=llm, embedder=embedder + nli_client=nli_client, + graph_db=graph_db, + llm=llm, + embedder=embedder, + pre_update_retriever=pre_update_retriever, ) - pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config( mem_reader_config, graph_db=graph_db, - pre_update_retriever=pre_update_retriever, history_manager=memory_history_manager, ) reranker = RerankerFactory.from_config(reranker_config) diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 0907168a7..8e54873c8 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -5,7 +5,6 @@ from memos.mem_reader.multi_modal_struct import MultiModalStructMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_reader.strategy_struct import StrategyStructMemReader -from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memos_tools.singleton import singleton_factory @@ -33,7 +32,6 @@ def from_config( config_factory: MemReaderConfigFactory, graph_db: Optional["BaseGraphDB | None"] = None, searcher: Optional["Searcher | None"] = None, - pre_update_retriever: PreUpdateRetriever | None = None, history_manager: Optional["MemoryHistoryManager | None"] = None, ) -> BaseMemReader: """ @@ -61,12 +59,6 @@ def from_config( if searcher is not None: reader.set_searcher(searcher) - if pre_update_retriever is not None: - if hasattr(reader, "set_pre_update_retriever"): - reader.set_pre_update_retriever(pre_update_retriever) - else: - reader.pre_update_retriever = pre_update_retriever - if history_manager is not None: if hasattr(reader, "set_history_manager"): reader.set_history_manager(history_manager) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index e1b78a080..f7120148a 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1,7 +1,6 @@ import concurrent.futures import json import re -import time import traceback from typing import TYPE_CHECKING, Any @@ -58,7 +57,6 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) - self.pre_update_retriever = None self.history_manager = None self.memory_version_switch = getattr(config, "memory_version_switch", "off") @@ -704,19 +702,27 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) - # ========== Stage 0: Memory versioning async update pipeline ========== + # ========== Stage 0: Memory version async extraction/update pipeline ========== if ( self.memory_version_switch == "on" and self.history_manager is not None and self.history_manager.is_applicable(fast_item) ): try: - new_items = self._process_async_versioning_update( - fast_item, mem_str, custom_tags, **kwargs + user_name = kwargs.get("user_name") + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt_template = PROMPT_DICT["custom_tags"][lang] + new_items = self.history_manager.apply_mem_version_update( + fast_item, + user_name, + self.qwen_llm, + custom_tags=custom_tags, + custom_tags_prompt_template=custom_tags_prompt_template, + timeout_sec=30, ) return new_items - except Exception as e: - logger.warning(f"[MultiModalFine] Async versioning pipeline failed: {e}") + except Exception as ex: + logger.warning(f"[MultiModalFine] Fine memory version pipeline failed: {ex}") return [] # ========== Stage 1: Normal extraction (without reference) ========== @@ -959,91 +965,6 @@ def _process_tool_trajectory_fine( return fine_memory_items - def _fast_resolve_memory_duplicates_and_conflicts( - self, item: list[TextualMemoryItem], user_name: str - ) -> None: - """ - 1. Recall related memories - 2. Fast conflict/duplication check with NLI model - 3. Attach conflicting/duplicate old memory contents onto fast memory items - 4. Mark conflicting/duplicate old memory nodes as "resolving", making them invisible to /search, - but still visible for other conflict/duplication checks' recalls. - """ - if not self.history_manager.is_applicable(item): - return - - if not self.pre_update_retriever or not self.history_manager: - logger.warning( - "[MultiModalStruct] PreUpdateRetriever or HistoryManager is not initialized." - ) - return - - try: - # recall related memories - retrieve_start = time.perf_counter() - related = self.pre_update_retriever.retrieve( - item=item, - user_name=user_name, - ) - retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 - logger.info( - "[MultiModalStruct] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", - retrieve_ms, - getattr(item, "id", None), - ) - # NLI check & attaching contents - nli_start = time.perf_counter() - conflicting_or_duplicate_ids = self.history_manager.resolve_history_via_nli( - item, related - ) - nli_ms = (time.perf_counter() - nli_start) * 1000 - logger.info( - "[MultiModalStruct] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", - nli_ms, - getattr(item, "id", None), - len(related), - len(conflicting_or_duplicate_ids), - ) - - except Exception as e: - logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") - - def _process_async_versioning_update( - self, item: TextualMemoryItem, mem_str: str, custom_tags: dict[str, str], **kwargs - ) -> list[TextualMemoryItem]: - """ - 1. Wait for fast node resolution and rebuild its history - 2. Build async update prompt (include custom tags and conversation context) - 3. Call LLM and parse JSON response - 4. Apply LLM updates to memory graph and return new items - """ - # resolve via nli first - user_name = kwargs.get("user_name") - self._fast_resolve_memory_duplicates_and_conflicts(item, user_name) - - self.history_manager.wait_and_update_fast_history(item, user_name, timeout_sec=30) - lang = detect_lang(kwargs.get("chat_history") or mem_str) - custom_tags_prompt = ( - PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) - if custom_tags - else "" - ) - prompt = self.history_manager.format_async_update_prompt(item, custom_tags_prompt) - try: - response_text = self.qwen_llm.generate([{"role": "user", "content": prompt}]) - if not response_text: - raise ValueError("Empty LLM response") - response_json = parse_json_result(response_text) - if not response_json: - raise ValueError("Empty LLM JSON response") - _, new_items = self.history_manager.apply_llm_memory_updates( - response_json, item, user_name=user_name - ) - return new_items - except Exception as e: - logger.warning(f"[MultiModalStruct] Async update fallback due to LLM failure: {e}") - return self.history_manager.build_fallback_new_items(item, user_name=user_name) - @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 180fd6ee8..09fd8ca12 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -24,7 +24,6 @@ from memos.memories.textual.tree_text_memory.organize.history_manager import ( MemoryHistoryManager, ) - from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang from memos.mem_reader.utils import ( @@ -199,7 +198,6 @@ def __init__(self, config: SimpleStructMemReaderConfig): # Initialize graph_db as None, can be set later via set_graph_db for # recall operations self.graph_db = None - self.pre_update_retriever = None self.history_manager = None def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: @@ -208,9 +206,6 @@ def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: def set_searcher(self, searcher: "Searcher | None") -> None: self.searcher = searcher - def set_pre_update_retriever(self, pre_update_retriever: "PreUpdateRetriever | None") -> None: - self.pre_update_retriever = pre_update_retriever - def set_history_manager(self, history_manager: "MemoryHistoryManager | None") -> None: self.history_manager = history_manager diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 9aad26854..7a4bdf326 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -261,13 +261,17 @@ def init_components() -> dict[str, Any]: llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) pre_update_retriever = PreUpdateRetriever(graph_db=graph_db, embedder=embedder) + memory_history_manager = MemoryHistoryManager( + nli_client=nli_client, + graph_db=graph_db, + embedder=embedder, + pre_update_retriever=pre_update_retriever, + ) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config( mem_reader_config, graph_db=graph_db, - pre_update_retriever=pre_update_retriever, history_manager=memory_history_manager, ) reranker = RerankerFactory.from_config(reranker_config) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 2af1290dc..d5eb69117 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -1,4 +1,6 @@ +import json import logging +import re import time import uuid @@ -17,6 +19,7 @@ TextualMemoryItem, TreeNodeTextualMemoryMetadata, ) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever from memos.templates.mem_reader_mem_version_prompts import ( ASYNC_MEMORY_UPDATE_PROMPT_DICT, MEMORY_MERGE_PROMPT_DICT, @@ -86,6 +89,48 @@ def _determine_lang(sources: list | None, fallback_text: str) -> str: return lang +def _parse_json_result(response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.warning( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) + return {} + + class MemoryHistoryManager: def __init__( self, @@ -93,6 +138,7 @@ def __init__( graph_db: BaseGraphDB, llm: BaseLLM | None = None, embedder: BaseEmbedder | None = None, + pre_update_retriever: PreUpdateRetriever | None = None, ) -> None: """ Initialize the MemoryHistoryManager. @@ -106,6 +152,7 @@ def __init__( self.graph_db = graph_db self.llm = llm self.embedder = embedder + self.pre_update_retriever = pre_update_retriever def _compute_embedding(self, text: str) -> list[float] | None: if not self.embedder: @@ -290,9 +337,7 @@ def wait_and_update_fast_history( return - def format_async_update_prompt( - self, item: TextualMemoryItem, custom_tags_prompt: str = "" - ) -> str: + def format_prompt(self, item: TextualMemoryItem, custom_tags_prompt: str = "") -> str: """ Format the prompt for asynchronous memory update. @@ -443,6 +488,92 @@ def mark_memory_status( future.result() return + def prepare_history_candidates_via_nli(self, item: TextualMemoryItem, user_name: str) -> None: + """ + 1. Recall related memories + 2. Fast conflict/duplication check with NLI model + 3. Attach conflicting/duplicate old memory contents onto fast memory items + """ + if not self.is_applicable(item): + return + + if not self.pre_update_retriever: + logger.warning("[MemoryHistoryManager] PreUpdateRetriever is not initialized.") + return + + try: + # recall related memories + retrieve_start = time.perf_counter() + related = self.pre_update_retriever.retrieve( + item=item, + user_name=user_name, + ) + retrieve_ms = (time.perf_counter() - retrieve_start) * 1000 + logger.info( + "[MemoryHistoryManager] pre_update_retriever.retrieve latency_ms=%.2f item_id=%s", + retrieve_ms, + getattr(item, "id", None), + ) + # NLI check & attaching contents + nli_start = time.perf_counter() + conflicting_or_duplicate_ids = self.resolve_history_via_nli(item, related) + nli_ms = (time.perf_counter() - nli_start) * 1000 + logger.info( + "[MemoryHistoryManager] history_manager.resolve_history_via_nli latency_ms=%.2f item_id=%s related_count=%s result_count=%s", + nli_ms, + getattr(item, "id", None), + len(related), + len(conflicting_or_duplicate_ids), + ) + + except Exception as e: + logger.warning(f"[MultiModalStruct] Fast recall failed: {e}") + + def apply_mem_version_update( + self, + original_item: TextualMemoryItem, + user_name: str, + llm: BaseLLM | None, + custom_tags: dict[str, str] | None, + custom_tags_prompt_template: str | None, + timeout_sec: int = 30, + ) -> list[TextualMemoryItem]: + """ + 1. Wait for 'fast histories' in the item to resolve, and rebuild its history + 2. Build memory extraction/update prompt (include custom tags and conversation context) + 3. Call LLM and parse JSON response + 4. Apply LLM updates to memory graph and return new items + """ + self.prepare_history_candidates_via_nli(original_item, user_name) + self.wait_and_update_fast_history(original_item, user_name, timeout_sec=timeout_sec) + + custom_tags_prompt = ( + custom_tags_prompt_template.replace("{custom_tags}", str(custom_tags)) + if custom_tags_prompt_template and custom_tags + else "" + ) + prompt = self.format_prompt(original_item, custom_tags_prompt) + try: + if llm is None: + raise ValueError("LLM is not initialized") + response_text = llm.generate([{"role": "user", "content": prompt}]) + if not response_text: + raise ValueError("Empty LLM response") + response_json = _parse_json_result(response_text) + if not response_json: + raise ValueError("Empty LLM JSON response") + + _, new_items = self.apply_llm_memory_updates( + response_json, original_item, user_name=user_name + ) + return new_items + + except Exception as e: + logger.warning( + f"[MemoryHistoryManager] Memory extraction/update fallback due to LLM failure: {e}" + ) + return self.build_fallback_new_items(original_item, user_name=user_name) + def _check_and_fetch_replacements( self, item: TextualMemoryItem, pending_indices: list[int], user_name: str ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 74962e6b7..813a4b622 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -98,7 +98,7 @@ def test_format_async_update_prompt(history_manager): ) # Execute - prompt = history_manager.format_async_update_prompt(item) + prompt = history_manager.format_prompt(item) # Verify assert "[ID:101]" in prompt @@ -114,18 +114,6 @@ def test_format_async_update_prompt(history_manager): assert "${conflict_candidates}" not in prompt -def test_format_async_update_prompt_with_args(history_manager): - item = TextualMemoryItem( - memory="Explicit conversation", - metadata=TreeNodeTextualMemoryMetadata(history=[]), - ) - - prompt = history_manager.format_async_update_prompt(item, custom_tags_prompt="Custom Tags") - - assert "Explicit conversation" in prompt - assert "Custom Tags" in prompt - - def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): llm_response = { "memory list": [ From 186f85732144d825b2e640345ee4333eb53b2f5f Mon Sep 17 00:00:00 2001 From: Jiang <33757498+hijzy@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:57:45 +0800 Subject: [PATCH 60/92] fix: Use relativity instead of score for preference memory (#1153) * test: add rerank model * test: test reranker model * test: test reranker model * test: delete useless log * test: reformat --------- Co-authored-by: jiang --- src/memos/api/handlers/search_handler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 58121776e..ba1c50b07 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -120,10 +120,7 @@ def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> d if not isinstance(mem, dict): continue meta = mem.get("metadata", {}) - if key == "text_mem": - score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 - else: - score = meta.get("score", 1.0) if isinstance(meta, dict) else 1.0 + score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 try: score_val = float(score) if score is not None else 1.0 except (TypeError, ValueError): From f4e73932bc796c69817db37d6a2a21125b5aa57a Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 4 Mar 2026 14:55:18 +0800 Subject: [PATCH 61/92] fix: image bug; single item in multi-mudal-reader has no embedding; (#1154) * fix: image url bug * fix: single item has no embedding bug * fix: image lang bug * fix: image lang bug --- src/memos/mem_reader/multi_modal_struct.py | 25 ++++++++----------- .../read_multi_modal/image_parser.py | 6 ++--- .../mem_reader/read_multi_modal/utils.py | 21 +++++++++++++++- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index f7120148a..08dc8acf8 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -193,8 +193,16 @@ def _concat_multi_modal_memories( else: processed_items.append(item) - # If only one item after processing, return as-is + # If only one item after processing, compute embedding and return if len(processed_items) == 1: + single_item = processed_items[0] + if single_item and single_item.memory: + try: + single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0] + except Exception as e: + logger.error( + f"[MultiModalStruct] Error computing embedding for single item: {e}" + ) return processed_items windows = [] @@ -292,7 +300,6 @@ def _build_window_from_items( # Collect all memory texts and sources memory_texts = [] all_sources = [] - seen_content = set() # Track seen source content to avoid duplicates roles = set() aggregated_file_ids: list[str] = [] @@ -306,18 +313,8 @@ def _build_window_from_items( item_sources = [item_sources] for source in item_sources: - # Get content from source for deduplication - source_content = None - if isinstance(source, dict): - source_content = source.get("content", "") - else: - source_content = getattr(source, "content", "") or "" - - # Only add if content is different (empty content is considered unique) - content_key = source_content if source_content else None - if content_key and content_key not in seen_content: - seen_content.add(content_key) - all_sources.append(source) + # Add source to all_sources + all_sources.append(source) # Extract role from source if hasattr(source, "role") and source.role: diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 97400ca26..d66642edb 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -137,10 +137,10 @@ def parse_fine( # Get context items if available context_items = kwargs.get("context_items") - # Determine language: prioritize lang from source (passed via kwargs), - # fallback to detecting from context_items if lang not provided + # Determine language: prioritize lang from context_items, + # fallback to kwargs lang = kwargs.get("lang") - if lang is None and context_items: + if context_items: for item in context_items: if hasattr(item, "memory") and item.memory: lang = detect_lang(item.memory) diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index be82587bf..96918589b 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -341,13 +341,32 @@ def detect_lang(text): if not text or not isinstance(text, str): return "en" cleaned_text = text - # remove role and timestamp + # remove role and timestamp-like prefixes cleaned_text = re.sub( r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE ) + # timestamps like [11:32 AM on 04 March, 2026] + cleaned_text = re.sub( + r"\[\s*\d{1,2}:\d{2}\s*(?:AM|PM)\s+on\s+\d{2}\s+[A-Za-z]+\s*,\s*\d{4}\s*\]", + "", + cleaned_text, + flags=re.IGNORECASE, + ) + # purely numeric timestamps like [2025-01-01 10:00] cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) # remove URLs to prevent the dilution of Chinese characters cleaned_text = re.sub(r'https?://[^\s<>"{}|\\^`\[\]]+', "", cleaned_text) + # remove MessageType schema keywords (multimodal JSON noise) + cleaned_text = re.sub( + r"\b(text|type|image_url|imageurl|url)\b", "", cleaned_text, flags=re.IGNORECASE + ) + # remove schema keywords like text / type / image_url / url + cleaned_text = re.sub( + r"\b(text|type|image_url|imageurl|url|file|file_id)\b", + "", + cleaned_text, + flags=re.IGNORECASE, + ) # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" chinese_chars = re.findall(chinese_pattern, cleaned_text) From ef355b0218412586a9614c66d4c10ade4d69b2ce Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 17:49:49 +0800 Subject: [PATCH 62/92] feat: plugin system and demo plugin --- extensions/memos_demo_plugin/__init__.py | 6 + extensions/memos_demo_plugin/hook_defs.py | 36 ++ extensions/memos_demo_plugin/hooks.py | 87 ++++ extensions/memos_demo_plugin/middleware.py | 25 + extensions/memos_demo_plugin/plugin.py | 78 ++++ extensions/memos_demo_plugin/routes.py | 92 ++++ .../memos_demo_plugin/tests/__init__.py | 0 .../memos_demo_plugin/tests/conftest.py | 14 + .../memos_demo_plugin/tests/test_hooks.py | 110 +++++ .../memos_demo_plugin/tests/test_lifecycle.py | 99 ++++ .../tests/test_middleware.py | 62 +++ .../memos_demo_plugin/tests/test_routes.py | 194 ++++++++ poetry.lock | 105 ++++- pyproject.toml | 14 +- src/memos/api/handlers/add_handler.py | 6 + src/memos/api/handlers/search_handler.py | 2 + src/memos/api/server_api.py | 5 + src/memos/plugins/__init__.py | 20 + src/memos/plugins/base.py | 72 +++ src/memos/plugins/hook_defs.py | 88 ++++ src/memos/plugins/hooks.py | 124 +++++ src/memos/plugins/manager.py | 75 +++ tests/plugins/__init__.py | 0 tests/plugins/conftest.py | 17 + tests/plugins/run_plugin_server.py | 0 tests/plugins/test_plugin_demo.py | 439 ++++++++++++++++++ 26 files changed, 1750 insertions(+), 20 deletions(-) create mode 100644 extensions/memos_demo_plugin/__init__.py create mode 100644 extensions/memos_demo_plugin/hook_defs.py create mode 100644 extensions/memos_demo_plugin/hooks.py create mode 100644 extensions/memos_demo_plugin/middleware.py create mode 100644 extensions/memos_demo_plugin/plugin.py create mode 100644 extensions/memos_demo_plugin/routes.py create mode 100644 extensions/memos_demo_plugin/tests/__init__.py create mode 100644 extensions/memos_demo_plugin/tests/conftest.py create mode 100644 extensions/memos_demo_plugin/tests/test_hooks.py create mode 100644 extensions/memos_demo_plugin/tests/test_lifecycle.py create mode 100644 extensions/memos_demo_plugin/tests/test_middleware.py create mode 100644 extensions/memos_demo_plugin/tests/test_routes.py create mode 100644 src/memos/plugins/__init__.py create mode 100644 src/memos/plugins/base.py create mode 100644 src/memos/plugins/hook_defs.py create mode 100644 src/memos/plugins/hooks.py create mode 100644 src/memos/plugins/manager.py create mode 100644 tests/plugins/__init__.py create mode 100644 tests/plugins/conftest.py create mode 100644 tests/plugins/run_plugin_server.py create mode 100644 tests/plugins/test_plugin_demo.py diff --git a/extensions/memos_demo_plugin/__init__.py b/extensions/memos_demo_plugin/__init__.py new file mode 100644 index 000000000..3e8d0ff78 --- /dev/null +++ b/extensions/memos_demo_plugin/__init__.py @@ -0,0 +1,6 @@ +"""memos-demo-plugin — a complete example plugin demonstrating the MemOS plugin system.""" + +from memos_demo_plugin.plugin import DemoPlugin + + +__all__ = ["DemoPlugin"] diff --git a/extensions/memos_demo_plugin/hook_defs.py b/extensions/memos_demo_plugin/hook_defs.py new file mode 100644 index 000000000..7ae920f48 --- /dev/null +++ b/extensions/memos_demo_plugin/hook_defs.py @@ -0,0 +1,36 @@ +"""Demo plugin-owned hook declarations. + +Hooks that the plugin declares, triggers, and registers callbacks for are defined here. +CE-exposed hooks (e.g. add.before/after) are managed by CE's hook_defs.py; the plugin only needs to reference them. +""" + +from memos.plugins.hook_defs import define_hook + + +class DemoH: + """Demo plugin hook name constants.""" + + # @hookable("demo.test") — auto-generates before/after + TEST_BEFORE = "demo.test.before" + TEST_AFTER = "demo.test.after" + + # Manually triggered via trigger_hook + TEST_POST_PROCESS = "demo.test.post_process" + REPORT_ENRICH = "demo.report.enrich" + + +# ── Custom hook declarations (@hookable-generated before/after need not be declared here) ── + +define_hook( + DemoH.TEST_POST_PROCESS, + description="post-process result after demo test endpoint business logic runs", + params=["request", "result"], + pipe_key="result", +) + +define_hook( + DemoH.REPORT_ENRICH, + description="after user activity report is generated, allows callbacks to extend report data", + params=["user_id", "report"], + pipe_key="report", +) diff --git a/extensions/memos_demo_plugin/hooks.py b/extensions/memos_demo_plugin/hooks.py new file mode 100644 index 000000000..776f9e877 --- /dev/null +++ b/extensions/memos_demo_plugin/hooks.py @@ -0,0 +1,87 @@ +"""Demo plugin hook callbacks. + +Two groups: + 1. CE hook responders — plugin listens to CE-exposed extension points (add/search etc.) + 2. Plugin-owned hooks — extension points the plugin declares and triggers (demo.test / demo.report) + +All callbacks are bound to the plugin instance via functools.partial(callback, plugin_instance). +""" + +from __future__ import annotations + +import logging + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from memos_demo_plugin.plugin import DemoPlugin + +logger = logging.getLogger(__name__) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# 1. CE hook responders — listen to CE-exposed @hookable / trigger_hook extension points +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +def log_operation(plugin: DemoPlugin, *, request, **kw) -> None: + """[add.before / search.before] Log operation (notification-style).""" + uid = getattr(request, "user_id", "unknown") + plugin.request_log.append({"user_id": uid}) + logger.info("[Demo] operation logged user=%s", uid) + + +def count_add(plugin: DemoPlugin, *, request, result, **kw) -> None: + """[add.after] Count add calls per user (notification-style).""" + uid = getattr(request, "user_id", "unknown") + plugin.add_counter[uid] = plugin.add_counter.get(uid, 0) + 1 + logger.info("[Demo] add count user=%s total=%d", uid, plugin.add_counter[uid]) + + +def post_process_add(plugin: DemoPlugin, *, request, result, **kw): + """[add.memories.post_process] Post-process add_memories result (pipe-style, returns result).""" + uid = getattr(request, "user_id", "unknown") + plugin.post_process_log.append({"user_id": uid, "result_count": len(result)}) + logger.info("[Demo] post_process_add user=%s count=%d", uid, len(result)) + return result + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# 2. Plugin-owned hooks — declared by plugin and triggered in routes.py +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +def on_test_before(plugin: DemoPlugin, *, request, **kw): + """[demo.test.before] @hookable auto-triggered, can modify request (pipe-style).""" + uid = getattr(request, "user_id", "anonymous") + plugin.hook_test_log.append({"phase": "before", "user_id": uid}) + logger.info("[Demo] test.before user=%s", uid) + return request + + +def on_test_after(plugin: DemoPlugin, *, request, result, **kw): + """[demo.test.after] @hookable auto-triggered, can modify result (pipe-style).""" + uid = getattr(request, "user_id", "anonymous") + plugin.hook_test_log.append({"phase": "after", "user_id": uid}) + result["hook_after_injected"] = True + logger.info("[Demo] test.after user=%s", uid) + return result + + +def on_test_post_process(plugin: DemoPlugin, *, request, result, **kw): + """[demo.test.post_process] trigger_hook manual trigger, can modify result (pipe-style).""" + uid = getattr(request, "user_id", "anonymous") + plugin.hook_test_log.append({"phase": "post_process", "user_id": uid}) + result["hook_post_process_injected"] = True + logger.info("[Demo] test.post_process user=%s", uid) + return result + + +def enrich_report(plugin: DemoPlugin, *, user_id, report, **kw): + """[demo.report.enrich] trigger_hook manual trigger, extend user activity report (pipe-style).""" + report["total_users_tracked"] = len(plugin.add_counter) + report["is_active_user"] = plugin.add_counter.get(user_id, 0) > 0 + report["enriched_by"] = plugin.name + logger.info("[Demo] enrich_report user=%s", user_id) + return report diff --git a/extensions/memos_demo_plugin/middleware.py b/extensions/memos_demo_plugin/middleware.py new file mode 100644 index 000000000..f2d4a73cd --- /dev/null +++ b/extensions/memos_demo_plugin/middleware.py @@ -0,0 +1,25 @@ +"""Demo plugin middleware — global audit logging for each request's method, path, status code, and duration.""" + +import logging +import time + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + + +logger = logging.getLogger(__name__) + + +class DemoAuditMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + start = time.time() + response = await call_next(request) + elapsed_ms = (time.time() - start) * 1000 + logger.info( + "[Demo Audit] %s %s → %s (%.1fms)", + request.method, + request.url.path, + response.status_code, + elapsed_ms, + ) + return response diff --git a/extensions/memos_demo_plugin/plugin.py b/extensions/memos_demo_plugin/plugin.py new file mode 100644 index 000000000..d3ef11c8d --- /dev/null +++ b/extensions/memos_demo_plugin/plugin.py @@ -0,0 +1,78 @@ +""" +Demo plugin main logic — complete demonstration of MemOS plugin's three extension capabilities. + +Scope: + 1. Register routes — self.register_router() + 2. Register middleware — self.register_middleware() + 3. Register hooks — self.register_hook() / self.register_hooks() + +Both community developers and enterprise self-hosted deployments can reference this plugin structure. +Package naming convention: memos-xx-plugin / memos_xx_plugin. +""" + +import logging + +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H +from memos_demo_plugin.hook_defs import DemoH + + +logger = logging.getLogger(__name__) + + +class DemoPlugin(MemOSPlugin): + name = "demo" + version = "0.1.0" + description = "Demo plugin — showcases routes, middleware, and hooks" + + # ── Lifecycle ──────────────────────────────────────────────────── + + def on_load(self) -> None: + self.add_counter: dict[str, int] = {} + self.request_log: list[dict] = [] + self.post_process_log: list[dict] = [] + self.hook_test_log: list[dict] = [] + logger.info("[Demo] plugin loaded") + + def init_app(self) -> None: + from memos_demo_plugin.hooks import ( + count_add, + enrich_report, + log_operation, + on_test_after, + on_test_before, + on_test_post_process, + post_process_add, + ) + from memos_demo_plugin.middleware import DemoAuditMiddleware + from memos_demo_plugin.routes import create_router + + # 1) Routes + self.register_router(create_router(self)) + + # 2) Middleware + self.register_middleware(DemoAuditMiddleware) + + # 3) Hooks — respond to CE @hookable extension points + self.register_hook(H.ADD_AFTER, partial(count_add, self)) + self.register_hooks([H.ADD_BEFORE, H.SEARCH_BEFORE], partial(log_operation, self)) + self.register_hook(H.ADD_MEMORIES_POST_PROCESS, partial(post_process_add, self)) + + # 4) Hooks — plugin-owned extension points (constants from DemoH) + self.register_hook(DemoH.TEST_BEFORE, partial(on_test_before, self)) + self.register_hook(DemoH.TEST_AFTER, partial(on_test_after, self)) + self.register_hook(DemoH.TEST_POST_PROCESS, partial(on_test_post_process, self)) + self.register_hook(DemoH.REPORT_ENRICH, partial(enrich_report, self)) + + logger.info("[Demo] plugin initialized") + + def on_shutdown(self) -> None: + logger.info( + "[Demo] plugin shutdown — users=%d, ops=%d, post_process=%d, hook_tests=%d", + len(self.add_counter), + len(self.request_log), + len(self.post_process_log), + len(self.hook_test_log), + ) diff --git a/extensions/memos_demo_plugin/routes.py b/extensions/memos_demo_plugin/routes.py new file mode 100644 index 000000000..af6a18983 --- /dev/null +++ b/extensions/memos_demo_plugin/routes.py @@ -0,0 +1,92 @@ +"""Demo plugin routes — demonstrates full usage of plugin-registered routes + both hook trigger styles.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter +from pydantic import BaseModel + +from memos.plugins.hooks import hookable, trigger_hook +from memos_demo_plugin.hook_defs import DemoH + + +if TYPE_CHECKING: + from memos_demo_plugin.plugin import DemoPlugin + + +# ── Request models ──────────────────────────────────────────────────────── + + +class TestHookRequest(BaseModel): + user_id: str = "anonymous" + message: str = "hello" + + +# ── Router factory ──────────────────────────────────────────────────────── + + +def create_router(plugin: DemoPlugin) -> APIRouter: + router = APIRouter(prefix="/demo", tags=["demo"]) + + # ── Basic routes ── + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name, "version": plugin.version} + + @router.get("/stats") + async def stats(): + return { + "add_counter": plugin.add_counter, + "total_adds": sum(plugin.add_counter.values()), + "recent_requests": plugin.request_log[-20:], + } + + # ── Hook demo routes ── + + class _HookDemoHandler: + """Demonstrates @hookable decorator: auto-triggers demo.test.before / demo.test.after.""" + + @hookable("demo.test") + def handle(self, request: TestHookRequest): + result = { + "user_id": request.user_id, + "echo": request.message, + "processed": True, + } + rv = trigger_hook(DemoH.TEST_POST_PROCESS, request=request, result=result) + return rv if rv is not None else result + + handler = _HookDemoHandler() + + @router.post("/test-hook") + async def test_hook(req: TestHookRequest): + """Full hook demo endpoint. + + Call chain: + 1. demo.test.before — @hookable auto, pipe-style, can modify request + 2. handler business logic + 3. demo.test.post_process — trigger_hook manual, pipe-style, can modify result + 4. demo.test.after — @hookable auto, pipe-style, can modify result + 5. demo.report.enrich — trigger_hook manual, pipe-style, can modify report + """ + result = handler.handle(req) + + report = { + "user_id": req.user_id, + "add_count": plugin.add_counter.get(req.user_id, 0), + "operation_count": sum( + 1 for r in plugin.request_log if r.get("user_id") == req.user_id + ), + } + rv = trigger_hook(DemoH.REPORT_ENRICH, user_id=req.user_id, report=report) + report = rv if rv is not None else report + + return { + "hook_test": result, + "user_report": report, + "plugin_state": {"hook_test_log": plugin.hook_test_log[-10:]}, + } + + return router diff --git a/extensions/memos_demo_plugin/tests/__init__.py b/extensions/memos_demo_plugin/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/extensions/memos_demo_plugin/tests/conftest.py b/extensions/memos_demo_plugin/tests/conftest.py new file mode 100644 index 000000000..91aaf61af --- /dev/null +++ b/extensions/memos_demo_plugin/tests/conftest.py @@ -0,0 +1,14 @@ +"""memos_demo_plugin tests — ensure hooks used by the plugin are declared. + +CE @hookable declarations are triggered by manually calling hookable(). +Plugin-owned hook declarations are triggered by importing the hook_defs module (module-level define_hook calls). +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("demo.test") + +import memos_demo_plugin.hook_defs # noqa: E402, F401 — triggers plugin-owned hook declarations diff --git a/extensions/memos_demo_plugin/tests/test_hooks.py b/extensions/memos_demo_plugin/tests/test_hooks.py new file mode 100644 index 000000000..40b2197a4 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_hooks.py @@ -0,0 +1,110 @@ +"""DemoPlugin hook callback verification — including @hookable before/after and custom trigger_hook.""" + +import logging + +from fastapi import FastAPI + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +def _make_plugin(): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + return plugin + + +class TestHookCallbacks: + """Verify business logic of each hook callback in the Demo plugin.""" + + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_add_after_counts(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "alice" + + trigger_hook("add.after", request=Req(), result={}) + trigger_hook("add.after", request=Req(), result={}) + + assert plugin.add_counter["alice"] == 2 + + def test_add_before_logs(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "bob" + + trigger_hook("add.before", request=Req()) + + assert len(plugin.request_log) == 1 + assert plugin.request_log[0]["user_id"] == "bob" + + def test_search_before_logs(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "charlie" + + trigger_hook("search.before", request=Req()) + + assert len(plugin.request_log) == 1 + assert plugin.request_log[0]["user_id"] == "charlie" + + def test_multiple_users(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class ReqA: + user_id = "alice" + + class ReqB: + user_id = "bob" + + trigger_hook("add.before", request=ReqA()) + trigger_hook("add.after", request=ReqA(), result={}) + trigger_hook("add.before", request=ReqB()) + trigger_hook("add.after", request=ReqB(), result={}) + trigger_hook("add.before", request=ReqA()) + trigger_hook("add.after", request=ReqA(), result={}) + + assert plugin.add_counter == {"alice": 2, "bob": 1} + assert len(plugin.request_log) == 3 + + def test_post_process_hook_is_pipeline(self): + """add.memories.post_process is a pipeline-style hook; callbacks can modify and return result.""" + from memos.plugins.hook_defs import H + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "dave" + + original = [{"id": 1}, {"id": 2}] + rv = trigger_hook(H.ADD_MEMORIES_POST_PROCESS, request=Req(), result=original) + + assert rv is original + assert len(plugin.post_process_log) == 1 + assert plugin.post_process_log[0]["user_id"] == "dave" + assert plugin.post_process_log[0]["result_count"] == 2 diff --git a/extensions/memos_demo_plugin/tests/test_lifecycle.py b/extensions/memos_demo_plugin/tests/test_lifecycle.py new file mode 100644 index 000000000..4c6156980 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_lifecycle.py @@ -0,0 +1,99 @@ +"""DemoPlugin lifecycle & PluginManager integration tests.""" + +import logging + +from fastapi import FastAPI + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + """Simulate the PluginManager initialization flow.""" + plugin._bind_app(app) + plugin.init_app() + + +class TestDemoPluginLifecycle: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_metadata(self): + from memos_demo_plugin.plugin import DemoPlugin + + plugin = DemoPlugin() + assert plugin.name == "demo" + assert plugin.version == "0.1.0" + + def test_on_load_initializes_state(self): + from memos_demo_plugin.plugin import DemoPlugin + + plugin = DemoPlugin() + plugin.on_load() + + assert plugin.add_counter == {} + assert plugin.request_log == [] + assert plugin.post_process_log == [] + assert plugin.hook_test_log == [] + + def test_on_shutdown_no_error(self): + from memos_demo_plugin.plugin import DemoPlugin + + plugin = DemoPlugin() + plugin.on_load() + plugin.on_shutdown() + + def test_full_lifecycle(self): + """Full lifecycle: on_load → init_app → normal operation → on_shutdown.""" + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + + paths = [r.path for r in app.routes] + assert "/demo/health" in paths + assert "/demo/stats" in paths + assert "/demo/test-hook" in paths + + plugin.on_shutdown() + + +class TestPluginManager: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_manual_registration_and_init(self): + from memos.plugins.manager import PluginManager + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + manager = PluginManager() + + plugin = DemoPlugin() + plugin.on_load() + manager._plugins[plugin.name] = plugin + + assert "demo" in manager.plugins + + manager.init_app(app) + + paths = [r.path for r in app.routes] + assert "/demo/health" in paths + assert "/demo/stats" in paths + + def test_shutdown(self): + from memos.plugins.manager import PluginManager + from memos_demo_plugin.plugin import DemoPlugin + + manager = PluginManager() + plugin = DemoPlugin() + plugin.on_load() + manager._plugins[plugin.name] = plugin + + manager.shutdown() diff --git a/extensions/memos_demo_plugin/tests/test_middleware.py b/extensions/memos_demo_plugin/tests/test_middleware.py new file mode 100644 index 000000000..45082fdb9 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_middleware.py @@ -0,0 +1,62 @@ +"""DemoPlugin middleware integration tests.""" + +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +def _make_app(): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + return app, plugin + + +class TestMiddlewareRegistration: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_audit_middleware_logs(self, caplog): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + return {"ok": True} + + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + + client = TestClient(app) + with caplog.at_level(logging.INFO): + resp = client.get("/test") + + assert resp.status_code == 200 + assert any("[Demo Audit]" in r.message for r in caplog.records) + + def test_audit_middleware_on_plugin_routes(self, caplog): + app, _ = _make_app() + + client = TestClient(app) + with caplog.at_level(logging.INFO): + client.get("/demo/health") + + assert any( + "[Demo Audit]" in r.message and "/demo/health" in r.message for r in caplog.records + ) diff --git a/extensions/memos_demo_plugin/tests/test_routes.py b/extensions/memos_demo_plugin/tests/test_routes.py new file mode 100644 index 000000000..82e83d558 --- /dev/null +++ b/extensions/memos_demo_plugin/tests/test_routes.py @@ -0,0 +1,194 @@ +"""DemoPlugin routes and /demo/test-hook endpoint tests.""" + +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +def _make_app(): + from memos_demo_plugin.plugin import DemoPlugin + + app = FastAPI() + plugin = DemoPlugin() + plugin.on_load() + _init_plugin(plugin, app) + return app, plugin + + +# ========================================================================= # +# Route registration verification +# ========================================================================= # + + +class TestRouteRegistration: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_routes_exist(self): + app, _ = _make_app() + paths = [r.path for r in app.routes] + assert "/demo/health" in paths + assert "/demo/stats" in paths + assert "/demo/test-hook" in paths + + def test_health_endpoint(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/demo/health") + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["plugin"] == "demo" + assert data["version"] == "0.1.0" + + def test_stats_endpoint_empty(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/demo/stats") + + assert resp.status_code == 200 + data = resp.json() + assert data["add_counter"] == {} + assert data["total_adds"] == 0 + assert data["recent_requests"] == [] + + def test_stats_endpoint_after_hooks(self): + from memos.plugins.hooks import trigger_hook + + app, plugin = _make_app() + + class FakeRequest: + user_id = "user_42" + + trigger_hook("add.before", request=FakeRequest()) + trigger_hook("add.after", request=FakeRequest(), result={"ok": True}) + trigger_hook("add.before", request=FakeRequest()) + trigger_hook("add.after", request=FakeRequest(), result={"ok": True}) + + client = TestClient(app) + resp = client.get("/demo/stats") + data = resp.json() + + assert data["add_counter"]["user_42"] == 2 + assert data["total_adds"] == 2 + assert len(data["recent_requests"]) == 2 + + +# ========================================================================= # +# /demo/test-hook endpoint — @hookable + custom trigger_hook full chain +# ========================================================================= # + + +class TestHookEndpoint: + """Verify full hook call chain of the test endpoint: + demo.test.before → business logic → demo.test.post_process → demo.test.after + """ + + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_basic_response(self): + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "tester", "message": "ping"}) + assert resp.status_code == 200 + + data = resp.json() + hook_result = data["hook_test"] + assert hook_result["user_id"] == "tester" + assert hook_result["echo"] == "ping" + assert hook_result["processed"] is True + + def test_after_hook_injects_field(self): + """demo.test.after callback injects hook_after_injected=True.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "u1", "message": "hi"}) + assert resp.json()["hook_test"]["hook_after_injected"] is True + + def test_post_process_hook_injects_field(self): + """demo.test.post_process custom hook injects hook_post_process_injected=True.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "u2", "message": "world"}) + assert resp.json()["hook_test"]["hook_post_process_injected"] is True + + def test_records_all_three_phases(self): + """plugin.hook_test_log should record all three phases: before / post_process / after.""" + app, plugin = _make_app() + client = TestClient(app) + + client.post("/demo/test-hook", json={"user_id": "u3", "message": "test"}) + + phases = [entry["phase"] for entry in plugin.hook_test_log] + assert "before" in phases + assert "post_process" in phases + assert "after" in phases + + def test_state_in_response(self): + """plugin_state.hook_test_log in response should contain records for all three phases.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "u4", "message": "check"}) + data = resp.json() + + log = data["plugin_state"]["hook_test_log"] + assert len(log) >= 3 + assert any(e["phase"] == "before" for e in log) + assert any(e["phase"] == "post_process" for e in log) + assert any(e["phase"] == "after" for e in log) + + def test_multiple_calls_accumulate(self): + """hook_test_log should accumulate after multiple calls.""" + app, plugin = _make_app() + client = TestClient(app) + + client.post("/demo/test-hook", json={"user_id": "a"}) + client.post("/demo/test-hook", json={"user_id": "b"}) + + assert len(plugin.hook_test_log) >= 6 + + def test_default_values(self): + """Call with default parameters.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={}) + data = resp.json() + + assert data["hook_test"]["user_id"] == "anonymous" + assert data["hook_test"]["echo"] == "hello" + + def test_custom_hook_enrich_report(self): + """demo.report.enrich custom hook example — response contains user_report extended by callback.""" + app, _ = _make_app() + client = TestClient(app) + + resp = client.post("/demo/test-hook", json={"user_id": "alice"}) + data = resp.json() + + report = data["user_report"] + assert report["user_id"] == "alice" + assert "add_count" in report + assert "operation_count" in report + assert report["enriched_by"] == "demo" + assert "total_users_tracked" in report + assert "is_active_user" in report diff --git a/poetry.lock b/poetry.lock index ba31d1a31..dccd154d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -68,10 +68,10 @@ trio = ["trio (>=0.26.1)"] name = "async-timeout" version = "5.0.1" description = "Timeout context manager for asyncio programs" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "python_full_version < \"3.11.3\" and (extra == \"mem-scheduler\" or extra == \"all\")" +markers = "python_full_version < \"3.11.3\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -394,10 +394,9 @@ files = [ name = "chonkie" version = "1.1.1" description = "🦛 CHONK your texts with Chonkie ✨ - The no-nonsense chunking library" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"mem-reader\" or extra == \"all\"" files = [ {file = "chonkie-1.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c56cff89f38ff5cc06b2e8b4c9a802b85b77ba8ecdda6896f5dba6b0c54d4303"}, {file = "chonkie-1.1.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8bb0d88b4254a9bac7b494349ee3c94f94d3ee2f8cd4970d23e0c0ef3e6392a4"}, @@ -1296,7 +1295,6 @@ files = [ {file = "grpcio-1.73.1-cp39-cp39-win_amd64.whl", hash = "sha256:42f0660bce31b745eb9d23f094a332d31f210dcadd0fc8e5be7e4c62a87ce86b"}, {file = "grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87"}, ] -markers = {main = "extra == \"pref-mem\" or extra == \"all\""} [package.extras] protobuf = ["grpcio-tools (>=1.73.1)"] @@ -1607,10 +1605,9 @@ files = [ name = "jieba" version = "0.42" description = "Chinese Words Segmentation Utilities" -optional = true +optional = false python-versions = "*" groups = ["main"] -markers = "extra == \"all\"" files = [ {file = "jieba-0.42.tar.gz", hash = "sha256:34a3c960cc2943d9da16d6d2565110cf5f305921a67413dddf04f84de69c939b"}, ] @@ -3235,7 +3232,6 @@ files = [ {file = "pandas-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b4b0de34dc8499c2db34000ef8baad684cfa4cbd836ecee05f323ebfba348c7d"}, {file = "pandas-2.3.1.tar.gz", hash = "sha256:0a95b9ac964fe83ce317827f80304d37388ea77616b1425f0ae41c9d2d0d7bb2"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [package.dependencies] numpy = [ @@ -3298,10 +3294,9 @@ image = ["Pillow"] name = "pika" version = "1.3.2" description = "Pika Python AMQP Client Library" -optional = true +optional = false python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"mem-scheduler\" or extra == \"all\"" files = [ {file = "pika-1.3.2-py3-none-any.whl", hash = "sha256:0779a7c1fafd805672796085560d290213a465e4f6f76a6fb19e378d8041a14f"}, {file = "pika-1.3.2.tar.gz", hash = "sha256:b2a327ddddf8570b4965b3576ac77091b850262d34ce8c1d8cb4e4146aa4145f"}, @@ -3568,7 +3563,83 @@ files = [ {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} + +[[package]] +name = "psycopg2-binary" +version = "2.9.11" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "psycopg2-binary-2.9.11.tar.gz", hash = "sha256:b6aed9e096bf63f9e75edf2581aa9a7e7186d97ab5c177aa6c87797cd591236c"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6fe6b47d0b42ce1c9f1fa3e35bb365011ca22e39db37074458f27921dca40f2"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a6c0e4262e089516603a09474ee13eabf09cb65c332277e39af68f6233911087"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c47676e5b485393f069b4d7a811267d3168ce46f988fa602658b8bb901e9e64d"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a28d8c01a7b27a1e3265b11250ba7557e5f72b5ee9e5f3a2fa8d2949c29bf5d2"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5f3f2732cf504a1aa9e9609d02f79bea1067d99edf844ab92c247bbca143303b"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:865f9945ed1b3950d968ec4690ce68c55019d79e4497366d36e090327ce7db14"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:91537a8df2bde69b1c1db01d6d944c831ca793952e4f57892600e96cee95f2cd"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4dca1f356a67ecb68c81a7bc7809f1569ad9e152ce7fd02c2f2036862ca9f66b"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:0da4de5c1ac69d94ed4364b6cbe7190c1a70d325f112ba783d83f8440285f152"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37d8412565a7267f7d79e29ab66876e55cb5e8e7b3bbf94f8206f6795f8f7e7e"}, + {file = "psycopg2_binary-2.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:c665f01ec8ab273a61c62beeb8cce3014c214429ced8a308ca1fc410ecac3a39"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0e8480afd62362d0a6a27dd09e4ca2def6fa50ed3a4e7c09165266106b2ffa10"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:763c93ef1df3da6d1a90f86ea7f3f806dc06b21c198fa87c3c25504abec9404a"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908"}, + {file = "psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f090b7ddd13ca842ebfe301cd587a76a4cf0913b1e429eb92c1be5dbeb1a19bc"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d"}, + {file = "psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:366df99e710a2acd90efed3764bb1e28df6c675d33a7fb40df9b7281694432ee"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1"}, + {file = "psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:92e3b669236327083a2e33ccfa0d320dd01b9803b3e14dd986a4fc54aa00f4e1"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e0deeb03da539fa3577fcb0b3f2554a97f7e5477c246098dbb18091a4a01c16f"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b52a3f9bb540a3e4ec0f6ba6d31339727b2950c9772850d6545b7eae0b9d7c5"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:db4fd476874ccfdbb630a54426964959e58da4c61c9feba73e6094d51303d7d8"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47f212c1d3be608a12937cc131bd85502954398aaa1320cb4c14421a0ffccf4c"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e35b7abae2b0adab776add56111df1735ccc71406e56203515e228a8dc07089f"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fcf21be3ce5f5659daefd2b3b3b6e4727b028221ddc94e6c1523425579664747"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:9bd81e64e8de111237737b29d68039b9c813bdf520156af36d26819c9a979e5f"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:32770a4d666fbdafab017086655bcddab791d7cb260a16679cc5a7338b64343b"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3cb3a676873d7506825221045bd70e0427c905b9c8ee8d6acd70cfcbd6e576d"}, + {file = "psycopg2_binary-2.9.11-cp314-cp314-win_amd64.whl", hash = "sha256:4012c9c954dfaccd28f94e84ab9f94e12df76b4afb22331b1f0d3154893a6316"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:20e7fb94e20b03dcc783f76c0865f9da39559dcc0c28dd1a3fce0d01902a6b9c"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4bdab48575b6f870f465b397c38f1b415520e9879fdf10a53ee4f49dcbdf8a21"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9d3a9edcfbe77a3ed4bc72836d466dfce4174beb79eda79ea155cc77237ed9e8"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:44fc5c2b8fa871ce7f0023f619f1349a0aa03a0857f2c96fbc01c657dcbbdb49"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9c55460033867b4622cda1b6872edf445809535144152e5d14941ef591980edf"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:2d11098a83cca92deaeaed3d58cfd150d49b3b06ee0d0852be466bf87596899e"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:691c807d94aecfbc76a14e1408847d59ff5b5906a04a23e12a89007672b9e819"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b81627b691f29c4c30a8f322546ad039c40c328373b11dff7490a3e1b517855"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:b637d6d941209e8d96a072d7977238eea128046effbf37d1d8b2c0764750017d"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:41360b01c140c2a03d346cec3280cf8a71aa07d94f3b1509fa0161c366af66b4"}, + {file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"}, +] [[package]] name = "pycparser" @@ -3837,10 +3908,9 @@ windows-terminal = ["colorama (>=0.4.6)"] name = "pymilvus" version = "2.6.2" description = "Python Sdk for Milvus" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"pref-mem\" or extra == \"all\"" files = [ {file = "pymilvus-2.6.2-py3-none-any.whl", hash = "sha256:933e447e09424d490dcf595053b01a7277dadea7ae3235cd704363bd6792509d"}, {file = "pymilvus-2.6.2.tar.gz", hash = "sha256:b4802cc954de8f2d47bf8d6230e92196514dcb8a3726ba6098dc27909d4bc8e3"}, @@ -4033,7 +4103,6 @@ files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\" or extra == \"pref-mem\""} [[package]] name = "pywin32" @@ -4200,10 +4269,9 @@ dev = ["pytest"] name = "redis" version = "6.2.0" description = "Python client for Redis database and key-value store" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"mem-scheduler\" or extra == \"all\"" files = [ {file = "redis-6.2.0-py3-none-any.whl", hash = "sha256:c8ddf316ee0aab65f04a11229e94a64b2618451dab7a67cb2f77eb799d872d5e"}, {file = "redis-6.2.0.tar.gz", hash = "sha256:e821f129b75dde6cb99dd35e5c76e8c49512a5a0d8dfdc560b2fbd44b85ca977"}, @@ -5080,7 +5148,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5638,7 +5706,6 @@ files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "ujson" @@ -6373,4 +6440,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "faff240c05a74263a404e8d9324ffd2f342cb4f0a4c1f5455b87349f6ccc61a5" +content-hash = "0fd4408ce33b59ac489d4d9b0e632bb17538853f049a4edf17425ba83027b74a" diff --git a/pyproject.toml b/pyproject.toml index bd37b0c71..3998fb53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,12 @@ dependencies = [ "python-dateutil (>=2.9.0.post0,<3.0.0)", "prometheus-client (>=0.23.1,<0.24.0)", "concurrent-log-handler (>=0.9.28,<1.0.0)", # Process-safe rotating file handler + "redis (>=6.2.0,<7.0.0)", # Key-value store + "pika (>=1.3.2,<2.0.0)", # RabbitMQ client + "jieba (>=0.38.1,<0.42.1)", # Chinese text segmentation + "chonkie (>=1.0.7,<2.0.0)", # Sentence chunking + "pymilvus (>=2.5.12,<3.0.0)", # Milvus vector DB + "psycopg2-binary (>=2.9.9,<3.0.0)", # PostgreSQL / PolarDB driver ] [project.urls] @@ -62,6 +68,9 @@ issues = "https://github.com/MemTensor/MemOS/issues" [project.scripts] memos = "memos.cli:main" +[project.entry-points."memos.plugins"] +demo = "memos_demo_plugin:DemoPlugin" + [project.optional-dependencies] # These are optional dependencies for various features of MemoryOS. # Developers install: `poetry install --extras `. e.g., `poetry install --extras general-mem` @@ -153,7 +162,10 @@ build-backend = "poetry.core.masonry.api" # https://python-poetry.org/docs/dependency-specification#caret-requirements ############################################################################## -packages = [{include = "memos", from = "src"}] +packages = [ + {include = "memos", from = "src"}, + {include = "memos_demo_plugin", from = "extensions"}, +] requires-poetry = ">=2.0" dependencies = { "python" = ">=3.10,<4.0" } diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..e53ef9393 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,8 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hook_defs import H +from memos.plugins.hooks import hookable, trigger_hook from memos.types import MessageList @@ -37,6 +39,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. @@ -106,6 +109,9 @@ def _check_messages(messages: MessageList) -> None: results = cube_view.add_memories(add_req) + rv = trigger_hook(H.ADD_MEMORIES_POST_PROCESS, request=add_req, result=results) + results = rv if rv is not None else results + self.logger.info(f"[AddHandler] Final add results count={len(results)}") return MemoryResponse( diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 91980bdeb..42b3ddb4e 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -20,6 +20,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable logger = get_logger(__name__) @@ -44,6 +45,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) + @hookable("search") def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ Main handler for search memories endpoint. diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 529a709a4..78185a035 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -9,10 +9,13 @@ from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware from memos.api.routers.server_router import router as server_router +from memos.plugins.manager import plugin_manager load_dotenv() +plugin_manager.discover() + # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -38,6 +41,8 @@ # Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) +plugin_manager.init_app(app) + if __name__ == "__main__": import argparse diff --git a/src/memos/plugins/__init__.py b/src/memos/plugins/__init__.py new file mode 100644 index 000000000..0a0f8cde3 --- /dev/null +++ b/src/memos/plugins/__init__.py @@ -0,0 +1,20 @@ +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H, HookSpec, all_hook_specs, define_hook, get_hook_spec +from memos.plugins.hooks import hookable, register_hook, register_hooks, trigger_hook +from memos.plugins.manager import PluginManager, plugin_manager + + +__all__ = [ + "H", + "HookSpec", + "MemOSPlugin", + "PluginManager", + "all_hook_specs", + "define_hook", + "get_hook_spec", + "hookable", + "plugin_manager", + "register_hook", + "register_hooks", + "trigger_hook", +] diff --git a/src/memos/plugins/base.py b/src/memos/plugins/base.py new file mode 100644 index 000000000..f55d81b75 --- /dev/null +++ b/src/memos/plugins/base.py @@ -0,0 +1,72 @@ +"""MemOS plugin base class — all plugins must inherit from this class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastapi import FastAPI + from starlette.middleware.base import BaseHTTPMiddleware + + +class MemOSPlugin: + """MemOS plugin base class. + + Provides three unified registration methods. Plugin developers need only + inherit from this class and register capabilities via self.register_* + in init_app. + """ + + name: str = "unnamed" + version: str = "0.0.0" + description: str = "" + + _app: FastAPI | None = None + + # ------------------------------------------------------------------ # + # Registration methods — called by plugins in init_app + # ------------------------------------------------------------------ # + + def register_router(self, router, **kwargs) -> None: + """Register a router.""" + self._app.include_router(router, **kwargs) + + def register_middleware(self, middleware_cls: type[BaseHTTPMiddleware], **kwargs) -> None: + """Register middleware.""" + self._app.add_middleware(middleware_cls, **kwargs) + + def register_hook(self, name: str, callback: Callable) -> None: + """Register a single Hook callback.""" + from memos.plugins.hooks import register_hook + + register_hook(name, callback) + + def register_hooks(self, names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple Hook points.""" + from memos.plugins.hooks import register_hooks + + register_hooks(names, callback) + + # ------------------------------------------------------------------ # + # Internal methods — called by PluginManager, plugin developers need not care + # ------------------------------------------------------------------ # + + def _bind_app(self, app: FastAPI) -> None: + """Bind FastAPI instance so that register_* methods are available.""" + self._app = app + + # ------------------------------------------------------------------ # + # Lifecycle methods — override in subclasses + # ------------------------------------------------------------------ # + + def on_load(self) -> None: + """Called after the plugin is discovered. Used for initialization logic, e.g. checking dependencies, reading config.""" + + def init_app(self) -> None: + """Called after FastAPI app is bound. Register routes, middleware, and Hooks via self.register_* here.""" + + def on_shutdown(self) -> None: + """Called when the service shuts down. Used for resource cleanup.""" diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py new file mode 100644 index 000000000..a3b9dbaf8 --- /dev/null +++ b/src/memos/plugins/hook_defs.py @@ -0,0 +1,88 @@ +"""Hook declaration registry — single source of truth for CE repo Hook points. + +The @hookable decorator automatically declares its before/after Hooks; no need to manually define_hook. +Hooks triggered by custom trigger_hook must be explicitly declared in this file. + +Plugin-owned Hooks should be declared within each plugin package, not in this file. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + +_specs: dict[str, HookSpec] = {} + + +@dataclass(frozen=True) +class HookSpec: + """Hook spec definition.""" + + name: str + description: str + params: list[str] + pipe_key: str | None = None + + +def define_hook( + name: str, + *, + description: str, + params: list[str], + pipe_key: str | None = None, +) -> None: + """Declare a Hook point. Skips if already exists (idempotent).""" + if name in _specs: + return + _specs[name] = HookSpec( + name=name, + description=description, + params=params, + pipe_key=pipe_key, + ) + logger.debug("Hook defined: %s (pipe_key=%s)", name, pipe_key) + + +def get_hook_spec(name: str) -> HookSpec | None: + return _specs.get(name) + + +def all_hook_specs() -> dict[str, HookSpec]: + """Return all declared Hooks (including @hookable auto-declared + plugin-declared).""" + return dict(_specs) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE Hook name constants +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +class H: + """CE Hook name constants. Plugin-owned Hook constants should be defined within the plugin package.""" + + # @hookable("add") — AddHandler.handle_add_memories + ADD_BEFORE = "add.before" + ADD_AFTER = "add.after" + + # @hookable("search") — SearchHandler.handle_search_memories + SEARCH_BEFORE = "search.before" + SEARCH_AFTER = "search.after" + + # Custom Hook (manually triggered via trigger_hook) + ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE custom Hook declarations (@hookable-generated ones need not be declared here) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +define_hook( + H.ADD_MEMORIES_POST_PROCESS, + description="Post-process result after add_memories returns, before constructing Response", + params=["request", "result"], + pipe_key="result", +) diff --git a/src/memos/plugins/hooks.py b/src/memos/plugins/hooks.py new file mode 100644 index 000000000..eda98f98a --- /dev/null +++ b/src/memos/plugins/hooks.py @@ -0,0 +1,124 @@ +"""Hook runtime — registration, triggering, and @hookable decorator.""" + +from __future__ import annotations + +import asyncio +import logging + +from collections import defaultdict +from functools import wraps +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger(__name__) + +_hooks: dict[str, list[Callable]] = defaultdict(list) + + +def register_hook(name: str, callback: Callable) -> None: + """Register a hook callback. Undeclared hook names will log a warning.""" + from memos.plugins.hook_defs import get_hook_spec + + if get_hook_spec(name) is None: + logger.warning( + "Registering callback for undeclared hook: %s (callback=%s)", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + _hooks[name].append(callback) + logger.debug( + "Hook registered: %s -> %s", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + + +def register_hooks(names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple hook points.""" + for name in names: + register_hook(name, callback) + + +def trigger_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook, invoking all registered callbacks in order. + + - Zero overhead when no callbacks are registered + - Undeclared hook names will log a warning and be skipped + - pipe_key is auto-fetched from HookSpec, supports piped return value passing + """ + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + logger.warning("Undeclared hook triggered: %s — ignored", name) + return None + + pipe_key = spec.pipe_key + + for cb in _hooks.get(name, []): + try: + rv = cb(**kwargs) + if pipe_key is not None and rv is not None: + kwargs[pipe_key] = rv + except Exception: + logger.exception( + "Hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + + return kwargs.get(pipe_key) if pipe_key else None + + +def hookable(name: str): + """Decorator: automatically triggers name.before / name.after hook before and after the method. + + Auto-declares before/after Hooks (idempotent); no need to manually define_hook in hook_defs.py. + Supports piped return values: before can modify request, after can modify result. + Compatible with both sync and async methods. + """ + from memos.plugins.hook_defs import define_hook + + define_hook( + f"{name}.before", + description=f"Before {name} executes; can modify request", + params=["request"], + pipe_key="request", + ) + define_hook( + f"{name}.after", + description=f"After {name} executes; can modify result", + params=["request", "result"], + pipe_key="result", + ) + + def decorator(func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = await func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return async_wrapper + + @wraps(func) + def sync_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return sync_wrapper + + return decorator diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py new file mode 100644 index 000000000..3706855a9 --- /dev/null +++ b/src/memos/plugins/manager.py @@ -0,0 +1,75 @@ +"""Plugin manager — discover, load, and manage MemOS plugins.""" + +from __future__ import annotations + +import importlib.metadata +import logging + +from typing import TYPE_CHECKING + +from memos.plugins.base import MemOSPlugin + + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "memos.plugins" + + +class PluginManager: + """Discover, load, and manage MemOS plugins.""" + + def __init__(self): + self._plugins: dict[str, MemOSPlugin] = {} + + @property + def plugins(self) -> dict[str, MemOSPlugin]: + return dict(self._plugins) + + def discover(self) -> None: + """Discover and load all installed plugins via entry_points.""" + try: + eps = importlib.metadata.entry_points() + if hasattr(eps, "select"): + plugin_eps = eps.select(group=ENTRY_POINT_GROUP) + else: + plugin_eps = eps.get(ENTRY_POINT_GROUP, []) + except Exception: + logger.exception("Failed to query entry_points") + return + + for ep in plugin_eps: + try: + plugin_cls = ep.load() + plugin = plugin_cls() + if not isinstance(plugin, MemOSPlugin): + logger.warning("Plugin %s does not extend MemOSPlugin, skipped", ep.name) + continue + plugin.on_load() + self._plugins[plugin.name] = plugin + logger.info("Plugin discovered: %s v%s", plugin.name, plugin.version) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + + def init_app(self, app: FastAPI) -> None: + """Bind app and initialize all loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin._bind_app(app) + plugin.init_app() + logger.info("Plugin initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin: %s", plugin.name) + + def shutdown(self) -> None: + """Shut down all plugins and release resources.""" + for plugin in self._plugins.values(): + try: + plugin.on_shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) + + +plugin_manager = PluginManager() diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/conftest.py b/tests/plugins/conftest.py new file mode 100644 index 000000000..6a1a16b68 --- /dev/null +++ b/tests/plugins/conftest.py @@ -0,0 +1,17 @@ +"""Ensure @hookable-generated hooks are declared for core framework tests. + +In production, @hookable("add") runs at import time of add_handler.py, +declaring add.before / add.after. Core framework tests don't import handler +modules (to avoid heavy dependencies), so we trigger declarations here. + +Plugin-specific hooks are declared in each plugin's own tests/conftest.py. +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("chat") +hookable("feedback") +hookable("memory.get") diff --git a/tests/plugins/run_plugin_server.py b/tests/plugins/run_plugin_server.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py new file mode 100644 index 000000000..51997cba4 --- /dev/null +++ b/tests/plugins/test_plugin_demo.py @@ -0,0 +1,439 @@ +""" +Plugin system core framework tests. + +Covers generic capabilities of the memos.plugins package (independent of specific plugin implementations): +1. Hook declaration registry (hook_defs) +2. Hook registration and triggering / pipe_key pipeline return value +3. @hookable decorator (sync + async + auto-declaration + pipeline return value) +4. MemOSPlugin base class register_* methods + +Plugin-specific functional tests are located in each plugin package: + extensions/memos_demo_plugin/tests/ +""" + +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +# ========================================================================= # +# 1. Hook declaration registry (hook_defs) +# ========================================================================= # + + +class TestHookDefs: + def test_define_hook_and_get_spec(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook( + "test.custom.hook", + description="test hook", + params=["request", "result"], + pipe_key="result", + ) + + spec = get_hook_spec("test.custom.hook") + assert spec is not None + assert spec.name == "test.custom.hook" + assert spec.params == ["request", "result"] + assert spec.pipe_key == "result" + + def test_define_hook_is_idempotent(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook("test.idempotent", description="first", params=["a"], pipe_key="a") + define_hook("test.idempotent", description="second", params=["b"], pipe_key="b") + + spec = get_hook_spec("test.idempotent") + assert spec.description == "first" + + def test_get_hook_spec_returns_none_for_unknown(self): + from memos.plugins.hook_defs import get_hook_spec + + assert get_hook_spec("definitely.does.not.exist") is None + + def test_all_hook_specs_includes_custom(self): + from memos.plugins.hook_defs import H, all_hook_specs + + specs = all_hook_specs() + assert H.ADD_MEMORIES_POST_PROCESS in specs + + def test_h_constants(self): + from memos.plugins.hook_defs import H + + assert H.ADD_BEFORE == "add.before" + assert H.ADD_AFTER == "add.after" + assert H.SEARCH_BEFORE == "search.before" + assert H.SEARCH_AFTER == "search.after" + assert H.ADD_MEMORIES_POST_PROCESS == "add.memories.post_process" + + +# ========================================================================= # +# 2. Hook registration and triggering / pipe_key pipeline return value +# ========================================================================= # + + +class TestHookMechanism: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_and_trigger(self): + from memos.plugins.hooks import register_hook, trigger_hook + + captured = {} + + def my_callback(*, request, **kwargs): + captured["request"] = request + + register_hook("add.before", my_callback) + trigger_hook("add.before", request="test_request") + + assert captured["request"] == "test_request" + + def test_register_hooks_batch(self): + from memos.plugins.hooks import register_hooks, trigger_hook + + call_count = 0 + + def my_callback(**kwargs): + nonlocal call_count + call_count += 1 + + register_hooks(["add.before", "search.before"], my_callback) + trigger_hook("add.before") + trigger_hook("search.before") + + assert call_count == 2 + + def test_trigger_undeclared_hook_returns_none(self): + from memos.plugins.hooks import trigger_hook + + result = trigger_hook("nonexistent.undeclared.hook", request="anything") + assert result is None + + def test_hook_exception_does_not_propagate(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook("test.exception", description="test", params=["x"]) + + results = [] + + def bad_callback(**kwargs): + raise ValueError("intentional error") + + def good_callback(**kwargs): + results.append("ok") + + register_hook("test.exception", bad_callback) + register_hook("test.exception", good_callback) + trigger_hook("test.exception", x=1) + + assert results == ["ok"] + + def test_trigger_hook_pipe_key_returns_modified_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.pipe", + description="pipe test", + params=["request", "result"], + pipe_key="result", + ) + + def double_result(*, request, result, **kwargs): + return result * 2 + + register_hook("test.pipe", double_result) + rv = trigger_hook("test.pipe", request="req", result=5) + + assert rv == 10 + + def test_trigger_hook_pipe_key_chains_callbacks(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.chain", + description="chain test", + params=["result"], + pipe_key="result", + ) + + def add_one(*, result, **kwargs): + return result + 1 + + def add_ten(*, result, **kwargs): + return result + 10 + + register_hook("test.chain", add_one) + register_hook("test.chain", add_ten) + + rv = trigger_hook("test.chain", result=0) + assert rv == 11 + + def test_trigger_hook_pipe_key_none_callback_no_modify(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.noop", + description="noop test", + params=["result"], + pipe_key="result", + ) + + def noop(*, result, **kwargs): + return None # explicitly return None — should not modify + + register_hook("test.noop", noop) + rv = trigger_hook("test.noop", result="original") + + assert rv == "original" + + def test_trigger_hook_notification_mode(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.notify", + description="notification test", + params=["data"], + pipe_key=None, + ) + + captured = [] + + def observer(*, data, **kwargs): + captured.append(data) + + register_hook("test.notify", observer) + rv = trigger_hook("test.notify", data="hello") + + assert rv is None + assert captured == ["hello"] + + +# ========================================================================= # +# 3. @hookable decorator +# ========================================================================= # + + +class TestHookableDecorator: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_hookable_auto_declares_specs(self): + from memos.plugins.hook_defs import get_hook_spec + from memos.plugins.hooks import hookable + + @hookable("auto_test") + def dummy(self, request): + return request + + before_spec = get_hook_spec("auto_test.before") + after_spec = get_hook_spec("auto_test.after") + + assert before_spec is not None + assert before_spec.pipe_key == "request" + assert after_spec is not None + assert after_spec.pipe_key == "result" + + def test_hookable_sync(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append(("before", request)) + + def on_after(*, request, result, **kwargs): + events.append(("after", result)) + + register_hook("sync_demo.before", on_before) + register_hook("sync_demo.after", on_after) + + class FakeHandler: + @hookable("sync_demo") + def do_work(self, request): + return f"processed:{request}" + + result = FakeHandler().do_work("my_input") + + assert result == "processed:my_input" + assert events == [("before", "my_input"), ("after", "processed:my_input")] + + def test_hookable_async(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append("before") + + def on_after(*, request, result, **kwargs): + events.append("after") + + register_hook("async_demo.before", on_before) + register_hook("async_demo.after", on_after) + + class FakeHandler: + @hookable("async_demo") + async def do_work(self, request): + return "async_result" + + result = asyncio.get_event_loop().run_until_complete(FakeHandler().do_work("req")) + + assert result == "async_result" + assert events == ["before", "after"] + + def test_hookable_before_can_modify_request(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_request(*, request, **kwargs): + return "modified_request" + + register_hook("modify_req.before", rewrite_request) + + class FakeHandler: + @hookable("modify_req") + def do_work(self, request): + return f"got:{request}" + + result = FakeHandler().do_work("original") + assert result == "got:modified_request" + + def test_hookable_after_can_modify_result(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_result(*, request, result, **kwargs): + return f"{result}+modified" + + register_hook("modify_res.after", rewrite_result) + + class FakeHandler: + @hookable("modify_res") + def do_work(self, request): + return "original_result" + + result = FakeHandler().do_work("req") + assert result == "original_result+modified" + + def test_hookable_falsy_return_preserved(self): + """ensure empty list / 0 / empty string are not treated as None""" + from memos.plugins.hooks import hookable, register_hook + + def return_empty_list(*, request, result, **kwargs): + return [] + + register_hook("falsy_test.after", return_empty_list) + + class FakeHandler: + @hookable("falsy_test") + def do_work(self, request): + return [1, 2, 3] + + result = FakeHandler().do_work("req") + assert result == [] + + +# ========================================================================= # +# 4. Base class register_* methods +# ========================================================================= # + + +class TestBaseClassRegisterMethods: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_router(self): + from fastapi import APIRouter + + from memos.plugins.base import MemOSPlugin + + app = FastAPI() + plugin = MemOSPlugin() + plugin._bind_app(app) + + router = APIRouter(prefix="/test") + + @router.get("/ping") + async def ping(): + return {"pong": True} + + plugin.register_router(router) + + paths = [r.path for r in app.routes] + assert "/test/ping" in paths + + def test_register_middleware(self): + from starlette.middleware.base import BaseHTTPMiddleware + + from memos.plugins.base import MemOSPlugin + + class NoopMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = FastAPI() + + @app.get("/x") + async def x(): + return {} + + plugin = MemOSPlugin() + plugin._bind_app(app) + plugin.register_middleware(NoopMiddleware) + + client = TestClient(app) + resp = client.get("/x") + assert resp.status_code == 200 + + def test_register_hook(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("test.reg.event", description="test", params=["x"]) + + called = [] + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hook("test.reg.event", lambda **kw: called.append(True)) + + trigger_hook("test.reg.event", x=1) + assert called == [True] + + def test_register_hooks_batch(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("batch.a", description="a", params=["x"]) + define_hook("batch.b", description="b", params=["x"]) + + count = 0 + + def cb(**kw): + nonlocal count + count += 1 + + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hooks(["batch.a", "batch.b"], cb) + + trigger_hook("batch.a", x=1) + trigger_hook("batch.b", x=2) + assert count == 2 From 8a4be2096ecf35b9f1056a2a4888b9cfe24e5903 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 17:56:18 +0800 Subject: [PATCH 63/92] feat: plugin system and demo plugin --- src/memos/api/handlers/add_handler.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index e53ef9393..1bd52c108 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,8 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView -from memos.plugins.hook_defs import H -from memos.plugins.hooks import hookable, trigger_hook +from memos.plugins.hooks import hookable from memos.types import MessageList @@ -109,9 +108,6 @@ def _check_messages(messages: MessageList) -> None: results = cube_view.add_memories(add_req) - rv = trigger_hook(H.ADD_MEMORIES_POST_PROCESS, request=add_req, result=results) - results = rv if rv is not None else results - self.logger.info(f"[AddHandler] Final add results count={len(results)}") return MemoryResponse( From 6979f5c99bd2ba30d26bbf4320e35e5c001799e8 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 20:21:49 +0800 Subject: [PATCH 64/92] feat: add sync script --- .private-paths | 11 ++++++ Makefile | 5 +++ scripts/check-public-push.sh | 70 ++++++++++++++++++++++++++++++++++++ scripts/sync-public.sh | 62 ++++++++++++++++++++++++++++++++ 4 files changed, 148 insertions(+) create mode 100644 .private-paths create mode 100755 scripts/check-public-push.sh create mode 100755 scripts/sync-public.sh diff --git a/.private-paths b/.private-paths new file mode 100644 index 000000000..1df5fa57d --- /dev/null +++ b/.private-paths @@ -0,0 +1,11 @@ +# Paths exclusive to the enterprise repo (one per line). +# These will NOT be synced to the public repository. +# This file itself is also excluded from the public repo. + +extensions/ +pyproject.toml +poetry.lock +.private-paths +scripts/sync-public.sh +scripts/check-public-push.sh +Makefile diff --git a/Makefile b/Makefile index 57ede5838..178a4c19a 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ install: poetry install --extras all --with dev --with test poetry run pre-commit install --install-hooks + cp scripts/check-public-push.sh .git/hooks/pre-push + chmod +x .git/hooks/pre-push clean: rm -rf .memos @@ -25,3 +27,6 @@ serve: openapi: poetry run memos export_openapi --output docs/openapi.json + +sync-public: + @bash scripts/sync-public.sh "$(msg)" $(commit) diff --git a/scripts/check-public-push.sh b/scripts/check-public-push.sh new file mode 100755 index 000000000..0e18cff55 --- /dev/null +++ b/scripts/check-public-push.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash +# Pre-push hook: block private files from being pushed to the public repo. +# Private paths are read from .private-paths (one per line, # comments allowed). +# Installed by `make install` into .git/hooks/pre-push. + +REMOTE_NAME="$1" +REMOTE_URL="$2" + +# Only enforce on the public remote (skip MemOS-Enterprise) +if [[ "${REMOTE_URL}" != *"MemTensor/MemOS.git"* ]] || [[ "${REMOTE_URL}" == *"MemOS-Enterprise"* ]]; then + exit 0 +fi + +PRIVATE_PATHS_FILE=".private-paths" +if [ ! -f "${PRIVATE_PATHS_FILE}" ]; then + echo "⚠️ ${PRIVATE_PATHS_FILE} not found — skipping private-path check." + exit 0 +fi + +# Read private paths into regex patterns +PATTERNS=() +while IFS= read -r line; do + line="$(echo "${line}" | sed 's/#.*//; s/^[[:space:]]*//; s/[[:space:]]*$//')" + [ -z "${line}" ] && continue + # Convert path to regex: strip trailing /, add ^ anchor + pattern="^$(echo "${line}" | sed 's|/$||')" + PATTERNS+=("${pattern}") +done < "${PRIVATE_PATHS_FILE}" + +ERRORS=0 + +while read local_ref local_sha remote_ref remote_sha; do + # Skip delete operations + if [ "${local_sha}" = "0000000000000000000000000000000000000000" ]; then + continue + fi + + # For new remote refs, compare against public/main + if [ "${remote_sha}" = "0000000000000000000000000000000000000000" ]; then + base=$(git merge-base public/main "${local_sha}" 2>/dev/null || echo "public/main") + range="${base}..${local_sha}" + else + range="${remote_sha}..${local_sha}" + fi + + files=$(git diff --name-only "${range}" 2>/dev/null || true) + if [ -z "${files}" ]; then + continue + fi + + for pattern in "${PATTERNS[@]}"; do + matched=$(echo "${files}" | grep -E "${pattern}" || true) + if [ -n "${matched}" ]; then + echo "❌ BLOCKED: Private files detected in push to public repo!" + echo "" + echo " Pattern: ${pattern}" + echo " Files:" + echo "${matched}" | sed 's/^/ /' + echo "" + ERRORS=1 + fi + done +done + +if [ "${ERRORS}" -ne 0 ]; then + echo "💡 Use 'git sync-public \"\"' to safely sync CE code." + exit 1 +fi + +exit 0 diff --git a/scripts/sync-public.sh b/scripts/sync-public.sh new file mode 100755 index 000000000..4dcceabee --- /dev/null +++ b/scripts/sync-public.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Sync CE-only changes from the enterprise repo to the public repo. +# Private paths are read from .private-paths (one per line, # comments allowed). +# +# Usage: +# git sync-public "" [commit-ref] +# make sync-public msg="" [commit=] + +PUBLIC_REMOTE="public" +PRIVATE_PATHS_FILE=".private-paths" + +CE_MSG="${1:?Usage: git sync-public \"\" [commit-ref]}" +COMMIT="${2:-HEAD}" +EE_BRANCH="$(git branch --show-current)" +PUBLIC_BRANCH="public-$(echo "${EE_BRANCH}" | tr '/' '-')" + +# Read private paths from config file +if [ ! -f "${PRIVATE_PATHS_FILE}" ]; then + echo "❌ ${PRIVATE_PATHS_FILE} not found. Cannot determine private paths." + exit 1 +fi + +EXCLUDE_ARGS="" +while IFS= read -r line; do + line="$(echo "${line}" | sed 's/#.*//; s/^[[:space:]]*//; s/[[:space:]]*$//')" + [ -z "${line}" ] && continue + EXCLUDE_ARGS="${EXCLUDE_ARGS} ':!${line}'" +done < "${PRIVATE_PATHS_FILE}" + +git fetch "${PUBLIC_REMOTE}" main + +# Find CE files changed in the specified commit +CE_FILES=$(eval git diff --name-only "${COMMIT}^..${COMMIT}" -- . ${EXCLUDE_ARGS}) + +if [ -z "${CE_FILES}" ]; then + echo "✅ No CE changes in commit $(git rev-parse --short "${COMMIT}"). Done." + exit 0 +fi + +echo "▶ CE changes from $(git log -1 --format='%h %s' "${COMMIT}"):" +echo "${CE_FILES}" | sed 's/^/ /' + +# Reuse existing public branch or create from public/main +if git show-ref --verify --quiet "refs/heads/${PUBLIC_BRANCH}"; then + git checkout "${PUBLIC_BRANCH}" +else + git checkout -B "${PUBLIC_BRANCH}" "${PUBLIC_REMOTE}/main" +fi + +# Checkout CE files from the enterprise commit +echo "${CE_FILES}" | xargs git checkout "${COMMIT}" -- + +git commit --no-verify -m "${CE_MSG}" +echo "▶ Pushing ${PUBLIC_BRANCH} to ${PUBLIC_REMOTE}..." +git push "${PUBLIC_REMOTE}" "${PUBLIC_BRANCH}" +git checkout "${EE_BRANCH}" + +echo "" +echo "✅ Done. Create PR:" +echo " https://github.com/MemTensor/MemOS/pull/new/${PUBLIC_BRANCH}" From 30d4f8ee761c348d0cc81afee66c794027b5771c Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 20:53:41 +0800 Subject: [PATCH 65/92] feat: test modify ce code --- src/memos/api/handlers/add_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 1bd52c108..3cdbedabf 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,7 +15,6 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView -from memos.plugins.hooks import hookable from memos.types import MessageList @@ -38,7 +37,6 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) - @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. From 3b0bfd20497a21bb35cbcf4554a0c8444fc9b4c7 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 4 Mar 2026 21:01:20 +0800 Subject: [PATCH 66/92] fix: retrieving item.metadata.history instead of item.history. --- .../textual/tree_text_memory/organize/history_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index d5eb69117..1d61e7774 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -942,7 +942,7 @@ def _handle_restored_memories( for i, data in enumerate(restored_memories): source_item = source_items[i] # deal with history - source_history = source_item.history.copy() + source_history = source_item.metadata.history.copy() value = data.get("value", "") value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) value = value_item.memory From 6d6468c3431847adeb9e9f12e164d9b2542eb696 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 21:05:19 +0800 Subject: [PATCH 67/92] feat: test modify ce code --- src/memos/api/handlers/search_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 2877e5138..8e7785ad5 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -20,7 +20,6 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView -from memos.plugins.hooks import hookable logger = get_logger(__name__) @@ -45,7 +44,6 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) - @hookable("search") def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ Main handler for search memories endpoint. From 5cc45a98697fee31ea2c96c4d1d08570747b8ad9 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 21:57:53 +0800 Subject: [PATCH 68/92] feat: addMemory add hookable --- src/memos/api/handlers/add_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..1bd52c108 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable from memos.types import MessageList @@ -37,6 +38,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. From 9b7c42e5073ffdb0d54db469b70cc4cd2dae83cd Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 4 Mar 2026 22:05:16 +0800 Subject: [PATCH 69/92] fix: sync-public --- scripts/sync-public.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scripts/sync-public.sh b/scripts/sync-public.sh index 4dcceabee..b71526fd0 100755 --- a/scripts/sync-public.sh +++ b/scripts/sync-public.sh @@ -52,6 +52,13 @@ fi # Checkout CE files from the enterprise commit echo "${CE_FILES}" | xargs git checkout "${COMMIT}" -- +# If there is no staged diff, the CE file content is already present on public branch. +if git diff --cached --quiet; then + echo "✅ All CE changes already synced on ${PUBLIC_BRANCH}. Nothing new to commit." + git checkout "${EE_BRANCH}" + exit 0 +fi + git commit --no-verify -m "${CE_MSG}" echo "▶ Pushing ${PUBLIC_BRANCH} to ${PUBLIC_REMOTE}..." git push "${PUBLIC_REMOTE}" "${PUBLIC_BRANCH}" From d92e4364368d2273d966c2c0442e78d7f936962c Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 5 Mar 2026 14:27:12 +0800 Subject: [PATCH 70/92] ci: fix test --- tests/plugins/test_plugin_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py index 51997cba4..77ea8dfce 100644 --- a/tests/plugins/test_plugin_demo.py +++ b/tests/plugins/test_plugin_demo.py @@ -292,7 +292,7 @@ class FakeHandler: async def do_work(self, request): return "async_result" - result = asyncio.get_event_loop().run_until_complete(FakeHandler().do_work("req")) + result = asyncio.run(FakeHandler().do_work("req")) assert result == "async_result" assert events == ["before", "after"] From d9ae69e825f0c1e840884d2a36e5b6dca12bce76 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Fri, 6 Mar 2026 10:13:14 +0800 Subject: [PATCH 71/92] fix: typo --- .../textual/tree_text_memory/organize/history_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 1d61e7774..0bbd6f27a 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -949,7 +949,7 @@ def _handle_restored_memories( tags = data.get("tags", []) keys = data.get("keys", []) memory_type = data.get("memory_type", "LongTermMemory") - version = source_item.verion + version = source_item.metadata.version new_history_item = ArchivedTextualMemory( version=version, is_fast=False, From e8364fea6c8f536ea8bb5121521b043eafc3ac79 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Mar 2026 20:09:14 +0800 Subject: [PATCH 72/92] fix: Fix missing sources and minor bugs in feedback --- src/memos/mem_feedback/feedback.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 18045af2c..7e46ae68a 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -243,7 +243,7 @@ def _single_add_operation( datetime.now().isoformat() ) to_add_memory.metadata.background = new_memory_item.metadata.background - to_add_memory.metadata.sources = [] + to_add_memory.metadata.sources = new_memory_item.metadata.sources added_ids = self._retry_db_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False) @@ -330,9 +330,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> f"[Memory Feedback UPDATE] Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" ) - delete_ids = [] - if bindings_to_delete: - delete_ids = list({bindings_to_delete}) + delete_ids = list(bindings_to_delete) for mid in delete_ids: try: @@ -345,6 +343,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> logger.warning( f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) + return bindings_to_delete def semantics_feedback( self, @@ -470,7 +469,7 @@ def semantics_feedback( f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) - if update_results: + if update_results and getattr(self.mem_reader, "memory_version_switch", "off") != "on": updated_ids = [item["archived_id"] for item in update_results] self._del_working_binding(updated_ids, user_name) @@ -1059,7 +1058,14 @@ def check_validity(item): tags=tags, key=key, embedding=embedding, - sources=[{"type": "chat"}], + sources=[ + { + "type": "feedback", + "role": "user", + "chat_time": feedback_time, + "content": feedback_content, + } + ], background=background, type="fine", info=info, From 9be5df3fcf63db84859aaeb477d4577c5b74f2cc Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 9 Mar 2026 20:15:30 +0800 Subject: [PATCH 73/92] feat: backward-compatible memory version feature for feedback. --- src/memos/api/handlers/component_init.py | 1 + src/memos/mem_feedback/feedback.py | 136 +++++++++++++++--- src/memos/mem_feedback/simple_feedback.py | 3 + src/memos/memories/textual/item.py | 4 +- .../organize/history_manager.py | 95 +++++++++--- .../memories/textual/test_history_manager.py | 52 +++++++ 6 files changed, 243 insertions(+), 48 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 53e7bfe16..a4cfcc77f 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -261,6 +261,7 @@ def init_server() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, + history_manager=memory_history_manager, pref_feedback=True, ) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 7e46ae68a..7a26c0849 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -27,6 +27,7 @@ from memos.mem_reader.factory import MemReaderFactory from memos.mem_reader.read_multi_modal import detect_lang from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import ( MemoryManager, extract_working_binding_ids, @@ -90,6 +91,12 @@ def __init__(self, config: MemFeedbackConfig): }, is_reorganize=self.is_reorganize, ) + # Actually this is initialized through SimpleMemFeedback, so it's fine. + self.history_manager = MemoryHistoryManager( + nli_client=None, + graph_db=self.graph_store, + embedder=self.embedder, + ) self.stopword_manager = StopwordManager self.searcher: Searcher = None self.reranker = None @@ -287,33 +294,43 @@ def _single_update_operation( new_memory_item.memory = operation["text"] new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] - if memory_type == "WorkingMemory": - fields = { - "memory": new_memory_item.memory, - "key": new_memory_item.metadata.key, - "tags": new_memory_item.metadata.tags, - "embedding": new_memory_item.metadata.embedding, - "background": new_memory_item.metadata.background, - "covered_history": old_memory_item.id, - } - self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) - item_id = old_memory_item.id - else: - done = self._single_add_operation( - old_memory_item, new_memory_item, user_id, user_name, async_mode + if getattr(self.mem_reader, "memory_version_switch", "off") != "on": + if memory_type == "WorkingMemory": + fields = { + "memory": new_memory_item.memory, + "key": new_memory_item.metadata.key, + "tags": new_memory_item.metadata.tags, + "embedding": new_memory_item.metadata.embedding, + "background": new_memory_item.metadata.background, + "covered_history": old_memory_item.id, + } + self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) + item_id = old_memory_item.id + else: + done = self._single_add_operation( + old_memory_item, new_memory_item, user_id, user_name, async_mode + ) + item_id = done.get("id") + self.graph_store.update_node( + item_id, {"covered_history": old_memory_item.id}, user_name=user_name + ) + self.graph_store.update_node( + old_memory_item.id, {"status": "archived"}, user_name=user_name + ) + + logger.info( + f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" ) - item_id = done.get("id") - self.graph_store.update_node( - item_id, {"covered_history": old_memory_item.id}, user_name=user_name + else: + item_id = self._single_update_operation_with_versions( + old_memory_item=old_memory_item, + new_memory_item=new_memory_item, + user_name=user_name, ) - self.graph_store.update_node( - old_memory_item.id, {"status": "archived"}, user_name=user_name + logger.info( + f"[Memory Feedback UPDATE] Updated:{item_id} | history appended | memory_type: {old_memory_item.metadata.memory_type}" ) - logger.info( - f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" - ) - return { "id": item_id, "text": new_memory_item.memory, @@ -322,6 +339,79 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } + def _single_update_operation_with_versions( + self, + old_memory_item: TextualMemoryItem, + new_memory_item: TextualMemoryItem, + user_name: str, + ) -> str: + try: + updated_item, archived_item, archived_metadata, updated_fields = ( + self.history_manager.update_from_feedback( + old_item=old_memory_item, + new_item=new_memory_item, + user_name=user_name, + ) + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] history fallback for %s: %s", old_memory_item.id, e + ) + updated_item = old_memory_item.model_copy(deep=True) + updated_item.memory = new_memory_item.memory + updated_item.metadata.key = new_memory_item.metadata.key + updated_item.metadata.tags = new_memory_item.metadata.tags + updated_item.metadata.background = new_memory_item.metadata.background + if getattr(new_memory_item.metadata, "sources", None) is not None: + current_sources = list(updated_item.metadata.sources or []) + updated_item.metadata.sources = ( + list(new_memory_item.metadata.sources or []) + current_sources + ) + if getattr(new_memory_item.metadata, "embedding", None) is not None: + updated_item.metadata.embedding = new_memory_item.metadata.embedding + if updated_item.metadata.memory_type == "PreferenceMemory": + updated_item.metadata.preference = updated_item.memory + updated_fields = { + "memory": updated_item.memory, + "key": updated_item.metadata.key, + "tags": updated_item.metadata.tags, + "embedding": updated_item.metadata.embedding, + "background": updated_item.metadata.background, + "sources": [ + source.model_dump(exclude_none=True) + if hasattr(source, "model_dump") + else source + for source in (updated_item.metadata.sources or []) + ], + "covered_history": old_memory_item.id, + } + archived_item = None + archived_metadata = None + + if archived_item and archived_metadata: + try: + self.graph_store.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=archived_metadata, + user_name=user_name, + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] archive add failed for %s: %s", + old_memory_item.id, + e, + ) + self._retry_db_operation( + lambda: self.graph_store.update_node( + id=updated_item.id, + fields=updated_fields, + user_name=user_name, + ) + ) + self._del_working_binding(user_name, [old_memory_item]) + return updated_item.id + def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index dfc9b9fdf..d28a8e9da 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -4,6 +4,7 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -23,12 +24,14 @@ def __init__( mem_reader: SimpleStructMemReader, searcher: Searcher, reranker: BaseReranker, + history_manager: MemoryHistoryManager, pref_feedback: bool = False, ): self.llm = llm self.embedder = embedder self.graph_store = graph_store self.memory_manager = memory_manager + self.history_manager = history_manager self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 28b99b05a..b0f90b537 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -69,9 +69,9 @@ class ArchivedTextualMemory(BaseModel): memory: str | None = Field( default_factory=lambda: "", description="The content of the archived version of the memory." ) - update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + update_type: Literal["conflict", "duplicate", "extract", "unrelated", "feedback"] = Field( default="unrelated", - description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`, `feedback`).", ) archived_memory_id: str | None = Field( default=None, diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 0bbd6f27a..4ed4e232b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -183,6 +183,12 @@ def update_node_with_history( tags: list[str] | None = None, key: str | None = None, ) -> tuple[TextualMemoryItem, TextualMemoryItem]: + """ + This method is used to update a given item. + It updates the item.memory to new_memory, and pushes the old item.memory content to its history. + Instead, it also creates an archived_item to store the embeddings and sources of the old memory content, + and stores it to the graph_db. + """ now = datetime.now().isoformat() last_update_time = item.metadata.updated_at @@ -197,7 +203,7 @@ def update_node_with_history( archived_item.metadata.updated_at = now # original memory with updated contents and history - archived_history = ArchivedTextualMemory( + history_item = ArchivedTextualMemory( version=item.metadata.version or 1, is_fast=item.metadata.is_fast or False, memory=item.memory, @@ -215,7 +221,7 @@ def update_node_with_history( item.metadata.key = key if item.metadata.history is None: item.metadata.history = [] - item.metadata.history.append(archived_history) + item.metadata.history.append(history_item) return item, archived_item @@ -574,6 +580,55 @@ def apply_mem_version_update( ) return self.build_fallback_new_items(original_item, user_name=user_name) + def update_from_feedback( + self, + old_item: TextualMemoryItem, + new_item: TextualMemoryItem, + user_name: str, + update_type: Literal[ + "conflict", "duplicate", "extract", "unrelated", "feedback" + ] = "feedback", + ) -> tuple[TextualMemoryItem, TextualMemoryItem, dict[str, Any], dict[str, Any]]: + current_item, archived_item = self.update_node_with_history( + item=old_item.model_copy(deep=True), + new_memory=new_item.memory, + update_type=update_type, + tags=new_item.metadata.tags, + key=new_item.metadata.key, + ) + current_item.metadata.background = new_item.metadata.background + if getattr(new_item.metadata, "sources", None) is not None: + current_sources = list(current_item.metadata.sources or []) + current_item.metadata.sources = list(new_item.metadata.sources or []) + current_sources + if getattr(new_item.metadata, "embedding", None) is not None: + current_item.metadata.embedding = new_item.metadata.embedding + elif self.embedder: + current_item.metadata.embedding = self._compute_embedding(current_item.memory) + if current_item.metadata.memory_type == "PreferenceMemory": + current_item.metadata.preference = current_item.memory + + archived_embedding = getattr(old_item.metadata, "embedding", None) + if archived_embedding is None: + archived_embedding = TextualMemoryItem( + **self.graph_db.get_node(old_item.id, user_name=user_name, include_embedding=True) + ).metadata.embedding + arch_meta = _sanitize_metadata_dict(archived_item.metadata.model_dump(exclude_none=True)) + arch_meta["embedding"] = archived_embedding + metadata_fields = _sanitize_metadata_dict( + current_item.metadata.model_dump(exclude_none=True) + ) + history_dump = [ + h.model_dump(exclude_none=True) for h in (current_item.metadata.history or []) + ] + update_fields = { + **metadata_fields, + "memory": current_item.memory, + "history": history_dump, + "version": current_item.metadata.version, + "covered_history": old_item.id, + } + return current_item, archived_item, arch_meta, update_fields + def _check_and_fetch_replacements( self, item: TextualMemoryItem, pending_indices: list[int], user_name: str ) -> tuple[dict[int, list[ArchivedTextualMemory]], list[str]]: @@ -691,15 +746,11 @@ def _update_existing_memory( Returns the updated primary TextualMemoryItem and optional new item when fallback is used. """ primary_id, secondary_ids = target_ids[0], target_ids[1:] - new_value, tags, key = ( + new_memory_value, tags, key = ( mem_data.get("value", ""), mem_data.get("tags", []), mem_data.get("key", ""), ) - new_value_item = TextualMemoryItem( - memory=new_value, metadata=TreeNodeTextualMemoryMetadata() - ) - new_value = new_value_item.memory # Fetch candidate nodes in batch and then select the primary # We update the primary and then merge the secondaries to the primary @@ -710,19 +761,21 @@ def _update_existing_memory( logger.warning( f"[MemoryHistoryManager] Target node {primary_id} not found for update. Skipping." ) - # Fallback to create new item when the source_id is hallucinated by llm + # Fallback to create new item when the source_id is not valid(hallucination from llm) new_item = self._create_new_memory(mem_data, fast_item) return None, new_item current_item = TextualMemoryItem(**node_data) - # For concurrency control, need to make sure the primary item has not been modified by others in the meantime + # For concurrency control, need to make sure the primary item has not been modified by others during the run. # If it has(version changed), then we need to use llm to merge again. - new_value = self._apply_cas_merge(primary_id, current_item, expected_versions, new_value) + new_memory_value = self._apply_cas_merge( + primary_id, current_item, expected_versions, new_memory_value + ) update_type = "duplicate" if primary_id in source_ids else "conflict" current_item, archived_item = self.update_node_with_history( current_item, - new_value, + new_memory_value, update_type, tags=tags, key=key, @@ -784,7 +837,7 @@ def _apply_cas_merge( primary_id: str, current_item: TextualMemoryItem, expected_versions: dict[str, int], - new_value: str, + new_memory_value: str, ) -> str: expected_version = expected_versions.get(primary_id) current_version = current_item.metadata.version or 1 @@ -794,16 +847,13 @@ def _apply_cas_merge( f"Expected v{expected_version}, but found v{current_version} in DB. " "Triggering merge logic." ) - latest_item = TextualMemoryItem( - memory=current_item.memory, metadata=TreeNodeTextualMemoryMetadata() - ) merged_content = self._merge_conflicting_memory( - latest_memory=latest_item.memory, - proposed_update=new_value, + latest_memory=current_item.memory, + proposed_update=new_memory_value, ) return merged_content - return new_value + return new_memory_value def _merge_secondary_nodes( self, @@ -883,14 +933,13 @@ def _merge_conflicting_memory(self, latest_memory: str, proposed_update: str) -> messages = [{"role": "user", "content": prompt}] try: response = self.llm.generate(messages) + if not response: + raise ValueError("LLM response is None.") return response.strip() except Exception as e: logger.error(f"[MemoryHistoryManager] Failed to merge memory via LLM: {e}") - # Fallback: append proposed update? or just return proposed? - # Returning proposed might overwrite latest changes. - # Returning latest might lose proposed changes. - # Let's concatenate as a safe fallback. - return f"{latest_memory}\n\n[System Merge Fallback] New Info: {proposed_update}" + # Fallback: concatenate as a safe fallback. + return f"{latest_memory}\n\n[New Info]: {proposed_update}" def _create_new_memory( self, mem_data: dict[str, Any], fast_item: TextualMemoryItem diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 813a4b622..a3dac0e02 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -618,6 +618,58 @@ def test_update_existing_memory_no_mark_when_working_binding_matches( assert updated is not None assert new_item is None + + +def test_update_from_feedback_returns_persistence_payload_without_side_effects( + history_manager, mock_graph_db +): + history_manager.mark_memory_status = MagicMock() + memory_id = str(uuid.uuid4()) + old_item = TextualMemoryItem( + id=memory_id, + memory="Old Content", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + memory_type="LongTermMemory", + embedding=[0.1, 0.2], + sources=[{"type": "chat", "content": "old source"}], + history=[], + ), + ) + new_item = TextualMemoryItem( + memory="Updated Content", + metadata=TreeNodeTextualMemoryMetadata( + tags=["fresh"], + key="topic", + background="new background", + embedding=[0.3, 0.4], + sources=[{"type": "feedback", "content": "new feedback source"}], + memory_type="LongTermMemory", + ), + ) + + current_item, archived_item, archived_metadata, update_fields = ( + history_manager.update_from_feedback( + old_item=old_item, + new_item=new_item, + user_name="u1", + ) + ) + + assert current_item.id == memory_id + assert current_item.memory == "Updated Content" + assert archived_item.memory == "Old Content" + assert current_item.metadata.sources[0].content == "new feedback source" + assert current_item.metadata.sources[0].type == "feedback" + assert current_item.metadata.sources[1].content == "old source" + assert archived_item.metadata.sources[0].content == "old source" + assert archived_metadata["embedding"] == [0.1, 0.2] + assert update_fields["memory"] == "Updated Content" + assert update_fields["covered_history"] == memory_id + assert update_fields["embedding"] == [0.3, 0.4] + mock_graph_db.get_node.assert_not_called() + mock_graph_db.add_node.assert_not_called() + mock_graph_db.update_node.assert_not_called() history_manager.mark_memory_status.assert_not_called() From b0c2848c07c2a9d1b43266d3097e12000686e2a7 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Mar 2026 14:20:46 +0800 Subject: [PATCH 74/92] fix: Metadata key validation error fix. --- .../textual/tree_text_memory/organize/history_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 4ed4e232b..14ccaac46 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -996,7 +996,7 @@ def _handle_restored_memories( value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) value = value_item.memory tags = data.get("tags", []) - keys = data.get("keys", []) + key = data.get("key", "") memory_type = data.get("memory_type", "LongTermMemory") version = source_item.metadata.version new_history_item = ArchivedTextualMemory( @@ -1015,7 +1015,7 @@ def _handle_restored_memories( "is_fast": False, "version": version + 1, "tags": tags, - "key": keys, + "key": key, "created_at": datetime.now().isoformat(), "history": source_history, "embedding": self._compute_embedding(value), From a3b761d2a09b5d95b67c0e4537d13bab7dd83794 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Tue, 10 Mar 2026 16:15:42 +0800 Subject: [PATCH 75/92] fix: restored_memory bug fix. --- .../textual/tree_text_memory/organize/history_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 14ccaac46..89e4ba298 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -968,8 +968,8 @@ def _create_new_memory( "embedding": self._compute_embedding(new_value), } metadata = fast_item.metadata.model_copy(deep=True) - for field_name, value in metadata_updates.items(): - setattr(metadata, field_name, value) + for field_name, field_value in metadata_updates.items(): + setattr(metadata, field_name, field_value) metadata = _sanitize_metadata_model(metadata) new_item = TextualMemoryItem( @@ -1021,8 +1021,8 @@ def _handle_restored_memories( "embedding": self._compute_embedding(value), } metadata = fast_item.metadata.model_copy(deep=True) - for field_name, value in metadata_updates.items(): - setattr(metadata, field_name, value) + for field_name, field_value in metadata_updates.items(): + setattr(metadata, field_name, field_value) metadata = _sanitize_metadata_model(metadata) new_item = TextualMemoryItem( From fc500facc54f1a016881c6f71abf1353961ca97b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 10 Mar 2026 16:16:53 +0800 Subject: [PATCH 76/92] doc: add extension plugin README --- extensions/README.md | 615 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 615 insertions(+) create mode 100644 extensions/README.md diff --git a/extensions/README.md b/extensions/README.md new file mode 100644 index 000000000..84c7de004 --- /dev/null +++ b/extensions/README.md @@ -0,0 +1,615 @@ +# MemOS 插件开发 README + +这是一份面向新人的标准操作流程。按顺序执行即可完成以下全流程: + +`环境搭建 -> 插件开发 -> 测试 -> 提交 -> 部署` + +如需了解设计原理,请参考《MemOS 插件系统设计与开发指南》。 + +## 快速导航 + +1. [环境搭建](#环境搭建) +2. [创建插件](#创建插件) +3. [注册插件](#注册插件) +4. [编写测试](#编写测试) +5. [代码提交](#代码提交) +6. [部署与验证](#部署与验证) +7. [开发 Checklist](#开发-checklist) +8. [常用命令速查](#常用命令速查) +9. [文件模板清单](#文件模板清单) + +## 环境搭建 + +这是一次性操作,首次接入时完成即可。 + +### 1. 克隆企业仓库 + +```bash +git clone git@github.com:MemTensor/MemOS-Enterprise.git +cd MemOS-Enterprise +``` + +### 2. 添加公开仓库 remote + +```bash +git remote add public git@github.com:MemTensor/MemOS.git +``` + +### 3. 安装依赖与 Git Hooks + +`make install` 会同时完成依赖安装,以及以下 Git Hooks 配置: + +- `pre-commit`:代码检查 +- `pre-push`:私有代码拦截 + +```bash +make install +``` + +### 4. 配置同步 alias(可选,推荐) + +```bash +git config alias.sync-public '!bash scripts/sync-public.sh' +``` + +### 5. 环境验证 + +#### 插件框架测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest tests/plugins/ -v +``` + +#### Demo 插件测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest extensions/memos_demo_plugin/tests/ -v +``` + +#### 启动服务并验证插件加载 + +```bash +uvicorn memos.api.server_api:app --port 8001 +curl http://127.0.0.1:8001/demo/health +``` + +预期返回: + +```json +{"status":"ok","plugin":"demo","version":"0.1.0"} +``` + +## 创建插件 + +以下以开发 `memos_foo_plugin` 为例,实际使用时将 `foo` 替换为你的插件名。 + +### 1. 创建目录 + +```bash +mkdir -p extensions/memos_foo_plugin/tests +touch extensions/memos_foo_plugin/__init__.py +touch extensions/memos_foo_plugin/tests/__init__.py +``` + +### 2. 包入口 + +文件:`extensions/memos_foo_plugin/__init__.py` + +```python +from memos_foo_plugin.plugin import FooPlugin + +__all__ = ["FooPlugin"] +``` + +### 3. 编写 Plugin 主类 + +文件:`extensions/memos_foo_plugin/plugin.py` + +```python +import logging +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H + +logger = logging.getLogger(__name__) + + +class FooPlugin(MemOSPlugin): + name = "foo" + version = "0.1.0" + description = "Foo plugin - brief description" + + def on_load(self) -> None: + self.counter: dict[str, int] = {} + logger.info("[Foo] plugin loaded") + + def init_app(self) -> None: + from memos_foo_plugin.hooks import on_add_after + from memos_foo_plugin.routes import create_router + + self.register_router(create_router(self)) + self.register_hook(H.ADD_AFTER, partial(on_add_after, self)) + + # from memos_foo_plugin.middleware import FooMiddleware + # self.register_middleware(FooMiddleware) + + logger.info("[Foo] plugin initialized") + + def on_shutdown(self) -> None: + logger.info("[Foo] plugin shutdown") +``` + +### 4. 编写路由 + +文件:`extensions/memos_foo_plugin/routes.py` + +```python +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter + +if TYPE_CHECKING: + from memos_foo_plugin.plugin import FooPlugin + + +def create_router(plugin: FooPlugin) -> APIRouter: + router = APIRouter(prefix="/foo", tags=["foo"]) + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name} + + @router.get("/stats") + async def stats(): + return {"counter": plugin.counter} + + return router +``` + +### 5. 编写 Hook 回调 + +文件:`extensions/memos_foo_plugin/hooks.py` + +```python +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from memos_foo_plugin.plugin import FooPlugin + +logger = logging.getLogger(__name__) + + +def on_add_after(plugin: FooPlugin, *, request, result, **kw) -> None: + """[add.after] Count add operations per user.""" + uid = getattr(request, "user_id", "unknown") + plugin.counter[uid] = plugin.counter.get(uid, 0) + 1 + logger.info("[Foo] add counted user=%s total=%d", uid, plugin.counter[uid]) +``` + +### 6. 编写中间件(可选) + +文件:`extensions/memos_foo_plugin/middleware.py` + +```python +import logging +import time + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +logger = logging.getLogger(__name__) + + +class FooMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + start = time.time() + response = await call_next(request) + elapsed_ms = (time.time() - start) * 1000 + logger.info( + "[Foo] %s %s -> %s (%.1fms)", + request.method, + request.url.path, + response.status_code, + elapsed_ms, + ) + return response +``` + +### 7. 自定义 Hook(可选) + +如果插件需要自己定义并触发 Hook,而不只是监听 CE 提供的 Hook,可以新增 `hook_defs.py`。 + +文件:`extensions/memos_foo_plugin/hook_defs.py` + +```python +from memos.plugins.hook_defs import define_hook + + +class FooH: + """Foo plugin hook name constants.""" + + RESULT_ENRICH = "foo.result.enrich" + + +define_hook( + FooH.RESULT_ENRICH, + description="Enrich result data after processing", + params=["user_id", "result"], + pipe_key="result", +) +``` + +在路由或业务逻辑中触发: + +```python +from memos.plugins.hooks import trigger_hook +from memos_foo_plugin.hook_defs import FooH + +rv = trigger_hook(FooH.RESULT_ENRICH, user_id="alice", result=data) +data = rv if rv is not None else data +``` + +在 `plugin.py` 中注册回调: + +```python +from memos_foo_plugin.hook_defs import FooH + +self.register_hook(FooH.RESULT_ENRICH, partial(enrich_result, self)) +``` + +## 注册插件 + +需要在 `pyproject.toml` 中添加两处配置。 + +### 1. 声明包路径 + +```toml +[tool.poetry] +packages = [ + {include = "memos", from = "src"}, + {include = "memos_foo_plugin", from = "extensions"}, +] +``` + +### 2. 注册 entry point + +```toml +[project.entry-points."memos.plugins"] +demo = "memos_demo_plugin:DemoPlugin" +foo = "memos_foo_plugin:FooPlugin" +``` + +### 3. 重新安装使 entry point 生效 + +```bash +pip install -e . +``` + +> 注意: +> 仅修改已安装插件的代码时,在 editable 模式下通常重启服务即可。 +> 如果是新增插件,或修改了 `pyproject.toml`,则必须重新安装。 + +## 编写测试 + +### 1. `conftest.py` + +文件:`extensions/memos_foo_plugin/tests/conftest.py` + +```python +"""Ensure hooks used by FooPlugin are declared for testing.""" + +from memos.plugins.hooks import hookable + +# Declare CE hooks (normally declared at import time of handler modules) +hookable("add") +hookable("search") + +# If plugin has custom hook_defs, import to trigger declarations: +# import memos_foo_plugin.hook_defs # noqa: F401 +``` + +### 2. 生命周期测试 + +文件:`extensions/memos_foo_plugin/tests/test_lifecycle.py` + +```python +from fastapi import FastAPI + + +def _init_plugin(plugin, app): + plugin._bind_app(app) + plugin.init_app() + + +class TestFooPluginLifecycle: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_metadata(self): + from memos_foo_plugin.plugin import FooPlugin + + plugin = FooPlugin() + assert plugin.name == "foo" + assert plugin.version == "0.1.0" + + def test_on_load_state(self): + from memos_foo_plugin.plugin import FooPlugin + + plugin = FooPlugin() + plugin.on_load() + assert plugin.counter == {} + + def test_full_lifecycle(self): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + _init_plugin(plugin, app) + + paths = [r.path for r in app.routes] + assert "/foo/health" in paths + assert "/foo/stats" in paths + + plugin.on_shutdown() +``` + +### 3. Hook 回调测试 + +文件:`extensions/memos_foo_plugin/tests/test_hooks.py` + +```python +from fastapi import FastAPI + + +def _make_plugin(): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + return plugin + + +class TestHookCallbacks: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_add_after_counts(self): + from memos.plugins.hooks import trigger_hook + + plugin = _make_plugin() + + class Req: + user_id = "alice" + + trigger_hook("add.after", request=Req(), result={}) + trigger_hook("add.after", request=Req(), result={}) + assert plugin.counter["alice"] == 2 +``` + +### 4. 路由测试 + +文件:`extensions/memos_foo_plugin/tests/test_routes.py` + +```python +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def _make_app(): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + return app, plugin + + +class TestRoutes: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_health(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/foo/health") + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + def test_stats_empty(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/foo/stats") + assert resp.json()["counter"] == {} +``` + +### 5. 运行测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest extensions/memos_foo_plugin/tests/ -v +``` + +## 代码提交 + +### 1. 提交到企业仓库 + +```bash +git add -A +git commit -m "feat: add foo plugin" +git push origin feature/foo +``` + +这是标准 Git 流程,完整代码推送到企业仓库。 + +### 2. 同步 CE 代码到公开仓库 + +如果本次改动包含 CE 代码,例如 `src/memos/`、`tests/plugins/` 等,需要执行同步。 + +同步最近一次 commit 的 CE 改动: + +```bash +git sync-public "feat: add plugin framework enhancement" +``` + +同步指定 commit: + +```bash +git sync-public "fix: hook trigger" abc1234 +``` + +或使用 `make`: + +```bash +make sync-public msg="feat: add plugin framework enhancement" +``` + +推送后,在 GitHub 创建 PR 合入 `public/main`。 + +### 3. 判断是否需要 `sync-public` + +| 改动内容 | 需要 `sync-public` | +| --- | --- | +| `extensions/` 下的插件代码 | ❌ | +| `pyproject.toml` / `poetry.lock` | ❌ | +| `scripts/` / `Makefile` / `.private-paths` | ❌ | +| `src/memos/plugins/` 框架代码 | ✅ | +| `src/memos/api/` 中新增 `@hookable` | ✅ | +| `tests/plugins/` 框架测试 | ✅ | + +### 4. 新增私有路径 + +如果新增了不应同步到公开仓库的文件或目录,请编辑 `.private-paths`,每行添加一个路径。 + +```text +extensions/ +pyproject.toml +poetry.lock +.private-paths +scripts/sync-public.sh +scripts/check-public-push.sh +Makefile +docs/internal/ +``` + +## 部署与验证 + +### 1. 启动服务 + +```bash +uvicorn memos.api.server_api:app --port 8001 +``` + +启动日志中应看到: + +```text +INFO: Plugin discovered: foo v0.1.0 +INFO: Plugin initialized: foo +``` + +### 2. 验证接口 + +插件健康检查: + +```bash +curl http://127.0.0.1:8001/foo/health +``` + +插件业务接口: + +```bash +curl http://127.0.0.1:8001/foo/stats +``` + +### 3. 验证 Hook 生效 + +通过调用 CE 接口触发 Hook,再检查插件状态。 + +触发 `add` 接口,使插件的 `add.after` hook 被调用: + +```bash +curl -X POST http://127.0.0.1:8001/product/add \ + -H "Content-Type: application/json" \ + -d '{"user_id": "test_user", ...}' +``` + +查看插件统计: + +```bash +curl http://127.0.0.1:8001/foo/stats +``` + +预期返回: + +```json +{"counter": {"test_user": 1}} +``` + +## 开发 Checklist + +开发完成后,逐项确认: + +- [ ] `Plugin` 类继承 `MemOSPlugin` +- [ ] 已实现 `name`、`version`、`description` +- [ ] 已在 `init_app()` 中注册路由 +- [ ] Hook 回调使用 `self.register_hook(...)` 正确注册 +- [ ] 如有中间件,已使用 `self.register_middleware(...)` 注册 +- [ ] 已在 `pyproject.toml` 中声明包路径 +- [ ] 已在 entry points 中注册插件 +- [ ] 测试通过:插件测试可完整运行 +- [ ] 服务启动日志出现插件发现与初始化信息 +- [ ] 插件接口返回预期结果 +- [ ] 代码已推送到企业仓库 +- [ ] 如涉及 CE 代码,已完成 `sync-public` + +## 常用命令速查 + +| 操作 | 命令 | +| --- | --- | +| 安装依赖 + hooks | `make install` | +| 运行全部测试 | `PYTHONPATH="src:extensions" python -m pytest tests/plugins/ extensions/ -v` | +| 运行单个插件测试 | `PYTHONPATH="src:extensions" python -m pytest extensions/memos_foo_plugin/tests/ -v` | +| 启动服务 | `uvicorn memos.api.server_api:app --port 8001` | +| 代码格式化 | `make format` | +| 代码检查 | `make pre_commit` | +| 提交到企业仓库 | `git commit + git push origin ` | +| 同步 CE 到公开仓库 | `git sync-public "message"` | +| 同步指定 commit | `git sync-public "message" ` | + +## 文件模板清单 + +新建插件时,通常需要创建如下文件: + +```text +extensions/memos_foo_plugin/ +├── __init__.py # 必须:包入口,re-export Plugin 类 +├── plugin.py # 必须:继承 MemOSPlugin,注册能力 +├── routes.py # 按需:FastAPI 路由 +├── hooks.py # 按需:Hook 回调函数 +├── middleware.py # 按需:Starlette 中间件 +├── hook_defs.py # 按需:插件自有 Hook 声明(有自定义 Hook 时需要) +└── tests/ + ├── __init__.py # 必须 + ├── conftest.py # 必须:声明测试中用到的 Hook + ├── test_lifecycle.py # 推荐:生命周期测试 + ├── test_hooks.py # 推荐:Hook 回调测试 + └── test_routes.py # 推荐:路由端点测试 +``` + +也可以直接复制 `extensions/memos_demo_plugin/` 作为模板,然后全局替换 `demo -> foo`、`Demo -> Foo`。 From 917ebccf8d98e0f7eaa20ef0517e42ef8856631c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 10 Mar 2026 21:14:46 +0800 Subject: [PATCH 77/92] feat: add memos_prompt_strategy_plugin --- .../memos_prompt_strategy_plugin/README.md | 134 +++++++ .../memos_prompt_strategy_plugin/__init__.py | 4 + .../classifier.py | 198 ++++++++++ .../memos_prompt_strategy_plugin/example.py | 132 +++++++ .../memos_prompt_strategy_plugin/hooks.py | 46 +++ .../memos_prompt_strategy_plugin/plugin.py | 44 +++ .../memos_prompt_strategy_plugin/routes.py | 34 ++ .../strategies.py | 346 ++++++++++++++++++ .../tests/__init__.py | 0 .../tests/conftest.py | 9 + .../tests/test_classifier.py | 106 ++++++ .../tests/test_lifecycle.py | 164 +++++++++ .../tests/test_strategies.py | 82 +++++ pyproject.toml | 2 + src/memos/mem_reader/multi_modal_struct.py | 14 + src/memos/plugins/hook_defs.py | 10 + 16 files changed, 1325 insertions(+) create mode 100644 extensions/memos_prompt_strategy_plugin/README.md create mode 100644 extensions/memos_prompt_strategy_plugin/__init__.py create mode 100644 extensions/memos_prompt_strategy_plugin/classifier.py create mode 100644 extensions/memos_prompt_strategy_plugin/example.py create mode 100644 extensions/memos_prompt_strategy_plugin/hooks.py create mode 100644 extensions/memos_prompt_strategy_plugin/plugin.py create mode 100644 extensions/memos_prompt_strategy_plugin/routes.py create mode 100644 extensions/memos_prompt_strategy_plugin/strategies.py create mode 100644 extensions/memos_prompt_strategy_plugin/tests/__init__.py create mode 100644 extensions/memos_prompt_strategy_plugin/tests/conftest.py create mode 100644 extensions/memos_prompt_strategy_plugin/tests/test_classifier.py create mode 100644 extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py create mode 100644 extensions/memos_prompt_strategy_plugin/tests/test_strategies.py diff --git a/extensions/memos_prompt_strategy_plugin/README.md b/extensions/memos_prompt_strategy_plugin/README.md new file mode 100644 index 000000000..42ebef812 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/README.md @@ -0,0 +1,134 @@ +# Prompt Strategy Plugin + +根据消息内容自动分类,为不同类型的对话选择最优的记忆提取 prompt。 + +## 解决什么问题 + +默认的 mem-reader 对所有消息使用同一套 prompt 模板(chat / doc / general\_string),无法针对不同对话场景做精细化提取。例如: + +- 闲聊中的偏好信息容易被忽略 +- 任务型对话的截止日期和行动项提取不够结构化 +- 代码讨论中的技术细节会被当作普通文本处理 + +本插件通过 **规则分类器 + 策略 prompt 注册表** 解决这个问题,整个过程对 CE 代码完全透明。 + +## 工作原理 + +``` +消息进入 mem-reader + ↓ +CE: _get_llm_response() 构建默认 prompt + ↓ +CE: trigger_hook("mem_reader.pre_extract") ← 通用扩展点 + ↓ +插件回调 on_pre_extract(): + 1. MessageClassifier 对消息分类 → 得到类别标签 + 2. StrategyRegistry 根据标签选择对应的 prompt 模板 + 3. 返回定制 prompt(或 None 走默认) + ↓ +CE: LLM.generate(prompt) +``` + +分类和 prompt 路由全部在插件内部完成,CE 只暴露了一个通用的 `mem_reader.pre_extract` 扩展点。 + +## 支持的分类类别 + +| 类别 | 判定规则 | Prompt 策略 | +|------|---------|------------| +| `casual_chat` | 单轮、短文本、无明确任务意图 | 轻量提取,关注偏好和习惯 | +| `task_oriented` | 含任务关键词(请、安排、deadline 等) | 结构化提取,关注任务、截止日期、约束 | +| `knowledge_sharing` | 长文本(>800 字符)、多段落 | 类文档提取,关注概念和定义 | +| `emotional` | 含情感词汇(开心、担心、grateful 等) | 关注情感状态、人际关系 | +| `code_discussion` | 含代码块或技术关键词 | 技术记忆提取,关注工具、框架、解决方案 | +| `multi_turn_qa` | 4+ 轮对话且含多个问句 | 关注结论和最终答案 | + +分类器采用**规则优先**策略,零 LLM 开销。可选配置 LLM 兜底处理模糊场景。 + +## 文件结构 + +``` +extensions/memos_prompt_strategy_plugin/ +├── __init__.py # 包入口,导出 PromptStrategyPlugin +├── plugin.py # 插件主类:生命周期 + 注册 +├── hooks.py # Hook 回调:on_pre_extract +├── classifier.py # 消息分类器(规则 + 可选 LLM) +├── strategies.py # 策略注册表 + 6 套 prompt 模板(中英文) +├── routes.py # 管理接口 +├── example.py # 可直接运行的分类演示 +└── tests/ + ├── conftest.py + ├── test_classifier.py + ├── test_strategies.py + └── test_lifecycle.py +``` + +## 快速体验 + +不需要启动服务,直接运行 example 查看分类效果: + +```bash +PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/example.py +``` + +## 安装与注册 + +`pyproject.toml` 中已包含以下配置: + +```toml +# 包路径 +[tool.poetry] +packages = [ + {include = "memos_prompt_strategy_plugin", from = "extensions"}, +] + +# Entry point +[project.entry-points."memos.plugins"] +prompt_strategy = "memos_prompt_strategy_plugin:PromptStrategyPlugin" +``` + +如果是首次安装,需要执行: + +```bash +pip install -e . +``` + +## 运行测试 + +```bash +PYTHONPATH="src:extensions" python -m pytest extensions/memos_prompt_strategy_plugin/tests/ -v +``` + +## 管理接口 + +插件注册了以下 HTTP 端点: + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/prompt_strategy/health` | GET | 插件健康检查 | +| `/prompt_strategy/strategies` | GET | 列出所有已注册的策略及描述 | +| `/prompt_strategy/stats` | GET | 查看各分类的命中次数统计 | + +## 自定义策略 + +可以在插件初始化后动态注册新策略: + +```python +from memos_prompt_strategy_plugin.strategies import PromptStrategy + +plugin.registry.register(PromptStrategy( + name="legal_discussion", + template_en="Extract legal terms, clauses...\n${conversation}\n${custom_tags_prompt}", + template_zh="提取法律术语、条款...\n${conversation}\n${custom_tags_prompt}", + description="法律讨论场景的记忆提取", +)) +``` + +同时在 `classifier.py` 中添加对应的分类规则即可生效。 + +## CE 依赖 + +本插件依赖 CE 侧的一个扩展点: + +- `mem_reader.pre_extract`:在 `MultiModalStructMemReader._get_llm_response()` 中,LLM 调用前触发 + +该扩展点声明在 `src/memos/plugins/hook_defs.py`,修改后需通过 `sync-public` 同步到公开仓库。 diff --git a/extensions/memos_prompt_strategy_plugin/__init__.py b/extensions/memos_prompt_strategy_plugin/__init__.py new file mode 100644 index 000000000..f4bd07412 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/__init__.py @@ -0,0 +1,4 @@ +from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + +__all__ = ["PromptStrategyPlugin"] diff --git a/extensions/memos_prompt_strategy_plugin/classifier.py b/extensions/memos_prompt_strategy_plugin/classifier.py new file mode 100644 index 000000000..29b78c835 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/classifier.py @@ -0,0 +1,198 @@ +"""Message classifier — categorize messages for prompt strategy selection. + +Uses a rule-based pipeline (zero LLM overhead) with optional LLM fallback +for ambiguous cases. +""" + +from __future__ import annotations + +import logging +import re + +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + +logger = logging.getLogger(__name__) + +# Category constants +CASUAL_CHAT = "casual_chat" +TASK_ORIENTED = "task_oriented" +KNOWLEDGE_SHARING = "knowledge_sharing" +EMOTIONAL = "emotional" +CODE_DISCUSSION = "code_discussion" +MULTI_TURN_QA = "multi_turn_qa" + +ALL_CATEGORIES = [ + CASUAL_CHAT, + TASK_ORIENTED, + KNOWLEDGE_SHARING, + EMOTIONAL, + CODE_DISCUSSION, + MULTI_TURN_QA, +] + +_CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```") +_TECH_KEYWORDS = re.compile( + r"\b(import|def |class |function |const |let |var |return |async |await " + r"|API|SDK|HTTP|JSON|SQL|docker|kubernetes|git|npm|pip)\b", + re.IGNORECASE, +) +_TASK_KEYWORDS_EN = re.compile( + r"\b(please|todo|deadline|schedule|remind|plan|book|order|buy|send|create" + r"|set up|configure|install|deploy|migrate|update)\b", + re.IGNORECASE, +) +_TASK_KEYWORDS_ZH = re.compile( + r"(请|帮我|提醒|安排|预定|预约|计划|任务|截止|部署|配置|安装|迁移|更新|设置|发送|创建|购买)", +) +_EMOTION_KEYWORDS_EN = re.compile( + r"\b(feel|happy|sad|angry|love|hate|miss|worry|afraid|grateful" + r"|excited|lonely|anxious|stressed|depressed|proud)\b", + re.IGNORECASE, +) +_EMOTION_KEYWORDS_ZH = re.compile( + r"(开心|难过|伤心|生气|爱|恨|想念|担心|害怕|感谢|感恩|兴奋|孤独|焦虑|压力|骄傲|烦|累|郁闷|失望|幸福)", +) +_QUESTION_PATTERNS = re.compile( + r"(\?\s*$|?\s*$|what|why|how|when|where|who|which|do you|can you|is it" + r"|什么|为什么|怎么|何时|哪里|谁|哪个|是否|能不能|会不会)", + re.IGNORECASE | re.MULTILINE, +) + + +def _extract_text(sources: list) -> str: + """Extract raw text content from sources for rule matching.""" + parts: list[str] = [] + for src in sources: + if isinstance(src, str): + parts.append(src) + elif hasattr(src, "content"): + parts.append(str(src.content)) + elif isinstance(src, dict): + parts.append(str(src.get("content", ""))) + return "\n".join(parts) + + +def _count_roles(sources: list) -> dict[str, int]: + """Count occurrences of each role in sources.""" + counts: dict[str, int] = {} + for src in sources: + role = None + if hasattr(src, "role"): + role = src.role + elif isinstance(src, dict): + role = src.get("role") + if role: + counts[role] = counts.get(role, 0) + 1 + return counts + + +class MessageClassifier: + """Rule-based message classifier with optional LLM fallback.""" + + def __init__(self, llm: Any | None = None): + self.llm = llm + self._rules: list[tuple[str, Callable[[list, str], str | None]]] = [ + ("code_discussion", self._check_code), + ("task_oriented", self._check_task), + ("emotional", self._check_emotion), + ("knowledge_sharing", self._check_knowledge), + ("multi_turn_qa", self._check_multi_turn_qa), + ("casual_chat", self._check_casual), + ] + + def classify( + self, + sources: list, + mem_str: str, + default_prompt_type: str, + info: dict[str, Any], + ) -> str: + """Classify messages and return a category label. + + Returns the default_prompt_type unchanged when no rule matches + and no LLM fallback is configured. + """ + text = _extract_text(sources) if sources else mem_str + + for _name, rule_fn in self._rules: + result = rule_fn(sources, text) + if result is not None: + return result + + if self.llm is not None: + return self._llm_classify(text, default_prompt_type) + + return default_prompt_type + + # ── Rule functions ────────────────────────────────────────────── + + @staticmethod + def _check_code(sources: list, text: str) -> str | None: + has_code_block = bool(_CODE_BLOCK_RE.search(text)) + tech_hits = len(_TECH_KEYWORDS.findall(text)) + if has_code_block or tech_hits >= 3: + return CODE_DISCUSSION + return None + + @staticmethod + def _check_task(sources: list, text: str) -> str | None: + en_hits = len(_TASK_KEYWORDS_EN.findall(text)) + zh_hits = len(_TASK_KEYWORDS_ZH.findall(text)) + if en_hits + zh_hits >= 2: + return TASK_ORIENTED + return None + + @staticmethod + def _check_emotion(sources: list, text: str) -> str | None: + en_hits = len(_EMOTION_KEYWORDS_EN.findall(text)) + zh_hits = len(_EMOTION_KEYWORDS_ZH.findall(text)) + if en_hits + zh_hits >= 2: + return EMOTIONAL + return None + + @staticmethod + def _check_knowledge(sources: list, text: str) -> str | None: + if len(text) > 800 and text.count("\n") >= 5: + return KNOWLEDGE_SHARING + return None + + @staticmethod + def _check_multi_turn_qa(sources: list, text: str) -> str | None: + role_counts = _count_roles(sources) + total_turns = sum(role_counts.values()) + question_hits = len(_QUESTION_PATTERNS.findall(text)) + if total_turns >= 4 and question_hits >= 2: + return MULTI_TURN_QA + return None + + @staticmethod + def _check_casual(sources: list, text: str) -> str | None: + role_counts = _count_roles(sources) + total_turns = sum(role_counts.values()) + if total_turns <= 2 and len(text) < 200: + return CASUAL_CHAT + return None + + # ── LLM fallback ──────────────────────────────────────────────── + + def _llm_classify(self, text: str, default: str) -> str: + categories_str = ", ".join(ALL_CATEGORIES) + prompt = ( + f"Classify the following conversation into exactly one category.\n" + f"Categories: {categories_str}\n\n" + f"Conversation:\n{text[:2000]}\n\n" + f"Reply with ONLY the category name, nothing else." + ) + try: + result = self.llm.generate([{"role": "user", "content": prompt}]) + label = result.strip().lower().replace(" ", "_") + if label in ALL_CATEGORIES: + return label + logger.warning("[PromptStrategy] LLM returned unknown category: %s", label) + except Exception: + logger.exception("[PromptStrategy] LLM classification failed") + return default diff --git a/extensions/memos_prompt_strategy_plugin/example.py b/extensions/memos_prompt_strategy_plugin/example.py new file mode 100644 index 000000000..8d28a7fc6 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/example.py @@ -0,0 +1,132 @@ +"""Quick demo — run directly to see classification + prompt strategy in action. + +Usage: + PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/example.py +""" + +from memos_prompt_strategy_plugin.classifier import MessageClassifier +from memos_prompt_strategy_plugin.strategies import StrategyRegistry + + +def _src(role: str, content: str): + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + +DEMO_CONVERSATIONS = { + "casual_chat": { + "sources": [_src("user", "Hey! I really like Italian food, especially pasta.")], + "mem_str": "Hey! I really like Italian food, especially pasta.", + }, + "task_oriented": { + "sources": [ + _src("user", "请帮我安排明天下午3点的会议,提醒我截止日期是周五"), + _src("assistant", "好的,已经帮你安排好了"), + ], + "mem_str": "请帮我安排明天下午3点的会议,提醒我截止日期是周五\n好的,已经帮你安排好了", + }, + "code_discussion": { + "sources": [ + _src( + "user", + "I'm getting an error with my FastAPI app:\n" + "```python\n" + "from fastapi import APIRouter\n" + "router = APIRouter()\n" + "@router.get('/health')\n" + "async def health(): return {'ok': True}\n" + "```\n" + "The import fails after upgrading the SDK.", + ), + ], + "mem_str": ( + "I'm getting an error with my FastAPI app:\n" + "```python\nfrom fastapi import APIRouter\n```\n" + "The import fails after upgrading the SDK." + ), + }, + "emotional": { + "sources": [ + _src("user", "今天特别开心,终于完成了马拉松,虽然很累但是非常骄傲"), + ], + "mem_str": "今天特别开心,终于完成了马拉松,虽然很累但是非常骄傲", + }, + "multi_turn_qa": { + "sources": [ + _src("user", "What's the difference between Redis and Memcached?"), + _src("assistant", "Redis supports more data structures. What's your use case?"), + _src("user", "I need pub/sub and sorted sets. Which one fits?"), + _src("assistant", "Redis is the clear choice for pub/sub and sorted sets."), + ], + "mem_str": ( + "What's the difference between Redis and Memcached?\n" + "Redis supports more data structures. What's your use case?\n" + "I need pub/sub and sorted sets. Which one fits?\n" + "Redis is the clear choice for pub/sub and sorted sets." + ), + }, + "knowledge_sharing": { + "sources": [ + _src( + "user", + "\n".join( + [ + "Transformers are a neural network architecture introduced in 2017.", + "They use self-attention mechanisms to process sequences in parallel.", + "Unlike RNNs, transformers don't require sequential computation.", + "The key components are: multi-head attention, feed-forward networks,", + "layer normalization, and positional encoding.", + "Pre-training on large corpora followed by fine-tuning has become", + "the dominant paradigm in NLP since BERT (2018).", + ] + * 3 + ), + ), + ], + "mem_str": "Transformers are a neural network architecture..." + "x" * 800, + }, +} + +SEPARATOR = "=" * 72 + + +def main(): + clf = MessageClassifier() + reg = StrategyRegistry() + reg.register_defaults() + + print(SEPARATOR) + print(" Prompt Strategy Plugin — Classification Demo") + print(SEPARATOR) + + for label, data in DEMO_CONVERSATIONS.items(): + sources = data["sources"] + mem_str = data["mem_str"] + + category = clf.classify(sources, mem_str, "chat", {}) + prompt = reg.build_prompt( + category, "zh" if any("\u4e00" <= c <= "\u9fff" for c in mem_str) else "en", mem_str + ) + + status = "MATCH" if category == label else "MISMATCH" + print(f"\n{'—' * 72}") + print(f" Scenario : {label}") + print(f" Classified: {category} [{status}]") + print(f" Input : {mem_str[:80]}{'...' if len(mem_str) > 80 else ''}") + print( + f" Prompt : {prompt[:120]}..." if prompt else " Prompt : (default, no override)" + ) + print(f"{'—' * 72}") + + print(f"\n{SEPARATOR}") + print(" All scenarios processed.") + print(SEPARATOR) + + +if __name__ == "__main__": + main() diff --git a/extensions/memos_prompt_strategy_plugin/hooks.py b/extensions/memos_prompt_strategy_plugin/hooks.py new file mode 100644 index 000000000..d48ceae0a --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/hooks.py @@ -0,0 +1,46 @@ +"""Prompt strategy plugin hook callbacks. + +All callbacks are bound to the plugin instance via functools.partial(callback, plugin_instance). +""" + +from __future__ import annotations + +import logging + +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + +logger = logging.getLogger(__name__) + + +def on_pre_extract( + plugin: PromptStrategyPlugin, + *, + prompt: str, + prompt_type: str, + mem_str: str, + lang: str, + sources: list, + **_kw: Any, +) -> str | None: + """[mem_reader.pre_extract] Classify the message and swap in a tailored prompt.""" + category = plugin.classifier.classify(sources, mem_str, prompt_type, info={}) + plugin.stats[category] += 1 + + if category != prompt_type: + logger.info("[PromptStrategy] Classified as %s (was %s)", category, prompt_type) + + custom_prompt = plugin.registry.build_prompt( + category=category, + lang=lang, + mem_str=mem_str, + custom_tags=None, + ) + if custom_prompt is not None: + logger.debug("[PromptStrategy] Using strategy prompt for %s", category) + return custom_prompt + + return None diff --git a/extensions/memos_prompt_strategy_plugin/plugin.py b/extensions/memos_prompt_strategy_plugin/plugin.py new file mode 100644 index 000000000..74cbfebce --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/plugin.py @@ -0,0 +1,44 @@ +"""PromptStrategyPlugin — classify messages and select specialised prompts.""" + +from __future__ import annotations + +import logging + +from collections import defaultdict +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H + + +logger = logging.getLogger(__name__) + + +class PromptStrategyPlugin(MemOSPlugin): + name = "prompt_strategy" + version = "0.1.0" + description = "Classify messages and apply category-specific extraction prompts" + + def on_load(self) -> None: + from memos_prompt_strategy_plugin.classifier import MessageClassifier + from memos_prompt_strategy_plugin.strategies import StrategyRegistry + + self.classifier = MessageClassifier() + self.registry = StrategyRegistry() + self.registry.register_defaults() + self.stats: dict[str, int] = defaultdict(int) + logger.info("[PromptStrategy] plugin loaded") + + def init_app(self) -> None: + from memos_prompt_strategy_plugin.hooks import on_pre_extract + from memos_prompt_strategy_plugin.routes import create_router + + self.register_router(create_router(self)) + self.register_hook(H.MEM_READER_PRE_EXTRACT, partial(on_pre_extract, self)) + logger.info("[PromptStrategy] plugin initialized") + + def on_shutdown(self) -> None: + logger.info( + "[PromptStrategy] plugin shutdown — classification stats: %s", + dict(self.stats), + ) diff --git a/extensions/memos_prompt_strategy_plugin/routes.py b/extensions/memos_prompt_strategy_plugin/routes.py new file mode 100644 index 000000000..649b08370 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/routes.py @@ -0,0 +1,34 @@ +"""Admin routes for the Prompt Strategy plugin.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter + + +if TYPE_CHECKING: + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + +def create_router(plugin: PromptStrategyPlugin) -> APIRouter: + router = APIRouter(prefix="/prompt_strategy", tags=["prompt_strategy"]) + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name, "version": plugin.version} + + @router.get("/strategies") + async def list_strategies(): + """Return all registered prompt strategies.""" + return { + name: {"description": s.description} + for name, s in plugin.registry.all_strategies().items() + } + + @router.get("/stats") + async def classification_stats(): + """Return per-category classification hit counts.""" + return {"stats": dict(plugin.stats)} + + return router diff --git a/extensions/memos_prompt_strategy_plugin/strategies.py b/extensions/memos_prompt_strategy_plugin/strategies.py new file mode 100644 index 000000000..60cf6e00b --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/strategies.py @@ -0,0 +1,346 @@ +"""Prompt strategy registry — maps message categories to specialised prompts. + +All strategy prompts produce the same JSON output format as the default +mem-reader (memory list + summary) so downstream processing stays unchanged. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PromptStrategy: + name: str + template_en: str + template_zh: str + description: str + + +class StrategyRegistry: + """Thread-safe registry that maps category labels to prompt strategies.""" + + def __init__(self): + self._strategies: dict[str, PromptStrategy] = {} + + def register(self, strategy: PromptStrategy) -> None: + self._strategies[strategy.name] = strategy + logger.info("[PromptStrategy] Registered strategy: %s", strategy.name) + + def get(self, name: str) -> PromptStrategy | None: + return self._strategies.get(name) + + def all_strategies(self) -> dict[str, PromptStrategy]: + return dict(self._strategies) + + def build_prompt( + self, + category: str, + lang: str, + mem_str: str, + custom_tags: list[str] | None = None, + ) -> str | None: + """Build a prompt for *category*. Returns ``None`` when the category + has no registered strategy (caller should fall back to the default).""" + strategy = self._strategies.get(category) + if strategy is None: + return None + + template = strategy.template_zh if lang == "zh" else strategy.template_en + prompt = template.replace("${conversation}", mem_str) + prompt = prompt.replace("{chunk_text}", mem_str) + + if custom_tags: + tags_instruction = ( + f"\n额外要求:提取的记忆请尽量关联以下标签:{custom_tags}" + if lang == "zh" + else f"\nAdditional: try to associate extracted memories with these tags: {custom_tags}" + ) + else: + tags_instruction = "" + prompt = prompt.replace("${custom_tags_prompt}", tags_instruction) + prompt = prompt.replace("{custom_tags_prompt}", tags_instruction) + + return prompt + + def register_defaults(self) -> None: + """Register built-in strategies for all standard categories.""" + for strategy in _DEFAULT_STRATEGIES: + self.register(strategy) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Default prompt templates per category +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +_CASUAL_CHAT_EN = """\ +You are a memory extraction expert specialised in casual conversations. +Extract lightweight memories focusing on the user's **personal preferences, habits, opinions, and lifestyle details**. +Skip trivial greetings or filler. Only retain information the user would realistically remember. + +Conversation: +${conversation} + +${custom_tags_prompt} + +Return a single valid JSON object: +{ + "memory list": [ + {"key": "", "memory_type": "UserMemory", "value": "", "tags": ["", ...]} + ], + "summary": "<1-2 sentence summary from the user's perspective>" +} +""" + +_CASUAL_CHAT_ZH = """\ +你是一位擅长提取日常闲聊记忆的专家。 +请提取**用户的个人偏好、习惯、观点和生活细节**相关的轻量级记忆。 +跳过无意义的寒暄和填充语。只保留用户真正会记住的信息。 + +对话内容: +${conversation} + +${custom_tags_prompt} + +返回一个合法 JSON 对象: +{ + "memory list": [ + {"key": "<简短标题>", "memory_type": "UserMemory", "value": "<第三人称陈述>", "tags": ["<标签>", ...]} + ], + "summary": "<1-2 句话从用户视角总结>" +} +""" + +_TASK_ORIENTED_EN = """\ +You are a memory extraction expert specialised in task-oriented conversations. +Focus on extracting **tasks, plans, deadlines, action items, constraints, and commitments** the user discussed. +Resolve all time references to absolute dates when possible. + +Conversation: +${conversation} + +${custom_tags_prompt} + +Return a single valid JSON object: +{ + "memory list": [ + {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + ], + "summary": "" +} +""" + +_TASK_ORIENTED_ZH = """\ +你是一位擅长提取任务型对话记忆的专家。 +专注提取用户讨论的**任务、计划、截止日期、行动项、约束条件和承诺**。 +尽可能将所有时间引用转换为绝对日期。 + +对话内容: +${conversation} + +${custom_tags_prompt} + +返回一个合法 JSON 对象: +{ + "memory list": [ + {"key": "<简短标题>", "memory_type": "LongTermMemory", "value": "<第三人称,包含日期和约束的详细陈述>", "tags": ["<标签>", ...]} + ], + "summary": "<总结任务和截止日期的段落,80-150字>" +} +""" + +_KNOWLEDGE_SHARING_EN = """\ +You are a memory extraction expert specialised in knowledge-sharing content. +Extract **key concepts, definitions, explanations, facts, and learned insights** from the conversation. +Treat this like document-level extraction: capture knowledge points completely. + +Content: +{chunk_text} + +{custom_tags_prompt} + +Return a single valid JSON object: +{ + "memory list": [ + {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + ], + "summary": "" +} +""" + +_KNOWLEDGE_SHARING_ZH = """\ +你是一位擅长提取知识分享内容的记忆专家。 +提取对话中的**核心概念、定义、解释、事实和学习到的见解**。 +像文档级别提取一样,完整捕获知识点。 + +内容: +{chunk_text} + +{custom_tags_prompt} + +返回一个合法 JSON 对象: +{ + "memory list": [ + {"key": "<概念或主题>", "memory_type": "LongTermMemory", "value": "<完整、自包含的解释>", "tags": ["<标签>", ...]} + ], + "summary": "<总结分享知识的段落,100-200字>" +} +""" + +_EMOTIONAL_EN = """\ +You are a memory extraction expert specialised in emotional and relational conversations. +Focus on extracting the user's **emotional states, feelings, relationship dynamics, personal concerns, and significant life events**. +Preserve emotional nuance — do not flatten sentiment into generic labels. + +Conversation: +${conversation} + +${custom_tags_prompt} + +Return a single valid JSON object: +{ + "memory list": [ + {"key": "", "memory_type": "UserMemory", "value": "", "tags": ["", ...]} + ], + "summary": "" +} +""" + +_EMOTIONAL_ZH = """\ +你是一位擅长提取情感与人际关系对话记忆的专家。 +专注提取用户的**情感状态、感受、人际关系动态、个人关切和重要生活事件**。 +保留情感细节,不要将情绪简化为泛泛的标签。 + +对话内容: +${conversation} + +${custom_tags_prompt} + +返回一个合法 JSON 对象: +{ + "memory list": [ + {"key": "<简短标题>", "memory_type": "UserMemory", "value": "<第三人称陈述,捕获情感和上下文>", "tags": ["<标签>", ...]} + ], + "summary": "<从用户视角出发的共情式总结,80-150字>" +} +""" + +_CODE_DISCUSSION_EN = """\ +You are a memory extraction expert specialised in technical and code discussions. +Extract **tools, frameworks, libraries, technical decisions, code patterns, bugs, solutions, and architecture choices** discussed by the user. +Include version numbers, configuration details, and error descriptions when available. + +Conversation: +${conversation} + +${custom_tags_prompt} + +Return a single valid JSON object: +{ + "memory list": [ + {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + ], + "summary": "" +} +""" + +_CODE_DISCUSSION_ZH = """\ +你是一位擅长提取技术和代码讨论记忆的专家。 +提取用户讨论的**工具、框架、库、技术决策、代码模式、Bug、解决方案和架构选择**。 +在可用时包含版本号、配置细节和错误描述。 + +对话内容: +${conversation} + +${custom_tags_prompt} + +返回一个合法 JSON 对象: +{ + "memory list": [ + {"key": "<技术主题>", "memory_type": "LongTermMemory", "value": "<第三人称陈述,包含技术细节>", "tags": ["<标签>", ...]} + ], + "summary": "<技术总结,80-150字>" +} +""" + +_MULTI_TURN_QA_EN = """\ +You are a memory extraction expert specialised in multi-turn Q&A conversations. +Extract the **conclusions, clarifications, and final answers** that emerged through the Q&A process. +Focus on what the user learned or decided, not intermediate back-and-forth. + +Conversation: +${conversation} + +${custom_tags_prompt} + +Return a single valid JSON object: +{ + "memory list": [ + {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + ], + "summary": "" +} +""" + +_MULTI_TURN_QA_ZH = """\ +你是一位擅长提取多轮问答对话记忆的专家。 +提取通过问答过程得出的**结论、澄清和最终答案**。 +关注用户学到了什么或做了什么决定,而非中间的来回讨论。 + +对话内容: +${conversation} + +${custom_tags_prompt} + +返回一个合法 JSON 对象: +{ + "memory list": [ + {"key": "<问题主题>", "memory_type": "LongTermMemory", "value": "<第三人称陈述结论或答案>", "tags": ["<标签>", ...]} + ], + "summary": "<总结问答中的关键结论,80-150字>" +} +""" + +_DEFAULT_STRATEGIES = [ + PromptStrategy( + name="casual_chat", + template_en=_CASUAL_CHAT_EN, + template_zh=_CASUAL_CHAT_ZH, + description="Lightweight extraction for casual conversation — preferences, habits, opinions", + ), + PromptStrategy( + name="task_oriented", + template_en=_TASK_ORIENTED_EN, + template_zh=_TASK_ORIENTED_ZH, + description="Structured extraction for tasks, plans, deadlines, and action items", + ), + PromptStrategy( + name="knowledge_sharing", + template_en=_KNOWLEDGE_SHARING_EN, + template_zh=_KNOWLEDGE_SHARING_ZH, + description="Document-style extraction for concepts, definitions, and learned insights", + ), + PromptStrategy( + name="emotional", + template_en=_EMOTIONAL_EN, + template_zh=_EMOTIONAL_ZH, + description="Emotion-aware extraction for feelings, relationships, and personal concerns", + ), + PromptStrategy( + name="code_discussion", + template_en=_CODE_DISCUSSION_EN, + template_zh=_CODE_DISCUSSION_ZH, + description="Technical extraction for tools, frameworks, bugs, and architecture decisions", + ), + PromptStrategy( + name="multi_turn_qa", + template_en=_MULTI_TURN_QA_EN, + template_zh=_MULTI_TURN_QA_ZH, + description="Conclusion-focused extraction for multi-turn Q&A conversations", + ), +] diff --git a/extensions/memos_prompt_strategy_plugin/tests/__init__.py b/extensions/memos_prompt_strategy_plugin/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/extensions/memos_prompt_strategy_plugin/tests/conftest.py b/extensions/memos_prompt_strategy_plugin/tests/conftest.py new file mode 100644 index 000000000..1729d8fdf --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/conftest.py @@ -0,0 +1,9 @@ +"""Ensure hooks used by PromptStrategyPlugin are declared for testing.""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") + +import memos.plugins.hook_defs # noqa: E402, F401 — triggers CE hook declarations diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py b/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py new file mode 100644 index 000000000..5fe55c1d8 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py @@ -0,0 +1,106 @@ +"""Tests for MessageClassifier rule-based classification.""" + +from memos_prompt_strategy_plugin.classifier import ( + CASUAL_CHAT, + CODE_DISCUSSION, + EMOTIONAL, + KNOWLEDGE_SHARING, + MULTI_TURN_QA, + TASK_ORIENTED, + MessageClassifier, +) + + +def _src(role: str, content: str): + """Helper: lightweight source-like object.""" + + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + +class TestClassifierRules: + def setup_method(self): + self.clf = MessageClassifier() + + # ── code_discussion ────────────────────────────────────────── + + def test_code_block_triggers_code(self): + sources = [_src("user", "Here is my code:\n```python\nprint('hello')\n```")] + result = self.clf.classify(sources, "", "chat", {}) + assert result == CODE_DISCUSSION + + def test_tech_keywords_trigger_code(self): + sources = [ + _src("user", "I need to import the SDK and call the API via HTTP with JSON payload") + ] + result = self.clf.classify(sources, "", "chat", {}) + assert result == CODE_DISCUSSION + + # ── task_oriented ──────────────────────────────────────────── + + def test_task_keywords_en(self): + sources = [_src("user", "Please schedule a meeting and set up the deployment pipeline")] + result = self.clf.classify(sources, "", "chat", {}) + assert result == TASK_ORIENTED + + def test_task_keywords_zh(self): + sources = [_src("user", "请帮我安排明天的会议,并提醒我截止日期")] + result = self.clf.classify(sources, "", "chat", {}) + assert result == TASK_ORIENTED + + # ── emotional ──────────────────────────────────────────────── + + def test_emotion_en(self): + sources = [_src("user", "I feel so happy and grateful for everything today")] + result = self.clf.classify(sources, "", "chat", {}) + assert result == EMOTIONAL + + def test_emotion_zh(self): + sources = [_src("user", "今天特别开心,也很感恩身边的朋友")] + result = self.clf.classify(sources, "", "chat", {}) + assert result == EMOTIONAL + + # ── knowledge_sharing ──────────────────────────────────────── + + def test_long_text_knowledge(self): + long_text = "This is a detailed explanation of how transformers work.\n" * 20 + sources = [_src("user", long_text)] + result = self.clf.classify(sources, "", "chat", {}) + assert result == KNOWLEDGE_SHARING + + # ── multi_turn_qa ──────────────────────────────────────────── + + def test_multi_turn_qa(self): + sources = [ + _src("user", "What is the best approach for caching?"), + _src("assistant", "It depends on your use case. What latency do you need?"), + _src("user", "Under 100ms. Which solution fits?"), + _src("assistant", "Redis would be ideal for that latency requirement."), + ] + text = "\n".join(s.content for s in sources) + result = self.clf.classify(sources, text, "chat", {}) + assert result == MULTI_TURN_QA + + # ── casual_chat ────────────────────────────────────────────── + + def test_short_casual(self): + sources = [_src("user", "Hey, nice weather today!")] + result = self.clf.classify(sources, "", "chat", {}) + assert result == CASUAL_CHAT + + # ── fallback ───────────────────────────────────────────────── + + def test_no_match_returns_default(self): + sources = [ + _src("user", "Let me think about that for a moment"), + _src("assistant", "Sure, take your time"), + _src("user", "Okay I have decided"), + ] + text = "\n".join(s.content for s in sources) + result = self.clf.classify(sources, text, "chat", {}) + assert result == "chat" diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py new file mode 100644 index 000000000..cc3edca7e --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py @@ -0,0 +1,164 @@ +"""Tests for PromptStrategyPlugin lifecycle and hook integration.""" + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def _make_app(): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + app = FastAPI() + plugin = PromptStrategyPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + return app, plugin + + +class TestPluginLifecycle: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_metadata(self): + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + plugin = PromptStrategyPlugin() + assert plugin.name == "prompt_strategy" + assert plugin.version == "0.1.0" + + def test_on_load_initialises_components(self): + from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin + + plugin = PromptStrategyPlugin() + plugin.on_load() + assert plugin.classifier is not None + assert plugin.registry is not None + assert len(plugin.registry.all_strategies()) > 0 + assert dict(plugin.stats) == {} + + def test_full_lifecycle(self): + app, plugin = _make_app() + paths = [r.path for r in app.routes] + assert "/prompt_strategy/health" in paths + assert "/prompt_strategy/strategies" in paths + assert "/prompt_strategy/stats" in paths + plugin.on_shutdown() + + +class TestPluginRoutes: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_health(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/prompt_strategy/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["plugin"] == "prompt_strategy" + + def test_strategies_list(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/prompt_strategy/strategies") + assert resp.status_code == 200 + strategies = resp.json() + assert "casual_chat" in strategies + assert "task_oriented" in strategies + + def test_stats_empty(self): + app, _ = _make_app() + client = TestClient(app) + resp = client.get("/prompt_strategy/stats") + assert resp.status_code == 200 + assert resp.json()["stats"] == {} + + +class TestHookIntegration: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_pre_extract_classifies_and_swaps_prompt(self): + """When the plugin classifies a message as task_oriented, it returns + the task-specific prompt containing the original mem_str.""" + from memos.plugins.hooks import trigger_hook + + _, plugin = _make_app() + + result = trigger_hook( + "mem_reader.pre_extract", + prompt="original prompt", + prompt_type="chat", + mem_str="Please schedule a meeting and remind me about the deadline", + lang="en", + sources=[], + ) + assert result != "original prompt" + assert "schedule a meeting" in result + assert plugin.stats["task_oriented"] >= 1 + + def test_pre_extract_preserves_prompt_when_no_rule_matches(self): + """When no classifier rule matches and the default prompt_type has no + registered strategy, the original prompt passes through unchanged.""" + from memos.plugins.hooks import trigger_hook + + def _src(role, content): + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + _make_app() + + sources = [ + _src("user", "Let me think about that"), + _src("assistant", "Sure, take your time"), + _src("user", "Okay I have decided"), + ] + result = trigger_hook( + "mem_reader.pre_extract", + prompt="original prompt", + prompt_type="chat", + mem_str="Let me think about that\nSure, take your time\nOkay I have decided", + lang="en", + sources=sources, + ) + assert result == "original prompt" + + def test_pre_extract_tracks_stats(self): + from memos.plugins.hooks import trigger_hook + + _, plugin = _make_app() + + trigger_hook( + "mem_reader.pre_extract", + prompt="p", + prompt_type="chat", + mem_str="I feel so happy and grateful today", + lang="en", + sources=[], + ) + trigger_hook( + "mem_reader.pre_extract", + prompt="p", + prompt_type="chat", + mem_str="```python\nimport os\nprint(os.getcwd())\n```", + lang="en", + sources=[], + ) + assert plugin.stats["emotional"] >= 1 + assert plugin.stats["code_discussion"] >= 1 diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py b/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py new file mode 100644 index 000000000..dcb41a0a1 --- /dev/null +++ b/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py @@ -0,0 +1,82 @@ +"""Tests for StrategyRegistry and prompt building.""" + +from memos_prompt_strategy_plugin.strategies import PromptStrategy, StrategyRegistry + + +class TestStrategyRegistry: + def setup_method(self): + self.reg = StrategyRegistry() + self.reg.register_defaults() + + def test_all_default_strategies_registered(self): + strategies = self.reg.all_strategies() + expected = { + "casual_chat", + "task_oriented", + "knowledge_sharing", + "emotional", + "code_discussion", + "multi_turn_qa", + } + assert set(strategies.keys()) == expected + + def test_build_prompt_returns_none_for_unknown(self): + result = self.reg.build_prompt("nonexistent_category", "en", "hello") + assert result is None + + def test_build_prompt_en(self): + prompt = self.reg.build_prompt("casual_chat", "en", "Hey, how are you?") + assert prompt is not None + assert "Hey, how are you?" in prompt + assert "preferences" in prompt.lower() or "habits" in prompt.lower() + + def test_build_prompt_zh(self): + prompt = self.reg.build_prompt("casual_chat", "zh", "你好,最近怎么样?") + assert prompt is not None + assert "你好,最近怎么样?" in prompt + assert "偏好" in prompt or "习惯" in prompt + + def test_build_prompt_with_custom_tags(self): + prompt = self.reg.build_prompt( + "task_oriented", "en", "Please deploy by Friday", custom_tags=["deadline", "ops"] + ) + assert prompt is not None + assert "deadline" in prompt + assert "ops" in prompt + + def test_task_oriented_has_deadline_focus(self): + prompt = self.reg.build_prompt("task_oriented", "en", "meeting at 3pm") + assert prompt is not None + assert "deadline" in prompt.lower() or "task" in prompt.lower() + + def test_code_discussion_has_tech_focus(self): + prompt = self.reg.build_prompt("code_discussion", "en", "fix the bug") + assert prompt is not None + assert "framework" in prompt.lower() or "tool" in prompt.lower() + + def test_emotional_has_feeling_focus(self): + prompt = self.reg.build_prompt("emotional", "en", "I feel sad") + assert prompt is not None + assert "emotion" in prompt.lower() or "feeling" in prompt.lower() + + def test_custom_strategy_registration(self): + custom = PromptStrategy( + name="custom_test", + template_en="Extract from: ${conversation} ${custom_tags_prompt}", + template_zh="提取:${conversation} ${custom_tags_prompt}", + description="Test strategy", + ) + self.reg.register(custom) + prompt = self.reg.build_prompt("custom_test", "en", "hello world") + assert prompt is not None + assert "hello world" in prompt + + +class TestStrategyRegistryIsolation: + def test_empty_registry_returns_none(self): + reg = StrategyRegistry() + assert reg.build_prompt("casual_chat", "en", "hi") is None + + def test_get_unknown_returns_none(self): + reg = StrategyRegistry() + assert reg.get("nonexistent") is None diff --git a/pyproject.toml b/pyproject.toml index 137f114e2..793ae1fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ memos = "memos.cli:main" [project.entry-points."memos.plugins"] demo = "memos_demo_plugin:DemoPlugin" +prompt_strategy = "memos_prompt_strategy_plugin:PromptStrategyPlugin" [project.optional-dependencies] # These are optional dependencies for various features of MemoryOS. @@ -165,6 +166,7 @@ build-backend = "poetry.core.masonry.api" packages = [ {include = "memos", from = "src"}, {include = "memos_demo_plugin", from = "extensions"}, + {include = "memos_prompt_strategy_plugin", from = "extensions"}, ] requires-poetry = ">=2.0" dependencies = { "python" = ">=3.10,<4.0" } diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index f9167319f..77ac9e5f2 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -451,6 +451,20 @@ def _get_llm_response( if self.config.remove_prompt_example and examples: prompt = prompt.replace(examples, "") + + from memos.plugins.hook_defs import H as _H + from memos.plugins.hooks import trigger_hook as _trigger_hook + + _rv = _trigger_hook( + _H.MEM_READER_PRE_EXTRACT, + prompt=prompt, + prompt_type=prompt_type, + mem_str=mem_str, + lang=lang, + sources=sources, + ) + prompt = _rv if _rv is not None else prompt + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py index a3b9dbaf8..f7fb237ca 100644 --- a/src/memos/plugins/hook_defs.py +++ b/src/memos/plugins/hook_defs.py @@ -75,6 +75,9 @@ class H: # Custom Hook (manually triggered via trigger_hook) ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" + # mem_reader — generic extension point before LLM extraction + MEM_READER_PRE_EXTRACT = "mem_reader.pre_extract" + # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # CE custom Hook declarations (@hookable-generated ones need not be declared here) @@ -86,3 +89,10 @@ class H: params=["request", "result"], pipe_key="result", ) + +define_hook( + H.MEM_READER_PRE_EXTRACT, + description="Customize prompt before mem_reader LLM extraction", + params=["prompt", "prompt_type", "mem_str", "lang", "sources"], + pipe_key="prompt", +) From 968ce3ee7fc842b3d9fe88d553e006ee89d4351a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 16:00:20 +0800 Subject: [PATCH 78/92] feat: update memos_prompt_strategy --- .../memos_prompt_strategy_plugin/README.md | 72 +-- .../classifier.py | 187 +++----- .../memos_prompt_strategy_plugin/example.py | 126 ++---- .../memos_prompt_strategy_plugin/hooks.py | 13 +- .../memos_prompt_strategy_plugin/plugin.py | 6 +- .../strategies.py | 416 +++++++++--------- .../tests/test_classifier.py | 180 ++++---- .../tests/test_lifecycle.py | 74 ++-- .../tests/test_strategies.py | 53 +-- 9 files changed, 482 insertions(+), 645 deletions(-) diff --git a/extensions/memos_prompt_strategy_plugin/README.md b/extensions/memos_prompt_strategy_plugin/README.md index 42ebef812..862d082fe 100644 --- a/extensions/memos_prompt_strategy_plugin/README.md +++ b/extensions/memos_prompt_strategy_plugin/README.md @@ -1,16 +1,12 @@ # Prompt Strategy Plugin -根据消息内容自动分类,为不同类型的对话选择最优的记忆提取 prompt。 +检测对话中的身份/亲属关系命名模式(如"我叫xxx"、"我的儿子叫xxx"),自动切换为专用的身份关系提取 prompt。 ## 解决什么问题 -默认的 mem-reader 对所有消息使用同一套 prompt 模板(chat / doc / general\_string),无法针对不同对话场景做精细化提取。例如: +默认的 mem-reader prompt 是通用型的,在遇到"我叫王沐辰,我的儿子叫王明泽"这类包含姓名和关系信息的对话时,可能无法精确提取出所有人名和关系。本插件增加一条专用规则:一旦检测到身份/关系命名句式,就替换为专门强调"不遗漏任何人名和关系"的提取 prompt。 -- 闲聊中的偏好信息容易被忽略 -- 任务型对话的截止日期和行动项提取不够结构化 -- 代码讨论中的技术细节会被当作普通文本处理 - -本插件通过 **规则分类器 + 策略 prompt 注册表** 解决这个问题,整个过程对 CE 代码完全透明。 +如果对话不包含这类句式,插件不做任何改动,走 CE 默认流程。 ## 工作原理 @@ -22,27 +18,23 @@ CE: _get_llm_response() 构建默认 prompt CE: trigger_hook("mem_reader.pre_extract") ← 通用扩展点 ↓ 插件回调 on_pre_extract(): - 1. MessageClassifier 对消息分类 → 得到类别标签 - 2. StrategyRegistry 根据标签选择对应的 prompt 模板 - 3. 返回定制 prompt(或 None 走默认) + 1. MessageClassifier 检查是否命中身份/关系命名规则 + 2. 命中 → 返回专用 identity_relation prompt + 3. 不命中 → 返回 None,CE 使用默认 prompt ↓ CE: LLM.generate(prompt) ``` -分类和 prompt 路由全部在插件内部完成,CE 只暴露了一个通用的 `mem_reader.pre_extract` 扩展点。 - -## 支持的分类类别 +## 命中规则 -| 类别 | 判定规则 | Prompt 策略 | -|------|---------|------------| -| `casual_chat` | 单轮、短文本、无明确任务意图 | 轻量提取,关注偏好和习惯 | -| `task_oriented` | 含任务关键词(请、安排、deadline 等) | 结构化提取,关注任务、截止日期、约束 | -| `knowledge_sharing` | 长文本(>800 字符)、多段落 | 类文档提取,关注概念和定义 | -| `emotional` | 含情感词汇(开心、担心、grateful 等) | 关注情感状态、人际关系 | -| `code_discussion` | 含代码块或技术关键词 | 技术记忆提取,关注工具、框架、解决方案 | -| `multi_turn_qa` | 4+ 轮对话且含多个问句 | 关注结论和最终答案 | +| 模式 | 示例 | +|------|------| +| 自我命名(中文) | 我叫xxx、我是xxx、我的名字是xxx | +| 亲属/社交关系命名(中文) | 我的儿子叫xxx、我老婆是xxx、我妈妈叫xxx、我朋友叫xxx | +| 自我命名(英文) | My name is xxx、I'm xxx、Call me xxx | +| 关系命名(英文) | My son is called xxx、My wife's name is xxx | -分类器采用**规则优先**策略,零 LLM 开销。可选配置 LLM 兜底处理模糊场景。 +支持的关系词:儿子、女儿、老婆、老公、爸爸、妈妈、哥哥、姐姐、弟弟、妹妹、爷爷、奶奶、朋友、同事、同学、宠物等。 ## 文件结构 @@ -51,10 +43,10 @@ extensions/memos_prompt_strategy_plugin/ ├── __init__.py # 包入口,导出 PromptStrategyPlugin ├── plugin.py # 插件主类:生命周期 + 注册 ├── hooks.py # Hook 回调:on_pre_extract -├── classifier.py # 消息分类器(规则 + 可选 LLM) -├── strategies.py # 策略注册表 + 6 套 prompt 模板(中英文) +├── classifier.py # 身份/关系命名规则检测 +├── strategies.py # identity_relation prompt 模板(中英文) ├── routes.py # 管理接口 -├── example.py # 可直接运行的分类演示 +├── example.py # 可直接运行的检测演示 └── tests/ ├── conftest.py ├── test_classifier.py @@ -64,8 +56,6 @@ extensions/memos_prompt_strategy_plugin/ ## 快速体验 -不需要启动服务,直接运行 example 查看分类效果: - ```bash PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/example.py ``` @@ -75,18 +65,16 @@ PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/examp `pyproject.toml` 中已包含以下配置: ```toml -# 包路径 [tool.poetry] packages = [ {include = "memos_prompt_strategy_plugin", from = "extensions"}, ] -# Entry point [project.entry-points."memos.plugins"] prompt_strategy = "memos_prompt_strategy_plugin:PromptStrategyPlugin" ``` -如果是首次安装,需要执行: +首次安装需执行: ```bash pip install -e . @@ -100,30 +88,10 @@ PYTHONPATH="src:extensions" python -m pytest extensions/memos_prompt_strategy_pl ## 管理接口 -插件注册了以下 HTTP 端点: - | 端点 | 方法 | 说明 | |------|------|------| | `/prompt_strategy/health` | GET | 插件健康检查 | -| `/prompt_strategy/strategies` | GET | 列出所有已注册的策略及描述 | -| `/prompt_strategy/stats` | GET | 查看各分类的命中次数统计 | - -## 自定义策略 - -可以在插件初始化后动态注册新策略: - -```python -from memos_prompt_strategy_plugin.strategies import PromptStrategy - -plugin.registry.register(PromptStrategy( - name="legal_discussion", - template_en="Extract legal terms, clauses...\n${conversation}\n${custom_tags_prompt}", - template_zh="提取法律术语、条款...\n${conversation}\n${custom_tags_prompt}", - description="法律讨论场景的记忆提取", -)) -``` - -同时在 `classifier.py` 中添加对应的分类规则即可生效。 +| `/prompt_strategy/stats` | GET | 查看 identity_relation 命中次数 | ## CE 依赖 @@ -131,4 +99,4 @@ plugin.registry.register(PromptStrategy( - `mem_reader.pre_extract`:在 `MultiModalStructMemReader._get_llm_response()` 中,LLM 调用前触发 -该扩展点声明在 `src/memos/plugins/hook_defs.py`,修改后需通过 `sync-public` 同步到公开仓库。 +该扩展点声明在 `src/memos/plugins/hook_defs.py`。 diff --git a/extensions/memos_prompt_strategy_plugin/classifier.py b/extensions/memos_prompt_strategy_plugin/classifier.py index 29b78c835..7adee817b 100644 --- a/extensions/memos_prompt_strategy_plugin/classifier.py +++ b/extensions/memos_prompt_strategy_plugin/classifier.py @@ -1,7 +1,8 @@ -"""Message classifier — categorize messages for prompt strategy selection. +"""Message classifier — rule-chain architecture with extensible rules. -Uses a rule-based pipeline (zero LLM overhead) with optional LLM fallback -for ambiguous cases. +Currently only one rule is registered (identity/relation naming detection). +To add a new rule, write a static method that returns a category string on +match or ``None`` on miss, then append it to ``self._rules`` in ``__init__``. """ from __future__ import annotations @@ -17,49 +18,37 @@ logger = logging.getLogger(__name__) -# Category constants -CASUAL_CHAT = "casual_chat" -TASK_ORIENTED = "task_oriented" -KNOWLEDGE_SHARING = "knowledge_sharing" -EMOTIONAL = "emotional" -CODE_DISCUSSION = "code_discussion" -MULTI_TURN_QA = "multi_turn_qa" - -ALL_CATEGORIES = [ - CASUAL_CHAT, - TASK_ORIENTED, - KNOWLEDGE_SHARING, - EMOTIONAL, - CODE_DISCUSSION, - MULTI_TURN_QA, -] - -_CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```") -_TECH_KEYWORDS = re.compile( - r"\b(import|def |class |function |const |let |var |return |async |await " - r"|API|SDK|HTTP|JSON|SQL|docker|kubernetes|git|npm|pip)\b", - re.IGNORECASE, +# ── Category constants ────────────────────────────────────────────── +IDENTITY_RELATION = "identity_relation" + +# ── Regex patterns for identity / relation detection ──────────────── +_SELF_NAME_RE = re.compile( + r"我(?:的名字)?(?:是|叫)\s*(?P\S+)", ) -_TASK_KEYWORDS_EN = re.compile( - r"\b(please|todo|deadline|schedule|remind|plan|book|order|buy|send|create" - r"|set up|configure|install|deploy|migrate|update)\b", - re.IGNORECASE, + +_RELATION_WORDS = ( + "儿子|女儿|孩子|小孩" + "|老婆|妻子|老公|丈夫|爱人|伴侣|对象" + "|爸爸|妈妈|父亲|母亲|爸|妈" + "|哥哥|姐姐|弟弟|妹妹|哥|姐|弟|妹" + "|爷爷|奶奶|外公|外婆|姥姥|姥爷" + "|叔叔|阿姨|舅舅|舅妈|姑姑|姑父" + "|朋友|同事|同学|室友|闺蜜|兄弟" + "|男朋友|女朋友|前任" + "|宠物|狗|猫" ) -_TASK_KEYWORDS_ZH = re.compile( - r"(请|帮我|提醒|安排|预定|预约|计划|任务|截止|部署|配置|安装|迁移|更新|设置|发送|创建|购买)", +_RELATION_NAME_RE = re.compile( + rf"我(?:的)?(?:{_RELATION_WORDS})(?:的名字)?(?:是|叫)\s*(?P\S+)", ) -_EMOTION_KEYWORDS_EN = re.compile( - r"\b(feel|happy|sad|angry|love|hate|miss|worry|afraid|grateful" - r"|excited|lonely|anxious|stressed|depressed|proud)\b", + +_MY_NAME_IS_EN = re.compile( + r"(?:my name is|i'?m|call me)\s+(?P[A-Z]\w+)", re.IGNORECASE, ) -_EMOTION_KEYWORDS_ZH = re.compile( - r"(开心|难过|伤心|生气|爱|恨|想念|担心|害怕|感谢|感恩|兴奋|孤独|焦虑|压力|骄傲|烦|累|郁闷|失望|幸福)", -) -_QUESTION_PATTERNS = re.compile( - r"(\?\s*$|?\s*$|what|why|how|when|where|who|which|do you|can you|is it" - r"|什么|为什么|怎么|何时|哪里|谁|哪个|是否|能不能|会不会)", - re.IGNORECASE | re.MULTILINE, +_MY_RELATION_IS_EN = re.compile( + r"my\s+(?:son|daughter|wife|husband|father|mother|brother|sister|friend)" + r"(?:'s name)?\s+is\s+(?P[A-Z]\w+)", + re.IGNORECASE, ) @@ -76,32 +65,21 @@ def _extract_text(sources: list) -> str: return "\n".join(parts) -def _count_roles(sources: list) -> dict[str, int]: - """Count occurrences of each role in sources.""" - counts: dict[str, int] = {} - for src in sources: - role = None - if hasattr(src, "role"): - role = src.role - elif isinstance(src, dict): - role = src.get("role") - if role: - counts[role] = counts.get(role, 0) + 1 - return counts +class MessageClassifier: + """Rule-chain classifier. + Rules are evaluated in registration order; the first match wins. + If no rule matches, ``classify()`` returns ``None`` so the caller + keeps the default prompt unchanged. -class MessageClassifier: - """Rule-based message classifier with optional LLM fallback.""" + To add a new rule: + 1. Define a static/class method ``_check_xxx(sources, text) -> str | None`` + 2. Append ``("category_name", self._check_xxx)`` to ``self._rules`` + """ - def __init__(self, llm: Any | None = None): - self.llm = llm + def __init__(self) -> None: self._rules: list[tuple[str, Callable[[list, str], str | None]]] = [ - ("code_discussion", self._check_code), - ("task_oriented", self._check_task), - ("emotional", self._check_emotion), - ("knowledge_sharing", self._check_knowledge), - ("multi_turn_qa", self._check_multi_turn_qa), - ("casual_chat", self._check_casual), + ("identity_relation", self._check_identity_relation), ] def classify( @@ -110,89 +88,28 @@ def classify( mem_str: str, default_prompt_type: str, info: dict[str, Any], - ) -> str: - """Classify messages and return a category label. - - Returns the default_prompt_type unchanged when no rule matches - and no LLM fallback is configured. - """ + ) -> str | None: + """Walk the rule chain; return the first matching category or ``None``.""" text = _extract_text(sources) if sources else mem_str + if not text: + return None for _name, rule_fn in self._rules: result = rule_fn(sources, text) if result is not None: return result - if self.llm is not None: - return self._llm_classify(text, default_prompt_type) - - return default_prompt_type - - # ── Rule functions ────────────────────────────────────────────── - - @staticmethod - def _check_code(sources: list, text: str) -> str | None: - has_code_block = bool(_CODE_BLOCK_RE.search(text)) - tech_hits = len(_TECH_KEYWORDS.findall(text)) - if has_code_block or tech_hits >= 3: - return CODE_DISCUSSION return None - @staticmethod - def _check_task(sources: list, text: str) -> str | None: - en_hits = len(_TASK_KEYWORDS_EN.findall(text)) - zh_hits = len(_TASK_KEYWORDS_ZH.findall(text)) - if en_hits + zh_hits >= 2: - return TASK_ORIENTED - return None + # ── Rules ─────────────────────────────────────────────────────── @staticmethod - def _check_emotion(sources: list, text: str) -> str | None: - en_hits = len(_EMOTION_KEYWORDS_EN.findall(text)) - zh_hits = len(_EMOTION_KEYWORDS_ZH.findall(text)) - if en_hits + zh_hits >= 2: - return EMOTIONAL - return None + def _check_identity_relation(sources: list, text: str) -> str | None: + has_self = bool(_SELF_NAME_RE.search(text)) or bool(_MY_NAME_IS_EN.search(text)) + has_relation = bool(_RELATION_NAME_RE.search(text)) or bool(_MY_RELATION_IS_EN.search(text)) - @staticmethod - def _check_knowledge(sources: list, text: str) -> str | None: - if len(text) > 800 and text.count("\n") >= 5: - return KNOWLEDGE_SHARING - return None - - @staticmethod - def _check_multi_turn_qa(sources: list, text: str) -> str | None: - role_counts = _count_roles(sources) - total_turns = sum(role_counts.values()) - question_hits = len(_QUESTION_PATTERNS.findall(text)) - if total_turns >= 4 and question_hits >= 2: - return MULTI_TURN_QA - return None + if has_self or has_relation: + logger.info("[PromptStrategy] Identity/relation pattern detected") + return IDENTITY_RELATION - @staticmethod - def _check_casual(sources: list, text: str) -> str | None: - role_counts = _count_roles(sources) - total_turns = sum(role_counts.values()) - if total_turns <= 2 and len(text) < 200: - return CASUAL_CHAT return None - - # ── LLM fallback ──────────────────────────────────────────────── - - def _llm_classify(self, text: str, default: str) -> str: - categories_str = ", ".join(ALL_CATEGORIES) - prompt = ( - f"Classify the following conversation into exactly one category.\n" - f"Categories: {categories_str}\n\n" - f"Conversation:\n{text[:2000]}\n\n" - f"Reply with ONLY the category name, nothing else." - ) - try: - result = self.llm.generate([{"role": "user", "content": prompt}]) - label = result.strip().lower().replace(" ", "_") - if label in ALL_CATEGORIES: - return label - logger.warning("[PromptStrategy] LLM returned unknown category: %s", label) - except Exception: - logger.exception("[PromptStrategy] LLM classification failed") - return default diff --git a/extensions/memos_prompt_strategy_plugin/example.py b/extensions/memos_prompt_strategy_plugin/example.py index 8d28a7fc6..9e948fcdb 100644 --- a/extensions/memos_prompt_strategy_plugin/example.py +++ b/extensions/memos_prompt_strategy_plugin/example.py @@ -1,11 +1,11 @@ -"""Quick demo — run directly to see classification + prompt strategy in action. +"""Quick demo — run directly to see identity/relation detection in action. Usage: PYTHONPATH="src:extensions" python extensions/memos_prompt_strategy_plugin/example.py """ from memos_prompt_strategy_plugin.classifier import MessageClassifier -from memos_prompt_strategy_plugin.strategies import StrategyRegistry +from memos_prompt_strategy_plugin.strategies import build_identity_relation_prompt def _src(role: str, content: str): @@ -18,113 +18,71 @@ class _S: return s -DEMO_CONVERSATIONS = { - "casual_chat": { - "sources": [_src("user", "Hey! I really like Italian food, especially pasta.")], - "mem_str": "Hey! I really like Italian food, especially pasta.", +DEMO_CONVERSATIONS = [ + { + "label": "自我介绍 + 亲属关系", + "sources": [_src("user", "你好,我叫王沐辰,我的儿子叫王明泽")], + "mem_str": "你好,我叫王沐辰,我的儿子叫王明泽", }, - "task_oriented": { - "sources": [ - _src("user", "请帮我安排明天下午3点的会议,提醒我截止日期是周五"), - _src("assistant", "好的,已经帮你安排好了"), - ], - "mem_str": "请帮我安排明天下午3点的会议,提醒我截止日期是周五\n好的,已经帮你安排好了", + { + "label": "仅自我介绍", + "sources": [_src("user", "我是李明,今年30岁")], + "mem_str": "我是李明,今年30岁", }, - "code_discussion": { - "sources": [ - _src( - "user", - "I'm getting an error with my FastAPI app:\n" - "```python\n" - "from fastapi import APIRouter\n" - "router = APIRouter()\n" - "@router.get('/health')\n" - "async def health(): return {'ok': True}\n" - "```\n" - "The import fails after upgrading the SDK.", - ), - ], - "mem_str": ( - "I'm getting an error with my FastAPI app:\n" - "```python\nfrom fastapi import APIRouter\n```\n" - "The import fails after upgrading the SDK." - ), + { + "label": "英文自我介绍 + 关系", + "sources": [_src("user", "Hi, my name is Alice. My son is called Bob.")], + "mem_str": "Hi, my name is Alice. My son is called Bob.", }, - "emotional": { - "sources": [ - _src("user", "今天特别开心,终于完成了马拉松,虽然很累但是非常骄傲"), - ], - "mem_str": "今天特别开心,终于完成了马拉松,虽然很累但是非常骄傲", + { + "label": "多种关系", + "sources": [_src("user", "我叫张三,我老婆叫李四,我女儿叫张小花,我妈妈叫王秀英")], + "mem_str": "我叫张三,我老婆叫李四,我女儿叫张小花,我妈妈叫王秀英", }, - "multi_turn_qa": { - "sources": [ - _src("user", "What's the difference between Redis and Memcached?"), - _src("assistant", "Redis supports more data structures. What's your use case?"), - _src("user", "I need pub/sub and sorted sets. Which one fits?"), - _src("assistant", "Redis is the clear choice for pub/sub and sorted sets."), - ], - "mem_str": ( - "What's the difference between Redis and Memcached?\n" - "Redis supports more data structures. What's your use case?\n" - "I need pub/sub and sorted sets. Which one fits?\n" - "Redis is the clear choice for pub/sub and sorted sets." - ), + { + "label": "普通闲聊(不应命中)", + "sources": [_src("user", "今天天气不错,出去走走吧")], + "mem_str": "今天天气不错,出去走走吧", }, - "knowledge_sharing": { - "sources": [ - _src( - "user", - "\n".join( - [ - "Transformers are a neural network architecture introduced in 2017.", - "They use self-attention mechanisms to process sequences in parallel.", - "Unlike RNNs, transformers don't require sequential computation.", - "The key components are: multi-head attention, feed-forward networks,", - "layer normalization, and positional encoding.", - "Pre-training on large corpora followed by fine-tuning has become", - "the dominant paradigm in NLP since BERT (2018).", - ] - * 3 - ), - ), - ], - "mem_str": "Transformers are a neural network architecture..." + "x" * 800, + { + "label": "任务型(不应命中)", + "sources": [_src("user", "请帮我安排明天下午3点的会议")], + "mem_str": "请帮我安排明天下午3点的会议", }, -} +] SEPARATOR = "=" * 72 def main(): clf = MessageClassifier() - reg = StrategyRegistry() - reg.register_defaults() print(SEPARATOR) - print(" Prompt Strategy Plugin — Classification Demo") + print(" Prompt Strategy Plugin — Identity/Relation Detection Demo") print(SEPARATOR) - for label, data in DEMO_CONVERSATIONS.items(): + for data in DEMO_CONVERSATIONS: + label = data["label"] sources = data["sources"] mem_str = data["mem_str"] category = clf.classify(sources, mem_str, "chat", {}) - prompt = reg.build_prompt( - category, "zh" if any("\u4e00" <= c <= "\u9fff" for c in mem_str) else "en", mem_str - ) + hit = category is not None - status = "MATCH" if category == label else "MISMATCH" print(f"\n{'—' * 72}") - print(f" Scenario : {label}") - print(f" Classified: {category} [{status}]") - print(f" Input : {mem_str[:80]}{'...' if len(mem_str) > 80 else ''}") - print( - f" Prompt : {prompt[:120]}..." if prompt else " Prompt : (default, no override)" - ) + print(f" Scenario: {label}") + print(f" Input : {mem_str[:80]}{'...' if len(mem_str) > 80 else ''}") + print(f" Hit : {'YES → identity_relation' if hit else 'NO → use default prompt'}") + + if hit: + lang = "zh" if any("\u4e00" <= c <= "\u9fff" for c in mem_str) else "en" + prompt = build_identity_relation_prompt(lang=lang, mem_str=mem_str) + print(f" Prompt : {prompt[:120]}...") + print(f"{'—' * 72}") print(f"\n{SEPARATOR}") - print(" All scenarios processed.") + print(" Done.") print(SEPARATOR) diff --git a/extensions/memos_prompt_strategy_plugin/hooks.py b/extensions/memos_prompt_strategy_plugin/hooks.py index d48ceae0a..ba68c102f 100644 --- a/extensions/memos_prompt_strategy_plugin/hooks.py +++ b/extensions/memos_prompt_strategy_plugin/hooks.py @@ -26,18 +26,21 @@ def on_pre_extract( sources: list, **_kw: Any, ) -> str | None: - """[mem_reader.pre_extract] Classify the message and swap in a tailored prompt.""" + """[mem_reader.pre_extract] If a classifier rule matches and a + corresponding strategy is registered, swap in the specialised prompt; + otherwise return None to keep the default.""" category = plugin.classifier.classify(sources, mem_str, prompt_type, info={}) - plugin.stats[category] += 1 - if category != prompt_type: - logger.info("[PromptStrategy] Classified as %s (was %s)", category, prompt_type) + if category is None: + return None + + plugin.stats[category] += 1 + logger.info("[PromptStrategy] Matched rule: %s", category) custom_prompt = plugin.registry.build_prompt( category=category, lang=lang, mem_str=mem_str, - custom_tags=None, ) if custom_prompt is not None: logger.debug("[PromptStrategy] Using strategy prompt for %s", category) diff --git a/extensions/memos_prompt_strategy_plugin/plugin.py b/extensions/memos_prompt_strategy_plugin/plugin.py index 74cbfebce..29f9a4de8 100644 --- a/extensions/memos_prompt_strategy_plugin/plugin.py +++ b/extensions/memos_prompt_strategy_plugin/plugin.py @@ -1,4 +1,4 @@ -"""PromptStrategyPlugin — classify messages and select specialised prompts.""" +"""PromptStrategyPlugin — rule-chain classifier + strategy registry.""" from __future__ import annotations @@ -17,7 +17,7 @@ class PromptStrategyPlugin(MemOSPlugin): name = "prompt_strategy" version = "0.1.0" - description = "Classify messages and apply category-specific extraction prompts" + description = "Rule-chain classifier with strategy registry for specialised extraction prompts" def on_load(self) -> None: from memos_prompt_strategy_plugin.classifier import MessageClassifier @@ -39,6 +39,6 @@ def init_app(self) -> None: def on_shutdown(self) -> None: logger.info( - "[PromptStrategy] plugin shutdown — classification stats: %s", + "[PromptStrategy] plugin shutdown — stats: %s", dict(self.stats), ) diff --git a/extensions/memos_prompt_strategy_plugin/strategies.py b/extensions/memos_prompt_strategy_plugin/strategies.py index 60cf6e00b..00337cdd7 100644 --- a/extensions/memos_prompt_strategy_plugin/strategies.py +++ b/extensions/memos_prompt_strategy_plugin/strategies.py @@ -1,4 +1,8 @@ -"""Prompt strategy registry — maps message categories to specialised prompts. +"""Prompt strategy registry — maps classifier categories to specialised prompts. + +Currently only one strategy is registered (identity_relation). To add a new +strategy, create a ``PromptStrategy`` and call ``register()`` or append it to +``_DEFAULT_STRATEGIES``. All strategy prompts produce the same JSON output format as the default mem-reader (memory list + summary) so downstream processing stays unchanged. @@ -23,7 +27,7 @@ class PromptStrategy: class StrategyRegistry: - """Thread-safe registry that maps category labels to prompt strategies.""" + """Registry that maps category labels to prompt strategies.""" def __init__(self): self._strategies: dict[str, PromptStrategy] = {} @@ -59,7 +63,7 @@ def build_prompt( tags_instruction = ( f"\n额外要求:提取的记忆请尽量关联以下标签:{custom_tags}" if lang == "zh" - else f"\nAdditional: try to associate extracted memories with these tags: {custom_tags}" + else f"\nAdditional: associate extracted memories with these tags: {custom_tags}" ) else: tags_instruction = "" @@ -69,278 +73,292 @@ def build_prompt( return prompt def register_defaults(self) -> None: - """Register built-in strategies for all standard categories.""" + """Register built-in strategies.""" for strategy in _DEFAULT_STRATEGIES: self.register(strategy) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ -# Default prompt templates per category +# Default prompt templates # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ -_CASUAL_CHAT_EN = """\ -You are a memory extraction expert specialised in casual conversations. -Extract lightweight memories focusing on the user's **personal preferences, habits, opinions, and lifestyle details**. -Skip trivial greetings or filler. Only retain information the user would realistically remember. +_IDENTITY_RELATION_EN = """\ +You are a memory extraction expert. Extract all kinds of memories, including accurate identity and relationship information about people. -Conversation: -${conversation} +Your task is to extract memories from the perspective of the user based on the conversation between the user and the assistant. This means identifying information that the user may remember — including the user’s own experiences, thoughts, plans, or relevant statements and actions made by others (such as the assistant) that affect the user or are acknowledged by the user. -${custom_tags_prompt} +Please perform the following: -Return a single valid JSON object: -{ - "memory list": [ - {"key": "", "memory_type": "UserMemory", "value": "", "tags": ["", ...]} - ], - "summary": "<1-2 sentence summary from the user's perspective>" -} -""" +1. If the current conversation contains the user’s self-reported name or information about family/social relationships, the extracted content must precisely include: + - The user’s own name (e.g., "My name is xxx", "I am xxx") + - All related persons mentioned by the user: relationship type + name (e.g., "My son’s name is xxx", "My wife is xxx") + - If there are further relationship descriptions among related persons, extract them as well + - If other content exists, extract it as usual -_CASUAL_CHAT_ZH = """\ -你是一位擅长提取日常闲聊记忆的专家。 -请提取**用户的个人偏好、习惯、观点和生活细节**相关的轻量级记忆。 -跳过无意义的寒暄和填充语。只保留用户真正会记住的信息。 + Extraction requirements: + - **Absolutely do not omit any name or relationship** + - Use third person ("The user’s son is named Wang Mingze" rather than "My son is Wang Mingze") + - Each identity/relationship item must be extracted as a separate memory -对话内容: -${conversation} +2. Identify information that reflects the user’s experiences, beliefs, concerns, decisions, plans, or reactions — including meaningful factual information from the assistant that the user acknowledges or responds to. +If the message is from the user, extract memories related to the user. If the message is from the assistant, only extract factual memories that are acknowledged or responded to by the user. -${custom_tags_prompt} +3. Clearly resolve all references to time, people, and events: + - If possible, use message timestamps to convert relative temporal expressions (such as “yesterday” or “next Friday”) into absolute dates. + - Clearly distinguish between event time and message time. + - If uncertainty exists, explicitly state it (e.g., “around June 2025”, “exact date unclear”). + - If a specific location is mentioned, include it. + - Resolve all pronouns, aliases, and vague references into full names or explicit identities. + - If there are multiple people with the same name, distinguish them clearly. -返回一个合法 JSON 对象: -{ - "memory list": [ - {"key": "<简短标题>", "memory_type": "UserMemory", "value": "<第三人称陈述>", "tags": ["<标签>", ...]} - ], - "summary": "<1-2 句话从用户视角总结>" -} -""" +4. Always write from the third-person perspective, using “the user” or the user’s name to refer to the user, rather than first person (“I”, “we”, “my”). +For example, write “The user feels tired...” rather than “I feel tired...”. -_TASK_ORIENTED_EN = """\ -You are a memory extraction expert specialised in task-oriented conversations. -Focus on extracting **tasks, plans, deadlines, action items, constraints, and commitments** the user discussed. -Resolve all time references to absolute dates when possible. +5. Do not omit any information that the user may remember. + - Include all key experiences, thoughts, emotional reactions, and plans — even if they seem minor. + - Prioritize completeness and fidelity over brevity. + - Do not generalize or skip details that may have personal significance to the user. -Conversation: -${conversation} +6. Please avoid including any content in the extracted memories that violates laws or regulations or involves politically sensitive information. -${custom_tags_prompt} +Return a valid JSON object with the following structure: -Return a single valid JSON object: { "memory list": [ - {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + { + "key": , + "memory_type": , + "value": , + "tags": + }, + ... ], - "summary": "" + "summary": } -""" -_TASK_ORIENTED_ZH = """\ -你是一位擅长提取任务型对话记忆的专家。 -专注提取用户讨论的**任务、计划、截止日期、行动项、约束条件和承诺**。 -尽可能将所有时间引用转换为绝对日期。 - -对话内容: -${conversation} +Language rules: +- The fields `key`, `value`, `tags`, and `summary` must match the main language of the input conversation. **If the input is Chinese, output in Chinese.** +- `memory_type` must remain in English. ${custom_tags_prompt} -返回一个合法 JSON 对象: -{ - "memory list": [ - {"key": "<简短标题>", "memory_type": "LongTermMemory", "value": "<第三人称,包含日期和约束的详细陈述>", "tags": ["<标签>", ...]} - ], - "summary": "<总结任务和截止日期的段落,80-150字>" -} -""" - -_KNOWLEDGE_SHARING_EN = """\ -You are a memory extraction expert specialised in knowledge-sharing content. -Extract **key concepts, definitions, explanations, facts, and learned insights** from the conversation. -Treat this like document-level extraction: capture knowledge points completely. - -Content: -{chunk_text} - -{custom_tags_prompt} +Example: +Conversation: +user: [June 26, 2025, 3:00 PM]: Hi Jerry, my name is Tom! Yesterday at 3:00 PM I had a meeting with my team to discuss a new project. +assistant: Do you think the team can finish by December 15? +user: [June 26, 2025, 3:00 PM]: I’m a little worried. The backend won’t be finished until December 10, so testing time will be tight. +assistant: [June 26, 2025, 3:00 PM]: Maybe suggest postponing it? +user: [June 26, 2025, 4:21 PM]: Good idea. I’ll bring it up at tomorrow’s 9:30 AM meeting — maybe push the deadline to January 5. -Return a single valid JSON object: +Output: { "memory list": [ - {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + { + "key": "User name", + "memory_type": "UserMemory", + "value": "The user’s name is Tom.", + "tags": ["identity information", "name"] + }, + { + "key": "Initial project meeting", + "memory_type": "LongTermMemory", + "value": "On June 25, 2025 at 3:00 PM, Tom met with his team to discuss a new project. The meeting involved the timeline and raised concerns about whether the December 15, 2025 deadline was feasible.", + "tags": ["project", "timeline", "meeting", "deadline"] + }, + { + "key": "Planned deadline adjustment", + "memory_type": "UserMemory", + "value": "Tom plans to suggest at the June 27, 2025 9:30 AM meeting that the team reprioritize work and postpone the project deadline to January 5, 2026.", + "tags": ["plan", "deadline change", "prioritization"] + } ], - "summary": "" + "summary": "Tom is currently focused on managing a new project with a tight schedule. After the team meeting on June 25, 2025, he realized that the original December 15, 2025 deadline might not be achievable because the backend is expected to be completed only by December 10, leaving very little time for testing. Because of this concern, Tom accepted Jerry’s suggestion to propose a delay. He plans to raise the idea of postponing the deadline to January 5, 2026 at the next morning’s meeting. His actions reflect concern about the timeline as well as a proactive, team-oriented approach to problem solving." } -""" - -_KNOWLEDGE_SHARING_ZH = """\ -你是一位擅长提取知识分享内容的记忆专家。 -提取对话中的**核心概念、定义、解释、事实和学习到的见解**。 -像文档级别提取一样,完整捕获知识点。 -内容: -{chunk_text} +Input: + user: [July 1, 2025, 10:00 AM]: My name is Li Ming. My wife’s name is Wang Ting, and my son’s name is Li Haoran. Next week we are planning to travel to Shanghai together. + assistant: That sounds great. How many days are you planning to stay? + user: [July 1, 2025, 10:05 AM]: About three days. -{custom_tags_prompt} - -返回一个合法 JSON 对象: +Output: { "memory list": [ - {"key": "<概念或主题>", "memory_type": "LongTermMemory", "value": "<完整、自包含的解释>", "tags": ["<标签>", ...]} + { + "key": "User name", + "memory_type": "UserMemory", + "value": "The user’s name is Li Ming.", + "tags": ["identity information", "name"] + }, + { + "key": "Spouse's name", + "memory_type": "UserMemory", + "value": "The user’s wife is named Wang Ting.", + "tags": ["relationship information", "wife", "name"] + }, + { + "key": "Son's name", + "memory_type": "UserMemory", + "value": "The user’s son is named Li Haoran.", + "tags": ["relationship information", "son", "name"] + }, + { + "key": "Family travel plan", + "memory_type": "LongTermMemory", + "value": "The user plans to travel to Shanghai together with his wife Wang Ting and son Li Haoran during the week following July 1, 2025, and expects the trip to last about three days. The exact departure date is not specified.", + "tags": ["travel", "family", "plan", "Shanghai"] + } ], - "summary": "<总结分享知识的段落,100-200字>" + "summary": "Li Ming is planning a family trip to Shanghai in the week after July 1, 2025, and expects the trip to last about three days. The conversation explicitly states that the user’s wife is named Wang Ting and the user’s son is named Li Haoran. This indicates that the user has a near-term travel plan involving close family members." } -""" -_EMOTIONAL_EN = """\ -You are a memory extraction expert specialised in emotional and relational conversations. -Focus on extracting the user's **emotional states, feelings, relationship dynamics, personal concerns, and significant life events**. -Preserve emotional nuance — do not flatten sentiment into generic labels. +Please always reply in the same language as the conversation. Conversation: ${conversation} -${custom_tags_prompt} - -Return a single valid JSON object: -{ - "memory list": [ - {"key": "", "memory_type": "UserMemory", "value": "", "tags": ["", ...]} - ], - "summary": "" -} +Your output: """ -_EMOTIONAL_ZH = """\ -你是一位擅长提取情感与人际关系对话记忆的专家。 -专注提取用户的**情感状态、感受、人际关系动态、个人关切和重要生活事件**。 -保留情感细节,不要将情绪简化为泛泛的标签。 +_IDENTITY_RELATION_ZH = """\ -对话内容: -${conversation} +您是记忆提取专家,提取各类记忆,包括准确的人物身份和关系信息。 +您的任务是根据用户与助手之间的对话,从用户的角度提取记忆。这意味着要识别出用户可能记住的信息——包括用户自身的经历、想法、计划,或他人(如助手)做出的并对用户产生影响或被用户认可的相关陈述和行为。 -${custom_tags_prompt} +请执行以下操作: +1. 如果当前对话中包含用户自述的姓名或亲属/社交关系信息。你提取的内容需要精确包括 + - 用户本人的姓名(如"我叫xxx"、"我是xxx") + - 用户提及的所有关系人:关系类型 + 姓名(如"我的儿子叫xxx"、"我老婆是xxx") + - 关系人之间如果存在进一步的关系描述,也要提取; + - 其他内容如果存在照常提取; + 提取要求: + - **绝对不能遗漏任何人名和关系** + - 使用第三人称("用户的儿子叫王明泽"而非"我的儿子叫王明泽") + - 每组身份/关系信息单独作为一条记忆 -返回一个合法 JSON 对象: -{ - "memory list": [ - {"key": "<简短标题>", "memory_type": "UserMemory", "value": "<第三人称陈述,捕获情感和上下文>", "tags": ["<标签>", ...]} - ], - "summary": "<从用户视角出发的共情式总结,80-150字>" -} -""" +2. 识别反映用户经历、信念、关切、决策、计划或反应的信息——包括用户认可或回应的来自助手的有意义信息。 +如果消息来自用户,请提取与用户相关的记忆;如果来自助手,则仅提取用户认可或回应的事实性记忆。 -_CODE_DISCUSSION_EN = """\ -You are a memory extraction expert specialised in technical and code discussions. -Extract **tools, frameworks, libraries, technical decisions, code patterns, bugs, solutions, and architecture choices** discussed by the user. -Include version numbers, configuration details, and error descriptions when available. +3. 清晰解析所有时间、人物和事件的指代: + - 如果可能,使用消息时间戳将相对时间表达(如“昨天”、“下周五”)转换为绝对日期。 + - 明确区分事件时间和消息时间。 + - 如果存在不确定性,需明确说明(例如,“约2025年6月”,“具体日期不详”)。 + - 若提及具体地点,请包含在内。 + - 将所有代词、别名和模糊指代解析为全名或明确身份。 + - 如有同名人物,需加以区分。 -Conversation: -${conversation} +4. 始终以第三人称视角撰写,使用“用户”或提及的姓名来指代用户,而不是使用第一人称(“我”、“我们”、“我的”)。 +例如,写“用户感到疲惫……”而不是“我感到疲惫……”。 -${custom_tags_prompt} +5. 不要遗漏用户可能记住的任何信息。 + - 包括所有关键经历、想法、情绪反应和计划——即使看似微小。 + - 优先考虑完整性和保真度,而非简洁性。 + - 不要泛化或跳过对用户具有个人意义的细节。 + +6. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 + +返回一个有效的JSON对象,结构如下: -Return a single valid JSON object: { "memory list": [ - {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + { + "key": <字符串,唯一且简洁的记忆标题>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, + "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, + "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])> + }, + ... ], - "summary": "" + "summary": <从用户视角自然总结上述记忆的段落,120–200字,与输入语言一致> } -""" - -_CODE_DISCUSSION_ZH = """\ -你是一位擅长提取技术和代码讨论记忆的专家。 -提取用户讨论的**工具、框架、库、技术决策、代码模式、Bug、解决方案和架构选择**。 -在可用时包含版本号、配置细节和错误描述。 -对话内容: -${conversation} +语言规则: +- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** +- `memory_type` 保持英文。 ${custom_tags_prompt} -返回一个合法 JSON 对象: +示例: +对话: +user: [2025年6月26日下午3:00]:嗨Jerry,我叫Tom!昨天下午3点我和团队开了个会,讨论新项目。 +assistant: 你觉得团队能在12月15日前完成吗? +user: [2025年6月26日下午3:00]:我有点担心。后端要到12月10日才能完成,所以测试时间会很紧。 +assistant: [2025年6月26日下午3:00]:也许提议延期? +user: [2025年6月26日下午4:21]:好主意。我明天上午9:30的会上提一下——也许把截止日期推迟到1月5日。 + +输出: { "memory list": [ - {"key": "<技术主题>", "memory_type": "LongTermMemory", "value": "<第三人称陈述,包含技术细节>", "tags": ["<标签>", ...]} + { + "key": "用户姓名", + "memory_type": "UserMemory", + "value": "用户名叫Tom。", + "tags": ["身份信息", "姓名"] + }, + { + "key": "项目初期会议", + "memory_type": "LongTermMemory", + "value": "2025年6月25日下午3:00,Tom与团队开会讨论新项目。会议涉及时间表,并提出了对2025年12月15日截止日期可行性的担忧。", + "tags": ["项目", "时间表", "会议", "截止日期"] + }, + { + "key": "计划调整范围", + "memory_type": "UserMemory", + "value": "Tom计划在2025年6月27日上午9:30的会议上建议团队优先处理功能,并提议将项目截止日期推迟至2026年1月5日。", + "tags": ["计划", "截止日期变更", "功能优先级"] + } ], - "summary": "<技术总结,80-150字>" + "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" } -""" - -_MULTI_TURN_QA_EN = """\ -You are a memory extraction expert specialised in multi-turn Q&A conversations. -Extract the **conclusions, clarifications, and final answers** that emerged through the Q&A process. -Focus on what the user learned or decided, not intermediate back-and-forth. - -Conversation: -${conversation} -${custom_tags_prompt} +输入: + user: [2025年7月1日上午10:00]:我叫李明,我老婆叫王婷,我儿子叫李浩然。下周我们打算一起去上海旅游。 + assistant: 听起来很不错,你们准备去几天? + user: [2025年7月1日上午10:05]:大概三天。 -Return a single valid JSON object: +输出: { "memory list": [ - {"key": "", "memory_type": "LongTermMemory", "value": "", "tags": ["", ...]} + { + "key": "用户姓名", + "memory_type": "UserMemory", + "value": "用户名叫李明。", + "tags": ["身份信息", "姓名"] + }, + { + "key": "配偶姓名", + "memory_type": "UserMemory", + "value": "用户的妻子叫王婷。", + "tags": ["关系信息", "妻子", "姓名"] + }, + { + "key": "儿子姓名", + "memory_type": "UserMemory", + "value": "用户的儿子叫李浩然。", + "tags": ["关系信息", "儿子", "姓名"] + }, + { + "key": "家庭出行计划", + "memory_type": "LongTermMemory", + "value": "用户计划于2025年7月8日所在周与妻子王婷和儿子李浩然一起前往上海旅游,预计行程约三天。具体出发日期未明确。", + "tags": ["旅行", "家庭", "计划", "上海"] + } ], - "summary": "" + "summary": "李明计划在2025年7月1日之后的下一周与家人一起去上海旅游,预计停留约三天。对话中明确提到用户的妻子名叫王婷,儿子名叫李浩然。这表明用户近期有一项与家庭相关的出行安排。" } -""" -_MULTI_TURN_QA_ZH = """\ -你是一位擅长提取多轮问答对话记忆的专家。 -提取通过问答过程得出的**结论、澄清和最终答案**。 -关注用户学到了什么或做了什么决定,而非中间的来回讨论。 +请始终使用与对话相同的语言进行回复。 -对话内容: +对话: ${conversation} -${custom_tags_prompt} - -返回一个合法 JSON 对象: -{ - "memory list": [ - {"key": "<问题主题>", "memory_type": "LongTermMemory", "value": "<第三人称陈述结论或答案>", "tags": ["<标签>", ...]} - ], - "summary": "<总结问答中的关键结论,80-150字>" -} +您的输出: """ _DEFAULT_STRATEGIES = [ PromptStrategy( - name="casual_chat", - template_en=_CASUAL_CHAT_EN, - template_zh=_CASUAL_CHAT_ZH, - description="Lightweight extraction for casual conversation — preferences, habits, opinions", - ), - PromptStrategy( - name="task_oriented", - template_en=_TASK_ORIENTED_EN, - template_zh=_TASK_ORIENTED_ZH, - description="Structured extraction for tasks, plans, deadlines, and action items", - ), - PromptStrategy( - name="knowledge_sharing", - template_en=_KNOWLEDGE_SHARING_EN, - template_zh=_KNOWLEDGE_SHARING_ZH, - description="Document-style extraction for concepts, definitions, and learned insights", - ), - PromptStrategy( - name="emotional", - template_en=_EMOTIONAL_EN, - template_zh=_EMOTIONAL_ZH, - description="Emotion-aware extraction for feelings, relationships, and personal concerns", - ), - PromptStrategy( - name="code_discussion", - template_en=_CODE_DISCUSSION_EN, - template_zh=_CODE_DISCUSSION_ZH, - description="Technical extraction for tools, frameworks, bugs, and architecture decisions", - ), - PromptStrategy( - name="multi_turn_qa", - template_en=_MULTI_TURN_QA_EN, - template_zh=_MULTI_TURN_QA_ZH, - description="Conclusion-focused extraction for multi-turn Q&A conversations", + name="identity_relation", + template_en=_IDENTITY_RELATION_EN, + template_zh=_IDENTITY_RELATION_ZH, + description="Precise extraction of names and family/social relationships", ), ] diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py b/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py index 5fe55c1d8..930a83fb9 100644 --- a/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py +++ b/extensions/memos_prompt_strategy_plugin/tests/test_classifier.py @@ -1,19 +1,12 @@ -"""Tests for MessageClassifier rule-based classification.""" +"""Tests for the single-rule identity/relation classifier.""" from memos_prompt_strategy_plugin.classifier import ( - CASUAL_CHAT, - CODE_DISCUSSION, - EMOTIONAL, - KNOWLEDGE_SHARING, - MULTI_TURN_QA, - TASK_ORIENTED, + IDENTITY_RELATION, MessageClassifier, ) def _src(role: str, content: str): - """Helper: lightweight source-like object.""" - class _S: pass @@ -23,84 +16,97 @@ class _S: return s -class TestClassifierRules: +class TestIdentityRelationRule: def setup_method(self): self.clf = MessageClassifier() - # ── code_discussion ────────────────────────────────────────── - - def test_code_block_triggers_code(self): - sources = [_src("user", "Here is my code:\n```python\nprint('hello')\n```")] - result = self.clf.classify(sources, "", "chat", {}) - assert result == CODE_DISCUSSION - - def test_tech_keywords_trigger_code(self): - sources = [ - _src("user", "I need to import the SDK and call the API via HTTP with JSON payload") - ] - result = self.clf.classify(sources, "", "chat", {}) - assert result == CODE_DISCUSSION - - # ── task_oriented ──────────────────────────────────────────── - - def test_task_keywords_en(self): - sources = [_src("user", "Please schedule a meeting and set up the deployment pipeline")] - result = self.clf.classify(sources, "", "chat", {}) - assert result == TASK_ORIENTED - - def test_task_keywords_zh(self): - sources = [_src("user", "请帮我安排明天的会议,并提醒我截止日期")] - result = self.clf.classify(sources, "", "chat", {}) - assert result == TASK_ORIENTED - - # ── emotional ──────────────────────────────────────────────── - - def test_emotion_en(self): - sources = [_src("user", "I feel so happy and grateful for everything today")] - result = self.clf.classify(sources, "", "chat", {}) - assert result == EMOTIONAL - - def test_emotion_zh(self): - sources = [_src("user", "今天特别开心,也很感恩身边的朋友")] - result = self.clf.classify(sources, "", "chat", {}) - assert result == EMOTIONAL - - # ── knowledge_sharing ──────────────────────────────────────── - - def test_long_text_knowledge(self): - long_text = "This is a detailed explanation of how transformers work.\n" * 20 - sources = [_src("user", long_text)] - result = self.clf.classify(sources, "", "chat", {}) - assert result == KNOWLEDGE_SHARING - - # ── multi_turn_qa ──────────────────────────────────────────── - - def test_multi_turn_qa(self): - sources = [ - _src("user", "What is the best approach for caching?"), - _src("assistant", "It depends on your use case. What latency do you need?"), - _src("user", "Under 100ms. Which solution fits?"), - _src("assistant", "Redis would be ideal for that latency requirement."), - ] - text = "\n".join(s.content for s in sources) - result = self.clf.classify(sources, text, "chat", {}) - assert result == MULTI_TURN_QA - - # ── casual_chat ────────────────────────────────────────────── - - def test_short_casual(self): - sources = [_src("user", "Hey, nice weather today!")] - result = self.clf.classify(sources, "", "chat", {}) - assert result == CASUAL_CHAT - - # ── fallback ───────────────────────────────────────────────── - - def test_no_match_returns_default(self): - sources = [ - _src("user", "Let me think about that for a moment"), - _src("assistant", "Sure, take your time"), - _src("user", "Okay I have decided"), - ] - text = "\n".join(s.content for s in sources) - result = self.clf.classify(sources, text, "chat", {}) - assert result == "chat" + # ── Chinese: self-naming ──────────────────────────────────── + + def test_wo_jiao(self): + sources = [_src("user", "你好,我叫王沐辰")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_wo_shi(self): + sources = [_src("user", "我是李明,今年30岁")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_wo_de_mingzi_shi(self): + sources = [_src("user", "我的名字是张三")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── Chinese: relation naming ──────────────────────────────── + + def test_son(self): + sources = [_src("user", "我的儿子叫王明泽")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_daughter(self): + sources = [_src("user", "我女儿叫小红")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_wife(self): + sources = [_src("user", "我老婆是刘芳")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_mother(self): + sources = [_src("user", "我妈妈叫李秀英")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_friend(self): + sources = [_src("user", "我朋友叫赵磊")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_pet(self): + sources = [_src("user", "我的猫叫小花")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── Chinese: combined self + relation ─────────────────────── + + def test_combined(self): + sources = [_src("user", "我叫王沐辰,我的儿子叫王明泽")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── English ───────────────────────────────────────────────── + + def test_my_name_is(self): + sources = [_src("user", "Hi, my name is Alice")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_im(self): + sources = [_src("user", "I'm Bob, nice to meet you")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_call_me(self): + sources = [_src("user", "Just call me Charlie")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_my_son_is(self): + sources = [_src("user", "My son is called David")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + def test_my_wife_is(self): + sources = [_src("user", "My wife's name is Emma")] + assert self.clf.classify(sources, "", "chat", {}) == IDENTITY_RELATION + + # ── No match → None ───────────────────────────────────────── + + def test_no_identity_returns_none(self): + sources = [_src("user", "今天天气不错")] + assert self.clf.classify(sources, "", "chat", {}) is None + + def test_task_text_returns_none(self): + sources = [_src("user", "请帮我安排明天的会议")] + assert self.clf.classify(sources, "", "chat", {}) is None + + def test_code_returns_none(self): + sources = [_src("user", "```python\nprint('hello')\n```")] + assert self.clf.classify(sources, "", "chat", {}) is None + + def test_empty_returns_none(self): + assert self.clf.classify([], "", "chat", {}) is None + + # ── mem_str fallback ──────────────────────────────────────── + + def test_uses_mem_str_when_no_sources(self): + result = self.clf.classify([], "我叫王沐辰,我的儿子叫王明泽", "chat", {}) + assert result == IDENTITY_RELATION diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py index cc3edca7e..54bc83d63 100644 --- a/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py +++ b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py @@ -19,6 +19,16 @@ def _make_app(): return app, plugin +def _src(role, content): + class _S: + pass + + s = _S() + s.role = role + s.content = content + return s + + class TestPluginLifecycle: def setup_method(self): from memos.plugins.hooks import _hooks @@ -30,7 +40,7 @@ def test_metadata(self): plugin = PromptStrategyPlugin() assert plugin.name == "prompt_strategy" - assert plugin.version == "0.1.0" + assert plugin.version == "0.2.0" def test_on_load_initialises_components(self): from memos_prompt_strategy_plugin.plugin import PromptStrategyPlugin @@ -72,8 +82,7 @@ def test_strategies_list(self): resp = client.get("/prompt_strategy/strategies") assert resp.status_code == 200 strategies = resp.json() - assert "casual_chat" in strategies - assert "task_oriented" in strategies + assert "identity_relation" in strategies def test_stats_empty(self): app, _ = _make_app() @@ -89,9 +98,8 @@ def setup_method(self): _hooks.clear() - def test_pre_extract_classifies_and_swaps_prompt(self): - """When the plugin classifies a message as task_oriented, it returns - the task-specific prompt containing the original mem_str.""" + def test_pre_extract_swaps_prompt_for_identity(self): + """When identity/relation pattern is detected, the prompt is swapped.""" from memos.plugins.hooks import trigger_hook _, plugin = _make_app() @@ -100,65 +108,45 @@ def test_pre_extract_classifies_and_swaps_prompt(self): "mem_reader.pre_extract", prompt="original prompt", prompt_type="chat", - mem_str="Please schedule a meeting and remind me about the deadline", - lang="en", + mem_str="我叫王沐辰,我的儿子叫王明泽", + lang="zh", sources=[], ) assert result != "original prompt" - assert "schedule a meeting" in result - assert plugin.stats["task_oriented"] >= 1 + assert "王沐辰" in result + assert "王明泽" in result + assert plugin.stats["identity_relation"] >= 1 - def test_pre_extract_preserves_prompt_when_no_rule_matches(self): - """When no classifier rule matches and the default prompt_type has no - registered strategy, the original prompt passes through unchanged.""" + def test_pre_extract_preserves_prompt_for_normal_text(self): + """When no classifier rule matches, original prompt passes through.""" from memos.plugins.hooks import trigger_hook - def _src(role, content): - class _S: - pass - - s = _S() - s.role = role - s.content = content - return s - _make_app() - sources = [ - _src("user", "Let me think about that"), - _src("assistant", "Sure, take your time"), - _src("user", "Okay I have decided"), - ] + sources = [_src("user", "今天天气不错,出去走走吧")] result = trigger_hook( "mem_reader.pre_extract", prompt="original prompt", prompt_type="chat", - mem_str="Let me think about that\nSure, take your time\nOkay I have decided", - lang="en", + mem_str="今天天气不错,出去走走吧", + lang="zh", sources=sources, ) assert result == "original prompt" - def test_pre_extract_tracks_stats(self): + def test_pre_extract_english_identity(self): from memos.plugins.hooks import trigger_hook _, plugin = _make_app() - trigger_hook( - "mem_reader.pre_extract", - prompt="p", - prompt_type="chat", - mem_str="I feel so happy and grateful today", - lang="en", - sources=[], - ) - trigger_hook( + result = trigger_hook( "mem_reader.pre_extract", - prompt="p", + prompt="original prompt", prompt_type="chat", - mem_str="```python\nimport os\nprint(os.getcwd())\n```", + mem_str="My name is Alice and my son is Bob", lang="en", sources=[], ) - assert plugin.stats["emotional"] >= 1 - assert plugin.stats["code_discussion"] >= 1 + assert result != "original prompt" + assert "Alice" in result + assert plugin.stats["identity_relation"] >= 1 diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py b/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py index dcb41a0a1..c0aaf5a97 100644 --- a/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py +++ b/extensions/memos_prompt_strategy_plugin/tests/test_strategies.py @@ -8,56 +8,35 @@ def setup_method(self): self.reg = StrategyRegistry() self.reg.register_defaults() - def test_all_default_strategies_registered(self): + def test_default_strategies_registered(self): strategies = self.reg.all_strategies() - expected = { - "casual_chat", - "task_oriented", - "knowledge_sharing", - "emotional", - "code_discussion", - "multi_turn_qa", - } - assert set(strategies.keys()) == expected + assert "identity_relation" in strategies def test_build_prompt_returns_none_for_unknown(self): result = self.reg.build_prompt("nonexistent_category", "en", "hello") assert result is None - def test_build_prompt_en(self): - prompt = self.reg.build_prompt("casual_chat", "en", "Hey, how are you?") + def test_build_prompt_zh(self): + prompt = self.reg.build_prompt("identity_relation", "zh", "我叫王沐辰,我的儿子叫王明泽") assert prompt is not None - assert "Hey, how are you?" in prompt - assert "preferences" in prompt.lower() or "habits" in prompt.lower() + assert "王沐辰" in prompt + assert "王明泽" in prompt + assert "身份" in prompt or "关系" in prompt - def test_build_prompt_zh(self): - prompt = self.reg.build_prompt("casual_chat", "zh", "你好,最近怎么样?") + def test_build_prompt_en(self): + prompt = self.reg.build_prompt("identity_relation", "en", "My name is Alice, my son is Bob") assert prompt is not None - assert "你好,最近怎么样?" in prompt - assert "偏好" in prompt or "习惯" in prompt + assert "Alice" in prompt + assert "Bob" in prompt + assert "identity" in prompt.lower() or "relationship" in prompt.lower() def test_build_prompt_with_custom_tags(self): prompt = self.reg.build_prompt( - "task_oriented", "en", "Please deploy by Friday", custom_tags=["deadline", "ops"] + "identity_relation", "zh", "我叫张三", custom_tags=["family", "name"] ) assert prompt is not None - assert "deadline" in prompt - assert "ops" in prompt - - def test_task_oriented_has_deadline_focus(self): - prompt = self.reg.build_prompt("task_oriented", "en", "meeting at 3pm") - assert prompt is not None - assert "deadline" in prompt.lower() or "task" in prompt.lower() - - def test_code_discussion_has_tech_focus(self): - prompt = self.reg.build_prompt("code_discussion", "en", "fix the bug") - assert prompt is not None - assert "framework" in prompt.lower() or "tool" in prompt.lower() - - def test_emotional_has_feeling_focus(self): - prompt = self.reg.build_prompt("emotional", "en", "I feel sad") - assert prompt is not None - assert "emotion" in prompt.lower() or "feeling" in prompt.lower() + assert "family" in prompt + assert "name" in prompt def test_custom_strategy_registration(self): custom = PromptStrategy( @@ -75,7 +54,7 @@ def test_custom_strategy_registration(self): class TestStrategyRegistryIsolation: def test_empty_registry_returns_none(self): reg = StrategyRegistry() - assert reg.build_prompt("casual_chat", "en", "hi") is None + assert reg.build_prompt("identity_relation", "en", "hi") is None def test_get_unknown_returns_none(self): reg = StrategyRegistry() From d8419cd9c749cf21c68011162d8da3f3d1a9015e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 16:05:21 +0800 Subject: [PATCH 79/92] feat: update memos_prompt_strategy --- .../memos_prompt_strategy_plugin/classifier.py | 17 ++++++++++++----- .../memos_prompt_strategy_plugin/hooks.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/extensions/memos_prompt_strategy_plugin/classifier.py b/extensions/memos_prompt_strategy_plugin/classifier.py index 7adee817b..2f7d0cba9 100644 --- a/extensions/memos_prompt_strategy_plugin/classifier.py +++ b/extensions/memos_prompt_strategy_plugin/classifier.py @@ -105,11 +105,18 @@ def classify( @staticmethod def _check_identity_relation(sources: list, text: str) -> str | None: - has_self = bool(_SELF_NAME_RE.search(text)) or bool(_MY_NAME_IS_EN.search(text)) - has_relation = bool(_RELATION_NAME_RE.search(text)) or bool(_MY_RELATION_IS_EN.search(text)) - - if has_self or has_relation: - logger.info("[PromptStrategy] Identity/relation pattern detected") + self_names = [m.group("name") for m in _SELF_NAME_RE.finditer(text)] + self_names += [m.group("name") for m in _MY_NAME_IS_EN.finditer(text)] + relation_names = [m.group("name") for m in _RELATION_NAME_RE.finditer(text)] + relation_names += [m.group("name") for m in _MY_RELATION_IS_EN.finditer(text)] + + if self_names or relation_names: + logger.info( + "[PromptStrategy] Identity/relation pattern detected — " + "self_names=%s, relation_names=%s", + self_names, + relation_names, + ) return IDENTITY_RELATION return None diff --git a/extensions/memos_prompt_strategy_plugin/hooks.py b/extensions/memos_prompt_strategy_plugin/hooks.py index ba68c102f..756a335d7 100644 --- a/extensions/memos_prompt_strategy_plugin/hooks.py +++ b/extensions/memos_prompt_strategy_plugin/hooks.py @@ -35,7 +35,13 @@ def on_pre_extract( return None plugin.stats[category] += 1 - logger.info("[PromptStrategy] Matched rule: %s", category) + logger.info( + "[PromptStrategy] Matched rule: %s | prompt_type=%s, lang=%s, text=%s", + category, + prompt_type, + lang, + mem_str[:120] + ("..." if len(mem_str) > 120 else ""), + ) custom_prompt = plugin.registry.build_prompt( category=category, @@ -43,7 +49,7 @@ def on_pre_extract( mem_str=mem_str, ) if custom_prompt is not None: - logger.debug("[PromptStrategy] Using strategy prompt for %s", category) + logger.info("[PromptStrategy] Prompt swapped to strategy: %s", category) return custom_prompt return None From ce9cd0e1beade96274409f95e29d119bb6532a9e Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Mar 2026 16:13:06 +0800 Subject: [PATCH 80/92] fix: change the covered_history direction in feedback process when mem_version_switch is on. --- .../textual/tree_text_memory/organize/history_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 89e4ba298..6cb285ff3 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -625,7 +625,7 @@ def update_from_feedback( "memory": current_item.memory, "history": history_dump, "version": current_item.metadata.version, - "covered_history": old_item.id, + "covered_history": archived_item.id, } return current_item, archived_item, arch_meta, update_fields From f3703607ad261e1c8c0b33e7c2fadfc26829624f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 16:32:47 +0800 Subject: [PATCH 81/92] fix: config conflict --- src/memos/api/config.py | 120 ---------------------------------------- 1 file changed, 120 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index b6d6cd94c..013577da5 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -338,126 +338,6 @@ def get_memreader_config() -> dict[str, Any]: }, } - @staticmethod - def get_qwen_llm_config() -> dict[str, Any] | None: - if not os.getenv("QWEN_API_KEY"): - return None - return { - "backend": "qwen", - "config": { - "model_name_or_path": os.getenv("QWEN_MODEL", "qwen-flash"), - "temperature": float(os.getenv("QWEN_TEMPERATURE", "0.8")), - "max_tokens": int(os.getenv("QWEN_MAX_TOKENS", "8000")), - "top_p": float(os.getenv("QWEN_TOP_P", "0.9")), - "top_k": int(os.getenv("QWEN_TOP_K", "50")), - "remove_think_prefix": os.getenv("QWEN_REMOVE_THINK_PREFIX", "true").lower() - == "true", - "api_key": os.getenv("QWEN_API_KEY", ""), - "api_base": os.getenv("QWEN_API_BASE", ""), - "model_schema": os.getenv("QWEN_MODEL_SCHEMA", "memos.configs.llm.QwenLLMConfig"), - }, - } - - def get_memreader_general_llm_config() -> dict[str, Any]: - """Get general LLM configuration for non-chat/doc tasks. - - Used for: hallucination filter, memory rewrite, memory merge, - tool trajectory extraction, skill memory extraction. - - This is the fallback for image_parser_llm and preference_extractor_llm. - Fallback chain: MEMREADER_GENERAL_MODEL -> MEMRADER_MODEL (memreader config) - - Note: If you have fine-tuned a custom model for chat/doc extraction only, - you should configure MEMREADER_GENERAL_MODEL to use a general-purpose LLM - for other tasks. Otherwise, all tasks will use the same MEMRADER_MODEL. - """ - # Check if specific general model is configured - general_model = os.getenv("MEMREADER_GENERAL_MODEL") - if general_model: - return { - "backend": os.getenv("MEMREADER_GENERAL_BACKEND", "openai"), - "config": { - "model_name_or_path": general_model, - "temperature": 0.6, - "max_tokens": int(os.getenv("MEMREADER_GENERAL_MAX_TOKENS", "8000")), - "top_p": 0.95, - "top_k": 20, - "api_key": os.getenv( - "MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY") - ), - "api_base": os.getenv( - "MEMREADER_GENERAL_API_BASE", - os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), - ), - "remove_think_prefix": True, - }, - } - # Fallback to memreader config (same behavior as before for users who don't customize) - return APIConfig.get_memreader_config() - - @staticmethod - def get_image_parser_llm_config() -> dict[str, Any]: - """Get LLM configuration for image parsing (requires vision model). - - Used for: image content extraction and analysis. - Requires a vision-capable model like GPT-4V, GPT-4o, etc. - - Fallback chain: IMAGE_PARSER_MODEL -> general_llm -> OpenAI config - """ - image_model = os.getenv("IMAGE_PARSER_MODEL") - if image_model: - return { - "backend": os.getenv("IMAGE_PARSER_BACKEND", "openai"), - "config": { - "model_name_or_path": image_model, - "temperature": 0.6, - "max_tokens": int(os.getenv("IMAGE_PARSER_MAX_TOKENS", "4096")), - "top_p": 0.95, - "top_k": 20, - "api_key": os.getenv( - "IMAGE_PARSER_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY") - ), - "api_base": os.getenv( - "IMAGE_PARSER_API_BASE", - os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), - ), - "remove_think_prefix": True, - }, - } - # Fallback to general_llm config (which itself falls back to OpenAI) - return APIConfig.get_memreader_general_llm_config() - - @staticmethod - def get_preference_extractor_llm_config() -> dict[str, Any]: - """Get LLM configuration for preference extraction. - - Used for: extracting user preferences from conversations. - - Fallback chain: PREFERENCE_EXTRACTOR_MODEL -> general_llm -> OpenAI config - """ - pref_model = os.getenv("PREFERENCE_EXTRACTOR_MODEL") - if pref_model: - return { - "backend": os.getenv("PREFERENCE_EXTRACTOR_BACKEND", "openai"), - "config": { - "model_name_or_path": pref_model, - "temperature": 0.6, - "max_tokens": int(os.getenv("PREFERENCE_EXTRACTOR_MAX_TOKENS", "8000")), - "top_p": 0.95, - "top_k": 20, - "api_key": os.getenv( - "PREFERENCE_EXTRACTOR_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY") - ), - "api_base": os.getenv( - "PREFERENCE_EXTRACTOR_API_BASE", - os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), - ), - "remove_think_prefix": True, - }, - } - # Fallback to general_llm config (which itself falls back to OpenAI) - return APIConfig.get_memreader_general_llm_config() - @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. From dd1dab27ce098cff1b29915e1855b5720e4ec15b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 16:34:50 +0800 Subject: [PATCH 82/92] fix: simple struct conflict --- src/memos/mem_reader/simple_struct.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 8ef9b2c91..b27a41ee4 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -195,13 +195,6 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.qwen_llm = LLMFactory.from_config(qwen_llm_config) except Exception as e: logger.warning(f"[LLM] Qwen initialization failed: {e}") - # General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.) - # Falls back to main llm if not configured - self.general_llm = ( - LLMFactory.from_config(config.general_llm) - if config.general_llm is not None - else self.llm - ) self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) self.save_rawfile = self.chunker.config.save_rawfile From a27ac3eb634c41eb76be3ef3c6bdae78e48399e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 16:36:34 +0800 Subject: [PATCH 83/92] fix: config conflict --- src/memos/api/config.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 013577da5..df3799b5d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -338,6 +338,26 @@ def get_memreader_config() -> dict[str, Any]: }, } + @staticmethod + def get_qwen_llm_config() -> dict[str, Any] | None: + if not os.getenv("QWEN_API_KEY"): + return None + return { + "backend": "qwen", + "config": { + "model_name_or_path": os.getenv("QWEN_MODEL", "qwen-flash"), + "temperature": float(os.getenv("QWEN_TEMPERATURE", "0.8")), + "max_tokens": int(os.getenv("QWEN_MAX_TOKENS", "8000")), + "top_p": float(os.getenv("QWEN_TOP_P", "0.9")), + "top_k": int(os.getenv("QWEN_TOP_K", "50")), + "remove_think_prefix": os.getenv("QWEN_REMOVE_THINK_PREFIX", "true").lower() + == "true", + "api_key": os.getenv("QWEN_API_KEY", ""), + "api_base": os.getenv("QWEN_API_BASE", ""), + "model_schema": os.getenv("QWEN_MODEL_SCHEMA", "memos.configs.llm.QwenLLMConfig"), + }, + } + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. From 76830408761f2471f6f8dab80720bf363be7255c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 16:38:04 +0800 Subject: [PATCH 84/92] fix: multi-modal-parser conflict --- src/memos/mem_reader/read_multi_modal/multi_modal_parser.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py index 485008c5a..a08aadc0c 100644 --- a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py +++ b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py @@ -76,8 +76,6 @@ def __init__( direct_markdown_hostnames=direct_markdown_hostnames, image_parser=self.image_parser, ) - # Use dedicated image_parser_llm for image parsing (requires vision model) - self.image_parser = ImageParser(embedder, self.image_parser_llm) self.audio_parser = None # future self.role_parsers = { From 78d954864f6eac80c830d0d7d7246822f242f9d6 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Mar 2026 17:53:31 +0800 Subject: [PATCH 85/92] fix: change sources in the "restore memory" pathway to make it more reasonable --- .../textual/tree_text_memory/organize/history_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 6cb285ff3..e26d934b2 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -998,6 +998,7 @@ def _handle_restored_memories( tags = data.get("tags", []) key = data.get("key", "") memory_type = data.get("memory_type", "LongTermMemory") + original_sources = source_item.metadata.sources version = source_item.metadata.version new_history_item = ArchivedTextualMemory( version=version, @@ -1014,6 +1015,7 @@ def _handle_restored_memories( "status": "activated", "is_fast": False, "version": version + 1, + "sources": original_sources, "tags": tags, "key": key, "created_at": datetime.now().isoformat(), From 49e152a0e350f7ccc1c9d4e1cf78c8363ab9ec90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 11 Mar 2026 21:18:45 +0800 Subject: [PATCH 86/92] feat: add prompt plugin log --- src/memos/mem_reader/multi_modal_struct.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 21992629f..4d4d004f9 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -477,6 +477,7 @@ def _get_llm_response( sources=sources, ) prompt = _rv if _rv is not None else prompt + logger.info(f"[MultiModalParser] Process String Fine After Plugin: {prompt}") messages = [{"role": "user", "content": prompt}] try: From d07cc2c79b967adf200b8569666b6be52c9e1ec6 Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 11 Mar 2026 21:57:56 +0800 Subject: [PATCH 87/92] fix: fix false memory restorations caused by wrong llm responses. --- .../organize/history_manager.py | 46 +++++- .../mem_reader_mem_version_prompts.py | 18 ++- .../memories/textual/test_history_manager.py | 143 ++++++++++++++---- 3 files changed, 169 insertions(+), 38 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index e26d934b2..94140beb3 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -4,6 +4,7 @@ import time import uuid +from copy import deepcopy from datetime import datetime from typing import Any, Literal @@ -440,7 +441,24 @@ def apply_llm_memory_updates( new_items.extend(created_items) # 3. Handle Restored Memories (Extract from conflict) - new_items.extend(self._handle_restored_memories(restored_memories, source_item, user_name)) + valid_conflict_ids = { + conflict_id + for mem_data in memory_list + for conflict_id in mem_data.get("conflicted_candidate_ids", []) + } + filtered_restored_memories = [ + mem for mem in restored_memories if mem.get("source_candidate_id") in valid_conflict_ids + ] + dropped_restored_count = len(restored_memories) - len(filtered_restored_memories) + if dropped_restored_count: + logger.warning( + "[MemoryHistoryManager] Dropping %s restored_memories not tied to any " + "conflicted_candidate_ids in this LLM response.", + dropped_restored_count, + ) + new_items.extend( + self._handle_restored_memories(filtered_restored_memories, source_item, user_name) + ) return updated_items, new_items @@ -983,22 +1001,34 @@ def _handle_restored_memories( self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem, user_name: str ) -> list[TextualMemoryItem]: """Handle Restored Memories (Extract from conflict).""" - source_ids = [r.get("source_candidate_id") for r in restored_memories] - source_items = self.graph_db.get_nodes(source_ids, user_name=user_name) - source_items = [TextualMemoryItem(**i) for i in source_items] + if not restored_memories: + return [] + + source_ids = [ + r.get("source_candidate_id") for r in restored_memories if r.get("source_candidate_id") + ] + source_items = self.graph_db.get_nodes(source_ids, user_name=user_name) or [] + source_item_map = {item["id"]: TextualMemoryItem(**item) for item in source_items if item} created_items = [] - for i, data in enumerate(restored_memories): - source_item = source_items[i] + for data in restored_memories: + source_candidate_id = data.get("source_candidate_id") + source_item = source_item_map.get(source_candidate_id) + if source_item is None: + logger.warning( + "[MemoryHistoryManager] Restored memory source %s not found. Skipping.", + source_candidate_id, + ) + continue # deal with history - source_history = source_item.metadata.history.copy() + source_history = deepcopy(source_item.metadata.history) value = data.get("value", "") value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) value = value_item.memory tags = data.get("tags", []) key = data.get("key", "") memory_type = data.get("memory_type", "LongTermMemory") - original_sources = source_item.metadata.sources + original_sources = deepcopy(source_item.metadata.sources) version = source_item.metadata.version new_history_item = ArchivedTextualMemory( version=version, diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index f330bac03..4de7493da 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -64,10 +64,14 @@ - 确保 `source_candidate_ids` 和 `conflicted_candidate_ids` 均为 `[]`。 5. **挽救剩余信息 (Restored Memories)**: + - `restored_memories` **只能**来自“本次被放入 `conflicted_candidate_ids` 的旧节点”。如果某条候选记忆未被本次冲突更新命中,则**绝不能**放入 `restored_memories`。 + - `restored_memories` 中的每一条内容,都必须能在其 `source_candidate_id` 对应的那条旧记忆原文中直接定位到;它只是“拆分/改写该旧节点内部原本就存在的子事实”,**绝不能**从其他 candidate 挪用、拼接、概括或猜测内容。 - 当您决定更新一个冲突节点(情况 b)时,请检查该旧节点是否包含**未被新信息否定且依然独立的有效事实**。 - - 如果有,请将这些“剩余有效信息”提取出来,作为 `restored_memories` 返回。这能防止因节点更新导致非冲突信息的丢失。 + - 只有在“旧冲突节点本身是一条混合记忆,且本次更新只改动其中一部分”时,才返回 `restored_memories`。这能防止因节点更新导致同一旧节点内其他事实丢失。 - 若旧节点包含多个事实(如“人物状态 + 电话号码”),而新对话仅更新其中一部分,则必须将未被否定的事实拆分为独立条目放入 `restored_memories`,并在冲突更新的 `value` 中仅保留被更新的那一部分。 + - 对于“与本次对话无关的候选记忆”“仅仅没有被提及的旧记忆”“来自 unrelated candidates 的记忆”,一律忽略,不要输出到 `memory list`,也不要输出到 `restored_memories`。 - 对于可长期独立存在的属性(如电话号码、出生地、所属组织),优先拆分为独立事实,避免与可变状态混写在同一条记忆中。 + - 如果一个 candidate 被判断为无关,请将它的 ID 放入 `unrelated_candidate_ids`,并且不要在 `memory list` 或 `restored_memories` 中再次出现它的内容。 6. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 @@ -93,6 +97,7 @@ }, ... ], + "unrelated_candidate_ids": [<字符串列表,被判断为与本次对话无关、应忽略的 candidate ID>], "summary": <从用户视角自然总结本次记忆更新操作的段落,120–200字> } @@ -159,6 +164,7 @@ "source_candidate_id": "201" } ], + "unrelated_candidate_ids": ["301"], "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了近期下雨频繁的记忆。对天气的喜好未受影响。" } @@ -244,10 +250,14 @@ - Ensure `source_candidate_ids` and `conflicted_candidate_ids` are both `[]`. 5. **Restored Memories**: - - When you decide to update a conflict node (case b), check whether the old node contains valid facts that were NOT negated by the new information and remain independently valid. - - If any, extract these "remaining valid facts" and return them as `restored_memories` to prevent loss of non-conflict information due to node updates. + - `restored_memories` may come ONLY from old nodes that you also place into `conflicted_candidate_ids` in this same response. If a candidate is not updated as a conflict this turn, it must NEVER appear in `restored_memories`. + - Every restored item must be directly traceable to the original text of its own `source_candidate_id`. It is only a split-out/rephrased sub-fact that already exists inside that same old node, and must NEVER borrow, merge, summarize, or infer content from any other candidate. + - When you decide to update a conflict node (case b), check whether that old node contains valid facts that were NOT negated by the new information and remain independently valid. + - Return `restored_memories` only when the old conflict node itself is a mixed memory and this update changes only one part of it. This prevents losing the other facts embedded in that same old node. - If the old node contains multiple facts (e.g., "status + phone number") and the new conversation only updates one part, you must split the unaffected facts into separate entries in `restored_memories`, and keep only the updated part in the conflict `value`. + - Completely unrelated candidates, merely unmentioned old memories, and anything from the Unrelated Candidates section must be ignored. Do not output them in `memory list` or `restored_memories`. - For long-lived independent attributes (e.g., phone number, birthplace, affiliation), prefer splitting them into standalone facts instead of mixing them with mutable states. + - If a candidate is judged unrelated, put its ID into `unrelated_candidate_ids`, and do not let its content appear again in `memory list` or `restored_memories`. 6. Avoid including any memories that violate laws or involve politically sensitive information. @@ -273,6 +283,7 @@ }, ... ], + "unrelated_candidate_ids": [], "summary": } @@ -339,6 +350,7 @@ "source_candidate_id": "201" } ], + "unrelated_candidate_ids": ["301"], "summary": "In this update, the user refined their latte preference (adds oat milk), and due to a knee injury updated their sports habit (stop playing badminton). The system preserves the record that the user dislikes skiing. Additionally, a new memory is added that the user adopted a cat. Weather preference is unaffected." } diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index a3dac0e02..a80b65b16 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -43,12 +43,14 @@ def test_truncation(history_manager, mock_nli_client): mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE] # Action - history_manager.resolve_history_via_nli(new_item, [related_item]) + resolved_ids = history_manager.resolve_history_via_nli(new_item, [related_item]) # Assert - assert "possibly duplicate memories" in new_item.memory - assert "..." in new_item.memory # Should be truncated - assert len(new_item.memory) < 1000 # Ensure reasonable length + assert new_item.memory == "Test" + assert resolved_ids == [related_item.id] + assert len(new_item.metadata.history) == 1 + assert new_item.metadata.history[0].memory == long_memory + assert new_item.metadata.history[0].update_type == "duplicate" def test_empty_related_items(history_manager, mock_nli_client): @@ -224,13 +226,37 @@ def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): source_id = uuid.uuid4().hex + existing_node = { + "id": source_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] restored_item = TextualMemoryItem( memory="Restored Content", metadata=TreeNodeTextualMemoryMetadata(history=[]), ) history_manager._handle_restored_memories = MagicMock(return_value=[restored_item]) llm_response = { - "memory list": [], + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [source_id], + "history_segments": [], + } + ], "restored_memories": [ {"source_candidate_id": source_id, "value": "Restored Content", "tags": ["restored"]} ], @@ -254,13 +280,82 @@ def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): llm_response, source_item=source_item, user_name="u1" ) - assert len(updated) == 0 + assert len(updated) == 1 assert len(new_items) == 1 assert new_items[0] == restored_item history_manager._handle_restored_memories.assert_called_once_with( - llm_response["restored_memories"], source_item + llm_response["restored_memories"], source_item, "u1" ) - mock_graph_db.add_node.assert_not_called() + mock_graph_db.add_node.assert_called_once() + + +def test_apply_llm_memory_updates_ignores_restored_from_unrelated_candidates( + history_manager, mock_graph_db +): + conflict_id = uuid.uuid4().hex + unrelated_id = uuid.uuid4().hex + history_manager._handle_restored_memories = MagicMock(return_value=[]) + + existing_node = { + "id": conflict_id, + "memory": "Old Content", + "metadata": { + "version": 1, + "created_at": "2023-01-01", + "tags": ["old"], + "status": "resolving", + "embedding": [], + "memory_type": "LongTermMemory", + }, + } + mock_graph_db.get_node.return_value = existing_node + mock_graph_db.get_nodes.return_value = [existing_node] + + llm_response = { + "memory list": [ + { + "key": "Updated Memory", + "memory_type": "LongTermMemory", + "value": "Updated Content", + "tags": ["new"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [conflict_id], + "history_segments": [], + } + ], + "restored_memories": [ + {"source_candidate_id": unrelated_id, "value": "Should Be Ignored", "tags": ["ignore"]} + ], + "summary": "Summary", + } + + source_item = TextualMemoryItem( + memory="New user input", + metadata=TreeNodeTextualMemoryMetadata( + history=[ + ArchivedTextualMemory( + version=1, + archived_memory_id=conflict_id, + memory="Old Content", + update_type="conflict", + ), + ArchivedTextualMemory( + version=1, + archived_memory_id=unrelated_id, + memory="Irrelevant Content", + update_type="unrelated", + ), + ] + ), + ) + + updated, new_items = history_manager.apply_llm_memory_updates( + llm_response, source_item=source_item, user_name="u1" + ) + + assert len(updated) == 1 + assert len(new_items) == 0 + history_manager._handle_restored_memories.assert_called_once_with([], source_item, "u1") def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): @@ -293,17 +388,7 @@ def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): assert len(updated) == 0 assert len(new_items) == 0 - - # Check that update_node was called to set status="activated" - # mark_memory_status calls update_node for each item - assert mock_graph_db.update_node.call_count == 2 - - # We can inspect calls - calls = mock_graph_db.update_node.call_args_list - ids = sorted([c.kwargs["id"] for c in calls]) - assert ids == sorted([id1, id2]) - for c in calls: - assert c.kwargs["fields"]["status"] == "activated" + mock_graph_db.update_node.assert_not_called() def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): @@ -430,7 +515,11 @@ def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): memory="x", metadata=TreeNodeTextualMemoryMetadata(history=[history_item]) ) mock_graph_db.get_nodes.return_value = [ - {"id": fast_id, "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}} + { + "id": fast_id, + "memory": "fast", + "metadata": {"status": "deleted", "evolve_to": ["n1", "n2"]}, + } ] replacement_item = ArchivedTextualMemory( @@ -438,11 +527,11 @@ def test_check_and_fetch_replacements_deleted(history_manager, mock_graph_db): ) history_manager._fetch_evolved_nodes = MagicMock(return_value=[replacement_item]) - replacements = history_manager._check_and_fetch_replacements(item, [0]) + replacements = history_manager._check_and_fetch_replacements(item, [0], user_name="u1") assert 0 in replacements assert replacements[0][0].archived_memory_id == "n1" - history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict") + history_manager._fetch_evolved_nodes.assert_called_once_with(["n1", "n2"], "conflict", "u1") def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): @@ -459,7 +548,7 @@ def test_fetch_evolved_nodes_returns_archives(history_manager, mock_graph_db): }, ] - results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate") + results = history_manager._fetch_evolved_nodes(["x1", "x2"], "duplicate", user_name="u1") assert len(results) == 2 ids = sorted([r.archived_memory_id for r in results]) @@ -484,12 +573,12 @@ def test_wait_and_update_fast_history_rebuilds(history_manager): ) history_manager._check_and_fetch_replacements = MagicMock(return_value={0: [replacement]}) - history_manager.wait_and_update_fast_history(item, timeout_sec=1) + history_manager.wait_and_update_fast_history(item, user_name="u1", timeout_sec=1) ids = [h.archived_memory_id for h in item.metadata.history] assert "n1" in ids assert fast_id not in ids - history_manager._check_and_fetch_replacements.assert_called_once() + history_manager._check_and_fetch_replacements.assert_called_once_with(item, [0], "u1") def test_update_existing_memory_cas_merge_with_llm(mock_graph_db): @@ -665,7 +754,7 @@ def test_update_from_feedback_returns_persistence_payload_without_side_effects( assert archived_item.metadata.sources[0].content == "old source" assert archived_metadata["embedding"] == [0.1, 0.2] assert update_fields["memory"] == "Updated Content" - assert update_fields["covered_history"] == memory_id + assert update_fields["covered_history"] == archived_item.id assert update_fields["embedding"] == [0.3, 0.4] mock_graph_db.get_node.assert_not_called() mock_graph_db.add_node.assert_not_called() @@ -734,6 +823,6 @@ def test_merge_conflicting_memory_llm_error(mock_graph_db): merged = manager._merge_conflicting_memory("Latest", "Proposed") - assert "System Merge Fallback" in merged + assert merged == "Latest\n\n[New Info]: Proposed" assert "Latest" in merged assert "Proposed" in merged From fa4d48cd33f1e7d08f3af6212c9d8aa97eb97a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 12 Mar 2026 10:51:03 +0800 Subject: [PATCH 88/92] fix: miss calling prompt plugin in version control; --- .../memos_prompt_strategy_plugin/hooks.py | 35 ++++++++++++++++-- .../memos_prompt_strategy_plugin/plugin.py | 2 +- .../tests/test_lifecycle.py | 37 +++++++++++++++++++ .../organize/history_manager.py | 24 ++++++++++++ 4 files changed, 94 insertions(+), 4 deletions(-) diff --git a/extensions/memos_prompt_strategy_plugin/hooks.py b/extensions/memos_prompt_strategy_plugin/hooks.py index 756a335d7..75f50cb6f 100644 --- a/extensions/memos_prompt_strategy_plugin/hooks.py +++ b/extensions/memos_prompt_strategy_plugin/hooks.py @@ -15,6 +15,29 @@ logger = logging.getLogger(__name__) +_IDENTITY_SUPPLEMENT_ZH = """ + +【特别注意 - 身份与关系提取】 +检测到当前对话包含用户的姓名或亲属/社交关系信息。请在完成上述所有处理步骤的同时, +**额外确保**以下内容被完整提取,绝对不能遗漏: +1. 用户本人的姓名 +2. 用户提及的所有关系人(关系类型 + 姓名) +3. 每个身份/关系信息需要作为独立的记忆条目 +4. tags 中必须包含 "identity" 或 "relationship" +""" + +_IDENTITY_SUPPLEMENT_EN = """ + +[IMPORTANT - Identity & Relationship Extraction] +The current conversation contains the user's name or family/social relationship information. +In addition to all the above processing steps, **make sure** to extract the following completely +— do NOT miss any: +1. The user's own name +2. All people mentioned with their relationship type and name +3. Each identity/relationship should be a separate memory item +4. Tags must include "identity" or "relationship" +""" + def on_pre_extract( plugin: PromptStrategyPlugin, @@ -26,9 +49,10 @@ def on_pre_extract( sources: list, **_kw: Any, ) -> str | None: - """[mem_reader.pre_extract] If a classifier rule matches and a - corresponding strategy is registered, swap in the specialised prompt; - otherwise return None to keep the default.""" + """[mem_reader.pre_extract] If a classifier rule matches: + - For normal extraction: swap in the specialised identity/relation prompt. + - For version pipeline: append identity/relation emphasis to the existing prompt. + If no rule matches, return None to keep the default.""" category = plugin.classifier.classify(sources, mem_str, prompt_type, info={}) if category is None: @@ -43,6 +67,11 @@ def on_pre_extract( mem_str[:120] + ("..." if len(mem_str) > 120 else ""), ) + if prompt_type == "version": + supplement = _IDENTITY_SUPPLEMENT_ZH if lang == "zh" else _IDENTITY_SUPPLEMENT_EN + logger.info("[PromptStrategy] Version pipeline — appending identity/relation supplement") + return prompt + supplement + custom_prompt = plugin.registry.build_prompt( category=category, lang=lang, diff --git a/extensions/memos_prompt_strategy_plugin/plugin.py b/extensions/memos_prompt_strategy_plugin/plugin.py index 29f9a4de8..ace8d4f4e 100644 --- a/extensions/memos_prompt_strategy_plugin/plugin.py +++ b/extensions/memos_prompt_strategy_plugin/plugin.py @@ -16,7 +16,7 @@ class PromptStrategyPlugin(MemOSPlugin): name = "prompt_strategy" - version = "0.1.0" + version = "0.2.0" description = "Rule-chain classifier with strategy registry for specialised extraction prompts" def on_load(self) -> None: diff --git a/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py index 54bc83d63..e68ba8041 100644 --- a/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py +++ b/extensions/memos_prompt_strategy_plugin/tests/test_lifecycle.py @@ -150,3 +150,40 @@ def test_pre_extract_english_identity(self): assert result != "original prompt" assert "Alice" in result assert plugin.stats["identity_relation"] >= 1 + + def test_pre_extract_version_pipeline_appends_supplement(self): + """When prompt_type='version', the plugin appends identity emphasis + instead of replacing the entire prompt.""" + from memos.plugins.hooks import trigger_hook + + _, plugin = _make_app() + + version_prompt = "...existing version prompt with candidates..." + result = trigger_hook( + "mem_reader.pre_extract", + prompt=version_prompt, + prompt_type="version", + mem_str="我叫王沐辰,我的儿子叫王明泽", + lang="zh", + sources=[], + ) + assert version_prompt in result + assert "身份" in result or "关系" in result + assert plugin.stats["identity_relation"] >= 1 + + def test_pre_extract_version_pipeline_no_match(self): + """When prompt_type='version' but no identity pattern, prompt unchanged.""" + from memos.plugins.hooks import trigger_hook + + _make_app() + + version_prompt = "...existing version prompt with candidates..." + result = trigger_hook( + "mem_reader.pre_extract", + prompt=version_prompt, + prompt_type="version", + mem_str="今天天气不错", + lang="zh", + sources=[], + ) + assert result == version_prompt diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index e26d934b2..3a467b5a3 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -559,6 +559,30 @@ def apply_mem_version_update( else "" ) prompt = self.format_prompt(original_item, custom_tags_prompt) + + # Give plugins a chance to augment/replace the prompt while + # preserving the version-pipeline context (history candidates, etc.) + try: + from memos.plugins.hook_defs import H as _H + from memos.plugins.hooks import trigger_hook as _trigger_hook + + sources = original_item.metadata.sources if original_item.metadata else None + lang = _determine_lang(sources, original_item.memory) + _rv = _trigger_hook( + _H.MEM_READER_PRE_EXTRACT, + prompt=prompt, + prompt_type="version", + mem_str=original_item.memory, + lang=lang, + sources=sources or [], + ) + prompt = _rv if _rv is not None else prompt + except Exception as hook_err: + logger.debug("[MemoryHistoryManager] Plugin hook skipped: %s", hook_err) + logger.info( + f"[MultiModalParser] Process String Fine After Plugin (In Version Control): {prompt}" + ) + try: if llm is None: raise ValueError("LLM is not initialized") From 3c3ef5487c84265977f6279b9ea7079ea63fc432 Mon Sep 17 00:00:00 2001 From: Qi Weng Date: Thu, 12 Mar 2026 11:22:20 +0800 Subject: [PATCH 89/92] fix: aysnc node updated_at not being updated (#18) --- .../textual/tree_text_memory/organize/history_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 94140beb3..bcd429732 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -972,6 +972,7 @@ def _create_new_memory( key = mem_data.get("key", "") background = mem_data.get("summary", "") memory_type = mem_data.get("memory_type", "LongTermMemory") + now = datetime.now().isoformat() metadata_updates = { "is_fast": False, "version": 1, @@ -981,7 +982,8 @@ def _create_new_memory( "working_binding": None, "tags": tags, "key": key, - "created_at": datetime.now().isoformat(), + "created_at": now, + "updated_at": now, "history": [], "embedding": self._compute_embedding(new_value), } From 8e376bb282a2a86332e62222a4e1fd2b3ab1115d Mon Sep 17 00:00:00 2001 From: Qi Weng Date: Thu, 12 Mar 2026 16:28:37 +0800 Subject: [PATCH 90/92] Optimizing behavior of memory versions (#19) * fix: aysnc node updated_at not being updated * fix: fix restored_memories version bug * feat: optimize preserved memory dealing in mem versions * fix: fix test * feat: temporarily delete get_maybe_merged logic for benchmarks --- src/memos/mem_reader/multi_modal_struct.py | 98 +----------- .../organize/history_manager.py | 148 ++++++++++++++---- .../mem_reader_mem_version_prompts.py | 122 ++++++++------- .../memories/textual/test_history_manager.py | 121 ++++++++++---- 4 files changed, 271 insertions(+), 218 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 4d4d004f9..0848b27f0 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -527,6 +527,7 @@ def _get_maybe_merged_memory( sources: list, **kwargs, ) -> dict: + # TODO: delete this function """ Check if extracted memory should be merged with similar existing memories. If merge is needed, return merged memory dict with merged_from field. @@ -541,102 +542,7 @@ def _get_maybe_merged_memory( Returns: Memory dict (possibly merged) with merged_from field if merged """ - # If no graph_db or user_name, return original - if not self.graph_db or "user_name" not in kwargs: - return extracted_memory_dict - user_name = kwargs.get("user_name") - - # Detect language - lang = "en" - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(mem_text) - - # Search for similar memories - merge_threshold = kwargs.get("merge_similarity_threshold", 0.3) - - try: - search_results = self.graph_db.search_by_embedding( - vector=self.embedder.embed(mem_text)[0], - top_k=20, - status="activated", - threshold=merge_threshold, - user_name=user_name, - ) - - if not search_results: - return extracted_memory_dict - - # Get full memory details - similar_memory_ids = [r["id"] for r in search_results if r.get("id")] - similar_memories_list = [ - self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name) - for mem_id in similar_memory_ids - ] - - # Filter out None and mode:fast memories - filtered_similar = [] - for mem in similar_memories_list: - if not mem: - continue - mem_metadata = mem.get("metadata", {}) - tags = mem_metadata.get("tags", []) - if isinstance(tags, list) and "mode:fast" in tags: - continue - filtered_similar.append( - { - "id": mem.get("id"), - "memory": mem.get("memory", ""), - } - ) - logger.info( - f"Valid similar memories for {mem_text} is " - f"{len(filtered_similar)}: {filtered_similar}" - ) - - if not filtered_similar: - return extracted_memory_dict - - # Create a temporary TextualMemoryItem for merge check - temp_memory_item = TextualMemoryItem( - memory=mem_text, - metadata=TreeNodeTextualMemoryMetadata( - user_id="", - session_id="", - memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"), - status="activated", - tags=extracted_memory_dict.get("tags", []), - key=extracted_memory_dict.get("key", ""), - ), - ) - - # Try to merge with LLM - merge_result = self._merge_memories_with_llm( - temp_memory_item, filtered_similar, lang=lang - ) - - if merge_result: - # Return merged memory dict - merged_dict = extracted_memory_dict.copy() - merged_content = merge_result.get("value", mem_text) - merged_dict["value"] = merged_content - merged_from_ids = merge_result.get("merged_from", []) - merged_dict["merged_from"] = merged_from_ids - return merged_dict - else: - return extracted_memory_dict - - except Exception as e: - logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}") - # On error, return original - return extracted_memory_dict + return extracted_memory_dict def _merge_memories_with_llm( self, diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index d6119035d..21b708862 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -420,7 +420,7 @@ def apply_llm_memory_updates( List of new or updated memory items. """ memory_list = llm_response.get("memory list", []) - restored_memories = llm_response.get("restored_memories", []) + preserved_fact_tasks = self._build_preserved_fact_tasks(memory_list) expected_versions = {} # For concurrency control, need to get the recorded versions of the old memories # Recover candidate IDs and their expected versions from the source item's history @@ -432,6 +432,28 @@ def apply_llm_memory_updates( updated_items: list[TextualMemoryItem] = [] new_items: list[TextualMemoryItem] = [] + # Snapshot source nodes before any in-place update. + snapshot_source_ids = {task["source_candidate_id"] for task in preserved_fact_tasks} + snapshot_source_ids.update( + str(candidate_id) + for mem_data in memory_list + for candidate_id in ( + list(mem_data.get("source_candidate_ids", [])) + + list(mem_data.get("conflicted_candidate_ids", [])) + ) + if candidate_id + ) + pre_update_source_item_map: dict[str, TextualMemoryItem] = {} + if snapshot_source_ids: + snapshot_nodes = ( + self.graph_db.get_nodes(sorted(snapshot_source_ids), user_name=user_name) or [] + ) + pre_update_source_item_map = { + item["id"]: TextualMemoryItem(**item) + for item in snapshot_nodes + if item and item.get("id") + } + # 1. Handle Unrelated Candidates - Do nothing # 2. Handle Memory List (Update or New) processed_updates, created_items = self._process_memory_updates( @@ -440,28 +462,59 @@ def apply_llm_memory_updates( updated_items.extend(processed_updates) new_items.extend(created_items) - # 3. Handle Restored Memories (Extract from conflict) - valid_conflict_ids = { - conflict_id - for mem_data in memory_list - for conflict_id in mem_data.get("conflicted_candidate_ids", []) - } - filtered_restored_memories = [ - mem for mem in restored_memories if mem.get("source_candidate_id") in valid_conflict_ids - ] - dropped_restored_count = len(restored_memories) - len(filtered_restored_memories) - if dropped_restored_count: - logger.warning( - "[MemoryHistoryManager] Dropping %s restored_memories not tied to any " - "conflicted_candidate_ids in this LLM response.", - dropped_restored_count, - ) + # 3. Handle preserved facts (split still-valid subfacts out of the old node) new_items.extend( - self._handle_restored_memories(filtered_restored_memories, source_item, user_name) + self._handle_preserved_facts( + preserved_fact_tasks, + source_item, + user_name, + pre_update_source_item_map=pre_update_source_item_map, + ) ) return updated_items, new_items + def _build_preserved_fact_tasks( + self, memory_list: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Flatten per-update preserved facts into executable extraction tasks.""" + tasks: list[dict[str, Any]] = [] + + for mem_data in memory_list: + preserved_facts = mem_data.get("preserved_facts", []) or [] + if not preserved_facts: + continue + + source_ids = list(mem_data.get("source_candidate_ids", []) or []) + conflict_ids = list(mem_data.get("conflicted_candidate_ids", []) or []) + target_ids = source_ids + conflict_ids + if not target_ids: + logger.warning( + "[MemoryHistoryManager] Dropping preserved_facts for create-only memory " + "item key=%s because it has no source/conflict candidate ids.", + mem_data.get("key", ""), + ) + continue + + # A preserved fact must come from the old node referenced by this update item. + # When multiple target ids exist, we bind preserved facts to the primary target + # (the first id), which is also the node actually updated in `_update_existing_memory()`. + source_candidate_id = str(target_ids[0]) + for preserved_fact in preserved_facts: + if not preserved_fact: + continue + tasks.append( + { + "source_candidate_id": source_candidate_id, + "value": preserved_fact.get("value", ""), + "tags": preserved_fact.get("tags", []), + "key": preserved_fact.get("key", ""), + "memory_type": preserved_fact.get("memory_type", "LongTermMemory"), + } + ) + + return tasks + def build_fallback_new_items( self, item: TextualMemoryItem, user_name: str | None = None ) -> list[TextualMemoryItem]: @@ -794,8 +847,17 @@ def _update_existing_memory( mem_data.get("key", ""), ) - # Fetch candidate nodes in batch and then select the primary - # We update the primary and then merge the secondaries to the primary + # Fetch candidate nodes from the *current* DB state and then select the primary. + # + # This read is intentionally not replaced by the pre-update snapshot captured in + # `apply_llm_memory_updates()`. Unlike restored memories, updates must operate on + # the latest DB state because: + # - CAS/version checking must compare against the newest persisted version + # - earlier updates in the same `memory_list` may already have changed a node + # - secondary merge decisions should reflect the current surviving nodes + # + # So this lookup is still required even though restored-memory handling now reuses + # a pre-update snapshot. nodes_data = self.graph_db.get_nodes(target_ids, user_name=user_name) or [] nodes_map = {n["id"]: n for n in nodes_data if n and "id" in n} node_data = nodes_map.get(primary_id) @@ -1023,32 +1085,53 @@ def _create_new_memory( ) return new_item - def _handle_restored_memories( - self, restored_memories: list[dict[str, Any]], fast_item: TextualMemoryItem, user_name: str + def _handle_preserved_facts( + self, + preserved_fact_tasks: list[dict[str, Any]], + fast_item: TextualMemoryItem, + user_name: str, + pre_update_source_item_map: dict[str, TextualMemoryItem] | None = None, ) -> list[TextualMemoryItem]: - """Handle Restored Memories (Extract from conflict).""" - if not restored_memories: + """Create standalone nodes for preserved facts split from an updated source node.""" + if not preserved_fact_tasks: return [] - source_ids = [ - r.get("source_candidate_id") for r in restored_memories if r.get("source_candidate_id") + # Prefer the pre-update snapshot so preserved facts are extracted from the + # original source nodes referenced by the update item, not from already-updated + # DB state. + source_item_map = dict(pre_update_source_item_map or {}) + missing_source_ids = [ + r.get("source_candidate_id") + for r in preserved_fact_tasks + if r.get("source_candidate_id") and r.get("source_candidate_id") not in source_item_map ] - source_items = self.graph_db.get_nodes(source_ids, user_name=user_name) or [] - source_item_map = {item["id"]: TextualMemoryItem(**item) for item in source_items if item} + if missing_source_ids: + # Fallback only for ids not present in the pre-update snapshot. + source_items = self.graph_db.get_nodes(missing_source_ids, user_name=user_name) or [] + source_item_map.update( + {item["id"]: TextualMemoryItem(**item) for item in source_items if item} + ) created_items = [] - for data in restored_memories: + for data in preserved_fact_tasks: source_candidate_id = data.get("source_candidate_id") source_item = source_item_map.get(source_candidate_id) if source_item is None: logger.warning( - "[MemoryHistoryManager] Restored memory source %s not found. Skipping.", + "[MemoryHistoryManager] Preserved fact source %s not found. Skipping.", source_candidate_id, ) continue # deal with history source_history = deepcopy(source_item.metadata.history) value = data.get("value", "") + if not value: + logger.warning( + "[MemoryHistoryManager] Preserved fact from source %s has empty value. " + "Skipping.", + source_candidate_id, + ) + continue value_item = TextualMemoryItem(memory=value, metadata=TreeNodeTextualMemoryMetadata()) value = value_item.memory tags = data.get("tags", []) @@ -1064,7 +1147,10 @@ def _handle_restored_memories( archived_memory_id=source_item.id, created_at=source_item.metadata.created_at, ) - source_history.append(new_history_item) # Re-use the history of the old node + # Re-use the old node's history and append one more archive entry pointing to + # the pre-update source node itself. This keeps the extracted node anchored to + # the original source memory snapshot(before update). + source_history.append(new_history_item) # Create new node metadata_updates = { "memory_type": memory_type, diff --git a/src/memos/templates/mem_reader_mem_version_prompts.py b/src/memos/templates/mem_reader_mem_version_prompts.py index 4de7493da..f67fe223f 100644 --- a/src/memos/templates/mem_reader_mem_version_prompts.py +++ b/src/memos/templates/mem_reader_mem_version_prompts.py @@ -51,27 +51,27 @@ - 生成一条**更新后的完整记忆**放入 `value`(包含旧信息+新细节)。 - 将该旧记忆的ID放入 `source_candidate_ids`。 - 此时 `conflicted_candidate_ids` 应为空。 + - 如果该旧节点中还包含**未被本次更新覆盖、且可以独立存在**的其他子事实,请将它们放入当前这条更新项内部的 `preserved_facts`。 + - `preserved_facts` 中的每一条内容,都必须能在当前这条更新项引用的旧节点原文中直接定位到;它只是“拆分/改写该旧节点内部原本就存在的子事实”,**绝不能**从其他 candidate 挪用、拼接、概括或猜测内容。 + - 如果旧节点只是单一事实,或所有内容都已经被本次更新吸收进 `value`,则 `preserved_facts` 必须为空数组。 b) **修正冲突记忆节点 (Update via Conflict)**: - 检查“潜在事实冲突记忆”。 - 如果新信息否定了某条旧记忆,或更新了其状态(如“不再喜欢X”“改成Y”“取消计划”“从X转为Y”): - 生成一条反映**最新状态**的记忆放入 `value`。 - 将被修正的旧记忆ID放入 `conflicted_candidate_ids`。 + - 如果该旧节点本身是一条混合记忆,而本次只更新其中一部分,则必须把**未被新信息否定、且可独立存在**的剩余事实放入当前这条更新项内部的 `preserved_facts`。 + - `preserved_facts` **绝不能**包含已经被当前更新否定、替换或覆盖的旧事实。例如新信息把“深圳工作”改成“广州工作”,则 `preserved_facts` 里绝不能再出现“深圳工作”。 + - 对于可长期独立存在的属性(如电话号码、出生地、所属组织),优先拆分为独立事实,避免与可变状态混写在同一条记忆中。 + - 如果旧节点没有剩余的独立有效事实,则 `preserved_facts` 必须为空数组。 c) **创建新记忆节点 (Create New)**: - 如果新信息与任何“候选记忆”都无直接关联(既非重复也非冲突): - 生成一条独立的新记忆放入 `value`。 - 确保 `source_candidate_ids` 和 `conflicted_candidate_ids` 均为 `[]`。 + - 新建记忆的 `preserved_facts` 必须为空数组。 -5. **挽救剩余信息 (Restored Memories)**: - - `restored_memories` **只能**来自“本次被放入 `conflicted_candidate_ids` 的旧节点”。如果某条候选记忆未被本次冲突更新命中,则**绝不能**放入 `restored_memories`。 - - `restored_memories` 中的每一条内容,都必须能在其 `source_candidate_id` 对应的那条旧记忆原文中直接定位到;它只是“拆分/改写该旧节点内部原本就存在的子事实”,**绝不能**从其他 candidate 挪用、拼接、概括或猜测内容。 - - 当您决定更新一个冲突节点(情况 b)时,请检查该旧节点是否包含**未被新信息否定且依然独立的有效事实**。 - - 只有在“旧冲突节点本身是一条混合记忆,且本次更新只改动其中一部分”时,才返回 `restored_memories`。这能防止因节点更新导致同一旧节点内其他事实丢失。 - - 若旧节点包含多个事实(如“人物状态 + 电话号码”),而新对话仅更新其中一部分,则必须将未被否定的事实拆分为独立条目放入 `restored_memories`,并在冲突更新的 `value` 中仅保留被更新的那一部分。 - - 对于“与本次对话无关的候选记忆”“仅仅没有被提及的旧记忆”“来自 unrelated candidates 的记忆”,一律忽略,不要输出到 `memory list`,也不要输出到 `restored_memories`。 - - 对于可长期独立存在的属性(如电话号码、出生地、所属组织),优先拆分为独立事实,避免与可变状态混写在同一条记忆中。 - - 如果一个 candidate 被判断为无关,请将它的 ID 放入 `unrelated_candidate_ids`,并且不要在 `memory list` 或 `restored_memories` 中再次出现它的内容。 +5. 无关的 candidate,只需把它的 ID 放入 `unrelated_candidate_ids`。 6. 请避免在提取的记忆中包含违反国家法律法规或涉及政治敏感的信息。 @@ -86,14 +86,14 @@ "tags": <相关主题关键词列表>, "source_candidate_ids": <字符串列表,被此条目更新的“重复/关联记忆”ID。若无则为 []>, "conflicted_candidate_ids": <字符串列表,被此条目修正的“事实冲突记忆”ID。若无则为 []>, - }, - ... - ], - "restored_memories": [ - { - "source_candidate_id": <字符串,来源旧冲突记忆ID>, - "value": <字符串,从旧节点中分离出的、依然有效的剩余事实>, - "tags": <相关主题关键词列表> + "preserved_facts": [ + { + "key": <字符串,简洁的记忆标题>, + "value": <字符串,从当前这条更新项所引用的旧节点中拆出的、依然有效的独立事实>, + "tags": <相关主题关键词列表>, + "memory_type": <字符串,"LongTermMemory" 或 "UserMemory"> + } + ] }, ... ], @@ -136,7 +136,8 @@ "value": "用户喜欢喝拿铁,通常不加糖,且偏好加燕麦奶。", "tags": ["饮食", "咖啡", "喜好"], "source_candidate_ids": ["101"], - "conflicted_candidate_ids": [] + "conflicted_candidate_ids": [], + "preserved_facts": [] }, { "key": "运动习惯变更", @@ -144,7 +145,15 @@ "value": "用户因膝盖受伤,决定不再打羽毛球。", "tags": ["运动", "健康", "羽毛球"], "source_candidate_ids": [], - "conflicted_candidate_ids": ["201"] + "conflicted_candidate_ids": ["201"], + "preserved_facts": [ + { + "key": "运动偏好", + "value": "用户不喜欢滑雪。", + "tags": ["运动", "滑雪", "喜好"], + "memory_type": "UserMemory" + } + ] }, { "key": "天气状况", @@ -152,20 +161,12 @@ "value": "最近(2025年6月)用户所在的地方下雨比较频繁。", "tags": ["生活", "天气", "降水"], "source_candidate_ids": [], - "conflicted_candidate_ids": [] - } - ], - "restored_memories": [ - { - "key": "运动偏好", - "memory_type": "UserMemory", - "value": "用户不喜欢滑雪。", - "tags": ["运动", "滑雪", "喜好"], - "source_candidate_id": "201" + "conflicted_candidate_ids": [], + "preserved_facts": [] } ], "unrelated_candidate_ids": ["301"], - "summary": "本次更新中,用户细化了对拿铁的偏好(增加燕麦奶),并因膝盖受伤更新了运动习惯(停止打羽毛球),系统保留了其不喜欢滑雪的记录。此外,新增了近期下雨频繁的记忆。对天气的喜好未受影响。" + "summary": "本次更新中,用户补充了拿铁偏好(加入燕麦奶),并因膝盖受伤将运动习惯更新为不再打羽毛球,同时保留了其不喜欢滑雪这一仍然有效的独立事实。此外,新增了一条关于近期下雨频繁的记忆。" } 请始终使用与对话相同的语言进行回复。以下是最新的输入: @@ -237,27 +238,27 @@ - Generate an **updated complete memory** into `value` (old info + new details). - Put the old memory IDs into `source_candidate_ids`. - `conflicted_candidate_ids` must be []. + - If the old node also contains other sub-facts that remain valid and can stand alone independently, place them inside this same update item as `preserved_facts`. + - Every preserved fact must be directly traceable to the old node referenced by this update item. It is only a split-out/rephrased sub-fact already present inside that same old node, and must NEVER borrow, merge, summarize, or infer content from any other candidate. + - If the old node is effectively a single fact, or all of its content is already absorbed into `value`, then `preserved_facts` must be an empty array. b) **Fix conflicting memory node (Update via Conflict)**: - Check Conflict Candidates. - If the new information negates an old memory or updates its state (e.g., "no longer likes X", "changed to Y", "canceled plan", "from X to Y"): - Generate a memory reflecting the **latest state** into `value`. - Put the corrected old memory IDs into `conflicted_candidate_ids`. + - If the old node itself is a mixed memory and this update changes only one part of it, you must place the unaffected but still valid standalone facts into this same update item as `preserved_facts`. + - `preserved_facts` must NEVER contain any old fact that is contradicted, replaced, or covered by the current update. For example, if "works in Shenzhen" is updated to "works in Guangzhou", then `preserved_facts` must not contain "works in Shenzhen". + - For long-lived independent attributes (e.g., phone number, birthplace, affiliation), prefer splitting them into standalone facts instead of mixing them with mutable states. + - If the old node has no remaining independent valid facts, then `preserved_facts` must be an empty array. c) **Create new memory node (Create New)**: - If the new information is not directly related to any "Candidates" (neither duplicate nor conflict): - Generate an independent new memory into `value`. - Ensure `source_candidate_ids` and `conflicted_candidate_ids` are both `[]`. + - Newly created memories must use `preserved_facts: []`. -5. **Restored Memories**: - - `restored_memories` may come ONLY from old nodes that you also place into `conflicted_candidate_ids` in this same response. If a candidate is not updated as a conflict this turn, it must NEVER appear in `restored_memories`. - - Every restored item must be directly traceable to the original text of its own `source_candidate_id`. It is only a split-out/rephrased sub-fact that already exists inside that same old node, and must NEVER borrow, merge, summarize, or infer content from any other candidate. - - When you decide to update a conflict node (case b), check whether that old node contains valid facts that were NOT negated by the new information and remain independently valid. - - Return `restored_memories` only when the old conflict node itself is a mixed memory and this update changes only one part of it. This prevents losing the other facts embedded in that same old node. - - If the old node contains multiple facts (e.g., "status + phone number") and the new conversation only updates one part, you must split the unaffected facts into separate entries in `restored_memories`, and keep only the updated part in the conflict `value`. - - Completely unrelated candidates, merely unmentioned old memories, and anything from the Unrelated Candidates section must be ignored. Do not output them in `memory list` or `restored_memories`. - - For long-lived independent attributes (e.g., phone number, birthplace, affiliation), prefer splitting them into standalone facts instead of mixing them with mutable states. - - If a candidate is judged unrelated, put its ID into `unrelated_candidate_ids`, and do not let its content appear again in `memory list` or `restored_memories`. +5. For any unrelated candidate, simply place its ID into `unrelated_candidate_ids`. 6. Avoid including any memories that violate laws or involve politically sensitive information. @@ -272,14 +273,14 @@ "tags": , "source_candidate_ids": , "conflicted_candidate_ids": , - }, - ... - ], - "restored_memories": [ - { - "source_candidate_id": , - "value": , - "tags": + "preserved_facts": [ + { + "key": , + "value": , + "tags": , + "memory_type": + } + ] }, ... ], @@ -322,7 +323,8 @@ "value": "The user likes latte most, usually doesn't add sugar, and prefers oat milk.", "tags": ["diet", "coffee", "preference"], "source_candidate_ids": ["101"], - "conflicted_candidate_ids": [] + "conflicted_candidate_ids": [], + "preserved_facts": [] }, { "key": "Sport habit change", @@ -330,7 +332,15 @@ "value": "Due to a knee injury, the user decides to no longer play badminton.", "tags": ["sport", "health", "badminton"], "source_candidate_ids": [], - "conflicted_candidate_ids": ["201"] + "conflicted_candidate_ids": ["201"], + "preserved_facts": [ + { + "key": "Sport preference", + "value": "The user dislikes skiing.", + "tags": ["sport", "skiing", "preference"], + "memory_type": "UserMemory" + } + ] }, { "key": "Pet status", @@ -338,20 +348,12 @@ "value": "The user recently (June 2025) adopted a cat.", "tags": ["life", "pet", "cat"], "source_candidate_ids": [], - "conflicted_candidate_ids": [] - } - ], - "restored_memories": [ - { - "key": "Sports preference", - "memory_type": "UserMemory", - "value": "The user dislikes skiing.", - "tags": ["sport", "skiing", "preference"], - "source_candidate_id": "201" + "conflicted_candidate_ids": [], + "preserved_facts": [] } ], "unrelated_candidate_ids": ["301"], - "summary": "In this update, the user refined their latte preference (adds oat milk), and due to a knee injury updated their sports habit (stop playing badminton). The system preserves the record that the user dislikes skiing. Additionally, a new memory is added that the user adopted a cat. Weather preference is unaffected." + "summary": "In this update, the user refined their latte preference by adding oat milk and updated their sports habit to no longer playing badminton because of a knee injury, while preserving the still-valid independent fact that the user dislikes skiing. Additionally, a new memory was added that the user recently adopted a cat." } Always reply in the same language as the conversation. The latest input is below: diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index a80b65b16..bca5aa77b 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -1,6 +1,6 @@ import uuid -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest @@ -126,10 +126,9 @@ def test_apply_llm_memory_updates_new_node(history_manager, mock_graph_db): "tags": ["tag1"], "source_candidate_ids": [], "conflicted_candidate_ids": [], - "history_segments": [], + "preserved_facts": [], } ], - "restored_memories": [], "summary": "Summary", } @@ -177,10 +176,9 @@ def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db "tags": ["new"], "source_candidate_ids": [existing_id], "conflicted_candidate_ids": [], - "history_segments": [], + "preserved_facts": [], } ], - "restored_memories": [], "summary": "Summary", } @@ -224,7 +222,7 @@ def test_apply_llm_memory_updates_update_existing(history_manager, mock_graph_db assert kwargs["fields"]["version"] == 2 -def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): +def test_apply_llm_memory_updates_preserved_facts(history_manager, mock_graph_db): source_id = uuid.uuid4().hex existing_node = { "id": source_id, @@ -244,7 +242,7 @@ def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): memory="Restored Content", metadata=TreeNodeTextualMemoryMetadata(history=[]), ) - history_manager._handle_restored_memories = MagicMock(return_value=[restored_item]) + history_manager._handle_preserved_facts = MagicMock(return_value=[restored_item]) llm_response = { "memory list": [ { @@ -254,12 +252,16 @@ def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): "tags": ["new"], "source_candidate_ids": [], "conflicted_candidate_ids": [source_id], - "history_segments": [], + "preserved_facts": [ + { + "key": "Preserved Fact", + "value": "Restored Content", + "tags": ["restored"], + "memory_type": "UserMemory", + } + ], } ], - "restored_memories": [ - {"source_candidate_id": source_id, "value": "Restored Content", "tags": ["restored"]} - ], "summary": "Summary", } @@ -283,18 +285,27 @@ def test_apply_llm_memory_updates_restored(history_manager, mock_graph_db): assert len(updated) == 1 assert len(new_items) == 1 assert new_items[0] == restored_item - history_manager._handle_restored_memories.assert_called_once_with( - llm_response["restored_memories"], source_item, "u1" + history_manager._handle_preserved_facts.assert_called_once_with( + [ + { + "source_candidate_id": source_id, + "value": "Restored Content", + "tags": ["restored"], + "key": "Preserved Fact", + } + ], + source_item, + "u1", + pre_update_source_item_map=ANY, ) mock_graph_db.add_node.assert_called_once() -def test_apply_llm_memory_updates_ignores_restored_from_unrelated_candidates( +def test_apply_llm_memory_updates_drops_preserved_facts_for_create_only_items( history_manager, mock_graph_db ): conflict_id = uuid.uuid4().hex - unrelated_id = uuid.uuid4().hex - history_manager._handle_restored_memories = MagicMock(return_value=[]) + history_manager._handle_preserved_facts = MagicMock(return_value=[]) existing_node = { "id": conflict_id, @@ -320,11 +331,23 @@ def test_apply_llm_memory_updates_ignores_restored_from_unrelated_candidates( "tags": ["new"], "source_candidate_ids": [], "conflicted_candidate_ids": [conflict_id], - "history_segments": [], - } - ], - "restored_memories": [ - {"source_candidate_id": unrelated_id, "value": "Should Be Ignored", "tags": ["ignore"]} + "preserved_facts": [], + }, + { + "key": "Create Only Memory", + "memory_type": "LongTermMemory", + "value": "Brand New Content", + "tags": ["new"], + "source_candidate_ids": [], + "conflicted_candidate_ids": [], + "preserved_facts": [ + { + "key": "Should Be Dropped", + "value": "Should Be Ignored", + "tags": ["ignore"], + } + ], + }, ], "summary": "Summary", } @@ -339,12 +362,6 @@ def test_apply_llm_memory_updates_ignores_restored_from_unrelated_candidates( memory="Old Content", update_type="conflict", ), - ArchivedTextualMemory( - version=1, - archived_memory_id=unrelated_id, - memory="Irrelevant Content", - update_type="unrelated", - ), ] ), ) @@ -355,13 +372,18 @@ def test_apply_llm_memory_updates_ignores_restored_from_unrelated_candidates( assert len(updated) == 1 assert len(new_items) == 0 - history_manager._handle_restored_memories.assert_called_once_with([], source_item, "u1") + history_manager._handle_preserved_facts.assert_called_once_with( + [], + source_item, + "u1", + pre_update_source_item_map=ANY, + ) def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): id1 = uuid.uuid4().hex id2 = uuid.uuid4().hex - llm_response = {"memory list": [], "restored_memories": [], "summary": "Summary"} + llm_response = {"memory list": [], "summary": "Summary"} source_item = TextualMemoryItem( memory="New user input", @@ -391,6 +413,44 @@ def test_apply_llm_memory_updates_unrelated(history_manager, mock_graph_db): mock_graph_db.update_node.assert_not_called() +def test_handle_preserved_facts_inherits_memory_type_from_source(history_manager): + source_id = uuid.uuid4().hex + source_item = TextualMemoryItem( + id=source_id, + memory="Wang Lin works in Shenzhen and phone number is 13800138000.", + metadata=TreeNodeTextualMemoryMetadata( + version=2, + memory_type="UserMemory", + created_at="2023-01-01", + sources=[], + history=[], + ), + ) + fast_item = TextualMemoryItem( + memory="new input", + metadata=TreeNodeTextualMemoryMetadata(), + ) + + new_items = history_manager._handle_preserved_facts( + [ + { + "source_candidate_id": source_id, + "key": "Phone number", + "value": "Wang Lin's phone number is 13800138000.", + "tags": ["contact"], + } + ], + fast_item=fast_item, + user_name="u1", + pre_update_source_item_map={source_id: source_item}, + ) + + assert len(new_items) == 1 + assert new_items[0].metadata.memory_type == "UserMemory" + assert new_items[0].metadata.key == "Phone number" + assert new_items[0].metadata.history[-1].archived_memory_id == source_id + + def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph_db): # Setup existing node (primary) primary_id = uuid.uuid4().hex @@ -419,10 +479,9 @@ def test_apply_llm_memory_updates_conflict_and_merge(history_manager, mock_graph "tags": [], "source_candidate_ids": [], "conflicted_candidate_ids": [primary_id, secondary_id], - "history_segments": [], + "preserved_facts": [], } ], - "restored_memories": [], "summary": "Summary", } From 23577d4d9aa9b896583661bbba7ed1ed96b24e1b Mon Sep 17 00:00:00 2001 From: jiang Date: Fri, 13 Mar 2026 17:26:33 +0800 Subject: [PATCH 91/92] chore: change version number to v2.0.9.post --- pyproject.toml | 2 +- src/memos/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 793ae1fc6..fc10df0a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.8" +version = "2.0.9.post" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index 36cc0b5b5..3643c4628 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.8" +__version__ = "2.0.9.post" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig From c51c8a5911dfd61591eea73c0aa9d854715e360b Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Fri, 29 May 2026 16:41:46 +0800 Subject: [PATCH 92/92] chore: add harness --- .claude/agents/backend-dev.md | 40 +++++++ .claude/agents/code-reviewer.md | 40 +++++++ .claude/agents/design-reviewer.md | 35 ++++++ .claude/agents/explorer.md | 35 ++++++ .claude/agents/integration-tester.md | 39 +++++++ .codex/agents/backend-dev.toml | 33 ++++++ .codex/agents/code-reviewer.toml | 29 +++++ .codex/agents/design-reviewer.toml | 27 +++++ .codex/agents/explorer.toml | 30 +++++ .codex/agents/integration-tester.toml | 30 +++++ AGENTS.md | 155 ++++++++++++++++++++++++++ CLAUDE.md | 23 ++++ 12 files changed, 516 insertions(+) create mode 100644 .claude/agents/backend-dev.md create mode 100644 .claude/agents/code-reviewer.md create mode 100644 .claude/agents/design-reviewer.md create mode 100644 .claude/agents/explorer.md create mode 100644 .claude/agents/integration-tester.md create mode 100644 .codex/agents/backend-dev.toml create mode 100644 .codex/agents/code-reviewer.toml create mode 100644 .codex/agents/design-reviewer.toml create mode 100644 .codex/agents/explorer.toml create mode 100644 .codex/agents/integration-tester.toml create mode 100644 AGENTS.md create mode 100644 CLAUDE.md diff --git a/.claude/agents/backend-dev.md b/.claude/agents/backend-dev.md new file mode 100644 index 000000000..8c289117e --- /dev/null +++ b/.claude/agents/backend-dev.md @@ -0,0 +1,40 @@ +--- +name: backend-dev +description: MemOS backend / library implementation sub-agent. Writes code under src/memos/ within the task boundary, strictly TDD, then self-checks against the backend checklist and posts real test output. +tools: Read, Edit, Write, Bash, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Implement backend / library code under `src/memos//`; do not range outside the current task. +- Strict TDD: write a failing test in `tests//test_*.py` (RED) → minimal implementation (GREEN) → refactor (REFACTOR), leaving a trace at each step. +- Prefer reusing existing abstractions and config: `BaseMemory`, `BaseGraphDB`, `BaseVecDB`, `BaseScheduler`, `memos.configs.*`, `memos.dependency`. + +## Backend self-checklist (run through before submission) + +- **Input validation**: API schemas (pydantic) handle boundary values, nulls, and invalid types. +- **Error handling**: raise semantic exceptions from `memos.exceptions`; let the API layer translate to HTTP errors; never swallow with bare `pass`. +- **Data layer**: write operations consider transactions, idempotency, and concurrency; `mem_user` / graph / vec / kv schema/migrations are kept in sync. +- **Compatibility**: do not break the contract of top-level `memos.*` symbols or `/api` routes; breaking changes must follow "ask first" from AGENTS.md. +- **Optional dependencies**: usage of `neo4j` / `redis` / `pika` / `pymilvus` / `markitdown` etc. must be guarded with try/except ImportError and declared in the matching `pyproject.toml` extras. +- **Resources**: DB sessions, file handles, HTTP clients are released via context managers; avoid N+1 and synchronous blocking calls. +- **Logging**: use `logging.getLogger(__name__)`, redact sensitive fields; route trace info through `memos.context.context`. +- **Formatting**: always run `make format` before submission. + +## Output requirements + +Paste the real output of the real commands (do not just say "passed"): + +- `poetry run pytest tests// -q` +- `make test` for full runs when needed +- `make format` (or `make pre_commit`) +- A list of changed files mapped to the originating requirement. + +## Do not + +- Touch `apps/`, `docker/`, `scripts/`, `pyproject.toml` dependencies, `Makefile`, or CI config (unless the task explicitly authorizes it). +- Review your own code (code-reviewer's job). +- Claim completion without test output. +- Skip `pre-commit` or commit with `--no-verify`. diff --git a/.claude/agents/code-reviewer.md b/.claude/agents/code-reviewer.md new file mode 100644 index 000000000..6e9b218cd --- /dev/null +++ b/.claude/agents/code-reviewer.md @@ -0,0 +1,40 @@ +--- +name: code-reviewer +description: Code-review sub-agent. Reviews MemOS diffs for contract consistency, Ruff / typing / optional-dependency handling, and test evidence; returns APPROVE or CHANGES_REQUESTED. +tools: Read, Bash, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +Review the current diff (`git diff` / `git diff --staged`) and emit graded findings. + +## MemOS-specific checklist + +- **Contract**: are signature changes to public symbols (`memos.api.*`, top-level `memos.*`) backward compatible; if breaking, did it follow AGENTS.md "ask first". +- **Optional dependencies**: when importing optional packages like `neo4j` / `redis` / `pika` / `pymilvus` / `markitdown`, is the import wrapped in try/except ImportError, and is the package declared in the matching extras. +- **Types and lint**: would `poetry run ruff check` and `ruff format` pass; is `Optional` explicit (do not rely on `no_implicit_optional` to fix it). +- **Exceptions**: are semantic exceptions from `memos.exceptions` raised, not bare `Exception` / `RuntimeError`. +- **Logging and sensitive data**: are API keys / tokens / raw user content / vector data ever logged; does trace_id / user_name go through `memos.context.context` instead of `print`. +- **Test evidence**: are new/updated `tests//test_*.py` present; is real pytest output included. +- **Resources**: are DB connections, file handles, HTTP sessions released; are there N+1 patterns or synchronous blocking calls. + +## Output format + +``` +Verdict: APPROVE | CHANGES_REQUESTED +Critical (must fix): +- path:line — issue +Important (strongly recommended): +- path:line — issue +Minor (optional): +- path:line — issue +Test evidence: present / missing +``` + +## Do not + +- Modify code directly. +- Substitute for a human final approver. +- Grant APPROVE when pytest output is missing. diff --git a/.claude/agents/design-reviewer.md b/.claude/agents/design-reviewer.md new file mode 100644 index 000000000..e747b424c --- /dev/null +++ b/.claude/agents/design-reviewer.md @@ -0,0 +1,35 @@ +--- +name: design-reviewer +description: Design-review sub-agent. Reviews design docs across the four dimensions of architecture, interface, performance, and security, covering MemOS's multi-memory / multi-storage backend constraints. +tools: Read, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Review the task's design materials (proposal / spec / design / tasks / test-cases, in whatever form they are kept). +- Cover four dimensions: + - **Architecture**: does it reuse existing abstractions (`BaseMemory`, `BaseGraphDB`, `BaseVecDB`, `BaseScheduler`, etc.), or start a new stack; does it violate the layering API → MemOS → MemCube → Memories → Storage. + - **Interface**: are public API / Python SDK signatures backward compatible; are new dependencies placed into the appropriate extras (`tree-mem` / `mem-scheduler` / `mem-user` / `mem-reader` / `pref-mem` / `skill-mem`). + - **Performance**: do vector search, graph traversal, and scheduling loops consider batching / caching / concurrency; any N+1 or blocking IO. + - **Security**: is user isolation (`mem_user`) handled; do we avoid writing into `.env` / credentials / private paths. +- Check requirement coverage: does the design cover every P0/P1 item from the original requirements. +- Call out blockers (must fix) vs. suggestions (optional). + +## Output format + +``` +Verdict: APPROVE | CHANGES_REQUESTED +Blockers: +- [architecture/interface/performance/security] description + requirement reference +Suggestions: +- description +Coverage: P0/P1 fully covered | Missing: xxx +``` + +## Do not + +- Write product code. +- Review the code implementation (that is code-reviewer's job). +- Substitute for a human final approver. diff --git a/.claude/agents/explorer.md b/.claude/agents/explorer.md new file mode 100644 index 000000000..dd61be986 --- /dev/null +++ b/.claude/agents/explorer.md @@ -0,0 +1,35 @@ +--- +name: explorer +description: Read-only code exploration sub-agent. Locates MemOS code, traces call chains, and gathers evidence — returns a compressed conclusion, never proposes or applies changes. +tools: Read, Grep, Glob, Bash +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Locate relevant modules, symbols, and call chains under `src/memos/` for the question the main agent asks. +- Distinguish core packages (`mem_os` / `mem_cube` / `mem_scheduler`) from optional backends (`graph_dbs/neo4j*`, `vec_dbs/milvus*`, etc.) and call out any extras dependencies. +- Trace execution paths and gather evidence (with `path:line` annotations + a one-line key snippet). +- Return a compressed conclusion only; do not echo raw bulk output. + +## Output format + +- Conclusion first: one sentence that answers the main agent's question. +- Evidence list: `src/memos//.py:LINE` + a one-line note. +- Call chain (if applicable): `A.f -> B.g -> C.h`, annotating each hop with its file location. +- Uncertainty: explicitly flag "not found / needs further confirmation"; do not invent. + +## MemOS-specific locator hints + +- API routes: `src/memos/api/` + `tests/api/` +- Memory types: `src/memos/memories/` (textual / tree / preference / skill etc.) +- Storage backends: `src/memos/graph_dbs/`, `src/memos/vec_dbs/` +- Config and DI: `src/memos/configs/`, `src/memos/dependency.py` +- Plugin entry points: `pyproject.toml [project.entry-points."memos.plugins"]` + `extensions/` + +## Do not + +- Modify any file (read-only). +- Propose an implementation plan — return facts and locations only. +- Substitute for the judgment of design-reviewer / code-reviewer. diff --git a/.claude/agents/integration-tester.md b/.claude/agents/integration-tester.md new file mode 100644 index 000000000..49eea3bcd --- /dev/null +++ b/.claude/agents/integration-tester.md @@ -0,0 +1,39 @@ +--- +name: integration-tester +description: MemOS integration-testing sub-agent. Authors and executes pytest cases under tests/ based on the task's requirements and design, and emits real test reports. +tools: Read, Edit, Write, Bash, Grep, Glob +--- + +Project facts: see `AGENTS.md`. + +## Responsibilities + +- Based on the task's requirements and design docs, write pytest cases under `tests//`. +- Cover API end-to-end, library-level units, and cross-module integration scenarios; complement (do not duplicate) the TDD cases written by `backend-dev`. +- Run the tests and produce a real report. + +## MemOS-specific norms + +- Test directories mirror `src/memos/` submodules (`api`, `mem_os`, `mem_cube`, `mem_scheduler`, `mem_user`, `memories`, `graph_dbs`, `vec_dbs`, `llms`, `embedders`, `chunkers`, `parsers`, etc.). +- Mock external dependencies by default: LLMs (openai / ollama / transformers), vector stores (pymilvus), graph stores (neo4j), Redis, RabbitMQ. +- Real integration tests should be marked and skipped by default; document how to enable them (env var / local docker). +- Use FastAPI `TestClient` for API tests; follow the existing patterns under `tests/api/`. +- Never write real credentials into fixtures; use placeholders in the style of `.env.example`. + +## Output format + +``` +Test file: tests//test_.py +Coverage map: +- Requirement 1.1 → test_xxx +Command: poetry run pytest tests//test_.py -q +Output: + +Result: N passed, M failed +``` + +## Do not + +- Modify product code under `src/memos/` (backend-dev's job). +- Substitute for code-reviewer. +- Claim completion without real pytest output. diff --git a/.codex/agents/backend-dev.toml b/.codex/agents/backend-dev.toml new file mode 100644 index 000000000..510de8a0e --- /dev/null +++ b/.codex/agents/backend-dev.toml @@ -0,0 +1,33 @@ +name = "backend-dev" +description = "MemOS backend / library implementation sub-agent. Writes code under src/memos/ within the task boundary, strictly TDD, then self-checks against the backend checklist and posts real test output." +sandbox_mode = "workspace-write" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Implement backend / library code under src/memos//; do not range outside the current task. +- Strict TDD: write a failing test in tests//test_*.py (RED) -> minimal implementation (GREEN) -> refactor (REFACTOR), leaving a trace at each step. +- Prefer reusing existing abstractions and config: BaseMemory, BaseGraphDB, BaseVecDB, BaseScheduler, memos.configs.*, memos.dependency. + +Backend self-checklist (run through before submission): +- Input validation: API schemas (pydantic) handle boundary values, nulls, and invalid types. +- Error handling: raise semantic exceptions from memos.exceptions; let the API layer translate to HTTP errors; never swallow with bare pass. +- Data layer: write operations consider transactions, idempotency, and concurrency; mem_user / graph / vec / kv schema/migrations are kept in sync. +- Compatibility: do not break the contract of top-level memos.* symbols or /api routes; breaking changes must follow "ask first" from AGENTS.md. +- Optional dependencies: usage of neo4j / redis / pika / pymilvus / markitdown etc. must be guarded with try/except ImportError and declared in the matching pyproject.toml extras. +- Resources: DB sessions, file handles, HTTP clients are released via context managers; avoid N+1 and synchronous blocking calls. +- Logging: use logging.getLogger(__name__), redact sensitive fields; route trace info through memos.context.context. +- Formatting: always run make format before submission. + +Output requirements (paste the real output of the real commands): +- poetry run pytest tests// -q +- make test for full runs when needed +- make format (or make pre_commit) +- A list of changed files mapped to the originating requirement. + +Do not: +- Touch apps/, docker/, scripts/, pyproject.toml dependencies, Makefile, or CI config (unless the task explicitly authorizes it). +- Review your own code (code-reviewer's job). +- Claim completion without test output. +- Skip pre-commit or commit with --no-verify. +""" diff --git a/.codex/agents/code-reviewer.toml b/.codex/agents/code-reviewer.toml new file mode 100644 index 000000000..8a713b4e9 --- /dev/null +++ b/.codex/agents/code-reviewer.toml @@ -0,0 +1,29 @@ +name = "code-reviewer" +description = "Code-review sub-agent. Reviews MemOS diffs for contract consistency, Ruff / typing / optional-dependency handling, and test evidence; returns APPROVE or CHANGES_REQUESTED." +sandbox_mode = "read-only" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: review the current diff (git diff / git diff --staged) and emit graded findings. + +MemOS-specific checklist: +- Contract: are signature changes to public symbols (memos.api.*, top-level memos.*) backward compatible; if breaking, did it follow AGENTS.md "ask first". +- Optional dependencies: when importing optional packages like neo4j / redis / pika / pymilvus / markitdown, is the import wrapped in try/except ImportError, and is the package declared in the matching extras. +- Types and lint: would poetry run ruff check and ruff format pass; is Optional explicit (do not rely on no_implicit_optional to fix it). +- Exceptions: are semantic exceptions from memos.exceptions raised, not bare Exception / RuntimeError. +- Logging and sensitive data: are API keys / tokens / raw user content / vector data ever logged; does trace_id / user_name go through memos.context.context instead of print. +- Test evidence: are new/updated tests//test_*.py present; is real pytest output included. +- Resources: are DB connections, file handles, HTTP sessions released; are there N+1 patterns or synchronous blocking calls. + +Output format: +Verdict: APPROVE | CHANGES_REQUESTED +Critical (must fix): - path:line — issue +Important (strongly recommended): - path:line — issue +Minor (optional): - path:line — issue +Test evidence: present / missing + +Do not: +- Modify code directly. +- Substitute for a human final approver. +- Grant APPROVE when pytest output is missing. +""" diff --git a/.codex/agents/design-reviewer.toml b/.codex/agents/design-reviewer.toml new file mode 100644 index 000000000..49c9b7be7 --- /dev/null +++ b/.codex/agents/design-reviewer.toml @@ -0,0 +1,27 @@ +name = "design-reviewer" +description = "Design-review sub-agent. Reviews design docs across the four dimensions of architecture, interface, performance, and security, covering MemOS's multi-memory / multi-storage backend constraints." +sandbox_mode = "read-only" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Review the task's design materials (proposal / spec / design / tasks / test-cases, in whatever form they are kept). +- Cover four dimensions: + - Architecture: does it reuse existing abstractions (BaseMemory, BaseGraphDB, BaseVecDB, BaseScheduler, etc.), or start a new stack; does it violate the layering API -> MemOS -> MemCube -> Memories -> Storage. + - Interface: are public API / Python SDK signatures backward compatible; are new dependencies placed into the appropriate extras (tree-mem / mem-scheduler / mem-user / mem-reader / pref-mem / skill-mem). + - Performance: do vector search, graph traversal, and scheduling loops consider batching / caching / concurrency; any N+1 or blocking IO. + - Security: is user isolation (mem_user) handled; do we avoid writing into .env / credentials / private paths. +- Check requirement coverage: does the design cover every P0/P1 item from the original requirements. +- Call out blockers (must fix) vs. suggestions (optional). + +Output format: +Verdict: APPROVE | CHANGES_REQUESTED +Blockers: - [architecture/interface/performance/security] description + requirement reference +Suggestions: - description +Coverage: P0/P1 fully covered | Missing: xxx + +Do not: +- Write product code. +- Review the code implementation (that is code-reviewer's job). +- Substitute for a human final approver. +""" diff --git a/.codex/agents/explorer.toml b/.codex/agents/explorer.toml new file mode 100644 index 000000000..b8a94a3b1 --- /dev/null +++ b/.codex/agents/explorer.toml @@ -0,0 +1,30 @@ +name = "explorer" +description = "Read-only code exploration sub-agent. Locates MemOS code, traces call chains, gathers evidence, and returns a compressed conclusion — never proposes or applies changes." +sandbox_mode = "read-only" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Locate relevant modules, symbols, and call chains under src/memos/ for the question the main agent asks. +- Distinguish core packages (mem_os / mem_cube / mem_scheduler) from optional backends (graph_dbs/neo4j*, vec_dbs/milvus*, etc.) and call out any extras dependencies. +- Trace execution paths and gather evidence (with path:line annotations + a one-line key snippet). +- Return a compressed conclusion only; do not echo raw bulk output. + +Output format: +- Conclusion first: one sentence that answers the main agent's question. +- Evidence list: src/memos//.py:LINE + a one-line note. +- Call chain (if applicable): A.f -> B.g -> C.h, annotating each hop with its file location. +- Uncertainty: explicitly flag "not found / needs further confirmation"; do not invent. + +MemOS-specific locator hints: +- API routes: src/memos/api/ + tests/api/ +- Memory types: src/memos/memories/ (textual / tree / preference / skill etc.) +- Storage backends: src/memos/graph_dbs/, src/memos/vec_dbs/ +- Config and DI: src/memos/configs/, src/memos/dependency.py +- Plugin entry points: pyproject.toml [project.entry-points."memos.plugins"] + extensions/ + +Do not: +- Modify any file (read-only). +- Propose an implementation plan — return facts and locations only. +- Substitute for the judgment of design-reviewer / code-reviewer. +""" diff --git a/.codex/agents/integration-tester.toml b/.codex/agents/integration-tester.toml new file mode 100644 index 000000000..5baa4621c --- /dev/null +++ b/.codex/agents/integration-tester.toml @@ -0,0 +1,30 @@ +name = "integration-tester" +description = "MemOS integration-testing sub-agent. Authors and executes pytest cases under tests/ based on the task's requirements and design, and emits real test reports." +sandbox_mode = "workspace-write" +developer_instructions = """ +Project facts: see AGENTS.md. + +Responsibilities: +- Based on the task's requirements and design docs, write pytest cases under tests//. +- Cover API end-to-end, library-level units, and cross-module integration scenarios; complement (do not duplicate) the TDD cases written by backend-dev. +- Run the tests and produce a real report. + +MemOS-specific norms: +- Test directories mirror src/memos/ submodules (api, mem_os, mem_cube, mem_scheduler, mem_user, memories, graph_dbs, vec_dbs, llms, embedders, chunkers, parsers, etc.). +- Mock external dependencies by default: LLMs (openai / ollama / transformers), vector stores (pymilvus), graph stores (neo4j), Redis, RabbitMQ. +- Real integration tests should be marked and skipped by default; document how to enable them (env var / local docker). +- Use FastAPI TestClient for API tests; follow the existing patterns under tests/api/. +- Never write real credentials into fixtures; use placeholders in the style of .env.example. + +Output format: +Test file: tests//test_.py +Coverage map: Requirement 1.1 -> test_xxx +Command: poetry run pytest tests//test_.py -q +Output: +Result: N passed, M failed + +Do not: +- Modify product code under src/memos/ (backend-dev's job). +- Substitute for code-reviewer. +- Claim completion without real pytest output. +""" diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..cd885b3c4 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,155 @@ +# AGENTS.md + +> Single source of truth for the project across AI runtimes. Claude Code, Codex, Cursor, Copilot, etc. all defer to this file. +> Runtime-specific adaptation belongs in each runtime's own file (Claude reads `CLAUDE.md`); do not mix it in here. + +## Project Overview + +**MemOS / MemoryOS**: a memory operating system for LLM agents. Python library plus a FastAPI service, providing multiple memory types (textual / tree / preference / skill / KV cache / LoRA parametric) plus scheduling, version management, and vector & graph storage. + +- **Repository**: https://github.com/MemTensor/MemOS +- **Documentation**: https://memos-docs.openmem.net/home/overview/ +- **PyPI**: https://pypi.org/project/MemoryOS/ +- **License**: Apache-2.0 +- **Top-level package**: `src/memos/`. Distribution name `MemoryOS`; import name `memos`. +- **CLI**: `memos` (entry `memos.cli:main`) +- **API service**: `memos.api.start_api:app` + +## Repository Layout + +| Path | Purpose | +|------|---------| +| `src/memos/mem_os/` | `MOS` / `MOSCore` — top-level Memory OS entry | +| `src/memos/mem_cube/` | `GeneralMemCube` — memory container aggregating multiple memory types | +| `src/memos/memories/` | Memory implementations: `textual/`, `activation/`, `parametric/` | +| `src/memos/mem_scheduler/` | Memory scheduler + monitors + ORM + task scheduling | +| `src/memos/mem_user/` | User / multi-tenant management (MySQL / Redis backends) | +| `src/memos/mem_chat/` `mem_reader/` `mem_agent/` `mem_feedback/` `multi_mem_cube/` | Chat sessions, ingest pipeline, agent integration, feedback channel, multi-cube routing | +| `src/memos/llms/` `embedders/` `vec_dbs/` `graph_dbs/` `chunkers/` `parsers/` `reranker/` | Provider implementations (`base.py` + `factory.py` + each backend) | +| `src/memos/api/` | FastAPI service (routers / handlers / middleware / MCP server) | +| `src/memos/configs/` | All pydantic configuration classes (one-to-one with the modules above) | +| `src/memos/context/` | Cross-thread context (trace_id / user / env) | +| `tests/` | pytest cases, subdirectories mirror `src/memos/` | +| `apps/` | Independent sub-projects, each with its own README; not part of the main Harness flow | +| `extensions/` | Official plugin examples | +| `docker/` `docs/` `evaluation/` `scripts/` | Deployment, documentation, evaluation, helper scripts | +| `.claude/agents/`, `.codex/agents/` | Project-recommended AI sub-agent definitions | + +## Command Cheatsheet + +- Install: `make install` (= `poetry install --extras all --with dev --with test` + pre-commit + push hook) +- Start API: `make serve` +- Export OpenAPI: `make openapi` (writes to `docs/openapi.json`) +- Run full tests: `make test` +- Run a single test: `poetry run pytest tests//test_xxx.py -q` +- Lint + format: `make format` +- Full pre-commit: `make pre_commit` +- Build: `poetry build` (publishing is automated by `python-release.yml` on GitHub release) + +## Core API + +### Python top-level entries (`from memos import ...`) + +| Symbol | Purpose | Source | +|--------|---------|--------| +| `MOS` | Memory OS top-level entry (inherits `MOSCore`) | `memos.mem_os.main` | +| `GeneralMemCube` | General memory container | `memos.mem_cube.general` | +| `MOSConfig` / `GeneralMemCubeConfig` | Primary configs | `memos.configs.mem_os` / `memos.configs.mem_cube` | +| `GeneralScheduler` / `SchedulerFactory` / `SchedulerConfigFactory` | Scheduler and factories | `memos.mem_scheduler.*` | + +Common `MOS` methods: `MOS.simple()` (auto-configure from env), `register_mem_cube(cube)`, `add(...)`, `search(...)`, `chat(...)`, `create_user(...)` / `list_users()`. + +### API entry + +- ASGI app: `memos.api.start_api:app` +- Routers: `src/memos/api/routers/` (`admin_router`, `product_router`, `server_router`) +- OpenAPI contract: `docs/openapi.json` (must run `make openapi` after touching the API) + +## Import Patterns + +| Use | Import | +|-----|--------| +| Top-level entries | `from memos import MOS, GeneralMemCube, MOSConfig` | +| Config classes | `from memos.configs. import ` | +| Any provider factory | `from memos..factory import Factory` | +| Logger | `from memos.log import get_logger`; `logger = get_logger(__name__)` | +| Context (trace) | `from memos.context.context import get_current_trace_id, get_current_user_name` | +| Exceptions | `from memos.exceptions import ` | + +## Provider Matrix + +Every provider follows the same three-piece pattern: `base.py` abstract class + `factory.py` registry + `configs/.py` config. The authoritative list of registered backends is the factory's `backend_to_class`; the snapshot below is provided for quick reference: + +| Category | Base class | Factory | Registered backends | +|----------|-----------|---------|---------------------| +| LLM | `BaseLLM` | `LLMFactory` | `openai` / `openai_new` / `azure` / `ollama` / `huggingface` / `huggingface_singleton` / `vllm` / `qwen` / `deepseek` | +| Embedder | `BaseEmbedder` | `EmbedderFactory` | `ollama` / `sentence_transformer` / `ark` / `universal_api` | +| Vector DB | `BaseVecDB` | `VecDBFactory` | `qdrant` / `milvus` | +| Graph DB | `BaseGraphDB` | `GraphStoreFactory` | `neo4j` / `neo4j_community` / `nebular` / `polardb` / `postgres` | +| Chunker | `BaseChunker` | `ChunkerFactory` | `sentence` / `markdown` / `simple` / `charactertext` | +| Parser | `BaseParser` | `ParserFactory` | `markitdown` | +| Reranker | `BaseReranker` | `RerankerFactory` | `cosine_local` / `http_bge` / `http_bge_strategy` / `concat` / `noop` | +| Memory | `BaseMemory` (+ `BaseTextMemory` / `BaseActMemory` / `BaseParaMemory`) | `MemoryFactory` | `naive_text` / `general_text` / `tree_text` / `simple_tree_text` / `pref_text` / `simple_pref_text` / `kv_cache` / `vllm_kv_cache` / `lora` | +| Scheduler | `BaseScheduler` | `SchedulerFactory` | `general` / `optimized` | + +## Adding a New Provider + +Mirror any existing provider in the same category: + +1. Implement `src/memos//.py`, inheriting the `base.py` abstract class and matching the signatures of existing providers. +2. Add a pydantic config in `src/memos/configs/.py` and register it in `ConfigFactory.backend_to_class`. +3. Register the implementation in `Factory.backend_to_class` in `src/memos//factory.py`. +4. Third-party dependencies **must** go into an optional extras group in `pyproject.toml` (`tree-mem` / `mem-scheduler` / `mem-user` / `mem-reader` / `pref-mem` / `skill-mem`) and be added to `all`; guard the import with try/except ImportError and raise a clear "install extras X" message on failure. +5. Add tests under `tests//test_.py`; external HTTP / model loading must be mocked. + +## Behavior Boundaries + +### Always do + +- Write a failing test first (TDD), placed under `tests//test_*.py`. +- Before claiming a task is done, run verification commands and paste the real output (at minimum `make format` plus the relevant pytest run). +- Keep changes within the directories the current task authorizes; cross-module edits need to be called out and approved first. +- Use `memos.log.get_logger(__name__)` for logging; route trace info through `memos.context.context` — do not `print`. +- Optional third-party dependencies (neo4j / redis / pika / pymilvus / markitdown, etc.) must be guarded with try/except ImportError and declared in the matching extras group. +- After touching `src/memos/api/`, run `make openapi` to refresh `docs/openapi.json`. + +### Ask first + +- Modifying `pyproject.toml` dependencies or the Python version constraint. +- Touching public routes, request/response models, or the OpenAPI contract under `src/memos/api/`. +- Changing DB schema, migrations, `mem_user` tables, or `graph_dbs` graph models. +- Deleting files or doing wide-scope renames of public APIs (`memos.*` top-level symbols). +- Editing `Makefile`, `.pre-commit-config.yaml`, `pyproject.toml [tool.*]`, or `.github/workflows/`. + +### Never do (IMPORTANT) + +- **Never** commit `.env`, `private/`, `.private-paths`, `tmp/`, `*.log`, secrets, tokens, or model credentials. +- Do not log or include real API keys, raw user data, or vector contents in tests/fixtures. +- Do not skip `pre-commit` or push with `--no-verify` (the `scripts/check-public-push.sh` pre-push hook is enforced). +- Do not claim tests pass without real pytest output as evidence. +- Do not add third-party dependencies to core `dependencies` — they must go into optional extras. +- Do not run wide-scope `rm -rf` outside `src/`; do not `git push --force` or `git reset --hard origin/*`. + +## Code Style + +- Format and lint with Ruff (configured in `pyproject.toml [tool.ruff]`); `make format` must pass before commit. +- Type annotations are required on public functions, API schemas, and config classes; implicit `Optional` is not allowed (enforced via pre-commit). +- All configs and API schemas use Pydantic v2. +- Logging: `logger.info("... %s", x)` form — do not pre-format with f-strings before passing to the logger. +- Exceptions: library code raises semantic exceptions from `memos.exceptions`, never bare `Exception` / `RuntimeError`; the API layer translates them to HTTP errors in `memos.api.exceptions`. +- File naming: source `snake_case.py`, tests `test_.py`. + +## Change → Test Mapping + +- Edit `src/memos//`: at minimum run `pytest tests// -q`; run `make test` once more before merging. +- Edit `src/memos/api/`: run `tests/api/` and `make openapi` to confirm the OpenAPI spec did not change unexpectedly. +- Edit `pyproject.toml` dependencies: `poetry lock --no-update`, then `make test`. +- Edit `Makefile` / pre-commit / Ruff config: run `make pre_commit` locally over the whole tree. + +## Git Conventions + +- Commits: Conventional Commits (`feat:` / `fix:` / `chore:` / `refactor:` / `docs:`), subject line ≤ 72 chars. +- Branches: `feat/` / `fix/` / `dev-YYYYMMDD-v`. +- `main` is protected — all changes go through PRs; never force-push to `main`; do not skip git hooks. +- Do not commit paths listed in `.private-paths`. +- The PR template lives at `.github/PULL_REQUEST_TEMPLATE.md` — its checklist must be fully ticked. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..c2402f7c7 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,23 @@ +# CLAUDE.md + +## Claude Code Entry + +Project facts live in `AGENTS.md`. This file only covers Claude Code runtime adaptation. + +## Sub-agents + +Five project-recommended sub-agents live under `.claude/agents/*.md`. Claude Code loads them automatically; the main agent should dispatch by task boundary: + +| Agent | Permissions | When to use | +|-------|-------------|-------------| +| `explorer` | Read-only | Locate code, trace call chains, gather evidence | +| `design-reviewer` | Read-only | Review design docs (architecture / interface / performance / security / requirement coverage) | +| `code-reviewer` | Read-only | Review diffs and return APPROVE or CHANGES_REQUESTED | +| `backend-dev` | Read-write | Implement backend / library code under `src/memos/` (TDD) | +| `integration-tester` | Read-write | Author and run integration / end-to-end cases under `tests/` | + +The main repo has no frontend stack, so no `frontend-dev` is provided; TypeScript sub-projects under `apps/` use their own AI configuration. + +## Project knowledge + +Before starting a task, run `ls docs/`. `docs/openapi.json` is the source of truth for the API contract; after touching `src/memos/api/`, run `make openapi` to regenerate it.