Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 38 additions & 43 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4792,35 +4792,35 @@ def delete_node_by_prams(
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
user_name_conditions = []
for cube_id in writable_cube_ids:
# Escape single quotes in cube IDs
escaped_cube_id = str(cube_id).replace("'", "\\'")
user_name_conditions.append(f"n.user_name = '{escaped_cube_id}'")
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
user_name_conditions.append(
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
)

# Build WHERE conditions separately for memory_ids and file_ids
where_conditions = []

# Handle memory_ids: query n.id
# Handle memory_ids: query properties.id
if memory_ids and len(memory_ids) > 0:
memory_id_conditions = []
for node_id in memory_ids:
# Escape single quotes in node IDs
escaped_id = str(node_id).replace("'", "\\'")
memory_id_conditions.append(f"'{escaped_id}'")
memory_id_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
)
if memory_id_conditions:
where_conditions.append(f"n.id IN [{', '.join(memory_id_conditions)}]")
where_conditions.append(f"({' OR '.join(memory_id_conditions)})")

# Handle file_ids: query n.file_ids field
# All file_ids must be present in the array field (AND relationship)
# Check if any file_id is in the file_ids array field (OR relationship)
if file_ids and len(file_ids) > 0:
file_id_and_conditions = []
file_id_conditions = []
for file_id in file_ids:
# Escape single quotes in file IDs
escaped_id = str(file_id).replace("'", "\\'")
# Check if this file_id is in the file_ids array field
file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_conditions.append(f"({' OR '.join(file_id_and_conditions)})")
# Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype)
file_id_conditions.append(
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
)
if file_id_conditions:
# Use OR to match any file_id in the array
where_conditions.append(f"({' OR '.join(file_id_conditions)})")

# Query nodes by filter if provided
filter_ids = set()
Expand All @@ -4846,11 +4846,11 @@ def delete_node_by_prams(
if filter_ids:
filter_id_conditions = []
for node_id in filter_ids:
# Escape single quotes in node IDs
escaped_id = str(node_id).replace("'", "\\'")
filter_id_conditions.append(f"'{escaped_id}'")
filter_id_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
)
if filter_id_conditions:
where_conditions.append(f"n.id IN [{', '.join(filter_id_conditions)}]")
where_conditions.append(f"({' OR '.join(filter_id_conditions)})")

# If no conditions (except user_name), return 0
if not where_conditions:
Expand All @@ -4865,26 +4865,21 @@ def delete_node_by_prams(

# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"{user_name_where} AND ({data_conditions})"
where_clause = f"({user_name_where}) AND ({data_conditions})"

# Use Cypher DELETE query
# Use SQL DELETE query for better performance
# First count matching nodes to get accurate count
count_query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (n:Memory)
WHERE {ids_where}
RETURN count(n) AS node_count
$$) AS (node_count agtype)
SELECT COUNT(*)
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(f"[delete_node_by_prams] count_query: {count_query}")

# Then delete nodes
delete_query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (n:Memory)
WHERE {ids_where}
DELETE n
$$) AS (result agtype)
DELETE FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""

logger.info(
Expand All @@ -4899,20 +4894,20 @@ def delete_node_by_prams(
with conn.cursor() as cursor:
# Count nodes before deletion
cursor.execute(count_query)
count_results = cursor.fetchall()
expected_count = 0
if count_results and len(count_results) > 0:
count_str = str(count_results[0][0])
count_str = count_str.strip('"').strip("'")
expected_count = int(count_str) if count_str.isdigit() else 0
count_result = cursor.fetchone()
expected_count = count_result[0] if count_result else 0

logger.info(
f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria"
)

# Delete nodes
cursor.execute(delete_query)
# Use the count from before deletion as the actual deleted count
deleted_count = expected_count
# Use rowcount to get actual deleted count
deleted_count = cursor.rowcount
elapsed_time = time.time() - batch_start_time
logger.info(
f"[delete_node_by_prams] execute_values completed successfully in {elapsed_time:.2f}s"
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes"
)
except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
Expand Down