diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 80efadf13..b795c2be6 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -395,6 +395,32 @@ def get_reranker_config() -> dict[str, Any]: }, } + @staticmethod + def get_feedback_reranker_config() -> dict[str, Any]: + """Get embedder configuration.""" + embedder_backend = os.getenv("MOS_FEEDBACK_RERANKER_BACKEND", "http_bge") + + if embedder_backend in ["http_bge", "http_bge_strategy"]: + return { + "backend": embedder_backend, + "config": { + "url": os.getenv("MOS_RERANKER_URL"), + "model": os.getenv("MOS_FEEDBACK_RERANKER_MODEL", "bge-reranker-v2-m3"), + "timeout": 10, + "headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")), + "rerank_source": os.getenv("MOS_RERANK_SOURCE"), + "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), + }, + } + else: + return { + "backend": "cosine_local", + "config": { + "level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0}, + "level_field": "background", + }, + } + @staticmethod def get_embedder_config() -> dict[str, Any]: """Get embedder configuration.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ac50bba47..8d7250a68 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -13,6 +13,7 @@ from memos.api.handlers.config_builders import ( build_chat_llm_config, build_embedder_config, + build_feedback_reranker_config, build_graph_db_config, build_internet_retriever_config, build_llm_config, @@ -159,6 +160,7 @@ def init_server() -> dict[str, Any]: embedder_config = build_embedder_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() + feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() vector_db_config = build_vec_db_config() pref_extractor_config = build_pref_extractor_config() @@ -179,6 +181,7 @@ def init_server() -> dict[str, Any]: embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) reranker = RerankerFactory.from_config(reranker_config) + feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) @@ -305,7 +308,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, mem_reader=mem_reader, searcher=searcher, - reranker=reranker, + reranker=feedback_reranker, ) # Initialize Scheduler diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 4a83700d0..fb3df80c2 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -140,6 +140,16 @@ def build_reranker_config() -> dict[str, Any]: return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) +def build_feedback_reranker_config() -> dict[str, Any]: + """ + Build reranker configuration. + + Returns: + Validated reranker configuration dictionary + """ + return RerankerConfigFactory.model_validate(APIConfig.get_feedback_reranker_config()) + + def build_internet_retriever_config() -> dict[str, Any]: """ Build internet retriever configuration. diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 266084a17..ee9af485f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1160,13 +1160,15 @@ def get_nodes( properties = properties_json if properties_json else {} # Parse embedding from JSONB if it exists - if embedding_json is not None: + if embedding_json is not None and kwargs.get("include_embedding"): try: # remove embedding - """ - embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json - # properties["embedding"] = embedding - """ + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json + ) + properties["embedding"] = embedding except (json.JSONDecodeError, TypeError): logger.warning(f"Failed to parse embedding for node {node_id}") nodes.append( diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index a51f80ff8..22545496a 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -144,6 +144,7 @@ def get_searcher( manual_close_internet=manual_close_internet, process_llm=process_llm, tokenizer=self.tokenizer, + include_embedding=self.include_embedding, ) return searcher