diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8eabda6d8..12f2c2ca9 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4960,7 +4960,8 @@ def delete_node_by_prams( If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. - filter (dict, optional): Filter dictionary to query matching nodes for deletion. + filter (dict, optional): Filter dictionary for metadata filtering. + Filter conditions are directly used in DELETE WHERE clause without pre-querying. Returns: int: Number of nodes deleted. @@ -4980,35 +4981,14 @@ def delete_node_by_prams( f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" ) - # Query nodes by filter if provided - filter_ids = set() + # Build filter conditions using common method (no query, direct use in WHERE clause) + filter_conditions = [] if filter: - # Parse filter to validate and transform field names (e.g., add "info." prefix if needed) - parsed_filter = self.parse_filter(filter) - if parsed_filter: - # Use get_by_metadata with empty filters list and parsed filter - filter_ids = set( - self.get_by_metadata( - filters=[], - user_name=None, - filter=parsed_filter, - knowledgebase_ids=writable_cube_ids, - ) - ) - else: - logger.warning( - "[delete_node_by_prams] Filter parsed to None, skipping filter query" - ) - - # Combine all IDs that need to be deleted - all_memory_ids = set() - if memory_ids: - all_memory_ids.update(memory_ids) - if filter_ids: - all_memory_ids.update(filter_ids) + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") # If no conditions to delete, return 0 - if not all_memory_ids and not file_ids: + if not memory_ids and not file_ids and not filter_conditions: logger.warning( "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" ) @@ -5019,74 +4999,58 @@ def delete_node_by_prams( try: conn = self._get_connection() with conn.cursor() as cursor: - # Process memory_ids and filter_ids (all at once, no batching) - if all_memory_ids: - memory_ids_list = list(all_memory_ids) - logger.info( - f"[delete_node_by_prams] Processing {len(memory_ids_list)} memory_ids" - ) + # Build WHERE conditions list + where_conditions = [] - # Build conditions for all memory_ids + # Add memory_ids conditions + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") id_conditions = [] - for node_id in memory_ids_list: + for node_id in memory_ids: id_conditions.append( f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" ) - id_where = f"({' OR '.join(id_conditions)})" - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({id_where})" - else: - where_clause = id_where - - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] memory_ids delete_query: {delete_query}") - - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count += deleted_count + where_conditions.append(f"({' OR '.join(id_conditions)})") - logger.info( - f"[delete_node_by_prams] Deleted {deleted_count} nodes by memory_ids" - ) - - # Process file_ids (all at once, no batching) + # Add file_ids conditions if file_ids: logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - - # Build conditions for all file_ids file_id_conditions = [] for file_id in file_ids: file_id_conditions.append( f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" ) - file_id_where = f"({' OR '.join(file_id_conditions)})" + where_conditions.append(f"({' OR '.join(file_id_conditions)})") - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({file_id_where})" - else: - where_clause = file_id_where + # Add filter conditions + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + where_conditions.extend(filter_conditions) - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] file_ids delete_query: {delete_query}") + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_conditions.append(f"({user_name_where})") + + # Build final WHERE clause + if not where_conditions: + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") + return 0 - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count += deleted_count + where_clause = " AND ".join(where_conditions) + + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count = deleted_count - logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes by file_ids") + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") elapsed_time = time.time() - batch_start_time logger.info(