Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 106 additions & 14 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4763,7 +4763,7 @@ def process_condition(condition):
@timed
def delete_node_by_prams(
self,
writable_cube_ids: list[str],
writable_cube_ids: list[str] | None = None,
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
Expand All @@ -4772,7 +4772,8 @@ def delete_node_by_prams(
Delete nodes by memory_ids, file_ids, or filter.

Args:
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
If not provided, no user_name filter will be applied.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
Expand All @@ -4785,17 +4786,15 @@ def delete_node_by_prams(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)

# Validate writable_cube_ids
if not writable_cube_ids or len(writable_cube_ids) == 0:
raise ValueError("writable_cube_ids is required and cannot be empty")

# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
# Only add user_name filter if writable_cube_ids is provided
user_name_conditions = []
for cube_id in writable_cube_ids:
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
user_name_conditions.append(
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
)
if writable_cube_ids and len(writable_cube_ids) > 0:
for cube_id in writable_cube_ids:
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
user_name_conditions.append(
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
)

# Build WHERE conditions separately for memory_ids and file_ids
where_conditions = []
Expand Down Expand Up @@ -4863,9 +4862,14 @@ def delete_node_by_prams(
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])

# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
user_name_where = " OR ".join(user_name_conditions)
where_clause = f"({user_name_where}) AND ({data_conditions})"
# Build final WHERE clause
# If user_name_conditions exist, combine with data_conditions using AND
# Otherwise, use only data_conditions
if user_name_conditions:
user_name_where = " OR ".join(user_name_conditions)
where_clause = f"({user_name_where}) AND ({data_conditions})"
else:
where_clause = f"({data_conditions})"

# Use SQL DELETE query for better performance
# First count matching nodes to get accurate count
Expand Down Expand Up @@ -4917,3 +4921,91 @@ def delete_node_by_prams(

logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count

@timed
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
"""Get user names by memory ids.

Args:
memory_ids: List of memory node IDs to query.

Returns:
dict[str, list[str]]: Dictionary with one key:
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
"""
if not memory_ids:
return {"exist_user_names": []}

# Build OR conditions for each memory_id
id_conditions = []
for mid in memory_ids:
id_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype"
)

where_clause = f"({' OR '.join(id_conditions)})"

# Query to check which memory_ids exist
check_query = f"""
SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""

logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}")
conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Check which memory_ids exist
cursor.execute(check_query)
check_results = cursor.fetchall()
existing_ids = set()
for row in check_results:
node_id = row[0]
# Remove quotes if present
if isinstance(node_id, str):
node_id = node_id.strip('"').strip("'")
existing_ids.add(node_id)

# Check if any memory_ids are missing
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]

# If any memory_ids are missing, return no_exist_memory_ids
if no_exist_list:
logger.info(
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
)
return {"no_exist_memory_ids": no_exist_list}

# All memory_ids exist, query user_names
user_names_query = f"""
SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")

cursor.execute(user_names_query)
results = cursor.fetchall()
user_names = []
for row in results:
user_name = row[0]
# Remove quotes if present
if isinstance(user_name, str):
user_name = user_name.strip('"').strip("'")
user_names.append(user_name)

logger.info(
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
)

return {"exist_user_names": user_names}
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
)
raise
finally:
self._return_connection(conn)