Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 134 additions & 66 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,89 +1361,157 @@ def get_subgraph(
r)
$$ ) as (centers agtype, neighbors agtype, rels agtype);
"""
query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory)
WHERE
center.id = '{center_id}'
AND center.status = '{center_status}'
AND center.user_name = '{user_name}'
RETURN
collect(DISTINCT
center), collect(DISTINCT
neighbor), collect(DISTINCT
r)
$$ ) as (centers agtype, neighbors agtype, rels agtype);
"""
# Use UNION ALL for better performance: separate queries for depth 1 and depth 2
if depth == 1:
query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH(center: Memory)-[r]->(neighbor:Memory)
WHERE
center.id = '{center_id}'
AND center.status = '{center_status}'
AND center.user_name = '{user_name}'
RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r)
$$ ) as (centers agtype, neighbors agtype, rels agtype);
"""
else:
# For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries
query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH(center: Memory)-[r]->(neighbor:Memory)
WHERE
center.id = '{center_id}'
AND center.status = '{center_status}'
AND center.user_name = '{user_name}'
RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r)
UNION ALL
MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory)
WHERE
center.id = '{center_id}'
AND center.status = '{center_status}'
AND center.user_name = '{user_name}'
RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1)
$$ ) as (centers agtype, neighbors agtype, rels agtype);
"""
conn = self._get_connection()
logger.info(f"[get_subgraph] Query: {query}")
try:
with conn.cursor() as cursor:
cursor.execute(query)
result = cursor.fetchone()
results = cursor.fetchall()

if not result or not result[0]:
if not results:
return {"core_node": None, "neighbors": [], "edges": []}

# Parse center node
centers_data = result[0] if result[0] else "[]"
neighbors_data = result[1] if result[1] else "[]"
edges_data = result[2] if result[2] else "[]"
# Merge results from all UNION ALL rows
all_centers_list = []
all_neighbors_list = []
all_edges_list = []

# Parse JSON data
try:
# Clean ::vertex and ::edge suffixes in data
if isinstance(centers_data, str):
centers_data = centers_data.replace("::vertex", "")
if isinstance(neighbors_data, str):
neighbors_data = neighbors_data.replace("::vertex", "")
if isinstance(edges_data, str):
edges_data = edges_data.replace("::edge", "")

centers_list = (
json.loads(centers_data) if isinstance(centers_data, str) else centers_data
)
neighbors_list = (
json.loads(neighbors_data)
if isinstance(neighbors_data, str)
else neighbors_data
)
edges_list = (
json.loads(edges_data) if isinstance(edges_data, str) else edges_data
)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON data: {e}")
return {"core_node": None, "neighbors": [], "edges": []}
for result in results:
if not result or not result[0]:
continue

centers_data = result[0] if result[0] else "[]"
neighbors_data = result[1] if result[1] else "[]"
edges_data = result[2] if result[2] else "[]"

# Parse JSON data
try:
# Clean ::vertex and ::edge suffixes in data
if isinstance(centers_data, str):
centers_data = centers_data.replace("::vertex", "")
if isinstance(neighbors_data, str):
neighbors_data = neighbors_data.replace("::vertex", "")
if isinstance(edges_data, str):
edges_data = edges_data.replace("::edge", "")

centers_list = (
json.loads(centers_data)
if isinstance(centers_data, str)
else centers_data
)
neighbors_list = (
json.loads(neighbors_data)
if isinstance(neighbors_data, str)
else neighbors_data
)
edges_list = (
json.loads(edges_data) if isinstance(edges_data, str) else edges_data
)

# Collect data from this row
if isinstance(centers_list, list):
all_centers_list.extend(centers_list)
if isinstance(neighbors_list, list):
all_neighbors_list.extend(neighbors_list)
if isinstance(edges_list, list):
all_edges_list.extend(edges_list)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON data: {e}")
continue

# Parse center node
# Deduplicate centers by ID
centers_dict = {}
for center_data in all_centers_list:
if isinstance(center_data, dict) and "properties" in center_data:
center_id_key = center_data["properties"].get("id")
if center_id_key and center_id_key not in centers_dict:
centers_dict[center_id_key] = center_data

# Parse center node (use first center)
core_node = None
if centers_list and len(centers_list) > 0:
center_data = centers_list[0]
if centers_dict:
center_data = next(iter(centers_dict.values()))
if isinstance(center_data, dict) and "properties" in center_data:
core_node = self._parse_node(center_data["properties"])

# Deduplicate neighbors by ID
neighbors_dict = {}
for neighbor_data in all_neighbors_list:
if isinstance(neighbor_data, dict) and "properties" in neighbor_data:
neighbor_id = neighbor_data["properties"].get("id")
if neighbor_id and neighbor_id not in neighbors_dict:
neighbors_dict[neighbor_id] = neighbor_data

# Parse neighbor nodes
neighbors = []
if isinstance(neighbors_list, list):
for neighbor_data in neighbors_list:
if isinstance(neighbor_data, dict) and "properties" in neighbor_data:
neighbor_parsed = self._parse_node(neighbor_data["properties"])
neighbors.append(neighbor_parsed)
for neighbor_data in neighbors_dict.values():
if isinstance(neighbor_data, dict) and "properties" in neighbor_data:
neighbor_parsed = self._parse_node(neighbor_data["properties"])
neighbors.append(neighbor_parsed)

# Deduplicate edges by (source, target, type)
edges_dict = {}
for edge_group in all_edges_list:
if isinstance(edge_group, list):
for edge_data in edge_group:
if isinstance(edge_data, dict):
edge_key = (
edge_data.get("start_id", ""),
edge_data.get("end_id", ""),
edge_data.get("label", ""),
)
if edge_key not in edges_dict:
edges_dict[edge_key] = {
"type": edge_data.get("label", ""),
"source": edge_data.get("start_id", ""),
"target": edge_data.get("end_id", ""),
}
elif isinstance(edge_group, dict):
# Handle single edge (not in a list)
edge_key = (
edge_group.get("start_id", ""),
edge_group.get("end_id", ""),
edge_group.get("label", ""),
)
if edge_key not in edges_dict:
edges_dict[edge_key] = {
"type": edge_group.get("label", ""),
"source": edge_group.get("start_id", ""),
"target": edge_group.get("end_id", ""),
}

# Parse edges
edges = []
if isinstance(edges_list, list):
for edge_group in edges_list:
if isinstance(edge_group, list):
for edge_data in edge_group:
if isinstance(edge_data, dict):
edges.append(
{
"type": edge_data.get("label", ""),
"source": edge_data.get("start_id", ""),
"target": edge_data.get("end_id", ""),
}
)
edges = list(edges_dict.values())

return self._convert_graph_edges(
{"core_node": core_node, "neighbors": neighbors, "edges": edges}
Expand Down
Loading