diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 018911db2..c7c32a388 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3333,11 +3333,96 @@ def add_node( logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") self._return_connection(conn) + def _write_batch_insert_sql_to_file( + self, + insert_query: str, + data_tuples: list[tuple], + embedding_column: str | None = None, + ) -> None: + """ + Generate complete SQL statement and write to file. + + Args: + insert_query: The INSERT query template + data_tuples: List of data tuples to insert + embedding_column: Optional embedding column name + """ + import os + + from datetime import datetime + + try: + # Build complete SQL statement + if embedding_column: + complete_sql = f'INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})\nVALUES\n' + else: + complete_sql = ( + f'INSERT INTO {self.db_name}_graph."Memory"(id, properties)\nVALUES\n' + ) + + values_list = [] + for data_tuple in data_tuples: + node_id = data_tuple[0] + properties_json = data_tuple[1] + + # Escape single quotes in JSON strings + properties_escaped = properties_json.replace("'", "''") + + if embedding_column and len(data_tuple) > 2: + embedding_json = data_tuple[2] + # Handle embedding vector - convert JSON array to PostgreSQL vector format + if embedding_json: + try: + embedding_list = json.loads(embedding_json) + # Convert to PostgreSQL vector format: [1,2,3] + embedding_vector_str = ( + "[" + ",".join(str(float(x)) for x in embedding_list) + "]" + ) + embedding_value = f"'{embedding_vector_str}'::vector" + except Exception: + embedding_value = "NULL::vector" + else: + embedding_value = "NULL::vector" + + value_str = f" (ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{node_id}'::text::cstring), '{properties_escaped}'::text::agtype, {embedding_value})" + else: + value_str = f" (ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{node_id}'::text::cstring), '{properties_escaped}'::text::agtype)" + + values_list.append(value_str) + + complete_sql += ",\n".join(values_list) + ";\n" + + # Write SQL to file + sql_file_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "examples", + "basic_modules", + f"batch_insert_sql_{datetime.now().strftime('%Y%m%d_%H%M%S')}.sql", + ) + sql_file_path = os.path.abspath(sql_file_path) + os.makedirs(os.path.dirname(sql_file_path), exist_ok=True) + + with open(sql_file_path, "w", encoding="utf-8") as f: + f.write(f"-- Batch insert SQL generated at {datetime.now().isoformat()}\n") + f.write(f"-- Total nodes: {len(data_tuples)}\n") + if embedding_column: + f.write(f"-- Embedding column: {embedding_column}\n\n") + else: + f.write("-- No embedding column\n\n") + f.write(complete_sql) + + logger.info(f"[add_nodes_batch] Complete SQL written to: {sql_file_path}") + except Exception as e: + logger.warning(f"[add_nodes_batch] Failed to write SQL to file: {e}") + @timed def add_nodes_batch( self, nodes: list[dict[str, Any]], user_name: str | None = None, + write_sql_to_file: bool = False, ) -> None: """ Batch add multiple memory nodes to the graph. @@ -3348,12 +3433,14 @@ def add_nodes_batch( - memory: str - Memory content - metadata: dict[str, Any] - Node metadata user_name: Optional user name (will use config default if not provided) + write_sql_to_file: If True, write the complete SQL statement to a file for debugging """ if not nodes: logger.warning("[add_nodes_batch] Empty nodes list, skipping") return - logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes") + # Only process the first node for debugging + logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") # user_name comes from parameter; fallback to config if missing effective_user_name = user_name if user_name else self.config.user_name @@ -3517,21 +3604,95 @@ def add_nodes_batch( %s::vector ) """ + # Write SQL to file if flag is set + if write_sql_to_file: + self._write_batch_insert_sql_to_file( + insert_query, + data_tuples, + embedding_column=embedding_column, + ) + + # Set statement timeout to prevent hanging (30 minutes) + try: + cursor.execute("SET statement_timeout = '1800s'") + logger.info("[add_nodes_batch] Set statement_timeout to 1800s") + except Exception as e: + logger.warning( + f"[add_nodes_batch] Failed to set statement_timeout: {e}" + ) + + # Execute batch insert with detailed logging logger.info( - f"[add_nodes_batch] embedding_column Inserting insert_query:{insert_query}" - ) - logger.info( - f"[add_nodes_batch] embedding_column Inserting data_tuples:{data_tuples}" + f"[add_nodes_batch] Starting execute_values - " + f"nodes: {len(data_tuples)}, page_size: 9" ) + start_time = time.time() + + # Check for lock waits before executing + try: + cursor.execute(""" + SELECT count(*) FROM pg_locks + WHERE NOT granted AND locktype = 'relation' + """) + lock_count = cursor.fetchone()[0] + if lock_count > 0: + logger.warning( + f"[add_nodes_batch] Found {lock_count} pending locks before insert" + ) + except Exception as e: + logger.debug(f"[add_nodes_batch] Could not check locks: {e}") + + try: + # Execute in smaller batches to avoid blocking + batch_size = 5 # Process 5 at a time + total_batches = (len(data_tuples) + batch_size - 1) // batch_size + + for batch_idx in range(total_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(data_tuples)) + batch_data = data_tuples[batch_start:batch_end] + + logger.info( + f"[add_nodes_batch] Processing batch {batch_idx + 1}/{total_batches} " + f"(nodes {batch_start + 1}-{batch_end} of {len(data_tuples)})" + ) + batch_start_time = time.time() + + execute_values( + cursor, + insert_query, + batch_data, + template=template, + page_size=len(batch_data), # Use actual batch size + ) + + # Explicitly commit after each batch if not autocommit + if not conn.autocommit: + conn.commit() + logger.debug( + f"[add_nodes_batch] Committed batch {batch_idx + 1}" + ) + + batch_elapsed = time.time() - batch_start_time + logger.info( + f"[add_nodes_batch] Batch {batch_idx + 1} completed in {batch_elapsed:.2f}s" + ) + + elapsed_time = time.time() - start_time + logger.info( + f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s" + ) + except Exception as e: + elapsed_time = time.time() - start_time + logger.error( + f"[add_nodes_batch] execute_values failed after {elapsed_time:.2f}s: {e}", + exc_info=True, + ) + if not conn.autocommit: + conn.rollback() + logger.warning("[add_nodes_batch] Rolled back transaction") + raise - # Execute batch insert - execute_values( - cursor, - insert_query, - data_tuples, - template=template, - page_size=100, # Insert in batches of 100 - ) else: # Prepare data tuples for batch insert without embedding data_tuples = [] @@ -3558,16 +3719,86 @@ def add_nodes_batch( %s::text::agtype ) """ - logger.info(f"[add_nodes_batch] Inserting insert_query:{insert_query}") - logger.info(f"[add_nodes_batch] Inserting data_tuples:{data_tuples}") - # Execute batch insert - execute_values( - cursor, - insert_query, - data_tuples, - template=template, - page_size=100, # Insert in batches of 100 + # Write SQL to file if flag is set + if write_sql_to_file: + self._write_batch_insert_sql_to_file( + insert_query, + data_tuples, + embedding_column=None, + ) + + # Set statement timeout to prevent hanging (30 minutes) + try: + cursor.execute("SET statement_timeout = '1800s'") + logger.info( + "[add_nodes_batch] Set statement_timeout to 1800s (no embedding)" + ) + except Exception as e: + logger.warning( + f"[add_nodes_batch] Failed to set statement_timeout: {e}" + ) + + # Execute batch insert with detailed logging + # Split into smaller batches to avoid blocking (5 at a time) + batch_size = 5 + total_batches = (len(data_tuples) + batch_size - 1) // batch_size + + logger.info( + f"[add_nodes_batch] Starting execute_values (no embedding) - " + f"nodes: {len(data_tuples)}, will process in {total_batches} batches of {batch_size}" ) + start_time = time.time() + + # Check for lock waits before executing + try: + cursor.execute(""" + SELECT count(*) FROM pg_locks + WHERE NOT granted AND locktype = 'relation' + """) + lock_count = cursor.fetchone()[0] + if lock_count > 0: + logger.warning( + f"[add_nodes_batch] Found {lock_count} pending locks before insert" + ) + except Exception as e: + logger.debug(f"[add_nodes_batch] Could not check locks: {e}") + + try: + for batch_idx in range(total_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(data_tuples)) + batch_data = data_tuples[batch_start:batch_end] + + logger.info( + f"[add_nodes_batch] Processing batch {batch_idx + 1}/{total_batches} " + f"(nodes {batch_start + 1}-{batch_end} of {len(data_tuples)})" + ) + batch_start_time = time.time() + + execute_values( + cursor, + insert_query, + batch_data, + template=template, + page_size=len(batch_data), # Use actual batch size + ) + + batch_elapsed = time.time() - batch_start_time + logger.info( + f"[add_nodes_batch] Batch {batch_idx + 1} completed in {batch_elapsed:.2f}s" + ) + + elapsed_time = time.time() - start_time + logger.info( + f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s" + ) + except Exception as e: + elapsed_time = time.time() - start_time + logger.error( + f"[add_nodes_batch] execute_values failed after {elapsed_time:.2f}s at batch {batch_idx + 1}: {e}", + exc_info=True, + ) + raise logger.info( f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"