Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,11 +1150,17 @@ def export_graph(
Returns:
{
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ],
"total_nodes": int, # Total number of nodes matching the filter criteria
"total_edges": int, # Total number of edges matching the filter criteria
}
"""
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name

# Initialize total counts
total_nodes = 0
total_edges = 0

# Determine if pagination is needed
use_pagination = page is not None and page_size is not None

Expand All @@ -1167,28 +1173,38 @@ def export_graph(
skip = (page - 1) * page_size

with self.driver.session(database=self.db_name) as session:
# Export nodes
node_query = "MATCH (n:Memory)"
edge_query = "MATCH (a:Memory)-[r]->(b:Memory)"
# Build base queries
node_base_query = "MATCH (n:Memory)"
edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)"
params = {}

if not self.config.use_multi_db and (self.config.user_name or user_name):
node_query += " WHERE n.user_name = $user_name"
edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
node_base_query += " WHERE n.user_name = $user_name"
edge_base_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
params["user_name"] = user_name

# Add ORDER BY and pagination for nodes
node_query += " RETURN n ORDER BY n.id"
# Get total count of nodes before pagination
count_node_query = node_base_query + " RETURN COUNT(n) AS count"
count_node_result = session.run(count_node_query, params)
total_nodes = count_node_result.single()["count"]

# Export nodes with ORDER BY created_at DESC
node_query = node_base_query + " RETURN n ORDER BY n.created_at DESC, n.id DESC"
if use_pagination:
node_query += f" SKIP {skip} LIMIT {page_size}"

node_result = session.run(node_query, params)
nodes = [self._parse_node(dict(record["n"])) for record in node_result]

# Export edges
# Add ORDER BY and pagination for edges
edge_query += (
" RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id"
# Get total count of edges before pagination
count_edge_query = edge_base_query + " RETURN COUNT(r) AS count"
count_edge_result = session.run(count_edge_query, params)
total_edges = count_edge_result.single()["count"]

# Export edges with ORDER BY created_at DESC
edge_query = (
edge_base_query
+ " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.created_at DESC, b.created_at DESC, a.id DESC, b.id DESC"
)
if use_pagination:
edge_query += f" SKIP {skip} LIMIT {page_size}"
Expand All @@ -1199,7 +1215,12 @@ def export_graph(
for record in edge_result
]

return {"nodes": nodes, "edges": edges}
return {
"nodes": nodes,
"edges": edges,
"total_nodes": total_nodes,
"total_edges": total_edges,
}

def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
"""
Expand Down
101 changes: 69 additions & 32 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2522,14 +2522,20 @@ def export_graph(
Returns:
{
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ],
"total_nodes": int, # Total number of nodes matching the filter criteria
"total_edges": int, # Total number of edges matching the filter criteria
}
"""
logger.info(
f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}"
)
user_id = user_id if user_id else self._get_config_value("user_id")

# Initialize total counts
total_nodes = 0
total_edges = 0

# Determine if pagination is needed
use_pagination = page is not None and page_size is not None

Expand All @@ -2546,12 +2552,6 @@ def export_graph(
conn = None
try:
conn = self._get_connection()
# Export nodes
# Build pagination clause if needed
pagination_clause = ""
if use_pagination:
pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

# Build WHERE conditions
where_conditions = []
if user_name:
Expand All @@ -2567,20 +2567,39 @@ def export_graph(
if where_conditions:
where_clause = f"WHERE {' AND '.join(where_conditions)}"

# Get total count of nodes before pagination
count_node_query = f"""
SELECT COUNT(*)
FROM "{self.db_name}_graph"."Memory"
{where_clause}
"""
logger.info(f"[export_graph nodes count] Query: {count_node_query}")
with conn.cursor() as cursor:
cursor.execute(count_node_query)
total_nodes = cursor.fetchone()[0]

# Export nodes
# Build pagination clause if needed
pagination_clause = ""
if use_pagination:
pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

if include_embedding:
node_query = f"""
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY id
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,
id DESC
{pagination_clause}
"""
else:
node_query = f"""
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY id
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,
id DESC
{pagination_clause}
"""
logger.info(f"[export_graph nodes] Query: {node_query}")
Expand All @@ -2591,9 +2610,11 @@ def export_graph(

for row in node_results:
if include_embedding:
properties_json, embedding_json = row
"""row is (id, properties, embedding)"""
_, properties_json, embedding_json = row
else:
properties_json = row
"""row is (id, properties)"""
_, properties_json = row
embedding_json = None

# Parse properties from JSONB if it's a string
Expand All @@ -2605,20 +2626,13 @@ def export_graph(
else:
properties = properties_json if properties_json else {}

# # Build node data

"""
# node_data = {
# "id": properties.get("id", node_id),
# "memory": properties.get("memory", ""),
# "metadata": properties
# }
"""

if include_embedding and embedding_json is not None:
# Remove embedding field if include_embedding is False
if not include_embedding:
properties.pop("embedding", None)
elif include_embedding and embedding_json is not None:
properties["embedding"] = embedding_json

nodes.append(self._parse_node(json.loads(properties[1])))
nodes.append(self._parse_node(properties))

except Exception as e:
logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True)
Expand All @@ -2629,13 +2643,6 @@ def export_graph(
conn = None
try:
conn = self._get_connection()
# Export edges using cypher query
# Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
# Build pagination clause if needed
edge_pagination_clause = ""
if use_pagination:
edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

# Build Cypher WHERE conditions for edges
cypher_where_conditions = []
if user_name:
Expand All @@ -2649,13 +2656,38 @@ def export_graph(
if cypher_where_conditions:
cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}"

# Get total count of edges before pagination
count_edge_query = f"""
SELECT COUNT(*)
FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (a:Memory)-[r]->(b:Memory)
{cypher_where_clause}
RETURN a.id AS source, b.id AS target, type(r) as edge
$$) AS (source agtype, target agtype, edge agtype)
) AS edges
"""
logger.info(f"[export_graph edges count] Query: {count_edge_query}")
with conn.cursor() as cursor:
cursor.execute(count_edge_query)
total_edges = cursor.fetchone()[0]

# Export edges using cypher query
# Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
# Build pagination clause if needed
edge_pagination_clause = ""
if use_pagination:
edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"

edge_query = f"""
SELECT source, target, edge FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (a:Memory)-[r]->(b:Memory)
{cypher_where_clause}
RETURN a.id AS source, b.id AS target, type(r) as edge
ORDER BY a.id, b.id
ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC,
COALESCE(b.created_at, '1970-01-01T00:00:00') DESC,
a.id DESC, b.id DESC
$$) AS (source agtype, target agtype, edge agtype)
) AS edges
{edge_pagination_clause}
Expand Down Expand Up @@ -2726,7 +2758,12 @@ def export_graph(
finally:
self._return_connection(conn)

return {"nodes": nodes, "edges": edges}
return {
"nodes": nodes,
"edges": edges,
"total_nodes": total_nodes,
"total_edges": total_edges,
}

@timed
def count_nodes(self, scope: str, user_name: str | None = None) -> int:
Expand Down