diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index b3ad8f085..080a76389 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -2,6 +2,7 @@ import contextlib import json import os +import time import traceback from memos.configs.mem_scheduler import GeneralSchedulerConfig @@ -337,9 +338,20 @@ def log_add_messages(self, msg: ScheduleMessageItem): for memory_id in userinput_memory_ids: try: # This mem_item represents the NEW content that was just added/processed - mem_item: TextualMemoryItem = self.current_mem_cube.text_mem.get( - memory_id=memory_id - ) + mem_item: TextualMemoryItem | None = None + for attempt in range(3): + try: + mem_item = self.current_mem_cube.text_mem.get( + memory_id=memory_id, user_name=msg.mem_cube_id + ) + break + except Exception: + if attempt < 2: + time.sleep(0.5) + else: + raise + if mem_item is None: + raise ValueError(f"Memory {memory_id} not found after retries") # Check if a memory with the same key already exists (determining if it's an update) key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( name=mem_item.memory @@ -366,7 +378,7 @@ def log_add_messages(self, msg: ScheduleMessageItem): # Crucial step: Fetch the original content for updates # This `get` is for the *existing* memory that will be updated original_mem_item = self.current_mem_cube.text_mem.get( - memory_id=original_item_id + memory_id=original_item_id, user_name=msg.mem_cube_id ) original_content = original_mem_item.memory @@ -825,7 +837,7 @@ def _process_memories_with_reader( memory_items = [] for mem_id in mem_ids: try: - memory_item = text_mem.get(mem_id) + memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: logger.warning(f"Failed to get memory {mem_id}: {e}") @@ -1077,7 +1089,7 @@ def process_message(message: ScheduleMessageItem): mem_items: list[TextualMemoryItem] = [] for mid in mem_ids: with contextlib.suppress(Exception): - mem_items.append(text_mem.get(mid)) + mem_items.append(text_mem.get(mid, user_name=user_name)) if len(mem_items) > 1: keys: list[str] = [] memcube_content: list[dict] = [] @@ -1133,7 +1145,7 @@ def process_message(message: ScheduleMessageItem): if merged_target_ids: post_ref_id = next(iter(merged_target_ids)) with contextlib.suppress(Exception): - merged_item = text_mem.get(post_ref_id) + merged_item = text_mem.get(post_ref_id, user_name=user_name) combined_key = ( getattr(getattr(merged_item, "metadata", {}), "key", None) or combined_key @@ -1242,7 +1254,7 @@ def _process_memories_with_reorganize( memory_items = [] for mem_id in mem_ids: try: - memory_item = text_mem.get(mem_id) + memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}") @@ -1357,7 +1369,9 @@ def process_session_turn( f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" ) - cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( + user_name=mem_cube_id + ) text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] intent_result = self.monitor.detect_intent( q_list=queries, text_working_memory=text_working_memory diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 8a6113345..6b0b7e8a6 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -50,7 +50,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem """ @abstractmethod - def get(self, memory_id: str) -> TextualMemoryItem: + def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem: """Get a memory by its ID. Args: memory_id (str): The ID of the memory to retrieve. diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index f56b2028d..b90f2a6ab 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -136,7 +136,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem ] return result_memories - def get(self, memory_id: str) -> TextualMemoryItem: + def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem: """Get a memory by its ID.""" result = self.vector_db.get_by_id(memory_id) if result is None: diff --git a/src/memos/memories/textual/naive.py b/src/memos/memories/textual/naive.py index 7bc49e767..14c86b036 100644 --- a/src/memos/memories/textual/naive.py +++ b/src/memos/memories/textual/naive.py @@ -127,7 +127,7 @@ def search(self, query: str, top_k: int, **kwargs) -> list[TextualMemoryItem]: # Convert search results to TextualMemoryItem objects return [TextualMemoryItem(**memory) for memory, _ in sims[:top_k]] - def get(self, memory_id: str) -> TextualMemoryItem: + def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem: """Get a memory by its ID.""" for memory in self.memories: if memory["id"] == memory_id: diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index c0ed1217d..e1bc0e72b 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -168,7 +168,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) """Update a memory by memory_id.""" raise NotImplementedError - def get(self, memory_id: str) -> TextualMemoryItem: + def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem: """Get a memory by its ID. Args: memory_id (str): The ID of the memory to retrieve. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 25e6276d9..1d0c344b4 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -296,9 +296,9 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: raise NotImplementedError - def get(self, memory_id: str) -> TextualMemoryItem: + def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem: """Get a memory by its ID.""" - result = self.graph_store.get_node(memory_id) + result = self.graph_store.get_node(memory_id, user_name=user_name) if result is None: raise ValueError(f"Memory with ID {memory_id} not found") metadata_dict = result.get("metadata", {})