Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 253 additions & 22 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3333,11 +3333,96 @@ def add_node(
logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}")
self._return_connection(conn)

def _write_batch_insert_sql_to_file(
self,
insert_query: str,
data_tuples: list[tuple],
embedding_column: str | None = None,
) -> None:
"""
Generate complete SQL statement and write to file.

Args:
insert_query: The INSERT query template
data_tuples: List of data tuples to insert
embedding_column: Optional embedding column name
"""
import os

from datetime import datetime

try:
# Build complete SQL statement
if embedding_column:
complete_sql = f'INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})\nVALUES\n'
else:
complete_sql = (
f'INSERT INTO {self.db_name}_graph."Memory"(id, properties)\nVALUES\n'
)

values_list = []
for data_tuple in data_tuples:
node_id = data_tuple[0]
properties_json = data_tuple[1]

# Escape single quotes in JSON strings
properties_escaped = properties_json.replace("'", "''")

if embedding_column and len(data_tuple) > 2:
embedding_json = data_tuple[2]
# Handle embedding vector - convert JSON array to PostgreSQL vector format
if embedding_json:
try:
embedding_list = json.loads(embedding_json)
# Convert to PostgreSQL vector format: [1,2,3]
embedding_vector_str = (
"[" + ",".join(str(float(x)) for x in embedding_list) + "]"
)
embedding_value = f"'{embedding_vector_str}'::vector"
except Exception:
embedding_value = "NULL::vector"
else:
embedding_value = "NULL::vector"

value_str = f" (ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{node_id}'::text::cstring), '{properties_escaped}'::text::agtype, {embedding_value})"
else:
value_str = f" (ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{node_id}'::text::cstring), '{properties_escaped}'::text::agtype)"

values_list.append(value_str)

complete_sql += ",\n".join(values_list) + ";\n"

# Write SQL to file
sql_file_path = os.path.join(
os.path.dirname(__file__),
"..",
"..",
"examples",
"basic_modules",
f"batch_insert_sql_{datetime.now().strftime('%Y%m%d_%H%M%S')}.sql",
)
sql_file_path = os.path.abspath(sql_file_path)
os.makedirs(os.path.dirname(sql_file_path), exist_ok=True)

with open(sql_file_path, "w", encoding="utf-8") as f:
f.write(f"-- Batch insert SQL generated at {datetime.now().isoformat()}\n")
f.write(f"-- Total nodes: {len(data_tuples)}\n")
if embedding_column:
f.write(f"-- Embedding column: {embedding_column}\n\n")
else:
f.write("-- No embedding column\n\n")
f.write(complete_sql)

logger.info(f"[add_nodes_batch] Complete SQL written to: {sql_file_path}")
except Exception as e:
logger.warning(f"[add_nodes_batch] Failed to write SQL to file: {e}")

@timed
def add_nodes_batch(
self,
nodes: list[dict[str, Any]],
user_name: str | None = None,
write_sql_to_file: bool = False,
) -> None:
"""
Batch add multiple memory nodes to the graph.
Expand All @@ -3348,12 +3433,14 @@ def add_nodes_batch(
- memory: str - Memory content
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
write_sql_to_file: If True, write the complete SQL statement to a file for debugging
"""
if not nodes:
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return

logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
# Only process the first node for debugging
logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})")

