From 1a9bb4a38f411fdd0df5e532db2fc6fbe9a9c429 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 17 Dec 2025 15:47:52 +0800 Subject: [PATCH 1/6] feat: update include embedding --- src/memos/api/config.py | 6 +++++ src/memos/api/handlers/component_init.py | 1 + src/memos/configs/memory.py | 4 ++++ src/memos/graph_dbs/polardb.py | 5 ++++ src/memos/memories/textual/simple_tree.py | 2 ++ src/memos/memories/textual/tree.py | 2 ++ .../retrieve/advanced_searcher.py | 2 ++ .../tree_text_memory/retrieve/recall.py | 24 ++++++++++++++----- .../tree_text_memory/retrieve/searcher.py | 5 +++- 9 files changed, 44 insertions(+), 7 deletions(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 9aa4dba5d..80efadf13 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -887,6 +887,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, + "include_embedding": bool( + os.getenv("INCLUDE_EMBEDDING", "false") == "true" + ), }, }, "act_mem": {} @@ -960,6 +963,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, "mode": os.getenv("ASYNC_MODE", "sync"), + "include_embedding": bool( + os.getenv("INCLUDE_EMBEDDING", "false") == "true" + ), }, }, "act_mem": {} diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 670a19110..ac50bba47 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -210,6 +210,7 @@ def init_server() -> dict[str, Any]: config=default_cube_config.text_mem.config, internet_retriever=internet_retriever, tokenizer=tokenizer, + include_embedding=bool(os.getenv("INCLUDE_EMBEDDING", "false") == "true"), ) logger.debug("Text memory initialized") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 04fc58ad6..fa71a40d8 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -196,6 +196,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): default="sync", description=("whether use asynchronous mode in memory add"), ) + include_embedding: bool | None = Field( + default=False, + description="Whether to include embedding in the memory retrieval", + ) class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 018911db2..025c0de3c 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3602,6 +3602,11 @@ def _build_node_from_agtype(self, node_agtype, embedding=None): return None if embedding is not None: + if isinstance(embedding, str): + try: + embedding = json.loads(embedding) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse embedding for node") props["embedding"] = embedding # Return standard format directly diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index c67271f76..2df819f3a 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -37,6 +37,7 @@ def __init__( internet_retriever: None = None, is_reorganize: bool = False, tokenizer: FastTokenizer | None = None, + include_embedding: bool = False, ): """Initialize memory with the given configuration.""" self.config: TreeTextMemoryConfig = config @@ -65,3 +66,4 @@ def __init__( ) else: logger.info("No internet retriever configured") + self.include_embedding = include_embedding diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 75eae30e8..a51f80ff8 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -92,6 +92,7 @@ def __init__(self, config: TreeTextMemoryConfig): else: logger.info("No internet retriever configured") self.tokenizer = None + self.include_embedding = config.include_embedding or False def add( self, @@ -192,6 +193,7 @@ def search( search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, tokenizer=self.tokenizer, + include_embedding=self.include_embedding, ) return searcher.search( query, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 6a10087f9..e58ebcdd1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -35,6 +35,7 @@ def __init__( manual_close_internet: bool = True, process_llm: Any | None = None, tokenizer: FastTokenizer | None = None, + include_embedding: bool = False, ): super().__init__( dispatcher_llm=dispatcher_llm, @@ -46,6 +47,7 @@ def __init__( search_strategy=search_strategy, manual_close_internet=manual_close_internet, tokenizer=tokenizer, + include_embedding=include_embedding, ) self.stage_retrieve_top = 3 diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 0b86b4ab2..9a6e2ddb4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -22,6 +22,7 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, bm25_retriever: EnhancedBM25 | None = None, + include_embedding: bool = False, ): self.graph_store = graph_store self.embedder = embedder @@ -29,6 +30,7 @@ def __init__( self.max_workers = 10 self.filter_weight = 0.6 self.use_bm25 = bool(self.bm25_retriever) + self.include_embedding = include_embedding def retrieve( self, @@ -72,7 +74,7 @@ def retrieve( # For working memory, retrieve all entries (no session-oriented filtering) working_memories = self.graph_store.get_all_memory_items( scope="WorkingMemory", - include_embedding=False, + include_embedding=self.include_embedding, user_name=user_name, filter=search_filter, ) @@ -244,7 +246,9 @@ def process_node(node): return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=self.include_embedding + ) final_nodes = [] for node in node_dicts: @@ -291,7 +295,7 @@ def process_node(node): # Load nodes and post-filter node_dicts = self.graph_store.get_nodes( - list(candidate_ids), include_embedding=False, user_name=user_name + list(candidate_ids), include_embedding=self.include_embedding, user_name=user_name ) final_nodes = [] @@ -385,7 +389,10 @@ def search_path_b(): unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name + list(unique_ids), + include_embedding=self.include_embedding, + cube_name=cube_name, + user_name=user_name, ) or [] ) @@ -416,7 +423,9 @@ def _bm25_recall( key_filters.append({"field": key, "op": "=", "value": value}) corpus_name += "".join(list(search_filter.values())) candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=self.include_embedding + ) bm25_query = " ".join(list({query, *parsed_goal.keys})) bm25_results = self.bm25_retriever.search( @@ -471,7 +480,10 @@ def _fulltext_recall( unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name + list(unique_ids), + include_embedding=self.include_embedding, + cube_name=cube_name, + user_name=user_name, ) or [] ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 843dce142..05a13c939 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -47,13 +47,16 @@ def __init__( search_strategy: dict | None = None, manual_close_internet: bool = True, tokenizer: FastTokenizer | None = None, + include_embedding: bool = False, ): self.graph_store = graph_store self.embedder = embedder self.llm = dispatcher_llm self.task_goal_parser = TaskGoalParser(dispatcher_llm) - self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever) + self.graph_retriever = GraphMemoryRetriever( + graph_store, embedder, bm25_retriever, include_embedding=include_embedding + ) self.reranker = reranker self.reasoner = MemoryReasoner(dispatcher_llm) From d77fdf30012e9c9d6c663ac11f4a7e36e0a8f588 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 17 Dec 2025 20:34:25 +0800 Subject: [PATCH 2/6] feat: update init --- src/memos/memories/textual/tree.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index a51f80ff8..22545496a 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -144,6 +144,7 @@ def get_searcher( manual_close_internet=manual_close_internet, process_llm=process_llm, tokenizer=self.tokenizer, + include_embedding=self.include_embedding, ) return searcher From 65863890aebc6914e38f37b3c77224f5b02ab33f Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 17 Dec 2025 20:42:10 +0800 Subject: [PATCH 3/6] feat: update embedding --- src/memos/graph_dbs/polardb.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 266084a17..265e0be25 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1160,13 +1160,11 @@ def get_nodes( properties = properties_json if properties_json else {} # Parse embedding from JSONB if it exists - if embedding_json is not None: + if embedding_json is not None and kwargs.get("include_embedding"): try: # remove embedding - """ embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json - # properties["embedding"] = embedding - """ + properties["embedding"] = embedding except (json.JSONDecodeError, TypeError): logger.warning(f"Failed to parse embedding for node {node_id}") nodes.append( From 0af4f193bbcc9b13236cb4df9991b8772d4a14d3 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 17 Dec 2025 20:45:54 +0800 Subject: [PATCH 4/6] fix: code --- src/memos/graph_dbs/polardb.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 265e0be25..ee9af485f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1163,7 +1163,11 @@ def get_nodes( if embedding_json is not None and kwargs.get("include_embedding"): try: # remove embedding - embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) properties["embedding"] = embedding except (json.JSONDecodeError, TypeError): logger.warning(f"Failed to parse embedding for node {node_id}") From 6da4a33c2dfef3b03fd226796da204bd30acbf51 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 17 Dec 2025 21:02:34 +0800 Subject: [PATCH 5/6] feat: update feedback --- src/memos/api/config.py | 26 +++++++++++++++++++++++ src/memos/api/handlers/component_init.py | 5 ++++- src/memos/api/handlers/config_builders.py | 10 +++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 80efadf13..b795c2be6 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -395,6 +395,32 @@ def get_reranker_config() -> dict[str, Any]: }, } + @staticmethod + def get_feedback_reranker_config() -> dict[str, Any]: + """Get embedder configuration.""" + embedder_backend = os.getenv("MOS_FEEDBACK_RERANKER_BACKEND", "http_bge") + + if embedder_backend in ["http_bge", "http_bge_strategy"]: + return { + "backend": embedder_backend, + "config": { + "url": os.getenv("MOS_RERANKER_URL"), + "model": os.getenv("MOS_FEEDBACK_RERANKER_MODEL", "bge-reranker-v2-m3"), + "timeout": 10, + "headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")), + "rerank_source": os.getenv("MOS_RERANK_SOURCE"), + "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), + }, + } + else: + return { + "backend": "cosine_local", + "config": { + "level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0}, + "level_field": "background", + }, + } + @staticmethod def get_embedder_config() -> dict[str, Any]: """Get embedder configuration.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ac50bba47..8d7250a68 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -13,6 +13,7 @@ from memos.api.handlers.config_builders import ( build_chat_llm_config, build_embedder_config, + build_feedback_reranker_config, build_graph_db_config, build_internet_retriever_config, build_llm_config, @@ -159,6 +160,7 @@ def init_server() -> dict[str, Any]: embedder_config = build_embedder_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() + feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() vector_db_config = build_vec_db_config() pref_extractor_config = build_pref_extractor_config() @@ -179,6 +181,7 @@ def init_server() -> dict[str, Any]: embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) reranker = RerankerFactory.from_config(reranker_config) + feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) @@ -305,7 +308,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, mem_reader=mem_reader, searcher=searcher, - reranker=reranker, + reranker=feedback_reranker, ) # Initialize Scheduler diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 4a83700d0..fb3df80c2 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -140,6 +140,16 @@ def build_reranker_config() -> dict[str, Any]: return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) +def build_feedback_reranker_config() -> dict[str, Any]: + """ + Build reranker configuration. + + Returns: + Validated reranker configuration dictionary + """ + return RerankerConfigFactory.model_validate(APIConfig.get_feedback_reranker_config()) + + def build_internet_retriever_config() -> dict[str, Any]: """ Build internet retriever configuration. From 50ac7d1e59d94e3a2d4af1f62df825df6b34c3a2 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 17 Dec 2025 21:52:01 +0800 Subject: [PATCH 6/6] feat: update prefdata --- src/memos/api/handlers/component_init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 8d7250a68..9c1212fe0 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -247,7 +247,7 @@ def init_server() -> dict[str, Any]: config_factory=pref_retriever_config, llm_provider=llm, embedder=embedder, - reranker=reranker, + reranker=feedback_reranker, vector_db=vector_db, ) if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" @@ -262,7 +262,7 @@ def init_server() -> dict[str, Any]: extractor_llm=llm, vector_db=vector_db, embedder=embedder, - reranker=reranker, + reranker=feedback_reranker, extractor=pref_extractor, adder=pref_adder, retriever=pref_retriever,