Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:

with self.driver.session(database=self.db_name) as session:
results = session.run(query, params)
return [self._parse_node(dict(record["n"])) for record in results]
nodes = []
for record in results:
try:
nodes.append(self._parse_node(dict(record["n"])))
except Exception as e:
logger.warning(f"Failed to parse node in get_nodes: {e}")
continue
return nodes

def get_edges(
self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
Expand Down
180 changes: 124 additions & 56 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,61 +334,115 @@ def log_add_messages(self, msg: ScheduleMessageItem):
prepared_update_items_with_original = []
missing_ids: list[str] = []

for memory_id in userinput_memory_ids:
if not userinput_memory_ids:
return prepared_add_items, prepared_update_items_with_original

# Batch fetch new items
new_mem_items = []
try:
new_mem_items = self.mem_cube.text_mem.get_batch(
memory_ids=userinput_memory_ids, user_name=msg.mem_cube_id
)
except Exception as e:
logger.warning(
f"Failed to batch get memories in log_add_messages: {e}. Fallback to iterative fetching."
)
# Fallback to iterative fetching
for mid in userinput_memory_ids:
try:
item = self.mem_cube.text_mem.get(memory_id=mid, user_name=msg.mem_cube_id)
if item:
new_mem_items.append(item)
except Exception as inner_e:
logger.warning(f"Failed to get memory {mid}: {inner_e}")

# Create a map for quick lookup and identify missing IDs
new_items_map = {item.id: item for item in new_mem_items}
for mid in userinput_memory_ids:
if mid not in new_items_map:
missing_ids.append(mid)

# Collect keys to check existence
keys_to_check = []
# Store items that have keys for quick access
items_with_keys = []

for item in new_mem_items:
key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory)
if key:
keys_to_check.append(key)
items_with_keys.append((key, item))

existing_candidates_map = {} # Map (key, memory_type) -> original_item

