From efebcde906525506d55a4d5821ac71404e2db7b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Sun, 7 Dec 2025 15:16:12 +0800 Subject: [PATCH] fix get_subgraph --- src/memos/graph_dbs/polardb.py | 200 ++++++++++++++++++++++----------- 1 file changed, 134 insertions(+), 66 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 7db840082..657caf054 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1361,89 +1361,157 @@ def get_subgraph( r) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory) - WHERE - center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - RETURN - collect(DISTINCT - center), collect(DISTINCT - neighbor), collect(DISTINCT - r) - $$ ) as (centers agtype, neighbors agtype, rels agtype); - """ + # Use UNION ALL for better performance: separate queries for depth 1 and depth 2 + if depth == 1: + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ + else: + # For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries + query = f""" + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH(center: Memory)-[r]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r) + UNION ALL + MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory) + WHERE + center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' + RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) + $$ ) as (centers agtype, neighbors agtype, rels agtype); + """ conn = self._get_connection() logger.info(f"[get_subgraph] Query: {query}") try: with conn.cursor() as cursor: cursor.execute(query) - result = cursor.fetchone() + results = cursor.fetchall() - if not result or not result[0]: + if not results: return {"core_node": None, "neighbors": [], "edges": []} - # Parse center node - centers_data = result[0] if result[0] else "[]" - neighbors_data = result[1] if result[1] else "[]" - edges_data = result[2] if result[2] else "[]" + # Merge results from all UNION ALL rows + all_centers_list = [] + all_neighbors_list = [] + all_edges_list = [] - # Parse JSON data - try: - # Clean ::vertex and ::edge suffixes in data - if isinstance(centers_data, str): - centers_data = centers_data.replace("::vertex", "") - if isinstance(neighbors_data, str): - neighbors_data = neighbors_data.replace("::vertex", "") - if isinstance(edges_data, str): - edges_data = edges_data.replace("::edge", "") - - centers_list = ( - json.loads(centers_data) if isinstance(centers_data, str) else centers_data - ) - neighbors_list = ( - json.loads(neighbors_data) - if isinstance(neighbors_data, str) - else neighbors_data - ) - edges_list = ( - json.loads(edges_data) if isinstance(edges_data, str) else edges_data - ) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON data: {e}") - return {"core_node": None, "neighbors": [], "edges": []} + for result in results: + if not result or not result[0]: + continue + + centers_data = result[0] if result[0] else "[]" + neighbors_data = result[1] if result[1] else "[]" + edges_data = result[2] if result[2] else "[]" + + # Parse JSON data + try: + # Clean ::vertex and ::edge suffixes in data + if isinstance(centers_data, str): + centers_data = centers_data.replace("::vertex", "") + if isinstance(neighbors_data, str): + neighbors_data = neighbors_data.replace("::vertex", "") + if isinstance(edges_data, str): + edges_data = edges_data.replace("::edge", "") + + centers_list = ( + json.loads(centers_data) + if isinstance(centers_data, str) + else centers_data + ) + neighbors_list = ( + json.loads(neighbors_data) + if isinstance(neighbors_data, str) + else neighbors_data + ) + edges_list = ( + json.loads(edges_data) if isinstance(edges_data, str) else edges_data + ) + + # Collect data from this row + if isinstance(centers_list, list): + all_centers_list.extend(centers_list) + if isinstance(neighbors_list, list): + all_neighbors_list.extend(neighbors_list) + if isinstance(edges_list, list): + all_edges_list.extend(edges_list) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON data: {e}") + continue - # Parse center node + # Deduplicate centers by ID + centers_dict = {} + for center_data in all_centers_list: + if isinstance(center_data, dict) and "properties" in center_data: + center_id_key = center_data["properties"].get("id") + if center_id_key and center_id_key not in centers_dict: + centers_dict[center_id_key] = center_data + + # Parse center node (use first center) core_node = None - if centers_list and len(centers_list) > 0: - center_data = centers_list[0] + if centers_dict: + center_data = next(iter(centers_dict.values())) if isinstance(center_data, dict) and "properties" in center_data: core_node = self._parse_node(center_data["properties"]) + # Deduplicate neighbors by ID + neighbors_dict = {} + for neighbor_data in all_neighbors_list: + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_id = neighbor_data["properties"].get("id") + if neighbor_id and neighbor_id not in neighbors_dict: + neighbors_dict[neighbor_id] = neighbor_data + # Parse neighbor nodes neighbors = [] - if isinstance(neighbors_list, list): - for neighbor_data in neighbors_list: - if isinstance(neighbor_data, dict) and "properties" in neighbor_data: - neighbor_parsed = self._parse_node(neighbor_data["properties"]) - neighbors.append(neighbor_parsed) + for neighbor_data in neighbors_dict.values(): + if isinstance(neighbor_data, dict) and "properties" in neighbor_data: + neighbor_parsed = self._parse_node(neighbor_data["properties"]) + neighbors.append(neighbor_parsed) + + # Deduplicate edges by (source, target, type) + edges_dict = {} + for edge_group in all_edges_list: + if isinstance(edge_group, list): + for edge_data in edge_group: + if isinstance(edge_data, dict): + edge_key = ( + edge_data.get("start_id", ""), + edge_data.get("end_id", ""), + edge_data.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_data.get("label", ""), + "source": edge_data.get("start_id", ""), + "target": edge_data.get("end_id", ""), + } + elif isinstance(edge_group, dict): + # Handle single edge (not in a list) + edge_key = ( + edge_group.get("start_id", ""), + edge_group.get("end_id", ""), + edge_group.get("label", ""), + ) + if edge_key not in edges_dict: + edges_dict[edge_key] = { + "type": edge_group.get("label", ""), + "source": edge_group.get("start_id", ""), + "target": edge_group.get("end_id", ""), + } - # Parse edges - edges = [] - if isinstance(edges_list, list): - for edge_group in edges_list: - if isinstance(edge_group, list): - for edge_data in edge_group: - if isinstance(edge_data, dict): - edges.append( - { - "type": edge_data.get("label", ""), - "source": edge_data.get("start_id", ""), - "target": edge_data.get("end_id", ""), - } - ) + edges = list(edges_dict.values()) return self._convert_graph_edges( {"core_node": core_node, "neighbors": neighbors, "edges": edges}