Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:

def delete_node_by_prams(
self,
writable_cube_ids: list[str],
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
Expand All @@ -1538,56 +1539,90 @@ def delete_node_by_prams(
Delete nodes by memory_ids, file_ids, or filter.

Args:
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.

Returns:
int: Number of nodes deleted.
"""
# Collect all node IDs to delete
ids_to_delete = set()
logger.info(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)
print(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Add memory_ids if provided
# Validate writable_cube_ids
if not writable_cube_ids or len(writable_cube_ids) == 0:
raise ValueError("writable_cube_ids is required and cannot be empty")

# Build WHERE conditions separately for memory_ids and file_ids
where_clauses = []
params = {}

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
user_name_conditions = []
for idx, cube_id in enumerate(writable_cube_ids):
param_name = f"cube_id_{idx}"
user_name_conditions.append(f"n.user_name = ${param_name}")
params[param_name] = cube_id

# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
ids_to_delete.update(memory_ids)
where_clauses.append("n.id IN $memory_ids")
params["memory_ids"] = memory_ids

# Add file_ids if provided (treating them as node IDs)
# Handle file_ids: query n.file_ids field
# All file_ids must be present in the array field (AND relationship)
if file_ids and len(file_ids) > 0:
ids_to_delete.update(file_ids)
file_id_and_conditions = []
for idx, file_id in enumerate(file_ids):
param_name = f"file_id_{idx}"
params[param_name] = file_id
# Check if this file_id is in the file_ids array field
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")

# Query nodes by filter if provided
filter_ids = []
if filter:
# Use get_by_metadata with empty filters list and filter
filter_ids = self.get_by_metadata(
filters=[],
user_name=None,
filter=filter,
knowledgebase_ids=None,
user_name_flag=False,
knowledgebase_ids=writable_cube_ids,
)
ids_to_delete.update(filter_ids)

# If no IDs to delete, return 0
if not ids_to_delete:
logger.warning("[delete_node_by_prams] No nodes to delete")
# If filter returned IDs, add condition for them
if filter_ids:
where_clauses.append("n.id IN $filter_ids")
params["filter_ids"] = filter_ids

# If no conditions (except user_name), return 0
if not where_clauses:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0

# Convert to list for easier handling
ids_list = list(ids_to_delete)
logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}")
# Build WHERE clause
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])

# Build WHERE condition for collected IDs (query n.id)
ids_where = "n.id IN $ids_to_delete"
params = {"ids_to_delete": ids_list}
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"

# Calculate total count for logging
total_count = len(ids_list)
logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
print(
f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)

# First count matching nodes to get accurate count
Expand All @@ -1599,16 +1634,17 @@ def delete_node_by_prams(
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] params: {params}")

deleted_count = 0
try:
with self.driver.session(database=self.db_name) as session:
# Count nodes before deletion
count_result = session.run(count_query, **params)
count_record = count_result.single()
expected_count = total_count
expected_count = 0
if count_record:
expected_count = count_record["node_count"] or total_count
expected_count = count_record["node_count"] or 0

# Delete nodes
session.run(delete_query, **params)
Expand Down
148 changes: 89 additions & 59 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,14 +1818,14 @@ def get_by_metadata(
raise ValueError(f"Unsupported operator: {op}")

# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions = []
if user_name_flag:
user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
default_user_name=self._get_config_value("user_name"),
)
print(f"[get_by_metadata] user_name_conditions: {user_name_conditions}")
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
default_user_name=self._get_config_value("user_name"),
)
print(f"[111get_by_metadata] user_name_conditions: {user_name_conditions}")

# Add user_name WHERE clause
if user_name_conditions:
Expand All @@ -1837,26 +1837,16 @@ def get_by_metadata(
# Build filter conditions using common method
filter_where_clause = self._build_filter_conditions_cypher(filter)

# Build WHERE clause: if where_conditions is empty, filter_where_clause should not have " AND " prefix
if where_conditions:
where_str = " AND ".join(where_conditions) + filter_where_clause
else:
# If no other conditions, remove " AND " prefix from filter_where_clause if present
if filter_where_clause.startswith(" AND "):
where_str = filter_where_clause[5:] # Remove " AND " prefix
else:
where_str = filter_where_clause
where_str = " AND ".join(where_conditions) + filter_where_clause

# Use cypher query
# Only include WHERE clause if where_str is not empty
where_clause = f"WHERE {where_str}" if where_str else ""
cypher_query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (n:Memory)
{where_clause}
RETURN n.id AS id
$$) AS (id agtype)
"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (n:Memory)
WHERE {where_str}
RETURN n.id AS id
$$) AS (id agtype)
"""

ids = []
conn = self._get_connection()
Expand Down Expand Up @@ -4008,6 +3998,7 @@ def process_condition(condition):
@timed
def delete_node_by_prams(
self,
writable_cube_ids: list[str],
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
Expand All @@ -4016,61 +4007,102 @@ def delete_node_by_prams(
Delete nodes by memory_ids, file_ids, or filter.

Args:
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.

Returns:
int: Number of nodes deleted.
"""
# Collect all node IDs to delete
ids_to_delete = set()
logger.info(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)
print(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Add memory_ids if provided
if memory_ids and len(memory_ids) > 0:
ids_to_delete.update(memory_ids)
# Validate writable_cube_ids
if not writable_cube_ids or len(writable_cube_ids) == 0:
raise ValueError("writable_cube_ids is required and cannot be empty")

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
user_name_conditions = []
for cube_id in writable_cube_ids:
# Escape single quotes in cube IDs
escaped_cube_id = str(cube_id).replace("'", "\\'")
user_name_conditions.append(f"n.user_name = '{escaped_cube_id}'")

# Add file_ids if provided (treating them as node IDs)
# Build WHERE conditions separately for memory_ids and file_ids
where_conditions = []

# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
memory_id_conditions = []
for node_id in memory_ids:
# Escape single quotes in node IDs
escaped_id = str(node_id).replace("'", "\\'")
memory_id_conditions.append(f"'{escaped_id}'")
if memory_id_conditions:
where_conditions.append(f"n.id IN [{', '.join(memory_id_conditions)}]")

# Handle file_ids: query n.file_ids field
# All file_ids must be present in the array field (AND relationship)
if file_ids and len(file_ids) > 0:
ids_to_delete.update(file_ids)
file_id_and_conditions = []
for file_id in file_ids:
# Escape single quotes in file IDs
escaped_id = str(file_id).replace("'", "\\'")
# Check if this file_id is in the file_ids array field
file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids")
if file_id_and_conditions:
# Use AND to require all file_ids to be present
where_conditions.append(f"({' AND '.join(file_id_and_conditions)})")

# Query nodes by filter if provided
filter_ids = set()
if filter:
# Parse filter to validate and transform field names (e.g., add "info." prefix if needed)
parsed_filter = self.parse_filter(filter)
if parsed_filter:
# Use get_by_metadata with empty filters list and parsed filter
filter_ids = self.get_by_metadata(
filters=[],
user_name=None,
filter=parsed_filter,
knowledgebase_ids=None,
user_name_flag=False,
filter_ids = set(
self.get_by_metadata(
filters=[],
user_name=None,
filter=parsed_filter,
knowledgebase_ids=writable_cube_ids,
)
)
ids_to_delete.update(filter_ids)
else:
logger.warning(
"[delete_node_by_prams] Filter parsed to None, skipping filter query"
)

# If no IDs to delete, return 0
if not ids_to_delete:
logger.warning("[delete_node_by_prams] No nodes to delete")
# If filter returned IDs, add condition for them
if filter_ids:
filter_id_conditions = []
for node_id in filter_ids:
# Escape single quotes in node IDs
escaped_id = str(node_id).replace("'", "\\'")
filter_id_conditions.append(f"'{escaped_id}'")
if filter_id_conditions:
where_conditions.append(f"n.id IN [{', '.join(filter_id_conditions)}]")

# If no conditions (except user_name), return 0
if not where_conditions:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0

# Convert to list for easier handling
ids_list = list(ids_to_delete)
logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}")

# Build WHERE condition for collected IDs (query n.id)
id_conditions = []
for node_id in ids_list:
# Escape single quotes in node IDs
escaped_id = str(node_id).replace("'", "\\'")
id_conditions.append(f"'{escaped_id}'")
# Build WHERE clause
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])

# Build WHERE clause for IDs
ids_where = f"n.id IN [{', '.join(id_conditions)}]"
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
ids_where = f"({user_name_where}) AND ({data_conditions})"

# Use Cypher DELETE query
# First count matching nodes to get accurate count
Expand All @@ -4093,13 +4125,11 @@ def delete_node_by_prams(
$$) AS (result agtype)
"""

# Calculate total count for logging
total_count = len(ids_list)
logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
print(
f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] delete_query: {delete_query}")
Expand All @@ -4111,11 +4141,11 @@ def delete_node_by_prams(
# Count nodes before deletion
cursor.execute(count_query)
count_results = cursor.fetchall()
expected_count = total_count
expected_count = 0
if count_results and len(count_results) > 0:
count_str = str(count_results[0][0])
count_str = count_str.strip('"').strip("'")
expected_count = int(count_str) if count_str.isdigit() else total_count
expected_count = int(count_str) if count_str.isdigit() else 0

# Delete nodes
cursor.execute(delete_query)
Expand Down
Loading