diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 339b9a330..c81e46804 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4763,7 +4763,7 @@ def process_condition(condition): @timed def delete_node_by_prams( self, - writable_cube_ids: list[str], + writable_cube_ids: list[str] | None = None, memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, @@ -4772,7 +4772,8 @@ def delete_node_by_prams( Delete nodes by memory_ids, file_ids, or filter. Args: - writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. + writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. + If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. filter (dict, optional): Filter dictionary to query matching nodes for deletion. @@ -4785,17 +4786,15 @@ def delete_node_by_prams( f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) - # Validate writable_cube_ids - if not writable_cube_ids or len(writable_cube_ids) == 0: - raise ValueError("writable_cube_ids is required and cannot be empty") - # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + # Only add user_name filter if writable_cube_ids is provided user_name_conditions = [] - for cube_id in writable_cube_ids: - # Use agtype_access_operator with VARIADIC ARRAY format for consistency - user_name_conditions.append( - f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" - ) + if writable_cube_ids and len(writable_cube_ids) > 0: + for cube_id in writable_cube_ids: + # Use agtype_access_operator with VARIADIC ARRAY format for consistency + user_name_conditions.append( + f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" + ) # Build WHERE conditions separately for memory_ids and file_ids where_conditions = [] @@ -4863,9 +4862,14 @@ def delete_node_by_prams( # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) data_conditions = " OR ".join([f"({cond})" for cond in where_conditions]) - # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({data_conditions})" + # Build final WHERE clause + # If user_name_conditions exist, combine with data_conditions using AND + # Otherwise, use only data_conditions + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({data_conditions})" + else: + where_clause = f"({data_conditions})" # Use SQL DELETE query for better performance # First count matching nodes to get accurate count @@ -4917,3 +4921,91 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") return deleted_count + + @timed + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, list[str]]: Dictionary with one key: + - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) + - 'exist_user_names': List of distinct user names (if all memory_ids exist) + """ + if not memory_ids: + return {"exist_user_names": []} + + # Build OR conditions for each memory_id + id_conditions = [] + for mid in memory_ids: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype" + ) + + where_clause = f"({' OR '.join(id_conditions)})" + + # Query to check which memory_ids exist + check_query = f""" + SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}") + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Check which memory_ids exist + cursor.execute(check_query) + check_results = cursor.fetchall() + existing_ids = set() + for row in check_results: + node_id = row[0] + # Remove quotes if present + if isinstance(node_id, str): + node_id = node_id.strip('"').strip("'") + existing_ids.add(node_id) + + # Check if any memory_ids are missing + no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] + + # If any memory_ids are missing, return no_exist_memory_ids + if no_exist_list: + logger.info( + f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" + ) + return {"no_exist_memory_ids": no_exist_list} + + # All memory_ids exist, query user_names + user_names_query = f""" + SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") + + cursor.execute(user_names_query) + results = cursor.fetchall() + user_names = [] + for row in results: + user_name = row[0] + # Remove quotes if present + if isinstance(user_name, str): + user_name = user_name.strip('"').strip("'") + user_names.append(user_name) + + logger.info( + f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + ) + + return {"exist_user_names": user_names} + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn)