Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 25 additions & 39 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,67 +1811,53 @@ def delete_node_by_prams(
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count

def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
"""Get user names by memory ids.

Args:
memory_ids: List of memory node IDs to query.

Returns:
dict[str, list[str]]: Dictionary with one key:
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
dict[str, str | None]: Dictionary mapping memory_id to user_name.
- Key: memory_id
- Value: user_name if exists, None if memory_id does not exist
Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
"""
if not memory_ids:
return {"exist_user_names": []}
return {}

logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids")
logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}")

try:
with self.driver.session(database=self.db_name) as session:
# Query to check which memory_ids exist
check_query = """
# Query to get memory_id and user_name pairs
query = """
MATCH (n:Memory)
WHERE n.id IN $memory_ids
RETURN n.id AS id
RETURN n.id AS memory_id, n.user_name AS user_name
"""
logger.info(f"[get_user_names_by_memory_ids] query: {query}")

check_result = session.run(check_query, memory_ids=memory_ids)
existing_ids = set()
for record in check_result:
node_id = record["id"]
existing_ids.add(node_id)
result = session.run(query, memory_ids=memory_ids)
result_dict = {}

# Check if any memory_ids are missing
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]

# If any memory_ids are missing, return no_exist_memory_ids
if no_exist_list:
logger.info(
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
)
return {"no_exist_memory_ids": no_exist_list}

# All memory_ids exist, query user_names
user_names_query = """
MATCH (n:Memory)
WHERE n.id IN $memory_ids
RETURN DISTINCT n.user_name AS user_name
"""
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")

user_names_result = session.run(user_names_query, memory_ids=memory_ids)
user_names = []
for record in user_names_result:
# Build result dictionary from query results
for record in result:
memory_id = record["memory_id"]
user_name = record["user_name"]
if user_name:
user_names.append(user_name)
result_dict[memory_id] = user_name if user_name else None

# Set None for memory_ids that were not found
for mid in memory_ids:
if mid not in result_dict:
result_dict[mid] = None

logger.info(
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
)

return {"exist_user_names": user_names}
return result_dict
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
Expand Down
112 changes: 65 additions & 47 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5065,86 +5065,104 @@ def delete_node_by_prams(
return total_deleted_count

@timed
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
"""Get user names by memory ids.

Args:
memory_ids: List of memory node IDs to query.

Returns:
dict[str, list[str]]: Dictionary with one key:
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
dict[str, str | None]: Dictionary mapping memory_id to user_name.
- Key: memory_id
- Value: user_name if exists, None if memory_id does not exist
Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
"""
logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}")
if not memory_ids:
return {"exist_user_names": []}
return {}

# Validate and normalize memory_ids
# Ensure all items are strings
normalized_memory_ids = []
for mid in memory_ids:
if not isinstance(mid, str):
mid = str(mid)
# Remove any whitespace
mid = mid.strip()
if mid:
normalized_memory_ids.append(mid)

if not normalized_memory_ids:
return {}

# Escape special characters for JSON string format in agtype
def escape_memory_id(mid: str) -> str:
"""Escape special characters in memory_id for JSON string format."""
# Escape backslashes first, then double quotes
mid_str = mid.replace("\\", "\\\\")
mid_str = mid_str.replace('"', '\\"')
return mid_str

# Build OR conditions for each memory_id
id_conditions = []
for mid in memory_ids:
for mid in normalized_memory_ids:
# Escape special characters
escaped_mid = escape_memory_id(mid)
id_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype"
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype"
)

where_clause = f"({' OR '.join(id_conditions)})"

# Query to check which memory_ids exist
check_query = f"""
SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text
# Query to get memory_id and user_name pairs
query = f"""
SELECT
ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text AS memory_id,
ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text AS user_name
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""

logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}")
logger.info(f"[get_user_names_by_memory_ids] query: {query}")
conn = None
result_dict = {}
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Check which memory_ids exist
cursor.execute(check_query)
check_results = cursor.fetchall()
existing_ids = set()
for row in check_results:
node_id = row[0]
# Remove quotes if present
if isinstance(node_id, str):
node_id = node_id.strip('"').strip("'")
existing_ids.add(node_id)
cursor.execute(query)
results = cursor.fetchall()

# Check if any memory_ids are missing
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]
# Build result dictionary from query results
for row in results:
memory_id_raw = row[0]
user_name_raw = row[1]

# If any memory_ids are missing, return no_exist_memory_ids
if no_exist_list:
logger.info(
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
)
return {"no_exist_memory_ids": no_exist_list}
# Remove quotes if present
if isinstance(memory_id_raw, str):
memory_id = memory_id_raw.strip('"').strip("'")
else:
memory_id = str(memory_id_raw).strip('"').strip("'")

# All memory_ids exist, query user_names
user_names_query = f"""
SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")
if isinstance(user_name_raw, str):
user_name = user_name_raw.strip('"').strip("'")
else:
user_name = (
str(user_name_raw).strip('"').strip("'") if user_name_raw else None
)

cursor.execute(user_names_query)
results = cursor.fetchall()
user_names = []
for row in results:
user_name = row[0]
# Remove quotes if present
if isinstance(user_name, str):
user_name = user_name.strip('"').strip("'")
user_names.append(user_name)
result_dict[memory_id] = user_name if user_name else None

# Set None for memory_ids that were not found
for mid in normalized_memory_ids:
if mid not in result_dict:
result_dict[mid] = None

logger.info(
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
)

return {"exist_user_names": user_names}
return result_dict
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
Expand Down