Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 86 additions & 7 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,94 @@ def handle_get_memories(
return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results)


def _build_quick_delete_constraints(delete_mem_req: DeleteMemoryRequest) -> dict[str, Any]:
"""Build fast-delete constraints from request-level fields."""
constraints: dict[str, Any] = {}
if delete_mem_req.user_id is not None:
constraints["user_id"] = delete_mem_req.user_id
if delete_mem_req.session_id is not None:
constraints["session_id"] = delete_mem_req.session_id
return constraints


def _merge_delete_filter(
base_filter: dict[str, Any] | None,
constraints: dict[str, Any],
) -> dict[str, Any]:
"""Merge user/session constraints into an existing filter."""
if not constraints:
return base_filter or {}
if base_filter is None:
return {"and": [constraints.copy()]}

if not base_filter:
return {"and": [constraints.copy()]}

if "and" in base_filter:
and_conditions = base_filter.get("and")
if not isinstance(and_conditions, list):
raise ValueError("Invalid filter format: 'and' must be a list")
return {"and": [*and_conditions, constraints.copy()]}

if "or" in base_filter:
or_conditions = base_filter.get("or")
if not isinstance(or_conditions, list):
raise ValueError("Invalid filter format: 'or' must be a list")

merged_or_conditions: list[dict[str, Any]] = []
for condition in or_conditions:
if not isinstance(condition, dict):
raise ValueError("Invalid filter format: each 'or' condition must be a dict")
merged_condition = condition.copy()
for key, value in constraints.items():
if key in merged_condition and merged_condition[key] != value:
raise ValueError(
f"Conflicting filter condition for '{key}'. "
"Please merge it manually into request.filter."
)
merged_condition[key] = value
merged_or_conditions.append(merged_condition)

return {"or": merged_or_conditions}

# For plain dict filters, keep strict AND semantics explicitly.
return {"and": [base_filter.copy(), constraints.copy()]}


def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube):
logger.info(
f"[Delete memory request] writable_cube_ids: {delete_mem_req.writable_cube_ids}, memory_ids: {delete_mem_req.memory_ids}"
"[Delete memory request] writable_cube_ids: %s, memory_ids: %s, file_ids: %s, "
"has_filter: %s, user_id: %s, session_id: %s",
delete_mem_req.writable_cube_ids,
delete_mem_req.memory_ids,
delete_mem_req.file_ids,
delete_mem_req.filter is not None,
delete_mem_req.user_id,
delete_mem_req.session_id,
)
# Validate that only one of memory_ids, file_ids, or filter is provided
quick_constraints = _build_quick_delete_constraints(delete_mem_req)
has_non_empty_filter = bool(delete_mem_req.filter)
has_filter_mode = has_non_empty_filter or bool(quick_constraints)

# Reject empty filter dict when no quick constraints are provided.
if delete_mem_req.filter is not None and not has_non_empty_filter and not quick_constraints:
return DeleteMemoryResponse(
message="filter cannot be empty. Provide a non-empty filter or user_id/session_id.",
data={"status": "failure"},
)

# Validate that only one mode is provided: memory_ids, file_ids, or filter-mode.
provided_params = [
delete_mem_req.memory_ids is not None,
delete_mem_req.file_ids is not None,
delete_mem_req.filter is not None,
has_filter_mode,
]
if sum(provided_params) != 1:
return DeleteMemoryResponse(
message="Exactly one of memory_ids, file_ids, or filter must be provided",
message=(
"Exactly one delete mode must be provided: "
"memory_ids, file_ids, or filter/user_id/session_id."
),
data={"status": "failure"},
)

Expand All @@ -439,10 +514,14 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
naive_mem_cube.text_mem.delete_by_filter(
writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids
)
elif delete_mem_req.filter is not None:
naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter)
elif has_filter_mode:
merged_filter = _merge_delete_filter(delete_mem_req.filter, quick_constraints)
naive_mem_cube.text_mem.delete_by_filter(
writable_cube_ids=delete_mem_req.writable_cube_ids,
filter=merged_filter,
)
if naive_mem_cube.pref_mem is not None:
naive_mem_cube.pref_mem.delete_by_filter(filter=delete_mem_req.filter)
naive_mem_cube.pref_mem.delete_by_filter(filter=merged_filter)
except Exception as e:
logger.error(f"Failed to delete memories: {e}", exc_info=True)
return DeleteMemoryResponse(
Expand Down
23 changes: 22 additions & 1 deletion src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,31 @@ class GetMemoryDashboardRequest(GetMemoryRequest):
class DeleteMemoryRequest(BaseRequest):
"""Request model for deleting memories."""

writable_cube_ids: list[str] = Field(None, description="Writable cube IDs")
writable_cube_ids: list[str] | None = Field(None, description="Writable cube IDs")
memory_ids: list[str] | None = Field(None, description="Memory IDs")
file_ids: list[str] | None = Field(None, description="File IDs")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
user_id: str | None = Field(
None,
description="Quick delete condition: remove memories for this user_id.",
)
session_id: str | None = Field(
None,
description="Quick delete condition: remove memories for this session_id.",
)
conversation_id: str | None = Field(
None,
description="Alias of session_id for backward compatibility.",
)

@model_validator(mode="after")
def normalize_session_alias(self) -> "DeleteMemoryRequest":
"""Normalize conversation_id to session_id."""
if self.conversation_id and self.session_id and self.conversation_id != self.session_id:
raise ValueError("conversation_id and session_id must be the same when both are set")
if self.session_id is None and self.conversation_id is not None:
self.session_id = self.conversation_id
return self


class SuggestionRequest(BaseRequest):
Expand Down
36 changes: 24 additions & 12 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,14 @@ def build_filter_condition(
if condition_str:
where_clauses.append(f"({condition_str})")
filter_params.update(filter_params_inner)
else:
# Simple dict syntax: {"user_id": "...", "session_id": "..."}
condition_str, filter_params_inner = build_filter_condition(
filter, param_counter
)
if condition_str:
where_clauses.append(f"({condition_str})")
filter_params.update(filter_params_inner)

where_str = " AND ".join(where_clauses) if where_clauses else ""
if where_str:
Expand Down Expand Up @@ -841,7 +849,7 @@ def build_filter_condition(

def delete_node_by_prams(
self,
writable_cube_ids: list[str],
writable_cube_ids: list[str] | None = None,
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
Expand All @@ -850,7 +858,7 @@ def delete_node_by_prams(
Delete nodes by memory_ids, file_ids, or filter.

Args:
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to scope deletion.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
Expand All @@ -865,20 +873,21 @@ def delete_node_by_prams(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Validate writable_cube_ids
if not writable_cube_ids or len(writable_cube_ids) == 0:
raise ValueError("writable_cube_ids is required and cannot be empty")
# file_ids deletion must be scoped by writable_cube_ids.
if file_ids and (not writable_cube_ids or len(writable_cube_ids) == 0):
raise ValueError("writable_cube_ids is required when deleting by file_ids")

# Build WHERE conditions separately for memory_ids and file_ids
where_clauses = []
params = {}

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
user_name_conditions = []
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id
if writable_cube_ids:
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id

# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
Expand Down Expand Up @@ -925,9 +934,12 @@ def delete_node_by_prams(
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])

# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"
# Then, combine with user_name condition using AND when scope is provided.
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"
else:
ids_where = data_conditions

logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
Expand Down
Loading