diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index c8a1f5144..9de06cd90 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1530,6 +1530,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: def delete_node_by_prams( self, + writable_cube_ids: list[str], memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, @@ -1538,6 +1539,7 @@ def delete_node_by_prams( Delete nodes by memory_ids, file_ids, or filter. Args: + writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. filter (dict, optional): Filter dictionary to query matching nodes for deletion. @@ -1545,49 +1547,82 @@ def delete_node_by_prams( Returns: int: Number of nodes deleted. """ - # Collect all node IDs to delete - ids_to_delete = set() + logger.info( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) + print( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) - # Add memory_ids if provided + # Validate writable_cube_ids + if not writable_cube_ids or len(writable_cube_ids) == 0: + raise ValueError("writable_cube_ids is required and cannot be empty") + + # Build WHERE conditions separately for memory_ids and file_ids + where_clauses = [] + params = {} + + # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + user_name_conditions = [] + for idx, cube_id in enumerate(writable_cube_ids): + param_name = f"cube_id_{idx}" + user_name_conditions.append(f"n.user_name = ${param_name}") + params[param_name] = cube_id + + # Handle memory_ids: query n.id if memory_ids and len(memory_ids) > 0: - ids_to_delete.update(memory_ids) + where_clauses.append("n.id IN $memory_ids") + params["memory_ids"] = memory_ids - # Add file_ids if provided (treating them as node IDs) + # Handle file_ids: query n.file_ids field + # All file_ids must be present in the array field (AND relationship) if file_ids and len(file_ids) > 0: - ids_to_delete.update(file_ids) + file_id_and_conditions = [] + for idx, file_id in enumerate(file_ids): + param_name = f"file_id_{idx}" + params[param_name] = file_id + # Check if this file_id is in the file_ids array field + file_id_and_conditions.append(f"${param_name} IN n.file_ids") + if file_id_and_conditions: + # Use AND to require all file_ids to be present + where_clauses.append(f"({' AND '.join(file_id_and_conditions)})") # Query nodes by filter if provided + filter_ids = [] if filter: # Use get_by_metadata with empty filters list and filter filter_ids = self.get_by_metadata( filters=[], user_name=None, filter=filter, - knowledgebase_ids=None, - user_name_flag=False, + knowledgebase_ids=writable_cube_ids, ) - ids_to_delete.update(filter_ids) - # If no IDs to delete, return 0 - if not ids_to_delete: - logger.warning("[delete_node_by_prams] No nodes to delete") + # If filter returned IDs, add condition for them + if filter_ids: + where_clauses.append("n.id IN $filter_ids") + params["filter_ids"] = filter_ids + + # If no conditions (except user_name), return 0 + if not where_clauses: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) return 0 - # Convert to list for easier handling - ids_list = list(ids_to_delete) - logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}") + # Build WHERE clause + # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) + data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) - # Build WHERE condition for collected IDs (query n.id) - ids_where = "n.id IN $ids_to_delete" - params = {"ids_to_delete": ids_list} + # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) + user_name_where = " OR ".join(user_name_conditions) + ids_where = f"({user_name_where}) AND ({data_conditions})" - # Calculate total count for logging - total_count = len(ids_list) logger.info( f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" ) print( - f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" ) # First count matching nodes to get accurate count @@ -1599,6 +1634,7 @@ def delete_node_by_prams( delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n" logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") print(f"[delete_node_by_prams] delete_query: {delete_query}") + print(f"[delete_node_by_prams] params: {params}") deleted_count = 0 try: @@ -1606,9 +1642,9 @@ def delete_node_by_prams( # Count nodes before deletion count_result = session.run(count_query, **params) count_record = count_result.single() - expected_count = total_count + expected_count = 0 if count_record: - expected_count = count_record["node_count"] or total_count + expected_count = count_record["node_count"] or 0 # Delete nodes session.run(delete_query, **params) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index e731ef138..d2d69c768 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1818,14 +1818,14 @@ def get_by_metadata( raise ValueError(f"Unsupported operator: {op}") # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - user_name_conditions = [] - if user_name_flag: - user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( - user_name=user_name, - knowledgebase_ids=knowledgebase_ids, - default_user_name=self._get_config_value("user_name"), - ) - print(f"[get_by_metadata] user_name_conditions: {user_name_conditions}") + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self._get_config_value("user_name"), + ) + print(f"[111get_by_metadata] user_name_conditions: {user_name_conditions}") # Add user_name WHERE clause if user_name_conditions: @@ -1837,26 +1837,16 @@ def get_by_metadata( # Build filter conditions using common method filter_where_clause = self._build_filter_conditions_cypher(filter) - # Build WHERE clause: if where_conditions is empty, filter_where_clause should not have " AND " prefix - if where_conditions: - where_str = " AND ".join(where_conditions) + filter_where_clause - else: - # If no other conditions, remove " AND " prefix from filter_where_clause if present - if filter_where_clause.startswith(" AND "): - where_str = filter_where_clause[5:] # Remove " AND " prefix - else: - where_str = filter_where_clause + where_str = " AND ".join(where_conditions) + filter_where_clause # Use cypher query - # Only include WHERE clause if where_str is not empty - where_clause = f"WHERE {where_str}" if where_str else "" cypher_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - {where_clause} - RETURN n.id AS id - $$) AS (id agtype) - """ + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (n:Memory) + WHERE {where_str} + RETURN n.id AS id + $$) AS (id agtype) + """ ids = [] conn = self._get_connection() @@ -4008,6 +3998,7 @@ def process_condition(condition): @timed def delete_node_by_prams( self, + writable_cube_ids: list[str], memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, @@ -4016,6 +4007,7 @@ def delete_node_by_prams( Delete nodes by memory_ids, file_ids, or filter. Args: + writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. filter (dict, optional): Filter dictionary to query matching nodes for deletion. @@ -4023,54 +4015,94 @@ def delete_node_by_prams( Returns: int: Number of nodes deleted. """ - # Collect all node IDs to delete - ids_to_delete = set() + logger.info( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) + print( + f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" + ) - # Add memory_ids if provided - if memory_ids and len(memory_ids) > 0: - ids_to_delete.update(memory_ids) + # Validate writable_cube_ids + if not writable_cube_ids or len(writable_cube_ids) == 0: + raise ValueError("writable_cube_ids is required and cannot be empty") + + # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + user_name_conditions = [] + for cube_id in writable_cube_ids: + # Escape single quotes in cube IDs + escaped_cube_id = str(cube_id).replace("'", "\\'") + user_name_conditions.append(f"n.user_name = '{escaped_cube_id}'") - # Add file_ids if provided (treating them as node IDs) + # Build WHERE conditions separately for memory_ids and file_ids + where_conditions = [] + + # Handle memory_ids: query n.id + if memory_ids and len(memory_ids) > 0: + memory_id_conditions = [] + for node_id in memory_ids: + # Escape single quotes in node IDs + escaped_id = str(node_id).replace("'", "\\'") + memory_id_conditions.append(f"'{escaped_id}'") + if memory_id_conditions: + where_conditions.append(f"n.id IN [{', '.join(memory_id_conditions)}]") + + # Handle file_ids: query n.file_ids field + # All file_ids must be present in the array field (AND relationship) if file_ids and len(file_ids) > 0: - ids_to_delete.update(file_ids) + file_id_and_conditions = [] + for file_id in file_ids: + # Escape single quotes in file IDs + escaped_id = str(file_id).replace("'", "\\'") + # Check if this file_id is in the file_ids array field + file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids") + if file_id_and_conditions: + # Use AND to require all file_ids to be present + where_conditions.append(f"({' AND '.join(file_id_and_conditions)})") # Query nodes by filter if provided + filter_ids = set() if filter: # Parse filter to validate and transform field names (e.g., add "info." prefix if needed) parsed_filter = self.parse_filter(filter) if parsed_filter: # Use get_by_metadata with empty filters list and parsed filter - filter_ids = self.get_by_metadata( - filters=[], - user_name=None, - filter=parsed_filter, - knowledgebase_ids=None, - user_name_flag=False, + filter_ids = set( + self.get_by_metadata( + filters=[], + user_name=None, + filter=parsed_filter, + knowledgebase_ids=writable_cube_ids, + ) ) - ids_to_delete.update(filter_ids) else: logger.warning( "[delete_node_by_prams] Filter parsed to None, skipping filter query" ) - # If no IDs to delete, return 0 - if not ids_to_delete: - logger.warning("[delete_node_by_prams] No nodes to delete") + # If filter returned IDs, add condition for them + if filter_ids: + filter_id_conditions = [] + for node_id in filter_ids: + # Escape single quotes in node IDs + escaped_id = str(node_id).replace("'", "\\'") + filter_id_conditions.append(f"'{escaped_id}'") + if filter_id_conditions: + where_conditions.append(f"n.id IN [{', '.join(filter_id_conditions)}]") + + # If no conditions (except user_name), return 0 + if not where_conditions: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) return 0 - # Convert to list for easier handling - ids_list = list(ids_to_delete) - logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}") - - # Build WHERE condition for collected IDs (query n.id) - id_conditions = [] - for node_id in ids_list: - # Escape single quotes in node IDs - escaped_id = str(node_id).replace("'", "\\'") - id_conditions.append(f"'{escaped_id}'") + # Build WHERE clause + # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) + data_conditions = " OR ".join([f"({cond})" for cond in where_conditions]) - # Build WHERE clause for IDs - ids_where = f"n.id IN [{', '.join(id_conditions)}]" + # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) + user_name_where = " OR ".join(user_name_conditions) + ids_where = f"({user_name_where}) AND ({data_conditions})" # Use Cypher DELETE query # First count matching nodes to get accurate count @@ -4093,13 +4125,11 @@ def delete_node_by_prams( $$) AS (result agtype) """ - # Calculate total count for logging - total_count = len(ids_list) logger.info( f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" ) print( - f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" + f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" ) logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") print(f"[delete_node_by_prams] delete_query: {delete_query}") @@ -4111,11 +4141,11 @@ def delete_node_by_prams( # Count nodes before deletion cursor.execute(count_query) count_results = cursor.fetchall() - expected_count = total_count + expected_count = 0 if count_results and len(count_results) > 0: count_str = str(count_results[0][0]) count_str = count_str.strip('"').strip("'") - expected_count = int(count_str) if count_str.isdigit() else total_count + expected_count = int(count_str) if count_str.isdigit() else 0 # Delete nodes cursor.execute(delete_query)