From ba01760c34b329743784db472ea0dc05c06ed4f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Tue, 30 Dec 2025 09:51:46 +0800 Subject: [PATCH 1/2] feat: add export_graph total --- src/memos/graph_dbs/polardb.py | 101 ++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 32 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8eb3e4ece..8eabda6d8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2522,7 +2522,9 @@ def export_graph( Returns: { "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], + "total_nodes": int, # Total number of nodes matching the filter criteria + "total_edges": int, # Total number of edges matching the filter criteria } """ logger.info( @@ -2530,6 +2532,10 @@ def export_graph( ) user_id = user_id if user_id else self._get_config_value("user_id") + # Initialize total counts + total_nodes = 0 + total_edges = 0 + # Determine if pagination is needed use_pagination = page is not None and page_size is not None @@ -2546,12 +2552,6 @@ def export_graph( conn = None try: conn = self._get_connection() - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - # Build WHERE conditions where_conditions = [] if user_name: @@ -2567,12 +2567,30 @@ def export_graph( if where_conditions: where_clause = f"WHERE {' AND '.join(where_conditions)}" + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + if include_embedding: node_query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY id + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC {pagination_clause} """ else: @@ -2580,7 +2598,8 @@ def export_graph( SELECT id, properties FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY id + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC {pagination_clause} """ logger.info(f"[export_graph nodes] Query: {node_query}") @@ -2591,9 +2610,11 @@ def export_graph( for row in node_results: if include_embedding: - properties_json, embedding_json = row + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row else: - properties_json = row + """row is (id, properties)""" + _, properties_json = row embedding_json = None # Parse properties from JSONB if it's a string @@ -2605,20 +2626,13 @@ def export_graph( else: properties = properties_json if properties_json else {} - # # Build node data - - """ - # node_data = { - # "id": properties.get("id", node_id), - # "memory": properties.get("memory", ""), - # "metadata": properties - # } - """ - - if include_embedding and embedding_json is not None: + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: properties["embedding"] = embedding_json - nodes.append(self._parse_node(json.loads(properties[1]))) + nodes.append(self._parse_node(properties)) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) @@ -2629,13 +2643,6 @@ def export_graph( conn = None try: conn = self._get_connection() - # Export edges using cypher query - # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery - # Build pagination clause if needed - edge_pagination_clause = "" - if use_pagination: - edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - # Build Cypher WHERE conditions for edges cypher_where_conditions = [] if user_name: @@ -2649,13 +2656,38 @@ def export_graph( if cypher_where_conditions: cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" + # Get total count of edges before pagination + count_edge_query = f""" + SELECT COUNT(*) + FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + {cypher_where_clause} + RETURN a.id AS source, b.id AS target, type(r) as edge + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + """ + logger.info(f"[export_graph edges count] Query: {count_edge_query}") + with conn.cursor() as cursor: + cursor.execute(count_edge_query) + total_edges = cursor.fetchone()[0] + + # Export edges using cypher query + # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery + # Build pagination clause if needed + edge_pagination_clause = "" + if use_pagination: + edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + edge_query = f""" SELECT source, target, edge FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (a:Memory)-[r]->(b:Memory) {cypher_where_clause} RETURN a.id AS source, b.id AS target, type(r) as edge - ORDER BY a.id, b.id + ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, + COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, + a.id DESC, b.id DESC $$) AS (source agtype, target agtype, edge agtype) ) AS edges {edge_pagination_clause} @@ -2726,7 +2758,12 @@ def export_graph( finally: self._return_connection(conn) - return {"nodes": nodes, "edges": edges} + return { + "nodes": nodes, + "edges": edges, + "total_nodes": total_nodes, + "total_edges": total_edges, + } @timed def count_nodes(self, scope: str, user_name: str | None = None) -> int: From b841633e24677e92622f4de1740024e18626ce37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Tue, 30 Dec 2025 09:57:19 +0800 Subject: [PATCH 2/2] feat: add export_graph total --- src/memos/graph_dbs/neo4j.py | 47 ++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index d57e7c596..c2dc4a629 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1150,11 +1150,17 @@ def export_graph( Returns: { "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], + "total_nodes": int, # Total number of nodes matching the filter criteria + "total_edges": int, # Total number of edges matching the filter criteria } """ user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + # Initialize total counts + total_nodes = 0 + total_edges = 0 + # Determine if pagination is needed use_pagination = page is not None and page_size is not None @@ -1167,28 +1173,38 @@ def export_graph( skip = (page - 1) * page_size with self.driver.session(database=self.db_name) as session: - # Export nodes - node_query = "MATCH (n:Memory)" - edge_query = "MATCH (a:Memory)-[r]->(b:Memory)" + # Build base queries + node_base_query = "MATCH (n:Memory)" + edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)" params = {} if not self.config.use_multi_db and (self.config.user_name or user_name): - node_query += " WHERE n.user_name = $user_name" - edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" + node_base_query += " WHERE n.user_name = $user_name" + edge_base_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" params["user_name"] = user_name - # Add ORDER BY and pagination for nodes - node_query += " RETURN n ORDER BY n.id" + # Get total count of nodes before pagination + count_node_query = node_base_query + " RETURN COUNT(n) AS count" + count_node_result = session.run(count_node_query, params) + total_nodes = count_node_result.single()["count"] + + # Export nodes with ORDER BY created_at DESC + node_query = node_base_query + " RETURN n ORDER BY n.created_at DESC, n.id DESC" if use_pagination: node_query += f" SKIP {skip} LIMIT {page_size}" node_result = session.run(node_query, params) nodes = [self._parse_node(dict(record["n"])) for record in node_result] - # Export edges - # Add ORDER BY and pagination for edges - edge_query += ( - " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id" + # Get total count of edges before pagination + count_edge_query = edge_base_query + " RETURN COUNT(r) AS count" + count_edge_result = session.run(count_edge_query, params) + total_edges = count_edge_result.single()["count"] + + # Export edges with ORDER BY created_at DESC + edge_query = ( + edge_base_query + + " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.created_at DESC, b.created_at DESC, a.id DESC, b.id DESC" ) if use_pagination: edge_query += f" SKIP {skip} LIMIT {page_size}" @@ -1199,7 +1215,12 @@ def export_graph( for record in edge_result ] - return {"nodes": nodes, "edges": edges} + return { + "nodes": nodes, + "edges": edges, + "total_nodes": total_nodes, + "total_edges": total_edges, + } def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: """