From f947d6c4030dc7e8754d0c15f1a8e21344b21a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 29 Dec 2025 17:24:05 +0800 Subject: [PATCH] feat: add get_user_names_by_memory_ids for polardb && neo4j --- src/memos/graph_dbs/neo4j.py | 64 ++++++++----------- src/memos/graph_dbs/polardb.py | 112 +++++++++++++++++++-------------- 2 files changed, 90 insertions(+), 86 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 2b3859252..d57e7c596 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1811,67 +1811,53 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") return deleted_count - def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: """Get user names by memory ids. Args: memory_ids: List of memory node IDs to query. Returns: - dict[str, list[str]]: Dictionary with one key: - - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) - - 'exist_user_names': List of distinct user names (if all memory_ids exist) + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} """ if not memory_ids: - return {"exist_user_names": []} + return {} - logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids") + logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") try: with self.driver.session(database=self.db_name) as session: - # Query to check which memory_ids exist - check_query = """ + # Query to get memory_id and user_name pairs + query = """ MATCH (n:Memory) WHERE n.id IN $memory_ids - RETURN n.id AS id + RETURN n.id AS memory_id, n.user_name AS user_name """ + logger.info(f"[get_user_names_by_memory_ids] query: {query}") - check_result = session.run(check_query, memory_ids=memory_ids) - existing_ids = set() - for record in check_result: - node_id = record["id"] - existing_ids.add(node_id) + result = session.run(query, memory_ids=memory_ids) + result_dict = {} - # Check if any memory_ids are missing - no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] - - # If any memory_ids are missing, return no_exist_memory_ids - if no_exist_list: - logger.info( - f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" - ) - return {"no_exist_memory_ids": no_exist_list} - - # All memory_ids exist, query user_names - user_names_query = """ - MATCH (n:Memory) - WHERE n.id IN $memory_ids - RETURN DISTINCT n.user_name AS user_name - """ - logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") - - user_names_result = session.run(user_names_query, memory_ids=memory_ids) - user_names = [] - for record in user_names_result: + # Build result dictionary from query results + for record in result: + memory_id = record["memory_id"] user_name = record["user_name"] - if user_name: - user_names.append(user_name) + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in memory_ids: + if mid not in result_dict: + result_dict[mid] = None logger.info( - f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" ) - return {"exist_user_names": user_names} + return result_dict except Exception as e: logger.error( f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f88824493..8eb3e4ece 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5065,86 +5065,104 @@ def delete_node_by_prams( return total_deleted_count @timed - def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: """Get user names by memory ids. Args: memory_ids: List of memory node IDs to query. Returns: - dict[str, list[str]]: Dictionary with one key: - - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) - - 'exist_user_names': List of distinct user names (if all memory_ids exist) + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} """ logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") if not memory_ids: - return {"exist_user_names": []} + return {} + + # Validate and normalize memory_ids + # Ensure all items are strings + normalized_memory_ids = [] + for mid in memory_ids: + if not isinstance(mid, str): + mid = str(mid) + # Remove any whitespace + mid = mid.strip() + if mid: + normalized_memory_ids.append(mid) + + if not normalized_memory_ids: + return {} + + # Escape special characters for JSON string format in agtype + def escape_memory_id(mid: str) -> str: + """Escape special characters in memory_id for JSON string format.""" + # Escape backslashes first, then double quotes + mid_str = mid.replace("\\", "\\\\") + mid_str = mid_str.replace('"', '\\"') + return mid_str # Build OR conditions for each memory_id id_conditions = [] - for mid in memory_ids: + for mid in normalized_memory_ids: + # Escape special characters + escaped_mid = escape_memory_id(mid) id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype" ) where_clause = f"({' OR '.join(id_conditions)})" - # Query to check which memory_ids exist - check_query = f""" - SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text + # Query to get memory_id and user_name pairs + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text AS memory_id, + ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text AS user_name FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ - logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}") + logger.info(f"[get_user_names_by_memory_ids] query: {query}") conn = None + result_dict = {} try: conn = self._get_connection() with conn.cursor() as cursor: - # Check which memory_ids exist - cursor.execute(check_query) - check_results = cursor.fetchall() - existing_ids = set() - for row in check_results: - node_id = row[0] - # Remove quotes if present - if isinstance(node_id, str): - node_id = node_id.strip('"').strip("'") - existing_ids.add(node_id) + cursor.execute(query) + results = cursor.fetchall() - # Check if any memory_ids are missing - no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] + # Build result dictionary from query results + for row in results: + memory_id_raw = row[0] + user_name_raw = row[1] - # If any memory_ids are missing, return no_exist_memory_ids - if no_exist_list: - logger.info( - f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" - ) - return {"no_exist_memory_ids": no_exist_list} + # Remove quotes if present + if isinstance(memory_id_raw, str): + memory_id = memory_id_raw.strip('"').strip("'") + else: + memory_id = str(memory_id_raw).strip('"').strip("'") - # All memory_ids exist, query user_names - user_names_query = f""" - SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") + if isinstance(user_name_raw, str): + user_name = user_name_raw.strip('"').strip("'") + else: + user_name = ( + str(user_name_raw).strip('"').strip("'") if user_name_raw else None + ) - cursor.execute(user_names_query) - results = cursor.fetchall() - user_names = [] - for row in results: - user_name = row[0] - # Remove quotes if present - if isinstance(user_name, str): - user_name = user_name.strip('"').strip("'") - user_names.append(user_name) + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in normalized_memory_ids: + if mid not in result_dict: + result_dict[mid] = None logger.info( - f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" ) - return {"exist_user_names": user_names} + return result_dict except Exception as e: logger.error( f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True