Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import json
import os
import time
import traceback

from memos.configs.mem_scheduler import GeneralSchedulerConfig
Expand Down Expand Up @@ -337,9 +338,20 @@ def log_add_messages(self, msg: ScheduleMessageItem):
for memory_id in userinput_memory_ids:
try:
# This mem_item represents the NEW content that was just added/processed
mem_item: TextualMemoryItem = self.current_mem_cube.text_mem.get(
memory_id=memory_id
)
mem_item: TextualMemoryItem | None = None
for attempt in range(3):
try:
mem_item = self.current_mem_cube.text_mem.get(
memory_id=memory_id, user_name=msg.mem_cube_id
)
break
except Exception:
if attempt < 2:
time.sleep(0.5)
else:
raise
if mem_item is None:
raise ValueError(f"Memory {memory_id} not found after retries")
# Check if a memory with the same key already exists (determining if it's an update)
key = getattr(mem_item.metadata, "key", None) or transform_name_to_key(
name=mem_item.memory
Expand All @@ -366,7 +378,7 @@ def log_add_messages(self, msg: ScheduleMessageItem):
# Crucial step: Fetch the original content for updates
# This `get` is for the *existing* memory that will be updated
original_mem_item = self.current_mem_cube.text_mem.get(
memory_id=original_item_id
memory_id=original_item_id, user_name=msg.mem_cube_id
)
original_content = original_mem_item.memory

Expand Down Expand Up @@ -825,7 +837,7 @@ def _process_memories_with_reader(
memory_items = []
for mem_id in mem_ids:
try:
memory_item = text_mem.get(mem_id)
memory_item = text_mem.get(mem_id, user_name=user_name)
memory_items.append(memory_item)
except Exception as e:
logger.warning(f"Failed to get memory {mem_id}: {e}")
Expand Down Expand Up @@ -1077,7 +1089,7 @@ def process_message(message: ScheduleMessageItem):
mem_items: list[TextualMemoryItem] = []
for mid in mem_ids:
with contextlib.suppress(Exception):
mem_items.append(text_mem.get(mid))
mem_items.append(text_mem.get(mid, user_name=user_name))
if len(mem_items) > 1:
keys: list[str] = []
memcube_content: list[dict] = []
Expand Down Expand Up @@ -1133,7 +1145,7 @@ def process_message(message: ScheduleMessageItem):
if merged_target_ids:
post_ref_id = next(iter(merged_target_ids))
with contextlib.suppress(Exception):
merged_item = text_mem.get(post_ref_id)
merged_item = text_mem.get(post_ref_id, user_name=user_name)
combined_key = (
getattr(getattr(merged_item, "metadata", {}), "key", None)
or combined_key
Expand Down Expand Up @@ -1242,7 +1254,7 @@ def _process_memories_with_reorganize(
memory_items = []
for mem_id in mem_ids:
try:
memory_item = text_mem.get(mem_id)
memory_item = text_mem.get(mem_id, user_name=user_name)
memory_items.append(memory_item)
except Exception as e:
logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}")
Expand Down Expand Up @@ -1357,7 +1369,9 @@ def process_session_turn(
f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}"
)

cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory(
user_name=mem_cube_id
)
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
intent_result = self.monitor.detect_intent(
q_list=queries, text_working_memory=text_working_memory
Expand Down
2 changes: 1 addition & 1 deletion src/memos/memories/textual/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
"""

@abstractmethod
def get(self, memory_id: str) -> TextualMemoryItem:
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
"""Get a memory by its ID.
Args:
memory_id (str): The ID of the memory to retrieve.
Expand Down
2 changes: 1 addition & 1 deletion src/memos/memories/textual/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
]
return result_memories

def get(self, memory_id: str) -> TextualMemoryItem:
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
"""Get a memory by its ID."""
result = self.vector_db.get_by_id(memory_id)
if result is None:
Expand Down
2 changes: 1 addition & 1 deletion src/memos/memories/textual/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def search(self, query: str, top_k: int, **kwargs) -> list[TextualMemoryItem]:
# Convert search results to TextualMemoryItem objects
return [TextualMemoryItem(**memory) for memory, _ in sims[:top_k]]

def get(self, memory_id: str) -> TextualMemoryItem:
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
"""Get a memory by its ID."""
for memory in self.memories:
if memory["id"] == memory_id:
Expand Down
2 changes: 1 addition & 1 deletion src/memos/memories/textual/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any])
"""Update a memory by memory_id."""
raise NotImplementedError

def get(self, memory_id: str) -> TextualMemoryItem:
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
"""Get a memory by its ID.
Args:
memory_id (str): The ID of the memory to retrieve.
Expand Down
4 changes: 2 additions & 2 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]:
def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None:
raise NotImplementedError

def get(self, memory_id: str) -> TextualMemoryItem:
def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
"""Get a memory by its ID."""
result = self.graph_store.get_node(memory_id)
result = self.graph_store.get_node(memory_id, user_name=user_name)
if result is None:
raise ValueError(f"Memory with ID {memory_id} not found")
metadata_dict = result.get("metadata", {})
Expand Down
Loading