Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,7 @@ def delete_node_by_prams(
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")
where_clauses.append(f"({' OR '.join(file_id_and_conditions)})")

# Query nodes by filter if provided
filter_ids = []
Expand Down
130 changes: 130 additions & 0 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,136 @@ def build_filter_condition(
result = session.run(query, params)
return [record["id"] for record in result]

def delete_node_by_prams(
self,
writable_cube_ids: list[str],
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
) -> int:
"""
Delete nodes by memory_ids, file_ids, or filter.

Args:
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.

Returns:
int: Number of nodes deleted.
"""
logger.info(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)
print(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Validate writable_cube_ids
if not writable_cube_ids or len(writable_cube_ids) == 0:
raise ValueError("writable_cube_ids is required and cannot be empty")

# Build WHERE conditions separately for memory_ids and file_ids
where_clauses = []
params = {}

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
user_name_conditions = []
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id

# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
where_clauses.append("n.id IN $memory_ids")
params["memory_ids"] = memory_ids

# Handle file_ids: query n.file_ids field
# All file_ids must be present in the array field (AND relationship)
if file_ids and len(file_ids) > 0:
file_id_and_conditions = []
for idx, file_id in enumerate(file_ids):
param_name = f"file_id_{idx}"
params[param_name] = file_id
# Check if this file_id is in the file_ids array field
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")

# Query nodes by filter if provided
filter_ids = []
if filter:
# Use get_by_metadata with empty filters list and filter
filter_ids = self.get_by_metadata(
filters=[],
user_name=None,
filter=filter,
knowledgebase_ids=writable_cube_ids,
)

# If filter returned IDs, add condition for them
if filter_ids:
where_clauses.append("n.id IN $filter_ids")
params["filter_ids"] = filter_ids

# If no conditions (except user_name), return 0
if not where_clauses:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0

# Build WHERE clause
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])

# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"

logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
print(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)

# First count matching nodes to get accurate count
count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
print(f"[delete_node_by_prams] count_query: {count_query}")

# Then delete nodes
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] params: {params}")

deleted_count = 0
try:
with self.driver.session(database=self.db_name) as session:
# Count nodes before deletion
count_result = session.run(count_query, **params)
count_record = count_result.single()
expected_count = 0
if count_record:
expected_count = count_record["node_count"] or 0

# Delete nodes
session.run(delete_query, **params)
# Use the count from before deletion as the actual deleted count
deleted_count = expected_count

except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
raise

logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count

def clear(self, user_name: str | None = None) -> None:
"""
Clear the entire graph if the target database exists.
Expand Down
3 changes: 2 additions & 1 deletion src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4113,6 +4113,7 @@ def parse_filter(
"memory_type",
"node_type",
"info",
"source",
}

def process_condition(condition):
Expand Down Expand Up @@ -4216,7 +4217,7 @@ def delete_node_by_prams(
file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_conditions.append(f"({' AND '.join(file_id_and_conditions)})")
where_conditions.append(f"({' OR '.join(file_id_and_conditions)})")

# Query nodes by filter if provided
filter_ids = set()
Expand Down
Loading