From 6019904688eb38aeb22a2fa87c1338b75a584a6c Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Tue, 9 Dec 2025 19:06:48 +0800 Subject: [PATCH 1/3] fix: optimize scheduler memory fetching performance and fault tolerance --- src/memos/graph_dbs/neo4j.py | 9 +- src/memos/mem_scheduler/general_scheduler.py | 131 +++++++++++-------- src/memos/memories/textual/tree.py | 24 ++++ 3 files changed, 109 insertions(+), 55 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index a0a4c6a50..67ca901f7 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -544,7 +544,14 @@ def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]: with self.driver.session(database=self.db_name) as session: results = session.run(query, params) - return [self._parse_node(dict(record["n"])) for record in results] + nodes = [] + for record in results: + try: + nodes.append(self._parse_node(dict(record["n"]))) + except Exception as e: + logger.warning(f"Failed to parse node in get_nodes: {e}") + continue + return nodes def get_edges( self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 8f3eccecf..552bee5d2 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -334,60 +334,86 @@ def log_add_messages(self, msg: ScheduleMessageItem): prepared_update_items_with_original = [] missing_ids: list[str] = [] - for memory_id in userinput_memory_ids: + if not userinput_memory_ids: + return prepared_add_items, prepared_update_items_with_original + + # Batch fetch new items + try: + new_mem_items = self.mem_cube.text_mem.get_batch( + memory_ids=userinput_memory_ids, user_name=msg.mem_cube_id + ) + except Exception as e: + logger.error(f"Failed to batch get memories: {e}", exc_info=True) + new_mem_items = [] + + # Create a map for quick lookup and identify missing IDs + new_items_map = {item.id: item for item in new_mem_items} + for mid in userinput_memory_ids: + if mid not in new_items_map: + missing_ids.append(mid) + + # Collect keys to check existence + keys_to_check = [] + # Store items that have keys for quick access + items_with_keys = [] + + for item in new_mem_items: + key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) + if key: + keys_to_check.append(key) + items_with_keys.append((key, item)) + + existing_candidates_map = {} # Map (key, memory_type) -> original_item + + # Batch check existence if there are keys to check + if keys_to_check and hasattr(self.mem_cube.text_mem, "graph_store"): try: - # This mem_item represents the NEW content that was just added/processed - mem_item: TextualMemoryItem | None = None - mem_item = self.current_mem_cube.text_mem.get( - memory_id=memory_id, user_name=msg.mem_cube_id + # Use "in" operator to batch query candidate IDs by key + candidate_ids = self.mem_cube.text_mem.graph_store.get_by_metadata( + [ + {"field": "key", "op": "in", "value": list(set(keys_to_check))}, + ], + user_name=msg.mem_cube_id, ) - if mem_item is None: - raise ValueError(f"Memory {memory_id} not found after retries") - # Check if a memory with the same key already exists (determining if it's an update) - key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( - name=mem_item.memory - ) - exists = False - original_content = None - original_item_id = None - - # Only check graph_store if a key exists and the text_mem has a graph_store - if key and hasattr(self.current_mem_cube.text_mem, "graph_store"): - candidates = self.current_mem_cube.text_mem.graph_store.get_by_metadata( - [ - {"field": "key", "op": "=", "value": key}, - { - "field": "memory_type", - "op": "=", - "value": mem_item.metadata.memory_type, - }, - ] + + if candidate_ids: + # Batch fetch candidate memory details + candidate_items = self.mem_cube.text_mem.get_batch( + memory_ids=candidate_ids, user_name=msg.mem_cube_id ) - if candidates: - exists = True - original_item_id = candidates[0] - # Crucial step: Fetch the original content for updates - # This `get` is for the *existing* memory that will be updated - original_mem_item = self.current_mem_cube.text_mem.get( - memory_id=original_item_id, user_name=msg.mem_cube_id - ) - original_content = original_mem_item.memory - if exists: + # Map candidates by (key, memory_type) + for cand in candidate_items: + cand_key = getattr(cand.metadata, "key", None) + cand_type = getattr(cand.metadata, "memory_type", None) + if cand_key: + existing_candidates_map[(cand_key, cand_type)] = cand + except Exception as e: + logger.error(f"Failed to batch check existing keys: {e}", exc_info=True) + + # Process results + for item in new_mem_items: + try: + key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory) + mem_type = getattr(item.metadata, "memory_type", None) + + original_item = existing_candidates_map.get((key, mem_type)) + + if original_item: prepared_update_items_with_original.append( { - "new_item": mem_item, - "original_content": original_content, - "original_item_id": original_item_id, + "new_item": item, + "original_content": original_item.memory, + "original_item_id": original_item.id, } ) else: - prepared_add_items.append(mem_item) + prepared_add_items.append(item) except Exception: - missing_ids.append(memory_id) + missing_ids.append(item.id) logger.debug( - f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation." + f"Error processing item {item.id} during preparation.", exc_info=True ) if missing_ids: @@ -832,14 +858,11 @@ def _process_memories_with_reader( return # Get the original memory items - memory_items = [] - for mem_id in mem_ids: - try: - memory_item = text_mem.get(mem_id, user_name=user_name) - memory_items.append(memory_item) - except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") - continue + try: + memory_items = text_mem.get_batch(mem_ids, user_name=user_name) + except Exception as e: + logger.warning(f"Failed to batch get memories: {e}") + memory_items = [] if not memory_items: logger.warning("No valid memory items found for processing") @@ -1089,10 +1112,10 @@ def process_message(message: ScheduleMessageItem): ) with contextlib.suppress(Exception): - mem_items: list[TextualMemoryItem] = [] - for mid in mem_ids: - with contextlib.suppress(Exception): - mem_items.append(text_mem.get(mid, user_name=user_name)) + try: + mem_items = text_mem.get_batch(mem_ids, user_name=user_name) + except Exception: + mem_items = [] if len(mem_items) > 1: keys: list[str] = [] memcube_content: list[dict] = [] diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 75eae30e8..92efb4e72 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -311,6 +311,30 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), ) + def get_batch( + self, memory_ids: list[str], user_name: str | None = None + ) -> list[TextualMemoryItem]: + """Batch get memories by IDs.""" + results = self.graph_store.get_nodes(memory_ids, user_name=user_name) + + items = [] + for result in results: + if result: + try: + metadata_dict = result.get("metadata", {}) + items.append( + TextualMemoryItem( + id=result["id"], + memory=result["memory"], + metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), + ) + ) + except Exception as e: + logger.warning( + f"Failed to create TextualMemoryItem for id {result.get('id')}: {e}" + ) + return items + def get_by_ids( self, memory_ids: list[str], user_name: str | None = None ) -> list[TextualMemoryItem]: From f3f953f37a7199a977c934ab401c701d40694d3f Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Tue, 9 Dec 2025 19:31:08 +0800 Subject: [PATCH 2/3] fix: enhance batch memory retrieval robustness and fix update logic --- src/memos/mem_scheduler/general_scheduler.py | 79 ++++++++++++++++---- src/memos/memories/textual/base.py | 12 +++ 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 552bee5d2..7ffc404d8 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -338,13 +338,23 @@ def log_add_messages(self, msg: ScheduleMessageItem): return prepared_add_items, prepared_update_items_with_original # Batch fetch new items + new_mem_items = [] try: new_mem_items = self.mem_cube.text_mem.get_batch( memory_ids=userinput_memory_ids, user_name=msg.mem_cube_id ) except Exception as e: - logger.error(f"Failed to batch get memories: {e}", exc_info=True) - new_mem_items = [] + logger.warning( + f"Failed to batch get memories in log_add_messages: {e}. Fallback to iterative fetching." + ) + # Fallback to iterative fetching + for mid in userinput_memory_ids: + try: + item = self.mem_cube.text_mem.get(memory_id=mid, user_name=msg.mem_cube_id) + if item: + new_mem_items.append(item) + except Exception as inner_e: + logger.warning(f"Failed to get memory {mid}: {inner_e}") # Create a map for quick lookup and identify missing IDs new_items_map = {item.id: item for item in new_mem_items} @@ -377,17 +387,36 @@ def log_add_messages(self, msg: ScheduleMessageItem): ) if candidate_ids: - # Batch fetch candidate memory details - candidate_items = self.mem_cube.text_mem.get_batch( - memory_ids=candidate_ids, user_name=msg.mem_cube_id - ) - - # Map candidates by (key, memory_type) - for cand in candidate_items: - cand_key = getattr(cand.metadata, "key", None) - cand_type = getattr(cand.metadata, "memory_type", None) - if cand_key: - existing_candidates_map[(cand_key, cand_type)] = cand + # Filter out current items from candidates to avoid self-match + filtered_candidate_ids = [ + cid for cid in candidate_ids if cid not in new_items_map + ] + + if filtered_candidate_ids: + # Batch fetch candidate memory details + try: + candidate_items = self.mem_cube.text_mem.get_batch( + memory_ids=filtered_candidate_ids, user_name=msg.mem_cube_id + ) + except Exception: + # Fallback if batch fetch fails for candidates + candidate_items = [] + for cid in filtered_candidate_ids: + try: + c_item = self.mem_cube.text_mem.get( + memory_id=cid, user_name=msg.mem_cube_id + ) + if c_item: + candidate_items.append(c_item) + except Exception: + pass + + # Map candidates by (key, memory_type) + for cand in candidate_items: + cand_key = getattr(cand.metadata, "key", None) + cand_type = getattr(cand.metadata, "memory_type", None) + if cand_key: + existing_candidates_map[(cand_key, cand_type)] = cand except Exception as e: logger.error(f"Failed to batch check existing keys: {e}", exc_info=True) @@ -399,7 +428,8 @@ def log_add_messages(self, msg: ScheduleMessageItem): original_item = existing_candidates_map.get((key, mem_type)) - if original_item: + # Ensure we are not comparing the item with itself + if original_item and original_item.id != item.id: prepared_update_items_with_original.append( { "new_item": item, @@ -858,11 +888,20 @@ def _process_memories_with_reader( return # Get the original memory items + memory_items = [] try: memory_items = text_mem.get_batch(mem_ids, user_name=user_name) except Exception as e: - logger.warning(f"Failed to batch get memories: {e}") - memory_items = [] + logger.warning( + f"Failed to batch get memories in _process_memories_with_reader: {e}. Fallback to iterative fetching." + ) + for mid in mem_ids: + try: + item = text_mem.get(mid, user_name=user_name) + if item: + memory_items.append(item) + except Exception: + pass if not memory_items: logger.warning("No valid memory items found for processing") @@ -1116,6 +1155,14 @@ def process_message(message: ScheduleMessageItem): mem_items = text_mem.get_batch(mem_ids, user_name=user_name) except Exception: mem_items = [] + # Fallback to iterative fetching + for mid in mem_ids: + try: + item = text_mem.get(mid, user_name=user_name) + if item: + mem_items.append(item) + except Exception: + pass if len(mem_items) > 1: keys: list[str] = [] memcube_content: list[dict] = [] diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 6b0b7e8a6..51233886d 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -58,6 +58,18 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem TextualMemoryItem: The memory with the given ID. """ + def get_batch( + self, memory_ids: list[str], user_name: str | None = None + ) -> list[TextualMemoryItem]: + """Batch get memories by IDs. + Args: + memory_ids (list[str]): List of memory IDs to retrieve. + user_name (str | None): Optional user name for multi-tenant retrieval. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs. + """ + return [self.get(mid, user_name=user_name) for mid in memory_ids] + @abstractmethod def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """Get memories by their IDs. From b268b3590f09e7f14eea642bc102cb279de010ec Mon Sep 17 00:00:00 2001 From: "glin1993@outlook.com" <> Date: Tue, 9 Dec 2025 19:31:08 +0800 Subject: [PATCH 3/3] fix: enhance batch memory retrieval robustness and fix update logic --- src/memos/mem_scheduler/general_scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 7ffc404d8..a790f2fb7 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -442,9 +442,7 @@ def log_add_messages(self, msg: ScheduleMessageItem): except Exception: missing_ids.append(item.id) - logger.debug( - f"Error processing item {item.id} during preparation.", exc_info=True - ) + logger.debug(f"Error processing item {item.id} during preparation.", exc_info=True) if missing_ids: content_preview = (