From 3c7154cc936c2bd2e0c011b44e97c5b33122ee2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 8 Dec 2025 09:48:35 +0800 Subject: [PATCH 1/3] optimize --- src/memos/graph_dbs/polardb.py | 101 +++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index ddcbfe285..4502197a5 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2,6 +2,7 @@ import random import textwrap +from contextlib import suppress from datetime import datetime from typing import Any, Literal @@ -211,15 +212,19 @@ def _get_connection(self): # Check if connection is closed if conn.closed != 0: - # Connection is closed, close it explicitly and try again + # Connection is closed, return it to pool with close flag and try again try: - conn.close() + self.connection_pool.putconn(conn, close=True) except Exception as e: - logger.warning(f"Failed to close connection: {e}") + logger.warning(f"Failed to return closed connection to pool: {e}") + with suppress(Exception): + conn.close() + + conn = None if attempt < max_retries - 1: continue else: - raise RuntimeError("Pool returned a closed connection") + raise RuntimeError("Pool returned a closed connection after all retries") # Set autocommit for PolarDB compatibility conn.autocommit = True @@ -231,20 +236,18 @@ def _get_connection(self): cursor.fetchone() cursor.close() except Exception as health_check_error: - # Connection is not usable, close it and try again + # Connection is not usable, return it to pool with close flag and try again logger.warning( - f"Connection health check failed: {health_check_error}, closing connection and retrying..." + f"Connection health check failed: {health_check_error}, returning connection to pool and retrying..." ) - try: - conn.close() - except Exception as close_error: - logger.warning(f"Failed to close unhealthy connection: {close_error}") - - # Return connection to pool if it's still valid try: self.connection_pool.putconn(conn, close=True) - except Exception as close_error: - logger.warning(f"Failed to connection_pool.putconn: {close_error}") + except Exception as putconn_error: + logger.warning( + f"Failed to return unhealthy connection to pool: {putconn_error}" + ) + with suppress(Exception): + conn.close() conn = None if attempt < max_retries - 1: @@ -257,14 +260,20 @@ def _get_connection(self): # Connection is healthy, return it return conn except Exception as e: - # If we have a connection that failed, try to return it to pool + # Only try to return connection if we actually got one + # If getconn() failed (e.g., pool exhausted), conn will be None if conn is not None: try: - self.connection_pool.putconn(conn, close=True) + # If it's a PoolError or similar, close the connection instead of returning + if "pool" in str(e).lower() or "exhausted" in str(e).lower(): + with suppress(Exception): + conn.close() + else: + self.connection_pool.putconn(conn, close=True) except Exception as putconn_error: - logger.warning( - f"Failed to connection_pool.putconn to pool: {putconn_error}" - ) + logger.warning(f"Failed to handle connection after error: {putconn_error}") + with suppress(Exception): + conn.close() if attempt >= max_retries - 1: raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e @@ -272,26 +281,38 @@ def _get_connection(self): def _return_connection(self, connection): """Return a connection to the pool.""" - if not self._pool_closed and connection: - try: - # Check if connection is closed - if hasattr(connection, "closed") and connection.closed != 0: - # Connection is closed, just close it and don't return to pool - try: - connection.close() - except Exception as e: - logger.warning(f"Failed to close connection: {e}") - return + if self._pool_closed: + # Pool is closed, just close the connection if it exists + if connection: + try: + connection.close() + except Exception as e: + logger.warning(f"Failed to close connection after pool closed: {e}") + return - # Connection is valid, return to pool - self.connection_pool.putconn(connection) - except Exception as e: - # If putconn fails, close the connection - logger.warning(f"Failed to return connection to pool: {e}") + if not connection: + # No connection to return + return + + try: + # Check if connection is closed + if hasattr(connection, "closed") and connection.closed != 0: + # Connection is closed, just close it explicitly and don't return to pool try: connection.close() except Exception as e: - logger.warning(f"Failed to close connection: {e}") + logger.warning(f"Failed to close closed connection: {e}") + return + + # Connection is valid, return to pool + self.connection_pool.putconn(connection) + except Exception as e: + # If putconn fails, try to close the connection + logger.warning(f"Failed to return connection to pool: {e}") + try: + connection.close() + except Exception as close_error: + logger.warning(f"Failed to close connection after putconn error: {close_error}") def _return_connection_old(self, connection): """Return a connection to the pool.""" @@ -3116,8 +3137,10 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - conn = self._get_connection() + conn = None + insert_query = None try: + conn = self._get_connection() with conn.cursor() as cursor: # Delete existing record first (if any) delete_query = f""" @@ -3161,8 +3184,12 @@ def add_node( logger.info( f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" ) + except Exception as e: + logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) + raise finally: - logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") + if insert_query: + logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): From d33869baede190ecd5937c6a5c91e59b9fb246ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 8 Dec 2025 10:07:06 +0800 Subject: [PATCH 2/3] optimize --- src/memos/graph_dbs/polardb.py | 102 ++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 34 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 4502197a5..57688f5ff 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -333,8 +333,9 @@ def _ensure_database_exists(self): def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: # Create schema if it doesn't exist cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') @@ -398,8 +399,9 @@ def create_index( Note: This creates PostgreSQL indexes on the underlying tables. """ # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables @@ -435,8 +437,9 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in params = [self.format_param_value(memory_type), self.format_param_value(user_name)] # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -461,8 +464,9 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: params = [self.format_param_value(scope), self.format_param_value(user_name)] # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -501,8 +505,9 @@ def remove_oldest_memory( self.format_param_value(user_name), keep_latest, ] - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: # Execute query to get IDs to delete cursor.execute(select_query, select_params) @@ -595,8 +600,9 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N params.append(self.format_param_value(user_name)) # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: @@ -625,8 +631,9 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: params.append(self.format_param_value(user_name)) # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: @@ -639,8 +646,9 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: # Ensure in the correct database context cursor.execute("SELECT current_database();") @@ -670,8 +678,9 @@ def create_extension(self): @timed def create_graph(self): # Get a connection from the pool - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(f""" SELECT COUNT(*) FROM ag_catalog.ag_graph @@ -697,9 +706,10 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - conn = self._get_connection() + conn = None logger.info(f"Creating elabel: {label_name}") try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") logger.info(f"Successfully created elabel: {label_name}") @@ -746,8 +756,9 @@ def add_edge( ); """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") @@ -770,8 +781,9 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type)) logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") @@ -831,8 +843,9 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -886,8 +899,9 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() @@ -925,8 +939,9 @@ def get_node( query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -1015,8 +1030,9 @@ def get_nodes( query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1273,8 +1289,9 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1413,9 +1430,10 @@ def get_subgraph( RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - conn = self._get_connection() + conn = None logger.info(f"[get_subgraph] Query: {query}") try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1621,8 +1639,9 @@ def seach_by_keywords_like( logger.info( f"[seach_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1717,8 +1736,9 @@ def seach_by_keywords_tfidf( logger.info( f"[seach_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1838,8 +1858,9 @@ def search_by_fulltext( params = [tsquery_string, tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1978,8 +1999,9 @@ def search_by_embedding( logger.info(f"[search_by_embedding] query: {query}, params: {params}") - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: try: # If params is empty, execute query directly without parameters @@ -2130,9 +2152,10 @@ def get_by_metadata( """ ids = [] - conn = self._get_connection() + conn = None logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -2292,8 +2315,9 @@ def get_grouped_counts( {where_clause} GROUP BY {", ".join(group_by_fields)} """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: # Handle parameterized query if params and isinstance(params, list): @@ -2352,8 +2376,9 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query) logger.info("Cleared all nodes from database.") @@ -2380,8 +2405,9 @@ def export_graph( } """ user_name = user_name if user_name else self._get_config_value("user_name") - conn = self._get_connection() + conn = None try: + conn = self._get_connection() # Export nodes if include_embedding: node_query = f""" @@ -2438,8 +2464,9 @@ def export_graph( finally: self._return_connection(conn) - conn = self._get_connection() + conn = None try: + conn = self._get_connection() # Export edges using cypher query edge_query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ @@ -2528,8 +2555,9 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() result = self.execute_query(query, conn) return int(result.one_or_none()["count"].value) finally: @@ -2614,9 +2642,10 @@ def get_all_memory_items( """ nodes = [] node_ids = set() - conn = self._get_connection() + conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -2665,9 +2694,10 @@ def get_all_memory_items( """ nodes = [] - conn = self._get_connection() + conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -2889,8 +2919,9 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() @@ -3297,8 +3328,9 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -3595,8 +3627,9 @@ def get_edges( RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ - conn = self._get_connection() + conn = None try: + conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -4494,9 +4527,10 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") print(f"[delete_node_by_prams] delete_query: {delete_query}") - conn = self._get_connection() + conn = None deleted_count = 0 try: + conn = self._get_connection() with conn.cursor() as cursor: # Count nodes before deletion cursor.execute(count_query) From cb5b99b58b179f942ce12079eb8783098c3c8063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E5=A4=A7=E6=B4=8B?= <714403855@qq.com> Date: Mon, 8 Dec 2025 12:36:01 +0800 Subject: [PATCH 3/3] optimize --- src/memos/graph_dbs/polardb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 57688f5ff..1d8a25b67 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -152,7 +152,7 @@ def __init__(self, config: PolarDBGraphDBConfig): # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( minconn=5, - maxconn=100, + maxconn=500, host=host, port=port, user=user,