diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index c101eece4..44ecbe531 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -21,7 +21,9 @@ from memos.api.product_models import ( APIADDRequest, APIChatCompleteRequest, + APISearchPlaygroundRequest, APISearchRequest, + ChatPlaygroundRequest, ChatRequest, ) from memos.context.context import ContextThread @@ -91,6 +93,7 @@ def __init__( self.enable_mem_scheduler = ( hasattr(dependencies, "enable_mem_scheduler") and dependencies.enable_mem_scheduler ) + self.dependencies = dependencies def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]: """ @@ -356,7 +359,7 @@ def generate_chat_response() -> Generator[str, None, None]: self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse: + def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> StreamingResponse: """ Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. @@ -413,8 +416,8 @@ def generate_chat_response() -> Generator[str, None, None]: label=QUERY_TASK_LABEL, ) - # ====== first search without parse goal ====== - search_req = APISearchRequest( + # ====== first search text mem with parse goal ====== + search_req = APISearchPlaygroundRequest( query=chat_req.query, user_id=chat_req.user_id, readable_cube_ids=readable_cube_ids, @@ -426,6 +429,7 @@ def generate_chat_response() -> Generator[str, None, None]: include_preference=chat_req.include_preference, pref_top_k=chat_req.pref_top_k, filter=chat_req.filter, + playground_search_goal_parser=True, ) search_response = self.search_handler.handle_search_memories(search_req) @@ -439,10 +443,10 @@ def generate_chat_response() -> Generator[str, None, None]: memories_list = text_mem_results[0]["memories"] # Filter memories by threshold - first_filtered_memories = self._filter_memories_by_threshold(memories_list) + filtered_memories = self._filter_memories_by_threshold(memories_list) # Prepare reference data (first search) - reference = prepare_reference_data(first_filtered_memories) + reference = prepare_reference_data(filtered_memories) # get preference string pref_string = search_response.data.get("pref_string", "") @@ -455,48 +459,68 @@ def generate_chat_response() -> Generator[str, None, None]: pref_md_string = self._build_pref_md_string_for_playground(pref_memories) yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n" - # internet status - yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n" - - # ====== second search with parse goal ====== - search_req = APISearchRequest( - query=chat_req.query, - user_id=chat_req.user_id, - readable_cube_ids=readable_cube_ids, - mode=chat_req.mode, - internet_search=chat_req.internet_search, - top_k=chat_req.top_k, - chat_history=chat_req.history, - session_id=chat_req.session_id, - include_preference=False, - filter=chat_req.filter, - playground_search_goal_parser=True, + # parse goal for internet search + searcher = self.dependencies.searcher + parsed_goal = searcher.task_goal_parser.parse( + task_description=chat_req.query, + context="\n".join( + [memory.get("memory", "") for memory in filtered_memories] + ), + conversation=chat_req.history, + mode="fine", ) - search_response = self.search_handler.handle_search_memories(search_req) - # Extract memories from search results (second search) - memories_list = [] - if search_response.data and search_response.data.get("text_mem"): - text_mem_results = search_response.data["text_mem"] - if text_mem_results and text_mem_results[0].get("memories"): - memories_list = text_mem_results[0]["memories"] + if chat_req.beginner_guide_step == "first": + chat_req.internet_search = False + parsed_goal.internet_search = False + elif chat_req.beginner_guide_step == "second": + chat_req.internet_search = True + parsed_goal.internet_search = True + + if chat_req.internet_search or parsed_goal.internet_search: + # internet status + yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n" + + # ====== internet search with parse goal ====== + search_req = APISearchPlaygroundRequest( + query=chat_req.query + + (f"{parsed_goal.tags}" if parsed_goal.tags else ""), + user_id=chat_req.user_id, + readable_cube_ids=readable_cube_ids, + mode=chat_req.mode, + internet_search=True, + top_k=chat_req.top_k, + chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=False, + filter=chat_req.filter, + search_memory_type="OuterMemory", + ) + search_response = self.search_handler.handle_search_memories(search_req) - # Filter memories by threshold - second_filtered_memories = self._filter_memories_by_threshold(memories_list) + # Extract memories from search results (second search) + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] - # dedup and supplement memories - filtered_memories = self._dedup_and_supplement_memories( - first_filtered_memories, second_filtered_memories - ) + # Filter memories by threshold + second_filtered_memories = self._filter_memories_by_threshold(memories_list) - # Prepare remain reference data (second search) - reference = prepare_reference_data(filtered_memories) - # get internet reference - internet_reference = self._get_internet_reference( - search_response.data.get("text_mem")[0]["memories"] - ) + # dedup and supplement memories + filtered_memories = self._dedup_and_supplement_memories( + filtered_memories, second_filtered_memories + ) - yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + # Prepare remain reference data (second search) + reference = prepare_reference_data(filtered_memories) + # get internet reference + internet_reference = self._get_internet_reference( + search_response.data.get("text_mem")[0]["memories"] + ) + + yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" # Step 2: Build system prompt with memories system_prompt = self._build_enhance_system_prompt( @@ -571,8 +595,9 @@ def generate_chat_response() -> Generator[str, None, None]: chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" yield chunk_data - # Yield internet reference after text response - yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" + if chat_req.internet_search or parsed_goal.internet_search: + # Yield internet reference after text response + yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" # Calculate timing time_end = time.time() diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index e77aee755..1f5a582fc 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -159,6 +159,14 @@ def _convert_deprecated_fields(self): return self +class ChatPlaygroundRequest(ChatRequest): + """Request model for chat operations in playground.""" + + beginner_guide_step: str | None = Field( + None, description="Whether to use beginner guide, option: [first, second]" + ) + + class ChatCompleteRequest(BaseRequest): """Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest.""" @@ -373,9 +381,11 @@ class APISearchRequest(BaseRequest): "If None, default thresholds will be applied." ), ) - - # TODO: tmp field for playground search goal parser, will be removed later - playground_search_goal_parser: bool = Field(False, description="Playground search goal parser") + # Internal field for search memory type + search_memory_type: str = Field( + "All", + description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory", + ) # ==== Context ==== chat_history: MessageList | None = Field( @@ -448,6 +458,13 @@ def _convert_deprecated_fields(self) -> "APISearchRequest": return self +class APISearchPlaygroundRequest(APISearchRequest): + """Request model for searching memories in playground.""" + + # TODO: tmp field for playground search goal parser, will be removed later + playground_search_goal_parser: bool = Field(False, description="Playground search goal parser") + + class APIADDRequest(BaseRequest): """Request model for creating memories.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 576cca55e..e8acf2e38 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -29,6 +29,7 @@ APIChatCompleteRequest, APIFeedbackRequest, APISearchRequest, + ChatPlaygroundRequest, ChatRequest, DeleteMemoryRequest, DeleteMemoryResponse, @@ -200,7 +201,7 @@ def chat_stream(chat_req: ChatRequest): @router.post("/chat/stream/playground", summary="Chat with MemOS playground") -def chat_stream_playground(chat_req: ChatRequest): +def chat_stream_playground(chat_req: ChatPlaygroundRequest): """ Chat with MemOS for a specific user. Returns SSE stream. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 1d0c344b4..813142826 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -137,9 +137,12 @@ def get_searcher( self.graph_store, self.embedder, self.reranker, + bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, + search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, process_llm=process_llm, + tokenizer=self.tokenizer, ) return searcher diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 3e769e424..4225ed99b 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -275,6 +275,10 @@ def _parse_task( **kwargs, ) + # TODO: tmp field playground_search_goal_parser for playground, will be removed later + if kwargs.get("playground_search_goal_parser", False): + parsed_goal.internet_search = False + query = parsed_goal.rephrased_query or query # if goal has extra memories, embed them too if parsed_goal.memories: @@ -527,7 +531,8 @@ def _retrieve_from_internet( if self.manual_close_internet and not parsed_goal.internet_search: logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)") return [] - if memory_type not in ["All"]: + if memory_type not in ["All", "OuterMemory"]: + logger.info(f"[PATH-C] '{query}' Skipped (memory_type does not match)") return [] logger.info(f"[PATH-C] '{query}' Retrieving from internet...") items = self.internet_retriever.retrieve_from_internet( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index f75f8d045..6b96d7e98 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -48,7 +48,7 @@ def parse( elif mode == "fine": if not self.llm: raise ValueError("LLM not provided for slow mode.") - return self._parse_fine(task_description, context, conversation) + return self._parse_fine(task_description, context, conversation, **kwargs) else: raise ValueError(f"Unknown mode: {mode}") @@ -81,7 +81,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: ) def _parse_fine( - self, query: str, context: str = "", conversation: list[dict] | None = None + self, query: str, context: str = "", conversation: list[dict] | None = None, **kwargs ) -> ParsedTaskGoal: """ Slow mode: LLM structured parse. diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py index 1b7b28949..55c6243d8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py @@ -4,7 +4,7 @@ 1. Keys: the high-level keywords directly relevant to the user’s task. 2. Tags: thematic tags to help categorize and retrieve related memories. 3. Goal Type: retrieval | qa | generation -4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string. +4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. Make full use of information related to the query. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string. 5. Need for internet search: If the user's task instruction only involves objective facts or can be completed without introducing external knowledge, set "internet_search" to False. Otherwise, set it to True. 6. Memories: Provide 2–5 short semantic expansions or rephrasings of the rephrased/original user task instruction. These are used for improved embedding search coverage. Each should be clear, concise, and meaningful for retrieval. diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index b51429376..15bcb99af 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -425,6 +425,7 @@ def _fast_search( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, + momory_type=search_req.search_memory_type, search_filter=search_filter, search_priority=search_priority, info={ @@ -436,7 +437,7 @@ def _fast_search( search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, # TODO: tmp field for playground search goal parser, will be removed later - playground_search_goal_parser=search_req.playground_search_goal_parser, + playground_search_goal_parser=search_req.get("playground_search_goal_parser", None), ) formatted_memories = [format_memory_item(data) for data in search_results]