Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/memos/graph_dbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,16 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> l
Returns:
list[dict]: Full list of memory items under this scope.
"""

@abstractmethod
def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None:
"""
Batch add multiple memory nodes to the graph.

Args:
nodes: List of node dictionaries, each containing:
- id: str - Node ID
- memory: str - Memory content
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
"""
104 changes: 104 additions & 0 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,110 @@ def add_node(
metadata=metadata,
)

def add_nodes_batch(
self,
nodes: list[dict[str, Any]],
user_name: str | None = None,
) -> None:
"""
Batch add multiple memory nodes to the graph.

Args:
nodes: List of node dictionaries, each containing:
- id: str - Node ID
- memory: str - Memory content
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
"""
if not nodes:
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return

logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")

# user_name comes from parameter; fallback to config if missing
effective_user_name = user_name if user_name else self.config.user_name

# Prepare all nodes
prepared_nodes = []
for node_data in nodes:
try:
id = node_data["id"]
memory = node_data["memory"]
metadata = node_data.get("metadata", {})

logger.debug(f"[add_nodes_batch] Processing node id: {id}")

# Set user_name in metadata if needed
if not self.config.use_multi_db and (self.config.user_name or effective_user_name):
metadata["user_name"] = effective_user_name

# Safely process metadata
metadata = _prepare_node_metadata(metadata)

# Flatten info fields to top level (for Neo4j flat structure)
metadata = _flatten_info_fields(metadata)

# Merge node and set metadata
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")

# Serialization for sources
if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])

prepared_nodes.append(
{
"id": id,
"memory": memory,
"created_at": created_at,
"updated_at": updated_at,
"metadata": metadata,
}
)
except Exception as e:
logger.error(
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
exc_info=True,
)
# Continue with other nodes
continue

if not prepared_nodes:
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
return

# Batch insert using Neo4j UNWIND for better performance
query = """
UNWIND $nodes AS node
MERGE (n:Memory {id: node.id})
SET n.memory = node.memory,
n.created_at = datetime(node.created_at),
n.updated_at = datetime(node.updated_at),
n += node.metadata
"""

# Prepare nodes data for UNWIND
nodes_data = [
{
"id": node["id"],
"memory": node["memory"],
"created_at": node["created_at"],
"updated_at": node["updated_at"],
"metadata": node["metadata"],
}
for node in prepared_nodes
]

try:
with self.driver.session(database=self.db_name) as session:
session.run(query, nodes=nodes_data)
logger.info(f"[add_nodes_batch] Successfully inserted {len(prepared_nodes)} nodes")
except Exception as e:
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
raise

def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
"""
Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present.
Expand Down
184 changes: 184 additions & 0 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3226,6 +3226,190 @@ def add_node(
logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}")
self._return_connection(conn)

@timed
def add_nodes_batch(
self,
nodes: list[dict[str, Any]],
user_name: str | None = None,
) -> None:
"""
Batch add multiple memory nodes to the graph.

Args:
nodes: List of node dictionaries, each containing:
- id: str - Node ID
- memory: str - Memory content
- metadata: dict[str, Any] - Node metadata
user_name: Optional user name (will use config default if not provided)
"""
if not nodes:
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return

logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")

# user_name comes from parameter; fallback to config if missing
effective_user_name = user_name if user_name else self.config.user_name

# Prepare all nodes
prepared_nodes = []
for node_data in nodes:
try:
id = node_data["id"]
memory = node_data["memory"]
metadata = node_data.get("metadata", {})

logger.debug(f"[add_nodes_batch] Processing node id: {id}")

# Set user_name in metadata
metadata["user_name"] = effective_user_name

metadata = _prepare_node_metadata(metadata)

# Merge node and set metadata
created_at = metadata.pop("created_at", datetime.utcnow().isoformat())
updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat())

# Prepare properties
properties = {
"id": id,
"memory": memory,
"created_at": created_at,
"updated_at": updated_at,
**metadata,
}

# Generate embedding if not provided
if "embedding" not in properties or not properties["embedding"]:
properties["embedding"] = generate_vector(
self._get_config_value("embedding_dimension", 1024)
)

# Serialization - JSON-serialize sources and usage fields
for field_name in ["sources", "usage"]:
if properties.get(field_name):
if isinstance(properties[field_name], list):
for idx in range(len(properties[field_name])):
# Serialize only when element is not a string
if not isinstance(properties[field_name][idx], str):
properties[field_name][idx] = json.dumps(
properties[field_name][idx]
)
elif isinstance(properties[field_name], str):
# If already a string, leave as-is
pass

# Extract embedding for separate column
embedding_vector = properties.pop("embedding", [])
if not isinstance(embedding_vector, list):
embedding_vector = []

# Select column name based on embedding dimension
embedding_column = "embedding" # default column
if len(embedding_vector) == 3072:
embedding_column = "embedding_3072"
elif len(embedding_vector) == 1024:
embedding_column = "embedding"
elif len(embedding_vector) == 768:
embedding_column = "embedding_768"

prepared_nodes.append(
{
"id": id,
"memory": memory,
"properties": properties,
"embedding_vector": embedding_vector,
"embedding_column": embedding_column,
}
)
except Exception as e:
logger.error(
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
exc_info=True,
)
# Continue with other nodes
continue

if not prepared_nodes:
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
return

# Group nodes by embedding column to optimize batch inserts
nodes_by_embedding_column = {}
for node in prepared_nodes:
col = node["embedding_column"]
if col not in nodes_by_embedding_column:
nodes_by_embedding_column[col] = []
nodes_by_embedding_column[col].append(node)

conn = None
try:
conn = self._get_connection()
with conn.cursor() as cursor:
# Process each group separately
for embedding_column, nodes_group in nodes_by_embedding_column.items():
# Delete existing records first (batch delete)
for node in nodes_group:
delete_query = f"""
DELETE FROM {self.db_name}_graph."Memory"
WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring)
"""
cursor.execute(delete_query, (node["id"],))

# Insert nodes (batch insert using executemany for better performance)
for node in nodes_group:
# Get graph_id for this node
get_graph_id_query = f"""
SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring)
"""
cursor.execute(get_graph_id_query, (node["id"],))
graph_id = cursor.fetchone()[0]
node["properties"]["graph_id"] = str(graph_id)

# Insert node
if node["embedding_vector"]:
insert_query = f"""
INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})
VALUES (
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
%s,
%s
)
"""
logger.info(
f"[add_nodes_batch] Inserting node insert_query={insert_query}"
)
cursor.execute(
insert_query,
(
node["id"],
json.dumps(node["properties"]),
json.dumps(node["embedding_vector"]),
),
)
else:
insert_query = f"""
INSERT INTO {self.db_name}_graph."Memory"(id, properties)
VALUES (
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
%s
)
"""
cursor.execute(
insert_query,
(node["id"], json.dumps(node["properties"])),
)

logger.info(
f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"
)

except Exception as e:
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
raise
finally:
self._return_connection(conn)

def _build_node_from_agtype(self, node_agtype, embedding=None):
"""
Parse the cypher-returned column `n` (agtype or JSON string)
Expand Down
Loading