diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index a686ac8f9..7a47f05e3 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -161,6 +161,11 @@ def mos_server(self): """Get MOS server instance.""" return self.deps.mos_server + @property + def deepsearch_agent(self): + """Get deepsearch agent instance.""" + return self.deps.deepsearch_agent + def _validate_dependencies(self, *required_deps: str) -> None: """ Validate that required dependencies are available. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 3ef1d529d..7b34fcfae 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -307,6 +308,10 @@ def init_server() -> dict[str, Any]: online_bot = get_online_bot_function() if dingding_enabled else None logger.info("DingDing bot is enabled") + deepsearch_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=tree_mem, + ) # Return all components as a dictionary for easy access and extension return { "graph_db": graph_db, @@ -330,4 +335,5 @@ def init_server() -> dict[str, Any]: "text_mem": text_mem, "pref_mem": pref_mem, "online_bot": online_bot, + "deepsearch_agent": deepsearch_agent, } diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index ece89909b..827f61b13 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -31,7 +31,9 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher") + self._validate_dependencies( + "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" + ) def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -52,10 +54,10 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse results = cube_view.search_memories(search_req) - self.logger.info(f"[AddHandler] Final add results count={len(results)}") + self.logger.info(f"[SearchHandler] Final search results count={len(results)}") return SearchResponse( - message="Memory searched successfully", + message="Search completed successfully", data=results, ) @@ -83,6 +85,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) else: single_views = [ @@ -93,6 +96,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) for cube_id in cube_ids ] diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5a070c6ad..5e51aec44 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from memos.types import MessageList +logger = get_logger(__name__) + class JSONResponseParser: """Elegant JSON response parser for LLM outputs""" @@ -48,9 +50,6 @@ def parse(response: str) -> dict[str, Any]: raise ValueError(f"Cannot parse JSON response: {response[:100]}...") -logger = get_logger(__name__) - - class QueryRewriter(BaseMemAgent): """Specialized agent for rewriting queries based on conversation history""" @@ -141,7 +140,7 @@ def __init__( memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) config: Configuration for deep search behavior """ - self.config = config or DeepSearchAgentConfig() + self.config = config or DeepSearchAgentConfig(agent_name="DeepSearchMemAgent") self.max_iterations = self.config.max_iterations self.timeout = self.config.timeout self.llm: BaseLLM = llm @@ -219,7 +218,7 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: return self._remove_duplicate_memories(accumulated_memories) else: return self._generate_final_answer( - query, accumulated_memories, accumulated_context, "", history + query, accumulated_memories, accumulated_context, history ) def _remove_duplicate_memories( @@ -248,9 +247,9 @@ def _generate_final_answer( original_query: str, search_results: list[TextualMemoryItem], context: list[str], - missing_info: str = "", history: list[str] | None = None, sources: list[str] | None = None, + missing_info: str | None = None, ) -> str: """ Generate the final answer. diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2055615d2..d2fde36a3 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,6 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any + deepsearch_agent: Any def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ @@ -247,8 +248,11 @@ def _fast_search( def _deep_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: - logger.error("waiting to be implemented") - return [] + deepsearch_results = self.deepsearch_agent.run( + search_req.query, user_id=user_context.mem_cube_id + ) + formatted_memories = [format_memory_item(data) for data in deepsearch_results] + return formatted_memories def _fine_search( self, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 477cd2409..eb624ef89 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -22,12 +22,14 @@ {context} Analyze the context and determine the next step. Return your response in JSON format with the following structure: -{{ + ```json + {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", "missing_entities": ["entity1", "entity2"], "new_search_query": "new search query", }} +``` Status definitions: - "sufficient": Context fully answers the query diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 853a271f6..2aa96257b 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -48,6 +48,7 @@ def mock_init_server(): "pref_mem": None, "online_bot": None, "chat_llms": Mock(), + "deepsearch_agent": Mock(), } with patch("memos.api.handlers.init_server", return_value=mock_components): diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py new file mode 100644 index 000000000..a80dd10ea --- /dev/null +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -0,0 +1,234 @@ +"""Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" + +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import ( + DeepSearchMemAgent, + JSONResponseParser, +) +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata + + +class TestJSONResponseParser: + """Test JSONResponseParser class.""" + + def test_parse_clean_json(self): + """Test parsing clean JSON response.""" + response = '{"status": "sufficient", "reasoning": "test"}' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_json_with_code_blocks(self): + """Test parsing JSON wrapped in code blocks.""" + response = '```json\n{"status": "sufficient", "reasoning": "test"}\n```' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_invalid_json_raises_error(self): + """Test that invalid JSON raises ValueError.""" + with pytest.raises(ValueError, match="Cannot parse JSON response"): + JSONResponseParser.parse("This is not JSON at all") + + +class TestDeepSearchMemAgent: + """Test DeepSearchMemAgent core functionality.""" + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + mock = MagicMock() + mock.generate.return_value = "Generated answer" + return mock + + @pytest.fixture + def mock_memory_retriever(self): + """Create a mock memory retriever.""" + mock = MagicMock() + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a programming language", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python was created by Guido van Rossum", + metadata=TextualMemoryMetadata(type="fact"), + ), + ] + mock.search.return_value = memory_items + return mock + + @pytest.fixture + def config(self): + """Create DeepSearchAgentConfig.""" + return DeepSearchAgentConfig(agent_name="TestDeepSearch", max_iterations=3, timeout=30) + + @pytest.fixture + def agent(self, mock_llm, mock_memory_retriever, config): + """Create DeepSearchMemAgent instance.""" + agent = DeepSearchMemAgent( + llm=mock_llm, memory_retriever=mock_memory_retriever, config=config + ) + # Mock the sub-agents to avoid complex interactions + agent.query_rewriter.run = MagicMock(return_value="Rewritten query") + agent.reflector.run = MagicMock( + return_value={ + "status": "sufficient", + "reasoning": "Enough info", + "missing_entities": [], + } + ) + return agent + + def test_init_with_config(self, mock_llm, mock_memory_retriever, config): + """Test DeepSearchMemAgent initialization with config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + assert agent.llm == mock_llm + assert agent.memory_retriever == mock_memory_retriever + assert agent.config == config + assert agent.max_iterations == 3 + assert agent.timeout == 30 + + def test_init_without_config(self, mock_llm, mock_memory_retriever): + """Test DeepSearchMemAgent initialization without config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever) + assert isinstance(agent.config, DeepSearchAgentConfig) + assert agent.config.agent_name == "DeepSearchMemAgent" + + def test_run_no_llm_raises_error(self, config): + """Test that running without LLM raises RuntimeError.""" + agent = DeepSearchMemAgent(llm=None, config=config) + with pytest.raises(RuntimeError, match="LLM not initialized"): + agent.run("test query") + + def test_run_returns_memories_when_no_generated_answer(self, agent, mock_memory_retriever): + """Test run returns memories when generated_answer is not requested.""" + result = agent.run("What is Python?", generated_answer=False) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, TextualMemoryItem) for item in result) + agent.query_rewriter.run.assert_called_once() + + def test_run_returns_answer_when_generated_answer(self, agent, mock_llm): + """Test run returns generated answer when requested.""" + result = agent.run("What is Python?", generated_answer=True) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_run_with_user_id(self, agent, mock_memory_retriever): + """Test run with user_id.""" + agent.run("What is Python?", user_id="user123", generated_answer=False) + + # Check that user_id was passed to search + call_kwargs = mock_memory_retriever.search.call_args[1] + assert call_kwargs.get("user_name") == "user123" + + def test_run_no_search_results(self, agent, mock_memory_retriever): + """Test behavior when search returns no results.""" + mock_memory_retriever.search.return_value = [] + + result = agent.run("What is Python?", generated_answer=False) + + assert result == [] + + def test_remove_duplicate_memories(self, agent): + """Test removing duplicate memories.""" + mem_id1 = str(uuid.uuid4()) + mem_id2 = str(uuid.uuid4()) + mem_id3 = str(uuid.uuid4()) + + memories = [ + TextualMemoryItem( + id=mem_id1, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + TextualMemoryItem( + id=mem_id2, + memory="Different content", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=mem_id3, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + ] + + result = agent._remove_duplicate_memories(memories) + + assert len(result) == 2 + assert result[0].id == mem_id1 + assert result[1].id == mem_id2 + + def test_generate_final_answer(self, agent, mock_llm): + """Test final answer generation.""" + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a language", + metadata=TextualMemoryMetadata(type="fact"), + ) + ] + context = ["Python is a programming language"] + + result = agent._generate_final_answer("What is Python?", memory_items, context) + + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_generate_final_answer_with_missing_info(self, agent, mock_llm): + """Test final answer generation with missing info.""" + result = agent._generate_final_answer( + "What is Python?", [], [], missing_info="Version details not found" + ) + + assert result == "Generated answer" + call_args = mock_llm.generate.call_args[0][0] + assert "Version details not found" in call_args[0]["content"] + + def test_generate_final_answer_llm_error(self, agent, mock_llm): + """Test final answer generation handles LLM errors.""" + mock_llm.generate.side_effect = Exception("LLM error") + + result = agent._generate_final_answer("What is Python?", [], []) + + assert "error" in result.lower() + assert "What is Python?" in result + + def test_perform_memory_search_no_retriever(self, mock_llm, config): + """Test memory search when retriever is not configured.""" + agent = DeepSearchMemAgent(mock_llm, memory_retriever=None, config=config) + result = agent._perform_memory_search("test query") + + assert result == [] + + def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config): + """Test full pipeline integration.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + + with ( + patch.object(agent.query_rewriter, "run", return_value="Rewritten query"), + patch.object( + agent.reflector, + "run", + return_value={ + "status": "sufficient", + "reasoning": "Info is sufficient", + "missing_entities": [], + }, + ), + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called()