# user_name comes from parameter; fallback to config if missing
effective_user_name = user_name if user_name else self.config.user_name
Expand Down Expand Up @@ -3517,21 +3604,95 @@ def add_nodes_batch(
%s::vector
)
"""
# Write SQL to file if flag is set
if write_sql_to_file:
self._write_batch_insert_sql_to_file(
insert_query,
data_tuples,
embedding_column=embedding_column,
)

# Set statement timeout to prevent hanging (30 minutes)
try:
cursor.execute("SET statement_timeout = '1800s'")
logger.info("[add_nodes_batch] Set statement_timeout to 1800s")
except Exception as e:
logger.warning(
f"[add_nodes_batch] Failed to set statement_timeout: {e}"
)

# Execute batch insert with detailed logging
logger.info(
f"[add_nodes_batch] embedding_column Inserting insert_query:{insert_query}"
)
logger.info(
f"[add_nodes_batch] embedding_column Inserting data_tuples:{data_tuples}"
f"[add_nodes_batch] Starting execute_values - "
f"nodes: {len(data_tuples)}, page_size: 9"
)
start_time = time.time()

# Check for lock waits before executing
try:
cursor.execute("""
SELECT count(*) FROM pg_locks
WHERE NOT granted AND locktype = 'relation'
""")
lock_count = cursor.fetchone()[0]
if lock_count > 0:
logger.warning(
f"[add_nodes_batch] Found {lock_count} pending locks before insert"
)
except Exception as e:
logger.debug(f"[add_nodes_batch] Could not check locks: {e}")

try:
# Execute in smaller batches to avoid blocking
batch_size = 5 # Process 5 at a time
total_batches = (len(data_tuples) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
batch_start = batch_idx * batch_size
batch_end = min(batch_start + batch_size, len(data_tuples))
batch_data = data_tuples[batch_start:batch_end]

logger.info(
f"[add_nodes_batch] Processing batch {batch_idx + 1}/{total_batches} "
f"(nodes {batch_start + 1}-{batch_end} of {len(data_tuples)})"
)
batch_start_time = time.time()

execute_values(
cursor,
insert_query,
batch_data,
template=template,
page_size=len(batch_data), # Use actual batch size
)

# Explicitly commit after each batch if not autocommit
if not conn.autocommit:
conn.commit()
logger.debug(
f"[add_nodes_batch] Committed batch {batch_idx + 1}"
)

batch_elapsed = time.time() - batch_start_time
logger.info(
f"[add_nodes_batch] Batch {batch_idx + 1} completed in {batch_elapsed:.2f}s"
)

elapsed_time = time.time() - start_time
logger.info(
f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s"
)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(
f"[add_nodes_batch] execute_values failed after {elapsed_time:.2f}s: {e}",
exc_info=True,
)
if not conn.autocommit:
conn.rollback()
logger.warning("[add_nodes_batch] Rolled back transaction")
raise

# Execute batch insert
execute_values(
cursor,
insert_query,
data_tuples,
template=template,
page_size=100, # Insert in batches of 100
)
else:
# Prepare data tuples for batch insert without embedding
data_tuples = []
Expand All @@ -3558,16 +3719,86 @@ def add_nodes_batch(
%s::text::agtype
)
"""
logger.info(f"[add_nodes_batch] Inserting insert_query:{insert_query}")
logger.info(f"[add_nodes_batch] Inserting data_tuples:{data_tuples}")
# Execute batch insert
execute_values(
cursor,
insert_query,
data_tuples,
template=template,
page_size=100, # Insert in batches of 100
# Write SQL to file if flag is set
if write_sql_to_file:
self._write_batch_insert_sql_to_file(
insert_query,
data_tuples,
embedding_column=None,
)

# Set statement timeout to prevent hanging (30 minutes)
try:
cursor.execute("SET statement_timeout = '1800s'")
logger.info(
"[add_nodes_batch] Set statement_timeout to 1800s (no embedding)"
)
except Exception as e:
logger.warning(
f"[add_nodes_batch] Failed to set statement_timeout: {e}"
)

# Execute batch insert with detailed logging
# Split into smaller batches to avoid blocking (5 at a time)
batch_size = 5
total_batches = (len(data_tuples) + batch_size - 1) // batch_size

logger.info(
f"[add_nodes_batch] Starting execute_values (no embedding) - "
f"nodes: {len(data_tuples)}, will process in {total_batches} batches of {batch_size}"
)
start_time = time.time()

# Check for lock waits before executing
try:
cursor.execute("""
SELECT count(*) FROM pg_locks
WHERE NOT granted AND locktype = 'relation'
""")
lock_count = cursor.fetchone()[0]
if lock_count > 0:
logger.warning(
f"[add_nodes_batch] Found {lock_count} pending locks before insert"
)
except Exception as e:
logger.debug(f"[add_nodes_batch] Could not check locks: {e}")

try:
for batch_idx in range(total_batches):
batch_start = batch_idx * batch_size
batch_end = min(batch_start + batch_size, len(data_tuples))
batch_data = data_tuples[batch_start:batch_end]

logger.info(
f"[add_nodes_batch] Processing batch {batch_idx + 1}/{total_batches} "
f"(nodes {batch_start + 1}-{batch_end} of {len(data_tuples)})"
)
batch_start_time = time.time()

execute_values(
cursor,
insert_query,
batch_data,
template=template,
page_size=len(batch_data), # Use actual batch size
)

batch_elapsed = time.time() - batch_start_time
logger.info(
f"[add_nodes_batch] Batch {batch_idx + 1} completed in {batch_elapsed:.2f}s"
)

elapsed_time = time.time() - start_time
logger.info(
f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s"
)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(
f"[add_nodes_batch] execute_values failed after {elapsed_time:.2f}s at batch {batch_idx + 1}: {e}",
exc_info=True,
)
raise

logger.info(
f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"
Expand Down
Loading