diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 88b95b536..126e974a3 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1588,7 +1588,7 @@ def delete_node_by_prams( file_id_and_conditions.append(f"${param_name} IN n.file_ids") if file_id_and_conditions: # Use AND to require all file_ids to be present - where_clauses.append(f"({' AND '.join(file_id_and_conditions)})") + where_clauses.append(f"({' OR '.join(file_id_and_conditions)})") # Query nodes by filter if provided filter_ids = [] diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index ff7d5f50b..e943616da 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -706,6 +706,136 @@ def build_filter_condition( result = session.run(query, params) return [record["id"] for record in result] + def delete_node_by_prams( + self, + writable_cube_ids: list[str], + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + ) -> int: + """ + Delete nodes by memory_ids, file_ids, or filter. + + Args: + writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. + memory_ids (list[str], optional): List of memory node IDs to delete. + file_ids (list[str], optional): List of file node IDs to delete. + filter (dict, optional): Filter dictionary to query matching nodes for deletion. + + Returns: + int: Number of nodes deleted. + """ + logger.info( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) + print( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) + + # Validate writable_cube_ids + if not writable_cube_ids or len(writable_cube_ids) == 0: + raise ValueError("writable_cube_ids is required and cannot be empty") + + # Build WHERE conditions separately for memory_ids and file_ids + where_clauses = [] + params = {} + + # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + user_name_conditions = [] + for idx, cube_id in enumerate(writable_cube_ids): + param_name = f"cube_id_{idx}" + user_name_conditions.append(f"n.user_name = ${param_name}") + params[param_name] = cube_id + + # Handle memory_ids: query n.id + if memory_ids and len(memory_ids) > 0: + where_clauses.append("n.id IN $memory_ids") + params["memory_ids"] = memory_ids + + # Handle file_ids: query n.file_ids field + # All file_ids must be present in the array field (AND relationship) + if file_ids and len(file_ids) > 0: + file_id_and_conditions = [] + for idx, file_id in enumerate(file_ids): + param_name = f"file_id_{idx}" + params[param_name] = file_id + # Check if this file_id is in the file_ids array field + file_id_and_conditions.append(f"${param_name} IN n.file_ids") + if file_id_and_conditions: + # Use AND to require all file_ids to be present + where_clauses.append(f"({' AND '.join(file_id_and_conditions)})") + + # Query nodes by filter if provided + filter_ids = [] + if filter: + # Use get_by_metadata with empty filters list and filter + filter_ids = self.get_by_metadata( + filters=[], + user_name=None, + filter=filter, + knowledgebase_ids=writable_cube_ids, + ) + + # If filter returned IDs, add condition for them + if filter_ids: + where_clauses.append("n.id IN $filter_ids") + params["filter_ids"] = filter_ids + + # If no conditions (except user_name), return 0 + if not where_clauses: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) + return 0 + + # Build WHERE clause + # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) + data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) + + # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) + user_name_where = " OR ".join(user_name_conditions) + ids_where = f"({user_name_where}) AND ({data_conditions})" + + logger.info( + f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + ) + print( + f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + ) + + # First count matching nodes to get accurate count + count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count" + logger.info(f"[delete_node_by_prams] count_query: {count_query}") + print(f"[delete_node_by_prams] count_query: {count_query}") + + # Then delete nodes + delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n" + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + print(f"[delete_node_by_prams] delete_query: {delete_query}") + print(f"[delete_node_by_prams] params: {params}") + + deleted_count = 0 + try: + with self.driver.session(database=self.db_name) as session: + # Count nodes before deletion + count_result = session.run(count_query, **params) + count_record = count_result.single() + expected_count = 0 + if count_record: + expected_count = count_record["node_count"] or 0 + + # Delete nodes + session.run(delete_query, **params) + # Use the count from before deletion as the actual deleted count + deleted_count = expected_count + + except Exception as e: + logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) + raise + + logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") + return deleted_count + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 638eac9c2..d3dc1b4f9 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4113,6 +4113,7 @@ def parse_filter( "memory_type", "node_type", "info", + "source", } def process_condition(condition): @@ -4216,7 +4217,7 @@ def delete_node_by_prams( file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids") if file_id_and_conditions: # Use AND to require all file_ids to be present - where_conditions.append(f"({' AND '.join(file_id_and_conditions)})") + where_conditions.append(f"({' OR '.join(file_id_and_conditions)})") # Query nodes by filter if provided filter_ids = set()