Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 83 additions & 26 deletions src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,22 +388,6 @@ def generate_chat_response() -> Generator[str, None, None]:
[chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id]
)

search_req = APISearchRequest(
query=chat_req.query,
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
mode=chat_req.mode,
internet_search=chat_req.internet_search,
top_k=chat_req.top_k,
chat_history=chat_req.history,
session_id=chat_req.session_id,
include_preference=chat_req.include_preference,
pref_top_k=chat_req.pref_top_k,
filter=chat_req.filter,
playground_search_goal_parser=True,
)

search_response = self.search_handler.handle_search_memories(search_req)
# for playground, add the query to memory without response
self._start_add_to_memory(
user_id=chat_req.user_id,
Expand All @@ -414,7 +398,6 @@ def generate_chat_response() -> Generator[str, None, None]:
async_mode="sync",
)

yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
# Use first readable cube ID for scheduler (backward compatibility)
scheduler_cube_id = (
readable_cube_ids[0] if readable_cube_ids else chat_req.user_id
Expand All @@ -425,22 +408,40 @@ def generate_chat_response() -> Generator[str, None, None]:
query=chat_req.query,
label=QUERY_TASK_LABEL,
)
# Extract memories from search results

# ====== first search without parse goal ======
search_req = APISearchRequest(
query=chat_req.query,
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
mode=chat_req.mode,
internet_search=False,
top_k=chat_req.top_k,
chat_history=chat_req.history,
session_id=chat_req.session_id,
include_preference=chat_req.include_preference,
pref_top_k=chat_req.pref_top_k,
filter=chat_req.filter,
)
search_response = self.search_handler.handle_search_memories(search_req)

yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"

# Extract memories from search results (first search)
memories_list = []
if search_response.data and search_response.data.get("text_mem"):
text_mem_results = search_response.data["text_mem"]
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]

# Filter memories by threshold
filtered_memories = self._filter_memories_by_threshold(memories_list)
first_filtered_memories = self._filter_memories_by_threshold(memories_list)

# Prepare reference data (first search)
reference = prepare_reference_data(first_filtered_memories)
# get preference string
pref_string = search_response.data.get("pref_string", "")

# Prepare reference data
reference = prepare_reference_data(filtered_memories)
# get internet reference
internet_reference = self._get_internet_reference(
search_response.data.get("text_mem")[0]["memories"]
)
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"

# Prepare preference markdown string
Expand All @@ -450,9 +451,52 @@ def generate_chat_response() -> Generator[str, None, None]:
pref_md_string = self._build_pref_md_string_for_playground(pref_memories)
yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n"

# internet status
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"

# ====== second search with parse goal ======
search_req = APISearchRequest(
query=chat_req.query,
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
mode=chat_req.mode,
internet_search=chat_req.internet_search,
top_k=chat_req.top_k,
chat_history=chat_req.history,
session_id=chat_req.session_id,
include_preference=False,
filter=chat_req.filter,
playground_search_goal_parser=True,
)
search_response = self.search_handler.handle_search_memories(search_req)

# Extract memories from search results (second search)
memories_list = []
if search_response.data and search_response.data.get("text_mem"):
text_mem_results = search_response.data["text_mem"]
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]

# Filter memories by threshold
second_filtered_memories = self._filter_memories_by_threshold(memories_list)

# dedup and supplement memories
filtered_memories = self._dedup_and_supplement_memories(
first_filtered_memories, second_filtered_memories
)

# Prepare remain reference data (second search)
reference = prepare_reference_data(filtered_memories)
# get internet reference
internet_reference = self._get_internet_reference(
search_response.data.get("text_mem")[0]["memories"]
)

yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"

# Step 2: Build system prompt with memories
system_prompt = self._build_enhance_system_prompt(
filtered_memories, search_response.data.get("pref_string", "")
filtered_memories, pref_string
)

# Prepare messages
Expand Down Expand Up @@ -588,6 +632,19 @@ def generate_chat_response() -> Generator[str, None, None]:
self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err

def _dedup_and_supplement_memories(
self, first_filtered_memories: list, second_filtered_memories: list
) -> list:
"""Remove memory from second_filtered_memories that already exists in first_filtered_memories, return remaining memories"""
# Create a set of IDs from first_filtered_memories for efficient lookup
first_memory_ids = {memory["id"] for memory in first_filtered_memories}

remaining_memories = []
for memory in second_filtered_memories:
if memory["id"] not in first_memory_ids:
remaining_memories.append(memory)
return remaining_memories

def _get_internet_reference(
self, search_response: list[dict[str, any]]
) -> list[dict[str, any]]:
Expand Down
8 changes: 2 additions & 6 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,8 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
if naive_mem_cube.pref_mem is not None:
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
elif delete_mem_req.file_ids is not None:
# TODO: Implement deletion by file_ids
# Need to find memory_ids associated with file_ids and delete them
logger.warning("Deletion by file_ids not implemented yet")
return DeleteMemoryResponse(
message="Deletion by file_ids not implemented yet",
data={"status": "failure"},
naive_mem_cube.text_mem.delete_by_filter(
writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids
)
elif delete_mem_req.filter is not None:
# TODO: Implement deletion by filter
Expand Down
22 changes: 22 additions & 0 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,28 @@ def delete_all(self) -> None:
logger.error(f"An error occurred while deleting all memories: {e}")
raise

def delete_by_filter(
self,
writable_cube_ids: list[str],
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
) -> int:
"""Delete memories by filter.
Returns:
int: Number of nodes deleted.
"""
try:
return self.graph_store.delete_node_by_prams(
writable_cube_ids=writable_cube_ids,
memory_ids=memory_ids,
file_ids=file_ids,
filter=filter,
)
except Exception as e:
logger.error(f"An error occurred while deleting memories by filter: {e}")
raise

def load(self, dir: str) -> None:
try:
memory_file = os.path.join(dir, self.config.memory_filename)
Expand Down
2 changes: 2 additions & 0 deletions src/memos/multi_mem_cube/composite_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
"para_mem": [],
"pref_mem": [],
"pref_note": "",
"tool_mem": [],
}

for view in self.cube_views:
Expand All @@ -52,6 +53,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
merged_results["act_mem"].extend(cube_result.get("act_mem", []))
merged_results["para_mem"].extend(cube_result.get("para_mem", []))
merged_results["pref_mem"].extend(cube_result.get("pref_mem", []))
merged_results["tool_mem"].extend(cube_result.get("tool_mem", []))

note = cube_result.get("pref_note")
if note:
Expand Down
Loading