diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index b26db5afa..b76ed9d08 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -250,3 +250,16 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> l Returns: list[dict]: Full list of memory items under this scope. """ + + @abstractmethod + def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None: + """ + Batch add multiple memory nodes to the graph. + + Args: + nodes: List of node dictionaries, each containing: + - id: str - Node ID + - memory: str - Memory content + - metadata: dict[str, Any] - Node metadata + user_name: Optional user name (will use config default if not provided) + """ diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 126e974a3..a0a4c6a50 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -236,6 +236,110 @@ def add_node( metadata=metadata, ) + def add_nodes_batch( + self, + nodes: list[dict[str, Any]], + user_name: str | None = None, + ) -> None: + """ + Batch add multiple memory nodes to the graph. + + Args: + nodes: List of node dictionaries, each containing: + - id: str - Node ID + - memory: str - Memory content + - metadata: dict[str, Any] - Node metadata + user_name: Optional user name (will use config default if not provided) + """ + if not nodes: + logger.warning("[add_nodes_batch] Empty nodes list, skipping") + return + + logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes") + + # user_name comes from parameter; fallback to config if missing + effective_user_name = user_name if user_name else self.config.user_name + + # Prepare all nodes + prepared_nodes = [] + for node_data in nodes: + try: + id = node_data["id"] + memory = node_data["memory"] + metadata = node_data.get("metadata", {}) + + logger.debug(f"[add_nodes_batch] Processing node id: {id}") + + # Set user_name in metadata if needed + if not self.config.use_multi_db and (self.config.user_name or effective_user_name): + metadata["user_name"] = effective_user_name + + # Safely process metadata + metadata = _prepare_node_metadata(metadata) + + # Flatten info fields to top level (for Neo4j flat structure) + metadata = _flatten_info_fields(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at") + updated_at = metadata.pop("updated_at") + + # Serialization for sources + if metadata.get("sources"): + for idx in range(len(metadata["sources"])): + metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) + + prepared_nodes.append( + { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + "metadata": metadata, + } + ) + except Exception as e: + logger.error( + f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", + exc_info=True, + ) + # Continue with other nodes + continue + + if not prepared_nodes: + logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") + return + + # Batch insert using Neo4j UNWIND for better performance + query = """ + UNWIND $nodes AS node + MERGE (n:Memory {id: node.id}) + SET n.memory = node.memory, + n.created_at = datetime(node.created_at), + n.updated_at = datetime(node.updated_at), + n += node.metadata + """ + + # Prepare nodes data for UNWIND + nodes_data = [ + { + "id": node["id"], + "memory": node["memory"], + "created_at": node["created_at"], + "updated_at": node["updated_at"], + "metadata": node["metadata"], + } + for node in prepared_nodes + ] + + try: + with self.driver.session(database=self.db_name) as session: + session.run(query, nodes=nodes_data) + logger.info(f"[add_nodes_batch] Successfully inserted {len(prepared_nodes)} nodes") + except Exception as e: + logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) + raise + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present. diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index fd0c0d4dd..a5599643e 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3226,6 +3226,190 @@ def add_node( logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") self._return_connection(conn) + @timed + def add_nodes_batch( + self, + nodes: list[dict[str, Any]], + user_name: str | None = None, + ) -> None: + """ + Batch add multiple memory nodes to the graph. + + Args: + nodes: List of node dictionaries, each containing: + - id: str - Node ID + - memory: str - Memory content + - metadata: dict[str, Any] - Node metadata + user_name: Optional user name (will use config default if not provided) + """ + if not nodes: + logger.warning("[add_nodes_batch] Empty nodes list, skipping") + return + + logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes") + + # user_name comes from parameter; fallback to config if missing + effective_user_name = user_name if user_name else self.config.user_name + + # Prepare all nodes + prepared_nodes = [] + for node_data in nodes: + try: + id = node_data["id"] + memory = node_data["memory"] + metadata = node_data.get("metadata", {}) + + logger.debug(f"[add_nodes_batch] Processing node id: {id}") + + # Set user_name in metadata + metadata["user_name"] = effective_user_name + + metadata = _prepare_node_metadata(metadata) + + # Merge node and set metadata + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Prepare properties + properties = { + "id": id, + "memory": memory, + "created_at": created_at, + "updated_at": updated_at, + **metadata, + } + + # Generate embedding if not provided + if "embedding" not in properties or not properties["embedding"]: + properties["embedding"] = generate_vector( + self._get_config_value("embedding_dimension", 1024) + ) + + # Serialization - JSON-serialize sources and usage fields + for field_name in ["sources", "usage"]: + if properties.get(field_name): + if isinstance(properties[field_name], list): + for idx in range(len(properties[field_name])): + # Serialize only when element is not a string + if not isinstance(properties[field_name][idx], str): + properties[field_name][idx] = json.dumps( + properties[field_name][idx] + ) + elif isinstance(properties[field_name], str): + # If already a string, leave as-is + pass + + # Extract embedding for separate column + embedding_vector = properties.pop("embedding", []) + if not isinstance(embedding_vector, list): + embedding_vector = [] + + # Select column name based on embedding dimension + embedding_column = "embedding" # default column + if len(embedding_vector) == 3072: + embedding_column = "embedding_3072" + elif len(embedding_vector) == 1024: + embedding_column = "embedding" + elif len(embedding_vector) == 768: + embedding_column = "embedding_768" + + prepared_nodes.append( + { + "id": id, + "memory": memory, + "properties": properties, + "embedding_vector": embedding_vector, + "embedding_column": embedding_column, + } + ) + except Exception as e: + logger.error( + f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", + exc_info=True, + ) + # Continue with other nodes + continue + + if not prepared_nodes: + logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") + return + + # Group nodes by embedding column to optimize batch inserts + nodes_by_embedding_column = {} + for node in prepared_nodes: + col = node["embedding_column"] + if col not in nodes_by_embedding_column: + nodes_by_embedding_column[col] = [] + nodes_by_embedding_column[col].append(node) + + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Process each group separately + for embedding_column, nodes_group in nodes_by_embedding_column.items(): + # Delete existing records first (batch delete) + for node in nodes_group: + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(delete_query, (node["id"],)) + + # Insert nodes (batch insert using executemany for better performance) + for node in nodes_group: + # Get graph_id for this node + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (node["id"],)) + graph_id = cursor.fetchone()[0] + node["properties"]["graph_id"] = str(graph_id) + + # Insert node + if node["embedding_vector"]: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + logger.info( + f"[add_nodes_batch] Inserting node insert_query={insert_query}" + ) + cursor.execute( + insert_query, + ( + node["id"], + json.dumps(node["properties"]), + json.dumps(node["embedding_vector"]), + ), + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute( + insert_query, + (node["id"], json.dumps(node["properties"])), + ) + + logger.info( + f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" + ) + + except Exception as e: + logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) + raise + finally: + self._return_connection(conn) + def _build_node_from_agtype(self, node_agtype, embedding=None): """ Parse the cypher-returned column `n` (agtype or JSON string)