diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 8da6a2890..ba7b558fd 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -160,6 +160,16 @@ def build_reranker_config() -> dict[str, Any]: return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) +def build_feedback_reranker_config() -> dict[str, Any]: + """ + Build reranker configuration. + + Returns: + Validated reranker configuration dictionary + """ + return RerankerConfigFactory.model_validate(APIConfig.get_feedback_reranker_config()) + + def build_internet_retriever_config() -> dict[str, Any]: """ Build internet retriever configuration. @@ -277,6 +287,7 @@ def init_components() -> dict[str, Any]: embedder_config = build_embedder_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() + feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() vector_db_config = build_vec_db_config() pref_extractor_config = build_pref_extractor_config() @@ -296,6 +307,7 @@ def init_components() -> dict[str, Any]: embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) reranker = RerankerFactory.from_config(reranker_config) + feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) @@ -359,7 +371,7 @@ def init_components() -> dict[str, Any]: config_factory=pref_retriever_config, llm_provider=llm, embedder=embedder, - reranker=reranker, + reranker=feedback_reranker, vector_db=vector_db, ) if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" @@ -374,7 +386,7 @@ def init_components() -> dict[str, Any]: extractor_llm=llm, vector_db=vector_db, embedder=embedder, - reranker=reranker, + reranker=feedback_reranker, extractor=pref_extractor, adder=pref_adder, retriever=pref_retriever, @@ -405,6 +417,7 @@ def init_components() -> dict[str, Any]: memory_manager=memory_manager, mem_reader=mem_reader, searcher=searcher, + reranker=feedback_reranker, ) # Return all components as a dictionary for easy access and extension return {"naive_mem_cube": naive_mem_cube, "feedback_server": feedback_server}