diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 9aa4dba5d..80efadf13 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -887,6 +887,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, + "include_embedding": bool( + os.getenv("INCLUDE_EMBEDDING", "false") == "true" + ), }, }, "act_mem": {} @@ -960,6 +963,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, "mode": os.getenv("ASYNC_MODE", "sync"), + "include_embedding": bool( + os.getenv("INCLUDE_EMBEDDING", "false") == "true" + ), }, }, "act_mem": {} diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 670a19110..ac50bba47 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -210,6 +210,7 @@ def init_server() -> dict[str, Any]: config=default_cube_config.text_mem.config, internet_retriever=internet_retriever, tokenizer=tokenizer, + include_embedding=bool(os.getenv("INCLUDE_EMBEDDING", "false") == "true"), ) logger.debug("Text memory initialized") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 04fc58ad6..fa71a40d8 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -196,6 +196,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): default="sync", description=("whether use asynchronous mode in memory add"), ) + include_embedding: bool | None = Field( + default=False, + description="Whether to include embedding in the memory retrieval", + ) class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 018911db2..025c0de3c 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3602,6 +3602,11 @@ def _build_node_from_agtype(self, node_agtype, embedding=None): return None if embedding is not None: + if isinstance(embedding, str): + try: + embedding = json.loads(embedding) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse embedding for node") props["embedding"] = embedding # Return standard format directly diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index c67271f76..2df819f3a 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -37,6 +37,7 @@ def __init__( internet_retriever: None = None, is_reorganize: bool = False, tokenizer: FastTokenizer | None = None, + include_embedding: bool = False, ): """Initialize memory with the given configuration.""" self.config: TreeTextMemoryConfig = config @@ -65,3 +66,4 @@ def __init__( ) else: logger.info("No internet retriever configured") + self.include_embedding = include_embedding diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 75eae30e8..a51f80ff8 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -92,6 +92,7 @@ def __init__(self, config: TreeTextMemoryConfig): else: logger.info("No internet retriever configured") self.tokenizer = None + self.include_embedding = config.include_embedding or False def add( self, @@ -192,6 +193,7 @@ def search( search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, tokenizer=self.tokenizer, + include_embedding=self.include_embedding, ) return searcher.search( query, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 6a10087f9..e58ebcdd1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -35,6 +35,7 @@ def __init__( manual_close_internet: bool = True, process_llm: Any | None = None, tokenizer: FastTokenizer | None = None, + include_embedding: bool = False, ): super().__init__( dispatcher_llm=dispatcher_llm, @@ -46,6 +47,7 @@ def __init__( search_strategy=search_strategy, manual_close_internet=manual_close_internet, tokenizer=tokenizer, + include_embedding=include_embedding, ) self.stage_retrieve_top = 3 diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 0b86b4ab2..9a6e2ddb4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -22,6 +22,7 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, bm25_retriever: EnhancedBM25 | None = None, + include_embedding: bool = False, ): self.graph_store = graph_store self.embedder = embedder @@ -29,6 +30,7 @@ def __init__( self.max_workers = 10 self.filter_weight = 0.6 self.use_bm25 = bool(self.bm25_retriever) + self.include_embedding = include_embedding def retrieve( self, @@ -72,7 +74,7 @@ def retrieve( # For working memory, retrieve all entries (no session-oriented filtering) working_memories = self.graph_store.get_all_memory_items( scope="WorkingMemory", - include_embedding=False, + include_embedding=self.include_embedding, user_name=user_name, filter=search_filter, ) @@ -244,7 +246,9 @@ def process_node(node): return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=self.include_embedding + ) final_nodes = [] for node in node_dicts: @@ -291,7 +295,7 @@ def process_node(node): # Load nodes and post-filter node_dicts = self.graph_store.get_nodes( - list(candidate_ids), include_embedding=False, user_name=user_name + list(candidate_ids), include_embedding=self.include_embedding, user_name=user_name ) final_nodes = [] @@ -385,7 +389,10 @@ def search_path_b(): unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name + list(unique_ids), + include_embedding=self.include_embedding, + cube_name=cube_name, + user_name=user_name, ) or [] ) @@ -416,7 +423,9 @@ def _bm25_recall( key_filters.append({"field": key, "op": "=", "value": value}) corpus_name += "".join(list(search_filter.values())) candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=self.include_embedding + ) bm25_query = " ".join(list({query, *parsed_goal.keys})) bm25_results = self.bm25_retriever.search( @@ -471,7 +480,10 @@ def _fulltext_recall( unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name + list(unique_ids), + include_embedding=self.include_embedding, + cube_name=cube_name, + user_name=user_name, ) or [] ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 843dce142..05a13c939 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -47,13 +47,16 @@ def __init__( search_strategy: dict | None = None, manual_close_internet: bool = True, tokenizer: FastTokenizer | None = None, + include_embedding: bool = False, ): self.graph_store = graph_store self.embedder = embedder self.llm = dispatcher_llm self.task_goal_parser = TaskGoalParser(dispatcher_llm) - self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever) + self.graph_retriever = GraphMemoryRetriever( + graph_store, embedder, bm25_retriever, include_embedding=include_embedding + ) self.reranker = reranker self.reasoner = MemoryReasoner(dispatcher_llm)