From ca16cc683adbdb62d1befc96fd3c268ccea4fcc1 Mon Sep 17 00:00:00 2001 From: nuoxiichn Date: Mon, 16 Mar 2026 16:02:57 +0800 Subject: [PATCH 1/3] feat: enable specifiy database and collection name in local api with neo4j(enterprise) --- src/memos/api/config.py | 42 +++- src/memos/api/handlers/add_handler.py | 47 ++++- src/memos/api/handlers/component_init.py | 81 ++++++++ src/memos/api/handlers/search_handler.py | 83 ++++++-- src/memos/configs/graph_db.py | 7 + src/memos/graph_dbs/neo4j.py | 201 +++++++++++++++++- src/memos/vec_dbs/qdrant.py | 252 ++++++++++++++++++----- 7 files changed, 634 insertions(+), 79 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 06aa50c65..a5a2ee768 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -700,20 +700,38 @@ def get_neo4j_config(user_id: str | None = None) -> dict[str, Any]: @staticmethod def get_noshared_neo4j_config(user_id) -> dict[str, Any]: """Get Neo4j configuration.""" - return { + config = { "uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"), "user": os.getenv("NEO4J_USER", "neo4j"), - "db_name": f"memos{user_id.replace('-', '')}", + "db_name": f"{user_id.replace('_', '-')}", "password": os.getenv("NEO4J_PASSWORD", "12345678"), "auto_create": True, "use_multi_db": True, "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } + enable_qdrant_sync = os.getenv("MOS_NEO4J_ENABLE_QDRANT_SYNC", "true").lower() == "true" + if enable_qdrant_sync: + config["vec_config"] = { + "backend": "qdrant", + "config": { + "collection_name": "neo4j_vec_db", + "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), + "distance_metric": "cosine", + "host": os.getenv("QDRANT_HOST", "localhost"), + "port": int(os.getenv("QDRANT_PORT", "6333")), + "path": os.getenv("QDRANT_PATH"), + "url": os.getenv("QDRANT_URL"), + "api_key": os.getenv("QDRANT_API_KEY"), + }, + } + + return config + @staticmethod def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]: """Get Neo4j configuration.""" - return { + config = { "uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"), "user": os.getenv("NEO4J_USER", "neo4j"), "db_name": os.getenv("NEO4J_DB_NAME", "shared-tree-textual-memory"), @@ -724,6 +742,24 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } + enable_qdrant_sync = os.getenv("MOS_NEO4J_ENABLE_QDRANT_SYNC", "true").lower() == "true" + if enable_qdrant_sync: + config["vec_config"] = { + "backend": "qdrant", + "config": { + "collection_name": "neo4j_vec_db", + "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), + "distance_metric": "cosine", + "host": os.getenv("QDRANT_HOST", "localhost"), + "port": int(os.getenv("QDRANT_PORT", "6333")), + "path": os.getenv("QDRANT_PATH"), + "url": os.getenv("QDRANT_URL"), + "api_key": os.getenv("QDRANT_API_KEY"), + }, + } + + return config + @staticmethod def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: """Get Nebular configuration.""" diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index e9ed4f955..e72585fd8 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -5,9 +5,13 @@ using dependency injection for better modularity and testability. """ +import os +import threading + from pydantic import validate_call from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.component_init import create_per_db_components from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse from memos.memories.textual.item import ( list_all_fields, @@ -36,6 +40,8 @@ def __init__(self, dependencies: HandlerDependencies): self._validate_dependencies( "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + self._per_user_cube_cache: dict[str, dict] = {} + self._cache_lock = threading.Lock() def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ @@ -113,6 +119,31 @@ def _check_messages(messages: MessageList) -> None: data=results, ) + @property + def _is_neo4j_multidb(self) -> bool: + """Return True when using Neo4j enterprise with one-database-per-user mode.""" + backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "")).lower() + shared_db = os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true" + return backend == "neo4j" and not shared_db + + def _get_per_user_components(self, user_id: str) -> dict: + """Return (creating on first access) per-user graph/mem components. + + Uses double-checked locking so the expensive component creation happens + only once per user even under concurrent requests. + """ + if user_id not in self._per_user_cube_cache: + with self._cache_lock: + if user_id not in self._per_user_cube_cache: + self.logger.info( + f"[AddHandler] Creating per-user components for user_id={user_id!r}" + ) + self._per_user_cube_cache[user_id] = create_per_db_components( + db_name=user_id, + base_components=vars(self.deps), + ) + return self._per_user_cube_cache[user_id] + def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: """ Normalize target cube ids from add_req. @@ -128,12 +159,20 @@ def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: cube_ids = self._resolve_cube_ids(add_req) + if self._is_neo4j_multidb: + per_user = self._get_per_user_components(add_req.user_id) + naive_mem_cube = per_user["naive_mem_cube"] + mem_reader = per_user["mem_reader"] + else: + naive_mem_cube = self.naive_mem_cube + mem_reader = self.mem_reader + if len(cube_ids) == 1: cube_id = cube_ids[0] return SingleCubeView( cube_id=cube_id, - naive_mem_cube=self.naive_mem_cube, - mem_reader=self.mem_reader, + naive_mem_cube=naive_mem_cube, + mem_reader=mem_reader, mem_scheduler=self.mem_scheduler, logger=self.logger, feedback_server=self.feedback_server, @@ -143,8 +182,8 @@ def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: single_views = [ SingleCubeView( cube_id=cube_id, - naive_mem_cube=self.naive_mem_cube, - mem_reader=self.mem_reader, + naive_mem_cube=naive_mem_cube, + mem_reader=mem_reader, mem_scheduler=self.mem_scheduler, logger=self.logger, feedback_server=self.feedback_server, diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index aa2525878..893097de8 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -316,3 +316,84 @@ def init_server() -> dict[str, Any]: "nli_client": nli_client, "memory_history_manager": memory_history_manager, } + + +def create_per_db_components(db_name: str, base_components: dict[str, Any]) -> dict[str, Any]: + """Create a set of per-database components for multi-db isolation. + + Reuses expensive shared singletons (LLM, embedder, reranker, etc.) but builds + a fresh graph_db, MemoryManager, SimpleTreeTextMemory, NaiveMemCube, and + searcher for the specified Neo4j database name. + + The returned ``mem_reader`` is a shallow copy of the shared one whose + ``searcher`` is overridden to point at the new database, so deduplication + during add operates against the correct graph. + + Args: + db_name: Target Neo4j database name (auto-created when ``auto_create=True``). + base_components: Shared component dict returned by :func:`init_server`. + + Returns: + Dict with keys: ``graph_db``, ``memory_manager``, ``text_mem``, + ``naive_mem_cube``, ``searcher``, ``mem_reader``. + """ + import copy + + from memos.api.config import APIConfig + from memos.configs.graph_db import GraphDBConfigFactory + from memos.graph_dbs.factory import GraphStoreFactory + + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j") + ).lower() + neo4j_cfg = APIConfig.get_neo4j_config(user_id=db_name) + new_graph_db = GraphStoreFactory.from_config( + GraphDBConfigFactory.model_validate( + {"backend": graph_db_backend, "config": neo4j_cfg} + ) + ) + + default_cube_config = base_components["default_cube_config"] + new_memory_manager = MemoryManager( + new_graph_db, + base_components["embedder"], + base_components["llm"], + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + new_text_mem = SimpleTreeTextMemory( + llm=base_components["llm"], + embedder=base_components["embedder"], + mem_reader=base_components["mem_reader"], + graph_db=new_graph_db, + reranker=base_components["reranker"], + memory_manager=new_memory_manager, + config=default_cube_config.text_mem.config, + internet_retriever=base_components["internet_retriever"], + tokenizer=FastTokenizer(), + include_embedding=bool(os.getenv("INCLUDE_EMBEDDING", "false") == "true"), + ) + + new_naive_mem_cube = NaiveMemCube(text_mem=new_text_mem, act_mem=None, para_mem=None) + + new_searcher = new_text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + process_llm=base_components["mem_reader"].llm, + ) + + # Shallow-copy the shared mem_reader and point its searcher at the new database + # so deduplication reads target the correct graph store. + new_mem_reader = copy.copy(base_components["mem_reader"]) + new_mem_reader.set_searcher(new_searcher) + + logger.info(f"[create_per_db_components] Created components for db_name={db_name!r}") + return { + "graph_db": new_graph_db, + "memory_manager": new_memory_manager, + "text_mem": new_text_mem, + "naive_mem_cube": new_naive_mem_cube, + "searcher": new_searcher, + "mem_reader": new_mem_reader, + } diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index ba1c50b07..a440ce246 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -7,13 +7,17 @@ import copy import math +import os +import threading from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.component_init import create_per_db_components from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( cosine_similarity_matrix, ) @@ -43,6 +47,35 @@ def __init__(self, dependencies: HandlerDependencies): self._validate_dependencies( "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) + # Cache per-database components in Neo4j multi-db mode. + self._per_db_cube_cache: dict[str, dict[str, Any]] = {} + self._cache_lock = threading.Lock() + + @property + def _is_neo4j_multidb(self) -> bool: + """Return True when using Neo4j enterprise with one-database-per-cube mode.""" + backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "")).lower() + shared_db = os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true" + return backend == "neo4j" and not shared_db + + def _get_per_db_components(self, db_name: str) -> dict[str, Any]: + """Return cached per-db components, creating them on first access.""" + if db_name not in self._per_db_cube_cache: + with self._cache_lock: + if db_name not in self._per_db_cube_cache: + self.logger.info( + f"[SearchHandler] Creating per-db components for db_name={db_name!r}" + ) + per_db = create_per_db_components( + db_name=db_name, + base_components=vars(self.deps), + ) + per_db["deepsearch_agent"] = DeepSearchMemAgent( + llm=self.llm, + memory_retriever=per_db["text_mem"], + ) + self._per_db_cube_cache[db_name] = per_db + return self._per_db_cube_cache[db_name] def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -801,8 +834,28 @@ def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCubeView: cube_ids = self._resolve_cube_ids(search_req) - searcher_to_use = searcher if searcher is not None else self.searcher + if self._is_neo4j_multidb: + single_views = [] + for cube_id in cube_ids: + per_db = self._get_per_db_components(cube_id) + searcher_to_use = searcher if searcher is not None else per_db["searcher"] + single_views.append( + SingleCubeView( + cube_id=cube_id, + naive_mem_cube=per_db["naive_mem_cube"], + mem_reader=per_db["mem_reader"], + mem_scheduler=self.mem_scheduler, + logger=self.logger, + searcher=searcher_to_use, + deepsearch_agent=per_db["deepsearch_agent"], + ) + ) + + if len(single_views) == 1: + return single_views[0] + return CompositeCubeView(cube_views=single_views, logger=self.logger) + searcher_to_use = searcher if searcher is not None else self.searcher if len(cube_ids) == 1: cube_id = cube_ids[0] return SingleCubeView( @@ -814,17 +867,17 @@ def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCu searcher=searcher_to_use, deepsearch_agent=self.deepsearch_agent, ) - else: - single_views = [ - SingleCubeView( - cube_id=cube_id, - naive_mem_cube=self.naive_mem_cube, - mem_reader=self.mem_reader, - mem_scheduler=self.mem_scheduler, - logger=self.logger, - searcher=searcher_to_use, - deepsearch_agent=self.deepsearch_agent, - ) - for cube_id in cube_ids - ] - return CompositeCubeView(cube_views=single_views, logger=self.logger) + + single_views = [ + SingleCubeView( + cube_id=cube_id, + naive_mem_cube=self.naive_mem_cube, + mem_reader=self.mem_reader, + mem_scheduler=self.mem_scheduler, + logger=self.logger, + searcher=searcher_to_use, + deepsearch_agent=self.deepsearch_agent, + ) + for cube_id in cube_ids + ] + return CompositeCubeView(cube_views=single_views, logger=self.logger) diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 5900d2357..6d7a9af11 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -69,6 +69,13 @@ class Neo4jGraphDBConfig(BaseGraphDBConfig): ) embedding_dimension: int = Field(default=768, description="Dimension of vector embedding") + vec_config: VectorDBConfigFactory | None = Field( + default=None, + description=( + "Optional external vector DB config for syncing embeddings (e.g., Qdrant). " + "When provided, graph writes can also sync to vector storage." + ), + ) @model_validator(mode="after") def validate_config(self): diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 33eb39692..3b5c63eda 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -8,6 +8,8 @@ from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger +from memos.vec_dbs.factory import VecDBFactory +from memos.vec_dbs.item import VecDBItem logger = get_logger(__name__) @@ -104,6 +106,16 @@ def __init__(self, config: Neo4jGraphDBConfig): self.driver = GraphDatabase.driver(config.uri, auth=(config.user, config.password)) self.db_name = config.db_name self.user_name = config.user_name + self.vec_db = None + + if config.vec_config is not None: + try: + self.vec_db = VecDBFactory.from_config(config.vec_config) + logger.info("[Neo4jGraphDB] External vector DB sync is enabled.") + except Exception as e: + logger.warning( + f"[Neo4jGraphDB] Failed to initialize external vector DB sync, disabling it: {e}" + ) self.system_db_name = "system" if config.use_multi_db else config.db_name if config.auto_create: @@ -230,6 +242,13 @@ def add_node( for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) + vec_item: VecDBItem | None = self._build_vec_item( + id=id, + memory=memory, + metadata=metadata, + user_name=user_name, + ) + with self.driver.session(database=self.db_name) as session: session.run( query, @@ -240,6 +259,9 @@ def add_node( metadata=metadata, ) + if vec_item is not None and self.vec_db is not None: + self.vec_db.add([vec_item]) + def add_nodes_batch( self, nodes: list[dict[str, Any]], @@ -267,6 +289,7 @@ def add_nodes_batch( # Prepare all nodes prepared_nodes = [] + vec_items: list[VecDBItem] = [] for node_data in nodes: try: id = node_data["id"] @@ -298,6 +321,15 @@ def add_nodes_batch( for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) + vec_item = self._build_vec_item( + id=id, + memory=memory, + metadata=metadata, + user_name=effective_user_name, + ) + if vec_item is not None: + vec_items.append(vec_item) + prepared_nodes.append( { "id": id, @@ -349,6 +381,55 @@ def add_nodes_batch( logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) raise + if vec_items and self.vec_db is not None: + self.vec_db.add(vec_items) + + def _build_vec_item( + self, + id: str, + memory: str, + metadata: dict[str, Any], + user_name: str | None, + ) -> VecDBItem | None: + """Build vector item for optional external vector DB sync.""" + if self.vec_db is None: + return None + + embedding = metadata.get("embedding") + if embedding is None: + logger.warning( + f"[Neo4jGraphDB] Missing embedding for node {id}, skip external vector sync." + ) + return None + + payload = { + "memory": memory, + "vector_sync": "success", + **metadata, + } + + if user_name and "user_name" not in payload: + payload["user_name"] = user_name + + # In multi-db mode, keep Neo4j db_name as-is (often '-') but route vector scope with + # underscore naming so Qdrant collection naming remains stable per logical user id. + if self.config.use_multi_db and "user_id" not in payload: + payload["user_id"] = self._get_vec_user_scope() + + return VecDBItem( + id=id, + vector=embedding, + payload=payload, + ) + + def _get_vec_user_scope(self) -> str: + """Return vector routing user scope name for multi-db mode. + + Neo4j database names may use '-' while Qdrant collections are expected to use + '_' by API user_id convention. + """ + return (self.db_name or "").replace("-", "_") + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present. @@ -511,10 +592,10 @@ def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[s Dictionary of node fields, or None if not found. """ logger.info(f"[get_node] id: {id}") - user_name = kwargs.get("user_name") + user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name where_user = "" params = {"id": id} - if user_name is not None: + if not self.config.use_multi_db and (self.config.user_name or user_name): where_user = " AND n.user_name = $user_name" params["user_name"] = user_name @@ -852,6 +933,122 @@ def search_by_embedding( matching archived or merged nodes. """ user_name = user_name if user_name else self.config.user_name + + # Prefer external vector DB path (e.g., Qdrant) when configured. + if self.vec_db is not None: + vec_filter: dict[str, Any] = {"vector_sync": "success"} + if scope: + vec_filter["memory_type"] = scope + if status: + vec_filter["status"] = status + + # Keep routing key consistent with write path: + # - multi-db: payload/filter carries user_id from db_name normalized to '_' + # - shared-db: payload carries user_name + if self.config.use_multi_db: + vec_filter["user_id"] = self._get_vec_user_scope() + elif kwargs.get("cube_name"): + vec_filter["user_name"] = kwargs["cube_name"] + elif user_name: + vec_filter["user_name"] = user_name + + if search_filter: + vec_filter.update(search_filter) + + # resolved_collection = self.vec_db.config.collection_name + # resolve_collection_fn = getattr(self.vec_db, "_resolve_collection_name", None) + # if callable(resolve_collection_fn): + # try: + # resolved_collection = resolve_collection_fn(filter_dict=vec_filter) + # except Exception as e: + # print(f"[DEBUG] collection_resolve_error: {e}") + + # print(f"[DEBUG] vec_filter: {vec_filter}") + # print(f"[DEBUG] collection(config): {self.vec_db.config.collection_name}") + # print(f"[DEBUG] collection(resolved): {resolved_collection}") + + vec_results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter) + + if threshold is not None: + vec_results = [r for r in vec_results if r.score is None or r.score >= threshold] + + vec_ids = [r.id for r in vec_results] + if not vec_ids: + return [] + + # Fast path: no extra graph-side filtering needed. + if not filter and not knowledgebase_ids and not return_fields: + return [{"id": r.id, "score": r.score} for r in vec_results] + + # Build Neo4j post-filter query on vector result IDs. + where_clauses = ["n.id IN $vec_ids"] + params: dict[str, Any] = {"vec_ids": vec_ids} + + user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + node_alias="n", + ) + + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter=filter, + param_counter_start=0, + node_alias="n", + ) + where_clauses.extend(filter_conditions) + + params.update(user_name_params) + if filter_params: + params.update(filter_params) + + return_clause = "RETURN n.id AS id" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + + query = f""" + MATCH (n:Memory) + WHERE {' AND '.join(where_clauses)} + {return_clause} + """ + + with self.driver.session(database=self.db_name) as session: + neo4j_results = list(session.run(query, params)) + + if return_fields: + neo4j_data = {} + for record in neo4j_results: + node_id = record["id"] + neo4j_data[node_id] = { + field: record[field] + for field in return_fields + if field != "id" and field in record.keys() + } + valid_ids = set(neo4j_data.keys()) + else: + valid_ids = {record["id"] for record in neo4j_results} + + filtered_results = [] + for r in vec_results: + if r.id in valid_ids: + item = {"id": r.id, "score": r.score} + if return_fields: + item.update(neo4j_data.get(r.id, {})) + filtered_results.append(item) + + return filtered_results + # Build WHERE clause dynamically where_clauses = [] if scope: diff --git a/src/memos/vec_dbs/qdrant.py b/src/memos/vec_dbs/qdrant.py index d0853c4af..6831e01da 100644 --- a/src/memos/vec_dbs/qdrant.py +++ b/src/memos/vec_dbs/qdrant.py @@ -1,3 +1,6 @@ +import re + +from collections import defaultdict from typing import Any from memos.configs.vec_db import QdrantVecDBConfig @@ -62,15 +65,110 @@ def __init__(self, config: QdrantVecDBConfig): except Exception as e: logger.warning(f"Failed to ensure default payload indexes: {e}") + def _sanitize_collection_name(self, name: str) -> str: + """Normalize user-scope names so they are safe as Qdrant collection names.""" + normalized = (name or "").strip() + if not normalized: + return self.config.collection_name + + # Keep a conservative charset to avoid backend-specific naming issues. + normalized = re.sub(r"[^a-zA-Z0-9_-]", "_", normalized).strip("_") + if not normalized: + return self.config.collection_name + + return normalized[:255] + + def _extract_user_scope(self, data: dict[str, Any] | None) -> str | None: + """Extract user scope from payload/filter for per-user collection routing.""" + if not isinstance(data, dict): + return None + + for key in ("user_id", "user_name"): + value = data.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + metadata = data.get("metadata") + if isinstance(metadata, dict): + for key in ("user_id", "user_name"): + value = metadata.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + metadata_info = metadata.get("info") + if isinstance(metadata_info, dict): + for key in ("user_id", "user_name"): + value = metadata_info.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + info = data.get("info") + if isinstance(info, dict): + for key in ("user_id", "user_name"): + value = info.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + return None + + def _resolve_collection_name( + self, + *, + payload: dict[str, Any] | None = None, + filter_dict: dict[str, Any] | None = None, + ) -> str: + """Resolve collection name from payload/filter, falling back to default config.""" + user_scope = self._extract_user_scope(payload) or self._extract_user_scope(filter_dict) + if user_scope: + return self._sanitize_collection_name(user_scope) + return self.config.collection_name + + def _strip_scope_filter(self, filter_dict: dict[str, Any] | None) -> dict[str, Any] | None: + """Drop user scope keys from filter when collection is already user-scoped.""" + if not filter_dict: + return filter_dict + + effective_filter = dict(filter_dict) + effective_filter.pop("user_id", None) + effective_filter.pop("user_name", None) + return effective_filter + + def _all_candidate_collections(self) -> list[str]: + """Return all collections with default collection first for compatibility.""" + collections = self.list_collections() + ordered = [self.config.collection_name] + ordered.extend(name for name in collections if name != self.config.collection_name) + return ordered + + def _ensure_collection_ready(self, collection_name: str) -> None: + """Create collection and payload indexes if missing.""" + if self.collection_exists(collection_name): + return + + self._create_collection_by_name(collection_name) + try: + self.ensure_payload_indexes( + self._default_payload_index_fields, + collection_name=collection_name, + ) + except Exception as e: + logger.warning( + f"Failed to ensure payload indexes for collection '{collection_name}': {e}" + ) + def create_collection(self) -> None: - """Create a new collection with specified parameters.""" + """Create the default configured collection with specified parameters.""" + self._create_collection_by_name(self.config.collection_name) + + def _create_collection_by_name(self, collection_name: str) -> None: + """Create a specific collection with configured vector parameters.""" from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse - if self.collection_exists(self.config.collection_name): - collection_info = self.client.get_collection(self.config.collection_name) + if self.collection_exists(collection_name): + collection_info = self.client.get_collection(collection_name) logger.warning( - f"Collection '{self.config.collection_name}' (vector dimension: {collection_info.config.params.vectors.size}) already exists. Skipping creation." + f"Collection '{collection_name}' (vector dimension: {collection_info.config.params.vectors.size}) already exists. Skipping creation." ) return @@ -84,7 +182,7 @@ def create_collection(self) -> None: try: self.client.create_collection( - collection_name=self.config.collection_name, + collection_name=collection_name, vectors_config=models.VectorParams( size=self.config.vector_dimension, distance=distance_map[self.config.distance_metric], @@ -94,7 +192,7 @@ def create_collection(self) -> None: # Cloud Qdrant returns 409 when the collection already exists; tolerate and continue. if getattr(err, "status_code", None) == 409 or "already exists" in str(err).lower(): logger.warning( - f"Collection '{self.config.collection_name}' already exists. Skipping creation." + f"Collection '{collection_name}' already exists. Skipping creation." ) return raise @@ -103,7 +201,7 @@ def create_collection(self) -> None: raise logger.info( - f"Collection '{self.config.collection_name}' created with {self.config.vector_dimension} dimensions." + f"Collection '{collection_name}' created with {self.config.vector_dimension} dimensions." ) def list_collections(self) -> list[str]: @@ -137,9 +235,15 @@ def search( Returns: List of search results with distance scores and payloads. """ - qdrant_filter = self._dict_to_filter(filter) if filter else None + collection_name = self._resolve_collection_name(filter_dict=filter) + if not self.collection_exists(collection_name): + logger.info(f"Qdrant collection '{collection_name}' does not exist, returning empty search result.") + return [] + + effective_filter = self._strip_scope_filter(filter) + qdrant_filter = self._dict_to_filter(effective_filter) if effective_filter else None response = self.client.query_points( - collection_name=self.config.collection_name, + collection_name=collection_name, query=query_vector, limit=top_k, query_filter=qdrant_filter, @@ -174,43 +278,56 @@ def _dict_to_filter(self, filter_dict: dict[str, Any]) -> Any: def get_by_id(self, id: str) -> VecDBItem | None: """Get a single item by ID.""" - response = self.client.retrieve( - collection_name=self.config.collection_name, - ids=[id], - with_payload=True, - with_vectors=True, - ) - - if not response: - return None + for collection_name in self._all_candidate_collections(): + try: + response = self.client.retrieve( + collection_name=collection_name, + ids=[id], + with_payload=True, + with_vectors=True, + ) + except Exception: + continue + + if response: + point = response[0] + return VecDBItem( + id=point.id, + vector=point.vector, + payload=point.payload, + ) - point = response[0] - return VecDBItem( - id=point.id, - vector=point.vector, - payload=point.payload, - ) + return None def get_by_ids(self, ids: list[str]) -> list[VecDBItem]: """Get multiple items by their IDs.""" - response = self.client.retrieve( - collection_name=self.config.collection_name, - ids=ids, - with_payload=True, - with_vectors=True, - ) + remaining_ids = set(ids) + found_items: dict[str, VecDBItem] = {} - if not response: - return [] + for collection_name in self._all_candidate_collections(): + if not remaining_ids: + break - return [ - VecDBItem( - id=point.id, - vector=point.vector, - payload=point.payload, - ) - for point in response - ] + try: + response = self.client.retrieve( + collection_name=collection_name, + ids=list(remaining_ids), + with_payload=True, + with_vectors=True, + ) + except Exception: + continue + + for point in response: + item = VecDBItem( + id=point.id, + vector=point.vector, + payload=point.payload, + ) + found_items[item.id] = item + remaining_ids.discard(item.id) + + return [found_items[id] for id in ids if id in found_items] def get_by_filter(self, filter: dict[str, Any], scroll_limit: int = 100) -> list[VecDBItem]: """ @@ -223,14 +340,22 @@ def get_by_filter(self, filter: dict[str, Any], scroll_limit: int = 100) -> list Returns: List of items including vectors and payload that match the filter """ - qdrant_filter = self._dict_to_filter(filter) if filter else None + collection_name = self._resolve_collection_name(filter_dict=filter) + if not self.collection_exists(collection_name): + logger.info( + f"Qdrant collection '{collection_name}' does not exist, returning empty filter result." + ) + return [] + + effective_filter = self._strip_scope_filter(filter) + qdrant_filter = self._dict_to_filter(effective_filter) if effective_filter else None all_points = [] offset = None # Use scroll to paginate through all matching points while True: points, offset = self.client.scroll( - collection_name=self.config.collection_name, + collection_name=collection_name, limit=scroll_limit, scroll_filter=qdrant_filter, offset=offset, @@ -263,12 +388,19 @@ def get_all(self, scroll_limit=100) -> list[VecDBItem]: def count(self, filter: dict[str, Any] | None = None) -> int: """Count items in the database, optionally with filter.""" + collection_name = self._resolve_collection_name(filter_dict=filter) + if not self.collection_exists(collection_name): + logger.info(f"Qdrant collection '{collection_name}' does not exist, count=0.") + return 0 + qdrant_filter = None if filter: - qdrant_filter = self._dict_to_filter(filter) + effective_filter = self._strip_scope_filter(filter) + qdrant_filter = self._dict_to_filter(effective_filter) if effective_filter else None response = self.client.count( - collection_name=self.config.collection_name, count_filter=qdrant_filter + collection_name=collection_name, + count_filter=qdrant_filter, ) return response.count @@ -285,15 +417,18 @@ def add(self, data: list[VecDBItem | dict[str, Any]]) -> None: - 'vector': embedding vector - 'payload': additional fields for filtering/retrieval """ - points = [] + points_by_collection: dict[str, list[Any]] = defaultdict(list) for item in data: if isinstance(item, dict): item = item.copy() item = VecDBItem.from_dict(item) point = models.PointStruct(id=item.id, vector=item.vector, payload=item.payload) - points.append(point) + collection_name = self._resolve_collection_name(payload=item.payload) + points_by_collection[collection_name].append(point) - self.client.upsert(collection_name=self.config.collection_name, points=points) + for collection_name, points in points_by_collection.items(): + self._ensure_collection_ready(collection_name) + self.client.upsert(collection_name=collection_name, points=points) def update(self, id: str, data: VecDBItem | dict[str, Any]) -> None: """Update an item in the vector database.""" @@ -303,19 +438,24 @@ def update(self, id: str, data: VecDBItem | dict[str, Any]) -> None: data = data.copy() data = VecDBItem.from_dict(data) + collection_name = self._resolve_collection_name(payload=data.payload) + self._ensure_collection_ready(collection_name) + if data.vector: # For vector updates (with or without payload), use upsert with the same ID self.client.upsert( - collection_name=self.config.collection_name, + collection_name=collection_name, points=[models.PointStruct(id=id, vector=data.vector, payload=data.payload)], ) else: # For payload-only updates self.client.set_payload( - collection_name=self.config.collection_name, payload=data.payload, points=[id] + collection_name=collection_name, + payload=data.payload, + points=[id], ) - def ensure_payload_indexes(self, fields: list[str]) -> None: + def ensure_payload_indexes(self, fields: list[str], collection_name: str | None = None) -> None: """ Create payload indexes for specified fields in the collection. This is idempotent: it will skip if index already exists. @@ -323,10 +463,11 @@ def ensure_payload_indexes(self, fields: list[str]) -> None: Args: fields (list[str]): List of field names to index (as keyword). """ + collection_name = collection_name or self.config.collection_name for field in fields: try: self.client.create_payload_index( - collection_name=self.config.collection_name, + collection_name=collection_name, field_name=field, field_schema="keyword", # Could be extended in future ) @@ -349,7 +490,8 @@ def delete(self, ids: list[str]) -> None: """Delete items from the vector database.""" point_ids: list[str | int] = ids - self.client.delete( - collection_name=self.config.collection_name, - points_selector=models.PointIdsList(points=point_ids), - ) + for collection_name in self._all_candidate_collections(): + self.client.delete( + collection_name=collection_name, + points_selector=models.PointIdsList(points=point_ids), + ) From 36d2711fb6ebc919c809611e0815231df73b3293 Mon Sep 17 00:00:00 2001 From: nuoxiichn Date: Mon, 16 Mar 2026 18:25:56 +0800 Subject: [PATCH 2/3] test: test for multi database and qdrant collection naming --- tests/api/test_multidb_handlers.py | 134 +++++++++++++++++++++++++++++ tests/vec_dbs/test_qdrant.py | 46 ++++++++++ 2 files changed, 180 insertions(+) create mode 100644 tests/api/test_multidb_handlers.py diff --git a/tests/api/test_multidb_handlers.py b/tests/api/test_multidb_handlers.py new file mode 100644 index 000000000..d59dfc634 --- /dev/null +++ b/tests/api/test_multidb_handlers.py @@ -0,0 +1,134 @@ +from unittest.mock import Mock, patch + +from memos.api.handlers.add_handler import AddHandler +from memos.api.handlers.base_handler import HandlerDependencies +from memos.api.handlers.search_handler import SearchHandler +from memos.api.product_models import APIADDRequest, APISearchRequest +from memos.multi_mem_cube.composite_cube import CompositeCubeView +from memos.multi_mem_cube.single_cube import SingleCubeView + + +def _build_dependencies() -> HandlerDependencies: + return HandlerDependencies( + llm=Mock(name="llm"), + naive_mem_cube=Mock(name="naive_mem_cube"), + mem_reader=Mock(name="mem_reader"), + mem_scheduler=Mock(name="mem_scheduler"), + searcher=Mock(name="searcher"), + deepsearch_agent=Mock(name="deepsearch_agent"), + feedback_server=Mock(name="feedback_server"), + reranker=Mock(name="reranker"), + embedder=Mock(name="embedder"), + internet_retriever=Mock(name="internet_retriever"), + default_cube_config=Mock(name="default_cube_config"), + ) + + +def test_add_handler_build_cube_view_uses_per_user_components(monkeypatch): + monkeypatch.setenv("GRAPH_DB_BACKEND", "neo4j") + monkeypatch.setenv("MOS_NEO4J_SHARED_DB", "false") + + handler = AddHandler(_build_dependencies()) + per_user_naive = Mock(name="per_user_naive") + per_user_reader = Mock(name="per_user_reader") + + with patch("memos.api.handlers.add_handler.create_per_db_components") as create_components: + create_components.return_value = { + "naive_mem_cube": per_user_naive, + "mem_reader": per_user_reader, + } + req = APIADDRequest(user_id="alice", memory_content="hello") + + cube_view = handler._build_cube_view(req) + + assert isinstance(cube_view, SingleCubeView) + assert cube_view.naive_mem_cube is per_user_naive + assert cube_view.mem_reader is per_user_reader + create_components.assert_called_once() + assert create_components.call_args.kwargs["db_name"] == "alice" + + +def test_add_handler_per_user_component_cache(monkeypatch): + monkeypatch.setenv("GRAPH_DB_BACKEND", "neo4j") + monkeypatch.setenv("MOS_NEO4J_SHARED_DB", "false") + + handler = AddHandler(_build_dependencies()) + with patch("memos.api.handlers.add_handler.create_per_db_components") as create_components: + create_components.return_value = { + "naive_mem_cube": Mock(), + "mem_reader": Mock(), + } + + first = handler._get_per_user_components("alice") + second = handler._get_per_user_components("alice") + + assert first is second + create_components.assert_called_once() + + +def test_search_handler_build_cube_view_uses_per_db_components(monkeypatch): + monkeypatch.setenv("GRAPH_DB_BACKEND", "neo4j") + monkeypatch.setenv("MOS_NEO4J_SHARED_DB", "false") + + handler = SearchHandler(_build_dependencies()) + + per_db_components = { + "cube_a": { + "naive_mem_cube": Mock(name="naive_a"), + "mem_reader": Mock(name="reader_a"), + "searcher": Mock(name="searcher_a"), + "text_mem": Mock(name="text_mem_a"), + }, + "cube_b": { + "naive_mem_cube": Mock(name="naive_b"), + "mem_reader": Mock(name="reader_b"), + "searcher": Mock(name="searcher_b"), + "text_mem": Mock(name="text_mem_b"), + }, + } + + with ( + patch("memos.api.handlers.search_handler.create_per_db_components") as create_components, + patch("memos.api.handlers.search_handler.DeepSearchMemAgent") as deepsearch_agent_cls, + ): + create_components.side_effect = lambda db_name, base_components: per_db_components[db_name] + deepsearch_agent_cls.side_effect = [Mock(name="agent_a"), Mock(name="agent_b")] + + req = APISearchRequest(query="hello", user_id="alice", readable_cube_ids=["cube_a", "cube_b"]) + cube_view = handler._build_cube_view(req) + + assert isinstance(cube_view, CompositeCubeView) + assert len(cube_view.cube_views) == 2 + first, second = cube_view.cube_views + assert first.cube_id == "cube_a" + assert first.searcher is per_db_components["cube_a"]["searcher"] + assert first.deepsearch_agent is per_db_components["cube_a"]["deepsearch_agent"] + assert second.cube_id == "cube_b" + assert second.searcher is per_db_components["cube_b"]["searcher"] + assert second.deepsearch_agent is per_db_components["cube_b"]["deepsearch_agent"] + + +def test_search_handler_per_db_component_cache(monkeypatch): + monkeypatch.setenv("GRAPH_DB_BACKEND", "neo4j") + monkeypatch.setenv("MOS_NEO4J_SHARED_DB", "false") + + handler = SearchHandler(_build_dependencies()) + + with ( + patch("memos.api.handlers.search_handler.create_per_db_components") as create_components, + patch("memos.api.handlers.search_handler.DeepSearchMemAgent") as deepsearch_agent_cls, + ): + create_components.return_value = { + "naive_mem_cube": Mock(), + "mem_reader": Mock(), + "searcher": Mock(), + "text_mem": Mock(), + } + deepsearch_agent_cls.return_value = Mock(name="agent") + + first = handler._get_per_db_components("cube_a") + second = handler._get_per_db_components("cube_a") + + assert first is second + create_components.assert_called_once() + deepsearch_agent_cls.assert_called_once() diff --git a/tests/vec_dbs/test_qdrant.py b/tests/vec_dbs/test_qdrant.py index 67f76d463..948d85263 100644 --- a/tests/vec_dbs/test_qdrant.py +++ b/tests/vec_dbs/test_qdrant.py @@ -69,6 +69,7 @@ def test_add_and_get_by_id(vec_db): def test_search(vec_db): + vec_db.collection_exists = MagicMock(return_value=True) id = str(uuid.uuid4()) mock_response = type( "QueryResponse", @@ -113,6 +114,7 @@ def test_delete(vec_db): def test_count(vec_db): + vec_db.collection_exists = MagicMock(return_value=True) vec_db.client.count.return_value.count = 5 count = vec_db.count() assert count == 5 @@ -148,3 +150,47 @@ def test_qdrant_client_cloud_init(): VecDBFactory.from_config(config) mockclient.assert_called_once_with(url="https://cloud.qdrant.example", api_key="secret-key") + + +def test_search_routes_to_user_collection_and_strips_scope_filter(vec_db): + vec_db.collection_exists = MagicMock(return_value=True) + vec_db._dict_to_filter = MagicMock(return_value="mock_filter") + vec_db.client.query_points.return_value = type("QueryResponse", (object,), {"points": []})() + + vec_db.search( + [0.1, 0.2, 0.3], + top_k=2, + filter={"user_id": "alice", "status": "activated"}, + ) + + vec_db._dict_to_filter.assert_called_once_with({"status": "activated"}) + vec_db.client.query_points.assert_called_once() + assert vec_db.client.query_points.call_args.kwargs["collection_name"] == "alice" + + +def test_add_groups_points_into_multiple_user_collections(vec_db): + vec_db._ensure_collection_ready = MagicMock() + + vec_db.add( + [ + { + "id": str(uuid.uuid4()), + "vector": [0.1, 0.2, 0.3], + "payload": {"user_id": "alice", "tag": "a"}, + }, + { + "id": str(uuid.uuid4()), + "vector": [0.3, 0.2, 0.1], + "payload": {"user_name": "bob", "tag": "b"}, + }, + ] + ) + + upsert_collections = {c.kwargs["collection_name"] for c in vec_db.client.upsert.call_args_list} + assert upsert_collections == {"alice", "bob"} + + +def test_resolve_collection_name_sanitizes_scope(vec_db): + collection = vec_db._resolve_collection_name(payload={"user_id": " user@-01 "}) + + assert collection == "user_-01" From 5e0141cfb6db04428f1a86bd1b7f9ae854cf213f Mon Sep 17 00:00:00 2001 From: nuoxiichn Date: Wed, 18 Mar 2026 09:43:29 +0800 Subject: [PATCH 3/3] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/memos/vec_dbs/qdrant.py | 87 +++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/src/memos/vec_dbs/qdrant.py b/src/memos/vec_dbs/qdrant.py index 6831e01da..4f4af7d31 100644 --- a/src/memos/vec_dbs/qdrant.py +++ b/src/memos/vec_dbs/qdrant.py @@ -278,54 +278,57 @@ def _dict_to_filter(self, filter_dict: dict[str, Any]) -> Any: def get_by_id(self, id: str) -> VecDBItem | None: """Get a single item by ID.""" - for collection_name in self._all_candidate_collections(): - try: - response = self.client.retrieve( - collection_name=collection_name, - ids=[id], - with_payload=True, - with_vectors=True, - ) - except Exception: - continue - - if response: - point = response[0] - return VecDBItem( - id=point.id, - vector=point.vector, - payload=point.payload, - ) + # Resolve a single collection to avoid cross-collection leakage and + # rely on the same scoping logic used elsewhere (e.g., get_by_filter). + collection_name = self._resolve_collection_name(filter_dict={}) + if not self.collection_exists(collection_name): + logger.info( + f"Qdrant collection '{collection_name}' does not exist, returning None for get_by_id." + ) + return None - return None + response = self.client.retrieve( + collection_name=collection_name, + ids=[id], + with_payload=True, + with_vectors=True, + ) + + if not response: + return None + + point = response[0] + return VecDBItem( + id=point.id, + vector=point.vector, + payload=point.payload, + ) def get_by_ids(self, ids: list[str]) -> list[VecDBItem]: """Get multiple items by their IDs.""" - remaining_ids = set(ids) - found_items: dict[str, VecDBItem] = {} + # Resolve a single collection using the same scoping logic as other methods. + collection_name = self._resolve_collection_name(filter_dict={}) + if not self.collection_exists(collection_name): + logger.info( + f"Qdrant collection '{collection_name}' does not exist, returning empty result for get_by_ids." + ) + return [] - for collection_name in self._all_candidate_collections(): - if not remaining_ids: - break + response = self.client.retrieve( + collection_name=collection_name, + ids=ids, + with_payload=True, + with_vectors=True, + ) - try: - response = self.client.retrieve( - collection_name=collection_name, - ids=list(remaining_ids), - with_payload=True, - with_vectors=True, - ) - except Exception: - continue - - for point in response: - item = VecDBItem( - id=point.id, - vector=point.vector, - payload=point.payload, - ) - found_items[item.id] = item - remaining_ids.discard(item.id) + found_items: dict[str, VecDBItem] = {} + for point in response: + item = VecDBItem( + id=point.id, + vector=point.vector, + payload=point.payload, + ) + found_items[item.id] = item return [found_items[id] for id in ids if id in found_items]