# Batch check existence if there are keys to check
if keys_to_check and hasattr(self.mem_cube.text_mem, "graph_store"):
try:
# This mem_item represents the NEW content that was just added/processed
mem_item: TextualMemoryItem | None = None
mem_item = self.current_mem_cube.text_mem.get(
memory_id=memory_id, user_name=msg.mem_cube_id
# Use "in" operator to batch query candidate IDs by key
candidate_ids = self.mem_cube.text_mem.graph_store.get_by_metadata(
[
{"field": "key", "op": "in", "value": list(set(keys_to_check))},
],
user_name=msg.mem_cube_id,
)
if mem_item is None:
raise ValueError(f"Memory {memory_id} not found after retries")
# Check if a memory with the same key already exists (determining if it's an update)
key = getattr(mem_item.metadata, "key", None) or transform_name_to_key(
name=mem_item.memory
)
exists = False
original_content = None
original_item_id = None

# Only check graph_store if a key exists and the text_mem has a graph_store
if key and hasattr(self.current_mem_cube.text_mem, "graph_store"):
candidates = self.current_mem_cube.text_mem.graph_store.get_by_metadata(
[
{"field": "key", "op": "=", "value": key},
{
"field": "memory_type",
"op": "=",
"value": mem_item.metadata.memory_type,
},
]
)
if candidates:
exists = True
original_item_id = candidates[0]
# Crucial step: Fetch the original content for updates
# This `get` is for the *existing* memory that will be updated
original_mem_item = self.current_mem_cube.text_mem.get(
memory_id=original_item_id, user_name=msg.mem_cube_id
)
original_content = original_mem_item.memory

if exists:
if candidate_ids:
# Filter out current items from candidates to avoid self-match
filtered_candidate_ids = [
cid for cid in candidate_ids if cid not in new_items_map
]

if filtered_candidate_ids:
# Batch fetch candidate memory details
try:
candidate_items = self.mem_cube.text_mem.get_batch(
memory_ids=filtered_candidate_ids, user_name=msg.mem_cube_id
)
except Exception:
# Fallback if batch fetch fails for candidates
candidate_items = []
for cid in filtered_candidate_ids:
try:
c_item = self.mem_cube.text_mem.get(
memory_id=cid, user_name=msg.mem_cube_id
)
if c_item:
candidate_items.append(c_item)
except Exception:
pass

# Map candidates by (key, memory_type)
for cand in candidate_items:
cand_key = getattr(cand.metadata, "key", None)
cand_type = getattr(cand.metadata, "memory_type", None)
if cand_key:
existing_candidates_map[(cand_key, cand_type)] = cand
except Exception as e:
logger.error(f"Failed to batch check existing keys: {e}", exc_info=True)

# Process results
for item in new_mem_items:
try:
key = getattr(item.metadata, "key", None) or transform_name_to_key(name=item.memory)
mem_type = getattr(item.metadata, "memory_type", None)

original_item = existing_candidates_map.get((key, mem_type))

# Ensure we are not comparing the item with itself
if original_item and original_item.id != item.id:
prepared_update_items_with_original.append(
{
"new_item": mem_item,
"original_content": original_content,
"original_item_id": original_item_id,
"new_item": item,
"original_content": original_item.memory,
"original_item_id": original_item.id,
}
)
else:
prepared_add_items.append(mem_item)
prepared_add_items.append(item)

except Exception:
missing_ids.append(memory_id)
logger.debug(
f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation."
)
missing_ids.append(item.id)
logger.debug(f"Error processing item {item.id} during preparation.", exc_info=True)

if missing_ids:
content_preview = (
Expand Down Expand Up @@ -833,13 +887,19 @@ def _process_memories_with_reader(

# Get the original memory items
memory_items = []
for mem_id in mem_ids:
try:
memory_item = text_mem.get(mem_id, user_name=user_name)
memory_items.append(memory_item)
except Exception as e:
logger.warning(f"Failed to get memory {mem_id}: {e}")
continue
try:
memory_items = text_mem.get_batch(mem_ids, user_name=user_name)
except Exception as e:
logger.warning(
f"Failed to batch get memories in _process_memories_with_reader: {e}. Fallback to iterative fetching."
)
for mid in mem_ids:
try:
item = text_mem.get(mid, user_name=user_name)
if item:
memory_items.append(item)
except Exception:
pass

if not memory_items:
logger.warning("No valid memory items found for processing")
Expand Down Expand Up @@ -1089,10 +1149,18 @@ def process_message(message: ScheduleMessageItem):
)

with contextlib.suppress(Exception):
mem_items: list[TextualMemoryItem] = []
for mid in mem_ids:
with contextlib.suppress(Exception):
mem_items.append(text_mem.get(mid, user_name=user_name))
try:
mem_items = text_mem.get_batch(mem_ids, user_name=user_name)
except Exception:
mem_items = []
# Fallback to iterative fetching
for mid in mem_ids:
try:
item = text_mem.get(mid, user_name=user_name)
if item:
mem_items.append(item)
except Exception:
pass
if len(mem_items) > 1:
keys: list[str] = []
memcube_content: list[dict] = []
Expand Down
12 changes: 12 additions & 0 deletions src/memos/memories/textual/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem
TextualMemoryItem: The memory with the given ID.
"""

def get_batch(
self, memory_ids: list[str], user_name: str | None = None
) -> list[TextualMemoryItem]:
"""Batch get memories by IDs.
Args:
memory_ids (list[str]): List of memory IDs to retrieve.
user_name (str | None): Optional user name for multi-tenant retrieval.
Returns:
list[TextualMemoryItem]: List of memories with the specified IDs.
"""
return [self.get(mid, user_name=user_name) for mid in memory_ids]

@abstractmethod
def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
"""Get memories by their IDs.
Expand Down
24 changes: 24 additions & 0 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,30 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem
metadata=TreeNodeTextualMemoryMetadata(**metadata_dict),
)

def get_batch(
self, memory_ids: list[str], user_name: str | None = None
) -> list[TextualMemoryItem]:
"""Batch get memories by IDs."""
results = self.graph_store.get_nodes(memory_ids, user_name=user_name)

items = []
for result in results:
if result:
try:
metadata_dict = result.get("metadata", {})
items.append(
TextualMemoryItem(
id=result["id"],
memory=result["memory"],
metadata=TreeNodeTextualMemoryMetadata(**metadata_dict),
)
)
except Exception as e:
logger.warning(
f"Failed to create TextualMemoryItem for id {result.get('id')}: {e}"
)
return items

def get_by_ids(
self, memory_ids: list[str], user_name: str | None = None
) -> list[TextualMemoryItem]:
Expand Down
Loading