From 587e5a7204d415882fbd5c98d41dae45ad8666d8 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 19 Nov 2025 21:02:40 +0800 Subject: [PATCH 01/17] feat: update memos headers --- src/memos/api/config.py | 3 ++- src/memos/configs/embedder.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a276fa63d..3236c6705 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -379,7 +379,7 @@ def get_reranker_config() -> dict[str, Any]: "url": os.getenv("MOS_RERANKER_URL"), "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, - "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), + "headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")), "rerank_source": os.getenv("MOS_RERANK_SOURCE"), "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), }, @@ -406,6 +406,7 @@ def get_embedder_config() -> dict[str, Any]: "api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"), "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), + "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), }, } else: # ollama diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index 70095a194..d88b6005e 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig): embedding_dims: int | None = Field( default=None, description="Number of dimensions for the embedding" ) + headers_extra: dict[str, Any] | None = Field( + default=None, + description="Extra headers for the embedding model, only for universal_api backend", + ) class OllamaEmbedderConfig(BaseEmbedderConfig): From 5637c9d16adeff6a0a8439dae9c8f1f452b02ede Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 19 Nov 2025 21:03:07 +0800 Subject: [PATCH 02/17] feat: headers add --- src/memos/embedders/universal_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index fc51cf073..6cbd06443 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -16,7 +16,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig): self.config = config if self.provider == "openai": - self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url) + self.client = OpenAIClient( + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.headers_extra if config.headers_extra else None, + ) elif self.provider == "azure": self.client = AzureClient( azure_endpoint=config.base_url, From 68831c080e9db6633136f0ed7424b8b4bc275ccb Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 20 Nov 2025 16:45:03 +0800 Subject: [PATCH 03/17] feat: update search agent --- examples/deepsearch_example.py | 76 ++++++ src/memos/configs/mem_agent.py | 60 +++++ src/memos/mem_agent/base.py | 17 ++ src/memos/mem_agent/deepsearch_agent.py | 320 +++++++++++++++++++++++ src/memos/mem_agent/factory.py | 21 ++ src/memos/templates/mem_agent_prompts.py | 76 ++++++ 6 files changed, 570 insertions(+) create mode 100644 examples/deepsearch_example.py create mode 100644 src/memos/configs/mem_agent.py create mode 100644 src/memos/mem_agent/base.py create mode 100644 src/memos/mem_agent/deepsearch_agent.py create mode 100644 src/memos/mem_agent/factory.py create mode 100644 src/memos/templates/mem_agent_prompts.py diff --git a/examples/deepsearch_example.py b/examples/deepsearch_example.py new file mode 100644 index 000000000..43bf4c012 --- /dev/null +++ b/examples/deepsearch_example.py @@ -0,0 +1,76 @@ +""" +Example usage of DeepSearchAgent in MemOS. + +This example demonstrates how to initialize and use the DeepSearchAgent +for comprehensive information retrieval and synthesis. +""" + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchAgent +from memos.llms.factory import LLMFactory +from memos.configs.llm import LLMConfigFactory + + +def main(): + """Example usage of DeepSearchAgent.""" + + # 1. Configure the LLM + llm_config = LLMConfigFactory( + backend="openai", # or "ollama", "azure", etc. + config={ + "api_key": "your-api-key-here", + "model_name_or_path": "gpt-4", + "temperature": 0.7, + "max_tokens": 2048 + } + ) + + # 2. Create LLM instance + llm = LLMFactory.from_config(llm_config) + + # 3. Configure DeepSearchAgent + agent_config = DeepSearchAgentConfig( + agent_name="DeepSearchAgent", + description="Advanced deep search agent for comprehensive information retrieval", + max_iterations=3, + timeout=60 + ) + + # 4. Create DeepSearchAgent instance + agent = DeepSearchAgent(agent_config) + + # 5. Initialize the agent with LLM + agent.set_llm(llm) + + # 6. Set up memory retriever (this would typically be injected by the framework) + # agent.set_memory_retriever(your_memory_retriever) + + # 7. Example queries + queries = [ + "What are the latest developments in AI research?", + "Tell me about my recent project meetings and their outcomes", + "What are the key trends in machine learning this year?" + ] + + # 8. Process queries + for query in queries: + print(f"\n{'='*60}") + print(f"Query: {query}") + print(f"{'='*60}") + + try: + # Run the deep search pipeline + response = agent.run( + input=query, + history=["Previous conversation context"], + user_id="example_user" + ) + + print(f"Response: {response}") + + except Exception as e: + print(f"Error processing query: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py new file mode 100644 index 000000000..e58f8e67c --- /dev/null +++ b/src/memos/configs/mem_agent.py @@ -0,0 +1,60 @@ +from typing import Any, ClassVar + +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig + + +class BaseAgentConfig(BaseConfig): + """Base configuration class for agents.""" + + agent_name: str = Field(..., description="Name of the agent") + description: str | None = Field( + default=None, description="Description of the agent" + ) + + +class SimpleAgentConfig(BaseAgentConfig): + """Simple agent configuration class.""" + + max_iterations: int = Field( + default=10, description="Maximum number of iterations for the agent" + ) + timeout: int = Field( + default=30, description="Timeout in seconds for agent execution" + ) + +class DeepSearchAgentConfig(BaseAgentConfig): + """Deep search agent configuration class.""" + + max_iterations: int = Field( + default=3, description="Maximum number of iterations for the agent" + ) + timeout: int = Field( + default=30, description="Timeout in seconds for agent execution" + ) + +class MemAgentConfigFactory(BaseConfig): + """Factory class for creating agent configurations.""" + + backend: str = Field(..., description="Backend for agent") + config: dict[str, Any] = Field(..., description="Configuration for the agent backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "simple": SimpleAgentConfig, + "deep_search": DeepSearchAgentConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "AgentConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self \ No newline at end of file diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py new file mode 100644 index 000000000..3b54a846f --- /dev/null +++ b/src/memos/mem_agent/base.py @@ -0,0 +1,17 @@ +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod +from pydantic import BaseModel +from memos.configs.mem_agent import BaseAgentConfig + +class BaseMemAgent(ABC): + """ + Base class for all agents. + """ + def __init__(self, config: BaseAgentConfig): + """Initialize the BaseMemAgent with the given configuration.""" + + @abstractmethod + def run(self, input: str) -> str: + """ + Run the agent. + """ \ No newline at end of file diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py new file mode 100644 index 000000000..eb67fabba --- /dev/null +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -0,0 +1,320 @@ +""" +Deep Search Agent implementation for MemOS. + +This module implements a sophisticated deep search agent that performs iterative +query refinement and memory retrieval to provide comprehensive answers. +""" + +import json +import logging +from memos.log import get_logger +from typing import Any, Dict, List, Optional, Tuple + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.llms.base import BaseLLM +from memos.mem_agent.base import BaseMemAgent +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageDict, MessageList +from memos.templates.mem_agent_prompts import ( + QUERY_REWRITE_PROMPT, + REFLECTION_PROMPT, + KEYWORD_EXTRACTION_PROMPT, + FINAL_GENERATION_PROMPT + ) + +logger = get_logger(__name__) + + +class QueryRewriter(BaseMemAgent): + """ + Specialized agent for rewriting queries based on conversation history. + Corresponds to the "LLM subAgent (Rewrite...)" in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): + self.llm = llm + self.name = name + + def run(self, query: str, history: list[str]| None = None) -> str: + """ + Rewrite the query to be standalone and more searchable. + + Args: + query: Original user query + history: List of previous conversation messages + + Returns: + Rewritten query string + """ + if history is None: + history = [] + + history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages + + prompt = QUERY_REWRITE_PROMPT.format( + history=history_str if history_str else "No previous conversation", + query=query + ) + + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + logger.info(f"[{self.name}] Rewritten query: {response.strip()}") + return response.strip() + except Exception as e: + logger.error(f"[{self.name}] Error rewriting query: {e}") + return query # Fallback to original query + + +class ReflectionAgent: + """ + Specialized agent for analyzing information sufficiency. + Corresponds to the decision diamond in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, name: str = "Reflector"): + self.llm = llm + self.name = name + + def run(self, query: str, context: List[str]) -> Dict[str, Any]: + """ + Analyze whether retrieved context is sufficient to answer the query. + + Args: + query: User query + context: List of retrieved context strings + + Returns: + Dictionary with status, reasoning, and missing entities + """ + context_str = "\n".join([f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" + for ctx in context[:10]]) # Limit context size + + prompt = REFLECTION_PROMPT.format(query=query, context=context_str) + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + result = json.loads(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result.get('status', 'unknown')}") + return result + except (json.JSONDecodeError, Exception) as e: + logger.error(f"[{self.name}] Error in reflection analysis: {e}") + # Fallback response + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [] + } + + +class DeepSearchMemAgent(BaseMemAgent): + """ + Main orchestrator agent implementing the deep search pipeline. + + This agent coordinates multiple sub-agents to perform iterative query refinement, + memory retrieval, and information synthesis as shown in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, memory_retriever: BaseMemoryRetriever | None = None): + super().__init__(config) + self.config = config + self.max_iterations = config.max_iterations + self.timeout = config.timeout + self.llm: Optional[BaseLLM] = llm + self.query_rewriter: Optional[QueryRewriteAgent] = QueryRewriter(llm, "QueryRewriter") + self.reflector: Optional[ReflectionAgent] = ReflectionAgent(llm, "Reflector") + self.memory_retriever = memory_retriever + + def _set_llm(self, llm: BaseLLM) -> None: + """Set the LLM and initialize sub-agents.""" + self.llm = llm + self.query_rewriter = QueryRewriteAgent(llm, "QueryRewriter") + self.reflector = ReflectionAgent(llm, "Reflector") + self.keyword_extractor = KeywordExtractionAgent(llm, "KeywordExtractor") + logger.info("LLM and sub-agents initialized") + + def _set_memory_retriever(self, retriever) -> None: + """Set the memory retrieval interface.""" + self.memory_retriever = retriever + logger.info("Memory retriever interface set") + + def run(self, input: str, **kwargs) -> str: + """ + Main execution method implementing the deep search pipeline. + + Args: + input: User query string + Returns: + Comprehensive response string + """ + if not self.llm: + raise RuntimeError("LLM not initialized. Call set_llm() first.") + + query = input + history = kwargs.get("history", []) + user_id = kwargs.get("user_id") + + # Step 1: Query Rewriting + current_query = self.query_rewriter.run(query, history) + + # Step 2: Keyword Extraction and Planning + keyword_analysis = self.keyword_extractor.run(current_query) + search_keywords = keyword_analysis.get("keywords", [current_query]) + + accumulated_context = [] + accumulated_memories = [] + + # Step 3: Iterative Search and Reflection Loop + for iteration in range(self.max_iterations): + search_results = self._perform_memory_search( + current_query, + keywords=search_keywords, + user_id=user_id + ) + + if search_results: + context_batch = [self._extract_context_from_memory(mem) for mem in search_results] + accumulated_context.extend(context_batch) + accumulated_memories.extend(search_results) + + reflection_result = self.reflector.run(current_query, context_batch) + status = reflection_result.get("status", "sufficient") + reasoning = reflection_result.get("reasoning", "") + + logger.info(f"Reflection status: {status} - {reasoning}") + + if status == "sufficient": + logger.info("Sufficient information collected") + break + elif status == "needs_raw": + logger.info("Need original sources, retrieving raw content") + break + elif status == "missing_info": + missing_entities = reflection_result.get("missing_entities", []) + logger.info(f"Missing information: {missing_entities}") + if missing_entities: + refined_query = self._refine_query_for_missing_info( + current_query, missing_entities + ) + current_query = refined_query + logger.info(f"Refined query: {current_query}") + else: + logger.warning(f"No search results for iteration {iteration + 1}") + if iteration == 0: + current_query = query + else: + break + final_answer = self._generate_final_answer( + original_query=query, + search_results=accumulated_memories, + context=accumulated_context, + missing_info=keyword_analysis.get("search_strategy", "") + ) + + logger.info("Deep search pipeline completed") + return final_answer + + def _perform_memory_search( + self, + query: str, + keywords: List[str] = None, + user_id: str = None, + top_k: int = 10 + ) -> List[TextualMemoryItem]: + """ + Perform memory search using the configured retriever. + + Args: + query: Search query + keywords: Additional keywords for search + user_id: User identifier + top_k: Number of results to retrieve + + Returns: + List of retrieved memory items + """ + if not self.memory_retriever: + logger.warning("Memory retriever not configured, returning empty results") + return [] + + try: + # Use the memory retriever interface + # This is a placeholder - actual implementation depends on the retriever interface + search_query = query + if keywords and len(keywords) > 1: + search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords + + # Assuming the retriever has a search method similar to TreeTextMemory + results = self.memory_retriever.search( + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id + ) + + return results if isinstance(results, list) else [] + + except Exception as e: + logger.error(f"Error performing memory search: {e}") + return [] + + def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: + """Extract readable context from a memory item.""" + if hasattr(memory_item, 'memory'): + return str(memory_item.memory) + elif hasattr(memory_item, 'content'): + return str(memory_item.content) + else: + return str(memory_item) + + def _refine_query_for_missing_info(self, query: str, missing_entities: List[str]) -> str: + """Refine the query to search for missing information.""" + if not missing_entities: + return query + + # Simple refinement strategy - append missing entities + entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities + refined_query = f"{query} {entities_str}" + + return refined_query + + def _generate_final_answer( + self, + original_query: str, + search_results: List[TextualMemoryItem], + context: List[str], + missing_info: str = "" + ) -> str: + """ + Generate the final comprehensive answer. + + Args: + original_query: Original user query + search_results: All retrieved memory items + context: Extracted context strings + missing_info: Information about missing data + + Returns: + Final answer string + """ + # Prepare context for the prompt + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context + sources = f"Retrieved {len(search_results)} memory items" if search_results else "No specific sources" + + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified" + ) + + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"Error generating final answer: {e}") + return f"I apologize, but I encountered an error while processing your query: {original_query}. Please try again." diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py new file mode 100644 index 000000000..be656acc0 --- /dev/null +++ b/src/memos/mem_agent/factory.py @@ -0,0 +1,21 @@ +from typing import Any, ClassVar + +from memos.configs.mem_agent import MemAgentConfigFactory +from memos.mem_agent.base import BaseMemAgent +from memos.mem_agent.deepsearch_agent import DeepSearchAgent + + +class MemAgentFactory: + """Factory class for creating MemAgent instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "deep_search": DeepSearchAgent, + } + + @classmethod + def from_config(cls, config_factory: MemAgentConfigFactory) -> BaseMemAgent: + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + mem_agent_class = cls.backend_to_class[backend] + return mem_agent_class(config_factory.config) \ No newline at end of file diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py new file mode 100644 index 000000000..399b2cd75 --- /dev/null +++ b/src/memos/templates/mem_agent_prompts.py @@ -0,0 +1,76 @@ +QUERY_REWRITE_PROMPT = """ +You are a query rewriting specialist. Your task is to rewrite user queries to be more standalone and searchable. + +Given the conversation history and current user query, rewrite the query to: +1. Be self-contained and independent of conversation context +2. Include relevant context from history when necessary +3. Maintain the original intent and scope +4. Use clear, specific terminology + +Conversation History: +{history} + +Current Query: {query} + +Rewritten Query:""" + +REFLECTION_PROMPT = """ +You are an information sufficiency analyst. Evaluate whether the retrieved context is sufficient to answer the user's query. + +Query: {query} +Retrieved Context: +{context} + +Analyze the context and determine the next step. Return your response in JSON format with the following structure: +{{ + "status": "sufficient|missing_info|needs_raw", + "reasoning": "Brief explanation of your decision", + "missing_entities": ["entity1", "entity2"] // Only if status is "missing_info", + +}} + +Status definitions: +- "sufficient": Context fully answers the query +- "missing_info": Key information is missing (e.g., specific dates, locations, details) +- "needs_raw": Content is relevant but too summarized/vague, need original sources + +Response:""" + +KEYWORD_EXTRACTION_PROMPT = """ +Analyze the user query and extract key search terms and identify optimal data sources. + +Query: {query} + +Extract: +1. Key search terms and concepts +2. Important entities (people, places, dates, etc.) +3. Suggested data sources or memory types to search + +Return response in JSON format: +{{ + "keywords": ["keyword1", "keyword2"], + "entities": ["entity1", "entity2"], + "search_strategy": "Brief strategy description" +}} + +Response:""" + + +FINAL_GENERATION_PROMPT = """ +You are a comprehensive information synthesizer. Generate a complete answer based on the retrieved information. + +User Query: {query} +Search Sources: {sources} +Retrieved Information: +{context} + +Missing Information (if any): {missing_info} + +Instructions: +1. Synthesize all relevant information to answer the query comprehensively +2. If information is missing, acknowledge gaps and suggest next steps +3. Maintain accuracy and cite sources when possible +4. Provide a well-structured, coherent response +5. Use natural, conversational tone + +Response:""" \ No newline at end of file From 58c512d8a7086e0e723bf232257067f8c7516876 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:08:50 +0800 Subject: [PATCH 04/17] feat: upadte mem story --- examples/mem_agent/deepsearch_example.py | 350 ++++++++++++++++++ .../deepsearch_with_api_components.py | 223 +++++++++++ src/memos/mem_agent/deepsearch_agent.py | 57 +-- src/memos/mem_agent/factory.py | 28 +- 4 files changed, 628 insertions(+), 30 deletions(-) create mode 100644 examples/mem_agent/deepsearch_example.py create mode 100644 examples/mem_agent/deepsearch_with_api_components.py diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py new file mode 100644 index 000000000..62f26f0b1 --- /dev/null +++ b/examples/mem_agent/deepsearch_example.py @@ -0,0 +1,350 @@ +""" +DeepSearch Agent Usage Examples + +This example demonstrates two ways to initialize DeepSearchMemAgent: +1. Using Factory pattern (recommended) +2. Direct initialization + +DeepSearchMemAgent implements iterative deep search, providing comprehensive answers through: +- Query rewriting: Optimize queries based on conversation history +- Iterative retrieval: Collect information through multiple search rounds +- Reflective analysis: Determine if information is sufficient +- Comprehensive response: Generate complete final answers +""" + +import os +import sys + +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphConfigFactory +from memos.configs.llms import LLMConfigFactory +from memos.configs.mem_agent import DeepSearchAgentConfig, MemAgentConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.textual_memory import TreeTextMemoryConfig +from memos.embedders.base import BaseEmbedder +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.mem_agent.factory import MemAgentFactory +from memos.mem_cube.naive_cube import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.memory_manager.memory_manager import MemoryManager +from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory +from memos.log import get_logger + +logger = get_logger(__name__) + + +def build_minimal_components(): + """ + Build the minimal component set required for DeepSearchMemAgent. + + Only need to initialize: + 1. LLM - Used for query rewriting, reflection, and final answer generation + 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval + + Returns: + dict: Dictionary containing llm and naive_mem_cube + """ + logger.info("Starting to build minimal component set...") + + # 1. Initialize LLM + llm_config = LLMConfigFactory( + backend="openai", + config={ + "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + "temperature": 0.7, + } + ) + llm = LLMFactory.from_config(llm_config) + logger.info("LLM initialization completed") + + # 2. Initialize Embedder (required by text_mem) + embedder_config = EmbedderConfigFactory( + backend="universal_api", + config={ + "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), + "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), + "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), + } + ) + embedder = EmbedderFactory.from_config(embedder_config) + logger.info("Embedder initialization completed") + + # 3. Initialize GraphDB (required by text_mem) + graph_db_config = GraphConfigFactory( + backend="polardb", + config={ + "host": os.getenv("POLAR_DB_HOST", "localhost"), + "port": int(os.getenv("POLAR_DB_PORT", "5432")), + "user": os.getenv("POLAR_DB_USER", "root"), + "password": os.getenv("POLAR_DB_PASSWORD", "123456"), + "db_name": os.getenv("POLAR_DB_DB_NAME", "shared_memos_db"), + "user_name": "memos_default", + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + } + ) + graph_db = GraphStoreFactory.from_config(graph_db_config) + logger.info("✓ GraphDB init Done") + + # 4. Initialize MemReader (required by text_mem) + mem_reader_config = MemReaderConfigFactory( + backend="simple_struct", + config={ + "llm": { + "backend": "openai", + "config": { + "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + } + } + } + ) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + logger.info("✓ MemReader init Done") + + # 5. Initialize MemoryManager + memory_manager = MemoryManager( + graph_db=graph_db, + embedder=embedder, + llm=llm, + memory_size=1000, + is_reorganize=False, + ) + logger.info("✓ MemoryManager init Done") + + # 6. Initialize TextMemory + text_mem_config = TreeTextMemoryConfig( + reorganize=False, + max_depth=3, + ) + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=None, + memory_manager=memory_manager, + config=text_mem_config, + internet_retriever=None, + ) + logger.info("✓ TextMemory initialization completed") + + # 7. Create NaiveMemCube + naive_mem_cube = NaiveMemCube( + text_mem=text_mem, + pref_mem=None, + act_mem=None, + para_mem=None, + ) + logger.info("✓ NaiveMemCube creation completed") + + logger.info("All components initialized!") + + return { + "llm": llm, + "naive_mem_cube": naive_mem_cube, + "embedder": embedder, + "graph_db": graph_db, + "mem_reader": mem_reader, + } + + +def example_1_factory_initialization(): + """ + Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) + + Advantages: + - Separation of configuration and code + - Easy to manage and modify + - Support loading from configuration files + """ + logger.info("\n" + "="*60) + logger.info("Example 1: Initialize using Factory pattern") + logger.info("="*60 + "\n") + + # Build necessary components + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Create configuration Factory + agent_config_factory = MemAgentConfigFactory( + backend="deep_search", + config={ + "agent_name": "MyDeepSearchAgent", + "description": "Intelligent agent for deep search", + "max_iterations": 3, # Maximum number of iterations + "timeout": 60, # Timeout in seconds + } + ) + + # Create Agent using Factory + # Pass text_mem as memory_retriever, it provides search method + deep_search_agent = MemAgentFactory.from_config( + config_factory=agent_config_factory, + llm=llm, + memory_retriever=naive_mem_cube.text_mem + ) + + logger.info(f"✓ DeepSearchMemAgent created successfully") + logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def example_3_usage(deep_search_agent, components): + """ + Example 3: Using DeepSearchMemAgent for search + + Demonstrates how to: + 1. Add memories to the system + 2. Use Agent for deep search + 3. Get comprehensive answers + """ + logger.info("\n" + "="*60) + logger.info("Example 3: Using DeepSearchMemAgent") + logger.info("="*60 + "\n") + + naive_mem_cube = components["naive_mem_cube"] + text_mem = naive_mem_cube.text_mem + + # Simulate adding some memories + logger.info("1. Adding test memories...") + test_memories = [ + { + "user_name": "test_user", + "messages": ["Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence."], + "source": "manual" + }, + { + "user_name": "test_user", + "messages": ["Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance."], + "source": "manual" + }, + { + "user_name": "test_user", + "messages": ["Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks."], + "source": "manual" + } + ] + + for memory in test_memories: + try: + text_mem.add( + user_name=memory["user_name"], + messages=memory["messages"], + source=memory.get("source", "manual") + ) + logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") + except Exception as e: + logger.warning(f" ✗ Failed to add memory: {e}") + + # Use Agent for search + logger.info("\n2. Executing deep search...") + query = "What is the relationship between artificial intelligence and machine learning?" + logger.info(f" Query: {query}") + + try: + response = deep_search_agent.run( + query=query, + user_id="test_user", + history=[ + "Hello", + "I want to learn about artificial intelligence" + ] + ) + + logger.info("\n3. Search results:") + logger.info("-" * 60) + logger.info(response) + logger.info("-" * 60) + + except Exception as e: + logger.error(f"Error during search: {e}") + import traceback + traceback.print_exc() + + +def example_4_minimal_initialization(): + """ + Example 4: Minimal initialization (using default configuration) + + Use cases: + - Quick testing + - Prototype development + - No need for custom configuration + """ + logger.info("\n" + "="*60) + logger.info("Example 4: Minimal initialization (default configuration)") + logger.info("="*60 + "\n") + + # Build necessary components + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Direct initialization using default configuration + deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem + # config parameter omitted, will use default configuration + ) + + logger.info(f"✓ DeepSearchMemAgent created successfully (using default configuration)") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def main(): + """Main function: Run all examples""" + logger.info("DeepSearch Agent Usage Examples") + logger.info("="*60) + + # Check environment variables + required_env_vars = ["OPENAI_API_KEY"] + missing_vars = [var for var in required_env_vars if not os.getenv(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set the following environment variables:") + logger.error(" - OPENAI_API_KEY: OpenAI API key") + logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") + logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") + return + + try: + # Run Example 1: Factory pattern + agent_factory, components_factory = example_1_factory_initialization() + + # Run Example 2: Direct initialization + agent_direct, components_direct = example_2_direct_initialization() + + # Run Example 4: Minimal initialization + agent_minimal, components_minimal = example_4_minimal_initialization() + + # Run Example 3: Actual usage (using agent created with factory method) + example_3_usage(agent_factory, components_factory) + + logger.info("\n" + "="*60) + logger.info("All examples completed!") + logger.info("="*60) + + except Exception as e: + logger.error(f"Error running examples: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mem_agent/deepsearch_with_api_components.py b/examples/mem_agent/deepsearch_with_api_components.py new file mode 100644 index 000000000..48e4df43c --- /dev/null +++ b/examples/mem_agent/deepsearch_with_api_components.py @@ -0,0 +1,223 @@ +""" +DeepSearch Agent 与 API 组件集成示例 + +本示例展示如何直接使用 API 服务器初始化的组件来创建 DeepSearch Agent。 +这种方式可以避免重复初始化,直接复用已有的组件。 + +适用场景: +- 在已有的 API 服务器中添加 DeepSearch 功能 +- 使用统一的组件配置 +- 避免重复初始化开销 +""" + +import os +import sys + +# 确保可以导入 memos 模块 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from memos.api.handlers.component_init import init_server +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.log import get_logger + +logger = get_logger(__name__) + + +def create_deepsearch_from_api_components(): + """ + 从 API 服务器组件创建 DeepSearch Agent + + 这个函数演示了如何: + 1. 使用 init_server() 初始化所有组件 + 2. 提取需要的组件(llm 和 naive_mem_cube) + 3. 创建 DeepSearch Agent + """ + logger.info("="*60) + logger.info("使用 API 组件初始化 DeepSearch Agent") + logger.info("="*60 + "\n") + + # 步骤 1: 初始化所有服务器组件 + # 这会创建 llm, naive_mem_cube, scheduler 等所有组件 + logger.info("1. 初始化服务器组件...") + components = init_server() + logger.info(" ✓ 服务器组件初始化完成") + + # 步骤 2: 提取需要的组件 + logger.info("\n2. 提取必需组件...") + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + logger.info(f" ✓ LLM: {type(llm).__name__}") + logger.info(f" ✓ MemCube: {type(naive_mem_cube).__name__}") + logger.info(f" ✓ TextMemory: {type(naive_mem_cube.text_mem).__name__}") + + # 步骤 3: 创建 DeepSearch Agent 配置 + logger.info("\n3. 创建 DeepSearch Agent...") + config = DeepSearchAgentConfig( + agent_name="APIDeepSearchAgent", + description="基于 API 组件的深度搜索代理", + max_iterations=3, + timeout=60, + ) + + # 步骤 4: 初始化 DeepSearch Agent + # memory_retriever 使用 naive_mem_cube.text_mem + # 它提供了 search() 方法用于检索记忆 + deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem, + config=config + ) + + logger.info(f" ✓ DeepSearch Agent 创建成功") + logger.info(f" ✓ Agent 名称: {config.agent_name}") + logger.info(f" ✓ 最大迭代次数: {config.max_iterations}") + + return deep_search_agent, components + + +def demo_usage(deep_search_agent, components): + """ + 演示如何使用 DeepSearch Agent + """ + logger.info("\n" + "="*60) + logger.info("DeepSearch Agent 使用演示") + logger.info("="*60 + "\n") + + naive_mem_cube = components["naive_mem_cube"] + text_mem = naive_mem_cube.text_mem + + # 添加一些测试记忆 + logger.info("1. 添加测试记忆...") + test_data = [ + "MemOS 是一个先进的记忆操作系统,专门为 AI 系统设计。", + "MemOS 支持多种记忆类型:文本记忆、偏好记忆、行为记忆等。", + "DeepSearch Agent 是 MemOS 中的一个重要组件,用于深度搜索和信息检索。", + ] + + for i, content in enumerate(test_data, 1): + try: + text_mem.add( + user_name="demo_user", + messages=[content], + source="demo" + ) + logger.info(f" ✓ 记忆 {i}: {content[:40]}...") + except Exception as e: + logger.warning(f" ✗ 添加记忆失败: {e}") + + # 执行深度搜索 + logger.info("\n2. 执行深度搜索...") + query = "MemOS 支持哪些功能?" + logger.info(f" 查询: {query}") + + try: + response = deep_search_agent.run( + query=query, + user_id="demo_user", + history=[] + ) + + logger.info("\n3. 搜索结果:") + logger.info("-" * 60) + logger.info(response) + logger.info("-" * 60) + + return response + + except Exception as e: + logger.error(f" ✗ 搜索失败: {e}") + import traceback + traceback.print_exc() + return None + + +def integrate_with_api_router(): + """ + 展示如何将 DeepSearch Agent 集成到 API Router 中 + + 这段代码展示了在 server_router.py 中如何添加一个新的端点 + 来使用 DeepSearch Agent。 + """ + logger.info("\n" + "="*60) + logger.info("API Router 集成示例") + logger.info("="*60 + "\n") + + logger.info("在 server_router.py 中添加以下代码:\n") + + integration_code = ''' +# 在 server_router.py 顶部导入 +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent + +# 在初始化 handlers 之后,创建 DeepSearch Agent +deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem, + config=DeepSearchAgentConfig( + agent_name="APIDeepSearchAgent", + max_iterations=3, + timeout=60, + ) +) + +# 添加新的 API 端点 +@router.post("/deepsearch", summary="Deep search with memory") +def deep_search( + user_id: str, + query: str, + history: list[str] = None +): + """Execute deep search with iterative memory retrieval.""" + try: + response = deep_search_agent.run( + query=query, + user_id=user_id, + history=history or [] + ) + return {"status": "success", "response": response} + except Exception as e: + return {"status": "error", "message": str(e)} +''' + + print(integration_code) + + +def main(): + """主函数""" + logger.info("DeepSearch Agent 与 API 组件集成示例\n") + + # 检查环境变量 + required_env_vars = ["OPENAI_API_KEY"] + missing_vars = [var for var in required_env_vars if not os.getenv(var)] + + if missing_vars: + logger.error(f"❌ 缺少必要的环境变量: {', '.join(missing_vars)}") + logger.error("\n请设置以下环境变量:") + logger.error(" export OPENAI_API_KEY='your-api-key'") + logger.error(" export OPENAI_BASE_URL='your-base-url' # 可选") + return + + try: + # 创建 DeepSearch Agent + agent, components = create_deepsearch_from_api_components() + + # 演示使用 + demo_usage(agent, components) + + # 展示集成方法 + integrate_with_api_router() + + logger.info("\n" + "="*60) + logger.info("示例运行完成!") + logger.info("="*60) + + except Exception as e: + logger.error(f"\n❌ 运行失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() + diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index eb67fabba..6dd072728 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -117,57 +117,60 @@ class DeepSearchMemAgent(BaseMemAgent): memory retrieval, and information synthesis as shown in the architecture diagram. """ - def __init__(self, llm: BaseLLM, memory_retriever: BaseMemoryRetriever | None = None): - super().__init__(config) - self.config = config - self.max_iterations = config.max_iterations - self.timeout = config.timeout - self.llm: Optional[BaseLLM] = llm - self.query_rewriter: Optional[QueryRewriteAgent] = QueryRewriter(llm, "QueryRewriter") - self.reflector: Optional[ReflectionAgent] = ReflectionAgent(llm, "Reflector") + def __init__( + self, + llm: BaseLLM, + memory_retriever: Any = None, + config: Optional[DeepSearchAgentConfig] = None + ): + """ + Initialize DeepSearchMemAgent. + + Args: + llm: Language model for query rewriting and response generation + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + config: Configuration for deep search behavior + """ + self.config = config or DeepSearchAgentConfig() + self.max_iterations = self.config.max_iterations + self.timeout = self.config.timeout + self.llm: BaseLLM = llm + self.query_rewriter: QueryRewriter = QueryRewriter(llm, "QueryRewriter") + self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") self.memory_retriever = memory_retriever - def _set_llm(self, llm: BaseLLM) -> None: - """Set the LLM and initialize sub-agents.""" - self.llm = llm - self.query_rewriter = QueryRewriteAgent(llm, "QueryRewriter") - self.reflector = ReflectionAgent(llm, "Reflector") - self.keyword_extractor = KeywordExtractionAgent(llm, "KeywordExtractor") - logger.info("LLM and sub-agents initialized") - def _set_memory_retriever(self, retriever) -> None: """Set the memory retrieval interface.""" self.memory_retriever = retriever logger.info("Memory retriever interface set") - def run(self, input: str, **kwargs) -> str: + def run(self, query: str, **kwargs) -> str: """ Main execution method implementing the deep search pipeline. Args: - input: User query string + query: User query string + **kwargs: Additional arguments (history, user_id, etc.) Returns: Comprehensive response string """ if not self.llm: - raise RuntimeError("LLM not initialized. Call set_llm() first.") + raise RuntimeError("LLM not initialized.") - query = input history = kwargs.get("history", []) user_id = kwargs.get("user_id") # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) - # Step 2: Keyword Extraction and Planning - keyword_analysis = self.keyword_extractor.run(current_query) - search_keywords = keyword_analysis.get("keywords", [current_query]) - accumulated_context = [] accumulated_memories = [] + search_keywords = [] # Can be extended with keyword extraction - # Step 3: Iterative Search and Reflection Loop + # Step 2: Iterative Search and Reflection Loop for iteration in range(self.max_iterations): + logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") + search_results = self._perform_memory_search( current_query, keywords=search_keywords, @@ -206,11 +209,13 @@ def run(self, input: str, **kwargs) -> str: current_query = query else: break + + # Step 3: Generate final answer final_answer = self._generate_final_answer( original_query=query, search_results=accumulated_memories, context=accumulated_context, - missing_info=keyword_analysis.get("search_strategy", "") + missing_info="" ) logger.info("Deep search pipeline completed") diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index be656acc0..33e7b6a6b 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -2,20 +2,40 @@ from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent -from memos.mem_agent.deepsearch_agent import DeepSearchAgent +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent class MemAgentFactory: """Factory class for creating MemAgent instances.""" backend_to_class: ClassVar[dict[str, Any]] = { - "deep_search": DeepSearchAgent, + "deep_search": DeepSearchMemAgent, } @classmethod - def from_config(cls, config_factory: MemAgentConfigFactory) -> BaseMemAgent: + def from_config( + cls, + config_factory: MemAgentConfigFactory, + llm: Any, + memory_retriever: Any = None + ) -> BaseMemAgent: + """ + Create a MemAgent instance from configuration. + + Args: + config_factory: Configuration factory for the agent + llm: Language model instance + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + + Returns: + Initialized MemAgent instance + """ backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") mem_agent_class = cls.backend_to_class[backend] - return mem_agent_class(config_factory.config) \ No newline at end of file + return mem_agent_class( + llm=llm, + memory_retriever=memory_retriever, + config=config_factory.config + ) \ No newline at end of file From a497d4614c8bc39b2f868a28952e7add5692e74e Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:12:07 +0800 Subject: [PATCH 05/17] feat: update mem scehduler --- examples/deepsearch_example.py | 76 ------ examples/mem_agent/deepsearch_example.py | 156 ++++++------ .../deepsearch_with_api_components.py | 223 ------------------ src/memos/configs/mem_agent.py | 20 +- src/memos/mem_agent/base.py | 7 +- src/memos/mem_agent/deepsearch_agent.py | 149 ++++++------ src/memos/mem_agent/factory.py | 17 +- src/memos/templates/mem_agent_prompts.py | 2 +- 8 files changed, 169 insertions(+), 481 deletions(-) delete mode 100644 examples/deepsearch_example.py delete mode 100644 examples/mem_agent/deepsearch_with_api_components.py diff --git a/examples/deepsearch_example.py b/examples/deepsearch_example.py deleted file mode 100644 index 43bf4c012..000000000 --- a/examples/deepsearch_example.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Example usage of DeepSearchAgent in MemOS. - -This example demonstrates how to initialize and use the DeepSearchAgent -for comprehensive information retrieval and synthesis. -""" - -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchAgent -from memos.llms.factory import LLMFactory -from memos.configs.llm import LLMConfigFactory - - -def main(): - """Example usage of DeepSearchAgent.""" - - # 1. Configure the LLM - llm_config = LLMConfigFactory( - backend="openai", # or "ollama", "azure", etc. - config={ - "api_key": "your-api-key-here", - "model_name_or_path": "gpt-4", - "temperature": 0.7, - "max_tokens": 2048 - } - ) - - # 2. Create LLM instance - llm = LLMFactory.from_config(llm_config) - - # 3. Configure DeepSearchAgent - agent_config = DeepSearchAgentConfig( - agent_name="DeepSearchAgent", - description="Advanced deep search agent for comprehensive information retrieval", - max_iterations=3, - timeout=60 - ) - - # 4. Create DeepSearchAgent instance - agent = DeepSearchAgent(agent_config) - - # 5. Initialize the agent with LLM - agent.set_llm(llm) - - # 6. Set up memory retriever (this would typically be injected by the framework) - # agent.set_memory_retriever(your_memory_retriever) - - # 7. Example queries - queries = [ - "What are the latest developments in AI research?", - "Tell me about my recent project meetings and their outcomes", - "What are the key trends in machine learning this year?" - ] - - # 8. Process queries - for query in queries: - print(f"\n{'='*60}") - print(f"Query: {query}") - print(f"{'='*60}") - - try: - # Run the deep search pipeline - response = agent.run( - input=query, - history=["Previous conversation context"], - user_id="example_user" - ) - - print(f"Response: {response}") - - except Exception as e: - print(f"Error processing query: {e}") - - -if __name__ == "__main__": - main() diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 62f26f0b1..dacfe7f1e 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -13,25 +13,24 @@ """ import os -import sys from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphConfigFactory from memos.configs.llms import LLMConfigFactory -from memos.configs.mem_agent import DeepSearchAgentConfig, MemAgentConfigFactory +from memos.configs.mem_agent import MemAgentConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory from memos.configs.textual_memory import TreeTextMemoryConfig -from memos.embedders.base import BaseEmbedder from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory +from memos.log import get_logger from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.mem_agent.factory import MemAgentFactory from memos.mem_cube.naive_cube import NaiveMemCube from memos.mem_reader.factory import MemReaderFactory from memos.memories.memory_manager.memory_manager import MemoryManager from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory -from memos.log import get_logger + logger = get_logger(__name__) @@ -39,16 +38,16 @@ def build_minimal_components(): """ Build the minimal component set required for DeepSearchMemAgent. - + Only need to initialize: 1. LLM - Used for query rewriting, reflection, and final answer generation 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval - + Returns: dict: Dictionary containing llm and naive_mem_cube """ logger.info("Starting to build minimal component set...") - + # 1. Initialize LLM llm_config = LLMConfigFactory( backend="openai", @@ -57,11 +56,11 @@ def build_minimal_components(): "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), "temperature": 0.7, - } + }, ) llm = LLMFactory.from_config(llm_config) logger.info("LLM initialization completed") - + # 2. Initialize Embedder (required by text_mem) embedder_config = EmbedderConfigFactory( backend="universal_api", @@ -69,11 +68,11 @@ def build_minimal_components(): "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), - } + }, ) embedder = EmbedderFactory.from_config(embedder_config) logger.info("Embedder initialization completed") - + # 3. Initialize GraphDB (required by text_mem) graph_db_config = GraphConfigFactory( backend="polardb", @@ -87,11 +86,11 @@ def build_minimal_components(): "use_multi_db": False, "auto_create": True, "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - } + }, ) graph_db = GraphStoreFactory.from_config(graph_db_config) logger.info("✓ GraphDB init Done") - + # 4. Initialize MemReader (required by text_mem) mem_reader_config = MemReaderConfigFactory( backend="simple_struct", @@ -102,13 +101,13 @@ def build_minimal_components(): "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), - } + }, } - } + }, ) mem_reader = MemReaderFactory.from_config(mem_reader_config) logger.info("✓ MemReader init Done") - + # 5. Initialize MemoryManager memory_manager = MemoryManager( graph_db=graph_db, @@ -118,7 +117,7 @@ def build_minimal_components(): is_reorganize=False, ) logger.info("✓ MemoryManager init Done") - + # 6. Initialize TextMemory text_mem_config = TreeTextMemoryConfig( reorganize=False, @@ -135,7 +134,7 @@ def build_minimal_components(): internet_retriever=None, ) logger.info("✓ TextMemory initialization completed") - + # 7. Create NaiveMemCube naive_mem_cube = NaiveMemCube( text_mem=text_mem, @@ -144,9 +143,9 @@ def build_minimal_components(): para_mem=None, ) logger.info("✓ NaiveMemCube creation completed") - + logger.info("All components initialized!") - + return { "llm": llm, "naive_mem_cube": naive_mem_cube, @@ -159,21 +158,21 @@ def build_minimal_components(): def example_1_factory_initialization(): """ Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) - + Advantages: - Separation of configuration and code - Easy to manage and modify - Support loading from configuration files """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 1: Initialize using Factory pattern") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + # Build necessary components components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - + # Create configuration Factory agent_config_factory = MemAgentConfigFactory( backend="deep_search", @@ -182,139 +181,141 @@ def example_1_factory_initialization(): "description": "Intelligent agent for deep search", "max_iterations": 3, # Maximum number of iterations "timeout": 60, # Timeout in seconds - } + }, ) - + # Create Agent using Factory # Pass text_mem as memory_retriever, it provides search method deep_search_agent = MemAgentFactory.from_config( - config_factory=agent_config_factory, - llm=llm, - memory_retriever=naive_mem_cube.text_mem + config_factory=agent_config_factory, llm=llm, memory_retriever=naive_mem_cube.text_mem ) - - logger.info(f"✓ DeepSearchMemAgent created successfully") + + logger.info("✓ DeepSearchMemAgent created successfully") logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - + return deep_search_agent, components def example_3_usage(deep_search_agent, components): """ Example 3: Using DeepSearchMemAgent for search - + Demonstrates how to: 1. Add memories to the system 2. Use Agent for deep search 3. Get comprehensive answers """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 3: Using DeepSearchMemAgent") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + naive_mem_cube = components["naive_mem_cube"] text_mem = naive_mem_cube.text_mem - + # Simulate adding some memories logger.info("1. Adding test memories...") test_memories = [ { "user_name": "test_user", - "messages": ["Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence."], - "source": "manual" + "messages": [ + "Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence." + ], + "source": "manual", }, { - "user_name": "test_user", - "messages": ["Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance."], - "source": "manual" + "user_name": "test_user", + "messages": [ + "Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance." + ], + "source": "manual", }, { "user_name": "test_user", - "messages": ["Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks."], - "source": "manual" - } + "messages": [ + "Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks." + ], + "source": "manual", + }, ] - + for memory in test_memories: try: text_mem.add( user_name=memory["user_name"], messages=memory["messages"], - source=memory.get("source", "manual") + source=memory.get("source", "manual"), ) logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") except Exception as e: logger.warning(f" ✗ Failed to add memory: {e}") - + # Use Agent for search logger.info("\n2. Executing deep search...") query = "What is the relationship between artificial intelligence and machine learning?" logger.info(f" Query: {query}") - + try: response = deep_search_agent.run( query=query, user_id="test_user", - history=[ - "Hello", - "I want to learn about artificial intelligence" - ] + history=["Hello", "I want to learn about artificial intelligence"], ) - + logger.info("\n3. Search results:") logger.info("-" * 60) logger.info(response) logger.info("-" * 60) - + except Exception as e: logger.error(f"Error during search: {e}") import traceback + traceback.print_exc() def example_4_minimal_initialization(): """ Example 4: Minimal initialization (using default configuration) - + Use cases: - Quick testing - Prototype development - No need for custom configuration """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 4: Minimal initialization (default configuration)") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + # Build necessary components components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - + # Direct initialization using default configuration deep_search_agent = DeepSearchMemAgent( llm=llm, - memory_retriever=naive_mem_cube.text_mem + memory_retriever=naive_mem_cube.text_mem, # config parameter omitted, will use default configuration ) - - logger.info(f"✓ DeepSearchMemAgent created successfully (using default configuration)") + + logger.info("✓ DeepSearchMemAgent created successfully (using default configuration)") logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - + return deep_search_agent, components def main(): """Main function: Run all examples""" logger.info("DeepSearch Agent Usage Examples") - logger.info("="*60) - + logger.info("=" * 60) + # Check environment variables required_env_vars = ["OPENAI_API_KEY"] missing_vars = [var for var in required_env_vars if not os.getenv(var)] - + if missing_vars: logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") logger.error("Please set the following environment variables:") @@ -322,29 +323,30 @@ def main(): logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") return - + try: # Run Example 1: Factory pattern agent_factory, components_factory = example_1_factory_initialization() - + # Run Example 2: Direct initialization agent_direct, components_direct = example_2_direct_initialization() - + # Run Example 4: Minimal initialization agent_minimal, components_minimal = example_4_minimal_initialization() - + # Run Example 3: Actual usage (using agent created with factory method) example_3_usage(agent_factory, components_factory) - - logger.info("\n" + "="*60) + + logger.info("\n" + "=" * 60) logger.info("All examples completed!") - logger.info("="*60) - + logger.info("=" * 60) + except Exception as e: logger.error(f"Error running examples: {e}") import traceback + traceback.print_exc() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/mem_agent/deepsearch_with_api_components.py b/examples/mem_agent/deepsearch_with_api_components.py deleted file mode 100644 index 48e4df43c..000000000 --- a/examples/mem_agent/deepsearch_with_api_components.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -DeepSearch Agent 与 API 组件集成示例 - -本示例展示如何直接使用 API 服务器初始化的组件来创建 DeepSearch Agent。 -这种方式可以避免重复初始化,直接复用已有的组件。 - -适用场景: -- 在已有的 API 服务器中添加 DeepSearch 功能 -- 使用统一的组件配置 -- 避免重复初始化开销 -""" - -import os -import sys - -# 确保可以导入 memos 模块 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) - -from memos.api.handlers.component_init import init_server -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent -from memos.log import get_logger - -logger = get_logger(__name__) - - -def create_deepsearch_from_api_components(): - """ - 从 API 服务器组件创建 DeepSearch Agent - - 这个函数演示了如何: - 1. 使用 init_server() 初始化所有组件 - 2. 提取需要的组件(llm 和 naive_mem_cube) - 3. 创建 DeepSearch Agent - """ - logger.info("="*60) - logger.info("使用 API 组件初始化 DeepSearch Agent") - logger.info("="*60 + "\n") - - # 步骤 1: 初始化所有服务器组件 - # 这会创建 llm, naive_mem_cube, scheduler 等所有组件 - logger.info("1. 初始化服务器组件...") - components = init_server() - logger.info(" ✓ 服务器组件初始化完成") - - # 步骤 2: 提取需要的组件 - logger.info("\n2. 提取必需组件...") - llm = components["llm"] - naive_mem_cube = components["naive_mem_cube"] - logger.info(f" ✓ LLM: {type(llm).__name__}") - logger.info(f" ✓ MemCube: {type(naive_mem_cube).__name__}") - logger.info(f" ✓ TextMemory: {type(naive_mem_cube.text_mem).__name__}") - - # 步骤 3: 创建 DeepSearch Agent 配置 - logger.info("\n3. 创建 DeepSearch Agent...") - config = DeepSearchAgentConfig( - agent_name="APIDeepSearchAgent", - description="基于 API 组件的深度搜索代理", - max_iterations=3, - timeout=60, - ) - - # 步骤 4: 初始化 DeepSearch Agent - # memory_retriever 使用 naive_mem_cube.text_mem - # 它提供了 search() 方法用于检索记忆 - deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - config=config - ) - - logger.info(f" ✓ DeepSearch Agent 创建成功") - logger.info(f" ✓ Agent 名称: {config.agent_name}") - logger.info(f" ✓ 最大迭代次数: {config.max_iterations}") - - return deep_search_agent, components - - -def demo_usage(deep_search_agent, components): - """ - 演示如何使用 DeepSearch Agent - """ - logger.info("\n" + "="*60) - logger.info("DeepSearch Agent 使用演示") - logger.info("="*60 + "\n") - - naive_mem_cube = components["naive_mem_cube"] - text_mem = naive_mem_cube.text_mem - - # 添加一些测试记忆 - logger.info("1. 添加测试记忆...") - test_data = [ - "MemOS 是一个先进的记忆操作系统,专门为 AI 系统设计。", - "MemOS 支持多种记忆类型:文本记忆、偏好记忆、行为记忆等。", - "DeepSearch Agent 是 MemOS 中的一个重要组件,用于深度搜索和信息检索。", - ] - - for i, content in enumerate(test_data, 1): - try: - text_mem.add( - user_name="demo_user", - messages=[content], - source="demo" - ) - logger.info(f" ✓ 记忆 {i}: {content[:40]}...") - except Exception as e: - logger.warning(f" ✗ 添加记忆失败: {e}") - - # 执行深度搜索 - logger.info("\n2. 执行深度搜索...") - query = "MemOS 支持哪些功能?" - logger.info(f" 查询: {query}") - - try: - response = deep_search_agent.run( - query=query, - user_id="demo_user", - history=[] - ) - - logger.info("\n3. 搜索结果:") - logger.info("-" * 60) - logger.info(response) - logger.info("-" * 60) - - return response - - except Exception as e: - logger.error(f" ✗ 搜索失败: {e}") - import traceback - traceback.print_exc() - return None - - -def integrate_with_api_router(): - """ - 展示如何将 DeepSearch Agent 集成到 API Router 中 - - 这段代码展示了在 server_router.py 中如何添加一个新的端点 - 来使用 DeepSearch Agent。 - """ - logger.info("\n" + "="*60) - logger.info("API Router 集成示例") - logger.info("="*60 + "\n") - - logger.info("在 server_router.py 中添加以下代码:\n") - - integration_code = ''' -# 在 server_router.py 顶部导入 -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent - -# 在初始化 handlers 之后,创建 DeepSearch Agent -deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - config=DeepSearchAgentConfig( - agent_name="APIDeepSearchAgent", - max_iterations=3, - timeout=60, - ) -) - -# 添加新的 API 端点 -@router.post("/deepsearch", summary="Deep search with memory") -def deep_search( - user_id: str, - query: str, - history: list[str] = None -): - """Execute deep search with iterative memory retrieval.""" - try: - response = deep_search_agent.run( - query=query, - user_id=user_id, - history=history or [] - ) - return {"status": "success", "response": response} - except Exception as e: - return {"status": "error", "message": str(e)} -''' - - print(integration_code) - - -def main(): - """主函数""" - logger.info("DeepSearch Agent 与 API 组件集成示例\n") - - # 检查环境变量 - required_env_vars = ["OPENAI_API_KEY"] - missing_vars = [var for var in required_env_vars if not os.getenv(var)] - - if missing_vars: - logger.error(f"❌ 缺少必要的环境变量: {', '.join(missing_vars)}") - logger.error("\n请设置以下环境变量:") - logger.error(" export OPENAI_API_KEY='your-api-key'") - logger.error(" export OPENAI_BASE_URL='your-base-url' # 可选") - return - - try: - # 创建 DeepSearch Agent - agent, components = create_deepsearch_from_api_components() - - # 演示使用 - demo_usage(agent, components) - - # 展示集成方法 - integrate_with_api_router() - - logger.info("\n" + "="*60) - logger.info("示例运行完成!") - logger.info("="*60) - - except Exception as e: - logger.error(f"\n❌ 运行失败: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() - diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py index e58f8e67c..3abe5d64d 100644 --- a/src/memos/configs/mem_agent.py +++ b/src/memos/configs/mem_agent.py @@ -9,9 +9,7 @@ class BaseAgentConfig(BaseConfig): """Base configuration class for agents.""" agent_name: str = Field(..., description="Name of the agent") - description: str | None = Field( - default=None, description="Description of the agent" - ) + description: str | None = Field(default=None, description="Description of the agent") class SimpleAgentConfig(BaseAgentConfig): @@ -20,19 +18,15 @@ class SimpleAgentConfig(BaseAgentConfig): max_iterations: int = Field( default=10, description="Maximum number of iterations for the agent" ) - timeout: int = Field( - default=30, description="Timeout in seconds for agent execution" - ) + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + class DeepSearchAgentConfig(BaseAgentConfig): """Deep search agent configuration class.""" - max_iterations: int = Field( - default=3, description="Maximum number of iterations for the agent" - ) - timeout: int = Field( - default=30, description="Timeout in seconds for agent execution" - ) + max_iterations: int = Field(default=3, description="Maximum number of iterations for the agent") + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + class MemAgentConfigFactory(BaseConfig): """Factory class for creating agent configurations.""" @@ -57,4 +51,4 @@ def validate_backend(cls, backend: str) -> str: def create_config(self) -> "AgentConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) - return self \ No newline at end of file + return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py index 3b54a846f..e4746cad6 100644 --- a/src/memos/mem_agent/base.py +++ b/src/memos/mem_agent/base.py @@ -1,12 +1,13 @@ -from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod -from pydantic import BaseModel + from memos.configs.mem_agent import BaseAgentConfig + class BaseMemAgent(ABC): """ Base class for all agents. """ + def __init__(self, config: BaseAgentConfig): """Initialize the BaseMemAgent with the given configuration.""" @@ -14,4 +15,4 @@ def __init__(self, config: BaseAgentConfig): def run(self, input: str) -> str: """ Run the agent. - """ \ No newline at end of file + """ diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 6dd072728..70212921a 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -6,21 +6,21 @@ """ import json -import logging -from memos.log import get_logger -from typing import Any, Dict, List, Optional, Tuple + +from typing import Optional, Any from memos.configs.mem_agent import DeepSearchAgentConfig from memos.llms.base import BaseLLM +from memos.log import get_logger from memos.mem_agent.base import BaseMemAgent from memos.memories.textual.item import TextualMemoryItem -from memos.types import MessageDict, MessageList from memos.templates.mem_agent_prompts import ( + FINAL_GENERATION_PROMPT, QUERY_REWRITE_PROMPT, REFLECTION_PROMPT, - KEYWORD_EXTRACTION_PROMPT, - FINAL_GENERATION_PROMPT - ) +) +from memos.types import MessageList + logger = get_logger(__name__) @@ -35,29 +35,28 @@ def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): self.llm = llm self.name = name - def run(self, query: str, history: list[str]| None = None) -> str: + def run(self, query: str, history: list[str] | None = None) -> str: """ Rewrite the query to be standalone and more searchable. - + Args: query: Original user query history: List of previous conversation messages - + Returns: Rewritten query string """ if history is None: history = [] - + history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages - + prompt = QUERY_REWRITE_PROMPT.format( - history=history_str if history_str else "No previous conversation", - query=query + history=history_str if history_str else "No previous conversation", query=query ) - + messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) logger.info(f"[{self.name}] Rewritten query: {response.strip()}") @@ -77,23 +76,24 @@ def __init__(self, llm: BaseLLM, name: str = "Reflector"): self.llm = llm self.name = name - def run(self, query: str, context: List[str]) -> Dict[str, Any]: + def run(self, query: str, context: list[str]) -> dict[str, Any]: """ Analyze whether retrieved context is sufficient to answer the query. - + Args: query: User query context: List of retrieved context strings - + Returns: Dictionary with status, reasoning, and missing entities """ - context_str = "\n".join([f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" - for ctx in context[:10]]) # Limit context size - + context_str = "\n".join( + [f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]] + ) # Limit context size + prompt = REFLECTION_PROMPT.format(query=query, context=context_str) messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) result = json.loads(response.strip()) @@ -105,27 +105,27 @@ def run(self, query: str, context: List[str]) -> Dict[str, Any]: return { "status": "sufficient", "reasoning": "Unable to analyze, proceeding with available information", - "missing_entities": [] + "missing_entities": [], } class DeepSearchMemAgent(BaseMemAgent): """ Main orchestrator agent implementing the deep search pipeline. - + This agent coordinates multiple sub-agents to perform iterative query refinement, memory retrieval, and information synthesis as shown in the architecture diagram. """ def __init__( - self, - llm: BaseLLM, - memory_retriever: Any = None, - config: Optional[DeepSearchAgentConfig] = None + self, + llm: BaseLLM, + memory_retriever: Optional[Any] = None, + config: DeepSearchAgentConfig | None = None, ): """ Initialize DeepSearchMemAgent. - + Args: llm: Language model for query rewriting and response generation memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) @@ -147,7 +147,7 @@ def _set_memory_retriever(self, retriever) -> None: def run(self, query: str, **kwargs) -> str: """ Main execution method implementing the deep search pipeline. - + Args: query: User query string **kwargs: Additional arguments (history, user_id, etc.) @@ -156,38 +156,36 @@ def run(self, query: str, **kwargs) -> str: """ if not self.llm: raise RuntimeError("LLM not initialized.") - + history = kwargs.get("history", []) user_id = kwargs.get("user_id") - + # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) - + accumulated_context = [] accumulated_memories = [] search_keywords = [] # Can be extended with keyword extraction - + # Step 2: Iterative Search and Reflection Loop for iteration in range(self.max_iterations): logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") - + search_results = self._perform_memory_search( - current_query, - keywords=search_keywords, - user_id=user_id + current_query, keywords=search_keywords, user_id=user_id ) - + if search_results: context_batch = [self._extract_context_from_memory(mem) for mem in search_results] accumulated_context.extend(context_batch) accumulated_memories.extend(search_results) - + reflection_result = self.reflector.run(current_query, context_batch) status = reflection_result.get("status", "sufficient") reasoning = reflection_result.get("reasoning", "") - + logger.info(f"Reflection status: {status} - {reasoning}") - + if status == "sufficient": logger.info("Sufficient information collected") break @@ -209,114 +207,111 @@ def run(self, query: str, **kwargs) -> str: current_query = query else: break - + # Step 3: Generate final answer final_answer = self._generate_final_answer( original_query=query, search_results=accumulated_memories, context=accumulated_context, - missing_info="" + missing_info="", ) - + logger.info("Deep search pipeline completed") return final_answer def _perform_memory_search( - self, - query: str, - keywords: List[str] = None, - user_id: str = None, - top_k: int = 10 - ) -> List[TextualMemoryItem]: + self, query: str, keywords: Optional[list[str]] = None, user_id: Optional[str] = None, top_k: int = 10 + ) -> list[TextualMemoryItem]: """ Perform memory search using the configured retriever. - + Args: query: Search query keywords: Additional keywords for search user_id: User identifier top_k: Number of results to retrieve - + Returns: List of retrieved memory items """ if not self.memory_retriever: logger.warning("Memory retriever not configured, returning empty results") return [] - + try: # Use the memory retriever interface # This is a placeholder - actual implementation depends on the retriever interface search_query = query if keywords and len(keywords) > 1: search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords - + # Assuming the retriever has a search method similar to TreeTextMemory results = self.memory_retriever.search( - query=search_query, - top_k=top_k, - mode="fast", - user_name=user_id + query=search_query, top_k=top_k, mode="fast", user_name=user_id ) - + return results if isinstance(results, list) else [] - + except Exception as e: logger.error(f"Error performing memory search: {e}") return [] def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: """Extract readable context from a memory item.""" - if hasattr(memory_item, 'memory'): + if hasattr(memory_item, "memory"): return str(memory_item.memory) - elif hasattr(memory_item, 'content'): + elif hasattr(memory_item, "content"): return str(memory_item.content) else: return str(memory_item) - def _refine_query_for_missing_info(self, query: str, missing_entities: List[str]) -> str: + def _refine_query_for_missing_info(self, query: str, missing_entities: list[str]) -> str: """Refine the query to search for missing information.""" if not missing_entities: return query - + # Simple refinement strategy - append missing entities entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities refined_query = f"{query} {entities_str}" - + return refined_query def _generate_final_answer( self, original_query: str, - search_results: List[TextualMemoryItem], - context: List[str], - missing_info: str = "" + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", ) -> str: """ Generate the final comprehensive answer. - + Args: original_query: Original user query search_results: All retrieved memory items context: Extracted context strings missing_info: Information about missing data - + Returns: Final answer string """ # Prepare context for the prompt context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context - sources = f"Retrieved {len(search_results)} memory items" if search_results else "No specific sources" - + sources = ( + f"Retrieved {len(search_results)} memory items" + if search_results + else "No specific sources" + ) + prompt = FINAL_GENERATION_PROMPT.format( query=original_query, sources=sources, context=context_str if context_str else "No specific context retrieved", - missing_info=missing_info if missing_info else "None identified" + missing_info=missing_info if missing_info else "None identified", ) - + messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) return response.strip() diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index 33e7b6a6b..415d4916e 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar +from typing import Optional, Any, ClassVar from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent @@ -14,19 +14,16 @@ class MemAgentFactory: @classmethod def from_config( - cls, - config_factory: MemAgentConfigFactory, - llm: Any, - memory_retriever: Any = None + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Optional[Any] = None ) -> BaseMemAgent: """ Create a MemAgent instance from configuration. - + Args: config_factory: Configuration factory for the agent llm: Language model instance memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) - + Returns: Initialized MemAgent instance """ @@ -35,7 +32,5 @@ def from_config( raise ValueError(f"Invalid backend: {backend}") mem_agent_class = cls.backend_to_class[backend] return mem_agent_class( - llm=llm, - memory_retriever=memory_retriever, - config=config_factory.config - ) \ No newline at end of file + llm=llm, memory_retriever=memory_retriever, config=config_factory.config + ) diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 399b2cd75..17e3c0ad9 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -73,4 +73,4 @@ 4. Provide a well-structured, coherent response 5. Use natural, conversational tone -Response:""" \ No newline at end of file +Response:""" From bd72e9bb804d84f1a12e4c817bae494f0359d600 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:33:53 +0800 Subject: [PATCH 06/17] feat: update deepsearch mem code --- examples/mem_agent/deepsearch_example.py | 3 --- src/memos/configs/mem_agent.py | 2 +- src/memos/mem_agent/base.py | 1 + src/memos/mem_agent/deepsearch_agent.py | 15 +++++++++++---- src/memos/mem_agent/factory.py | 4 ++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index dacfe7f1e..d0842b6ec 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -328,9 +328,6 @@ def main(): # Run Example 1: Factory pattern agent_factory, components_factory = example_1_factory_initialization() - # Run Example 2: Direct initialization - agent_direct, components_direct = example_2_direct_initialization() - # Run Example 4: Minimal initialization agent_minimal, components_minimal = example_4_minimal_initialization() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py index 3abe5d64d..7cb623899 100644 --- a/src/memos/configs/mem_agent.py +++ b/src/memos/configs/mem_agent.py @@ -48,7 +48,7 @@ def validate_backend(cls, backend: str) -> str: return backend @model_validator(mode="after") - def create_config(self) -> "AgentConfigFactory": + def create_config(self) -> "MemAgentConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py index e4746cad6..daa5f075b 100644 --- a/src/memos/mem_agent/base.py +++ b/src/memos/mem_agent/base.py @@ -10,6 +10,7 @@ class BaseMemAgent(ABC): def __init__(self, config: BaseAgentConfig): """Initialize the BaseMemAgent with the given configuration.""" + self.config = config @abstractmethod def run(self, input: str) -> str: diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 70212921a..6c6fc32a9 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -7,7 +7,7 @@ import json -from typing import Optional, Any +from typing import TYPE_CHECKING, Any from memos.configs.mem_agent import DeepSearchAgentConfig from memos.llms.base import BaseLLM @@ -19,7 +19,10 @@ QUERY_REWRITE_PROMPT, REFLECTION_PROMPT, ) -from memos.types import MessageList + + +if TYPE_CHECKING: + from memos.types import MessageList logger = get_logger(__name__) @@ -120,7 +123,7 @@ class DeepSearchMemAgent(BaseMemAgent): def __init__( self, llm: BaseLLM, - memory_retriever: Optional[Any] = None, + memory_retriever: Any | None = None, config: DeepSearchAgentConfig | None = None, ): """ @@ -220,7 +223,11 @@ def run(self, query: str, **kwargs) -> str: return final_answer def _perform_memory_search( - self, query: str, keywords: Optional[list[str]] = None, user_id: Optional[str] = None, top_k: int = 10 + self, + query: str, + keywords: list[str] | None = None, + user_id: str | None = None, + top_k: int = 10, ) -> list[TextualMemoryItem]: """ Perform memory search using the configured retriever. diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index 415d4916e..09537bd8a 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, ClassVar +from typing import Any, ClassVar from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent @@ -14,7 +14,7 @@ class MemAgentFactory: @classmethod def from_config( - cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Optional[Any] = None + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Any | None = None ) -> BaseMemAgent: """ Create a MemAgent instance from configuration. From f332ef22c3a2a29e2c1820de3916e45d70e4d33a Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 22 Nov 2025 18:13:04 +0800 Subject: [PATCH 07/17] feat: update deepsearch agent --- examples/mem_agent/deepsearch_example.py | 358 +++++++---------------- src/memos/mem_agent/deepsearch_agent.py | 200 ++++++++----- src/memos/templates/mem_agent_prompts.py | 5 +- 3 files changed, 227 insertions(+), 336 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index d0842b6ec..6a9405456 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -1,35 +1,36 @@ """ -DeepSearch Agent Usage Examples +DeepSearch Agent Usage Examples - Simplified Version -This example demonstrates two ways to initialize DeepSearchMemAgent: -1. Using Factory pattern (recommended) -2. Direct initialization - -DeepSearchMemAgent implements iterative deep search, providing comprehensive answers through: -- Query rewriting: Optimize queries based on conversation history -- Iterative retrieval: Collect information through multiple search rounds -- Reflective analysis: Determine if information is sufficient -- Comprehensive response: Generate complete final answers +This example demonstrates simplified initialization of DeepSearchMemAgent without +external config builders, using APIConfig methods directly. """ import os +from typing import Any + +from memos.api.config import APIConfig from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.graph_db import GraphConfigFactory -from memos.configs.llms import LLMConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_agent import MemAgentConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory -from memos.configs.textual_memory import TreeTextMemoryConfig +from memos.configs.reranker import RerankerConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.mem_agent.factory import MemAgentFactory -from memos.mem_cube.naive_cube import NaiveMemCube +from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.factory import MemReaderFactory -from memos.memories.memory_manager.memory_manager import MemoryManager -from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory logger = get_logger(__name__) @@ -37,114 +38,105 @@ def build_minimal_components(): """ - Build the minimal component set required for DeepSearchMemAgent. - - Only need to initialize: - 1. LLM - Used for query rewriting, reflection, and final answer generation - 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval + Build minimal components for DeepSearchMemAgent with simplified configuration. - Returns: - dict: Dictionary containing llm and naive_mem_cube + This function creates all necessary components using APIConfig methods, + similar to config_builders.py but inline for easier customization. """ - logger.info("Starting to build minimal component set...") + logger.info("Initializing simplified MemOS components...") - # 1. Initialize LLM - llm_config = LLMConfigFactory( - backend="openai", - config={ - "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), - "api_key": os.getenv("OPENAI_API_KEY"), - "base_url": os.getenv("OPENAI_BASE_URL"), - "temperature": 0.7, - }, + # Build component configurations using APIConfig methods (like config_builders.py) + + # Graph DB configuration - using APIConfig.get_nebular_config() + graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() + graph_db_backend_map = { + "polardb": APIConfig.get_polardb_config(), + } + graph_db_config = GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } ) - llm = LLMFactory.from_config(llm_config) - logger.info("LLM initialization completed") - # 2. Initialize Embedder (required by text_mem) - embedder_config = EmbedderConfigFactory( - backend="universal_api", - config={ - "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), - "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), - "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), - }, + # LLM configuration - using APIConfig.get_openai_config() + llm_config = LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } ) - embedder = EmbedderFactory.from_config(embedder_config) - logger.info("Embedder initialization completed") - # 3. Initialize GraphDB (required by text_mem) - graph_db_config = GraphConfigFactory( - backend="polardb", - config={ - "host": os.getenv("POLAR_DB_HOST", "localhost"), - "port": int(os.getenv("POLAR_DB_PORT", "5432")), - "user": os.getenv("POLAR_DB_USER", "root"), - "password": os.getenv("POLAR_DB_PASSWORD", "123456"), - "db_name": os.getenv("POLAR_DB_DB_NAME", "shared_memos_db"), - "user_name": "memos_default", - "use_multi_db": False, - "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - }, + # Embedder configuration - using APIConfig.get_embedder_config() + embedder_config = EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + # Memory reader configuration - using APIConfig.get_product_default_config() + mem_reader_config = MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] ) - graph_db = GraphStoreFactory.from_config(graph_db_config) - logger.info("✓ GraphDB init Done") - # 4. Initialize MemReader (required by text_mem) - mem_reader_config = MemReaderConfigFactory( - backend="simple_struct", - config={ - "llm": { - "backend": "openai", - "config": { - "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), - "api_key": os.getenv("OPENAI_API_KEY"), - "base_url": os.getenv("OPENAI_BASE_URL"), - }, - } - }, + # Reranker configuration - using APIConfig.get_reranker_config() + reranker_config = RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + # Internet retriever configuration - using APIConfig.get_internet_config() + internet_retriever_config = InternetRetrieverConfigFactory.model_validate( + APIConfig.get_internet_config() ) - mem_reader = MemReaderFactory.from_config(mem_reader_config) - logger.info("✓ MemReader init Done") - # 5. Initialize MemoryManager - memory_manager = MemoryManager( - graph_db=graph_db, - embedder=embedder, - llm=llm, - memory_size=1000, - is_reorganize=False, + logger.debug("Component configurations built successfully") + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder ) - logger.info("✓ MemoryManager init Done") - # 6. Initialize TextMemory - text_mem_config = TreeTextMemoryConfig( - reorganize=False, - max_depth=3, + logger.debug("Core components instantiated") + + # Get default cube configuration like component_init.py + default_cube_config = APIConfig.get_default_cube_config() + + # Get default memory size from cube config (like component_init.py) + def get_memory_size_from_config(cube_config): + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + memory_size = get_memory_size_from_config(default_cube_config) + is_reorganize = getattr(default_cube_config.text_mem.config, "reorganize", False) + + # Initialize memory manager with config from APIConfig + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=memory_size, + is_reorganize=is_reorganize, ) + text_memory_config = default_cube_config.text_mem.config text_mem = SimpleTreeTextMemory( llm=llm, embedder=embedder, mem_reader=mem_reader, graph_db=graph_db, - reranker=None, + reranker=reranker, memory_manager=memory_manager, - config=text_mem_config, - internet_retriever=None, + config=text_memory_config, + internet_retriever=internet_retriever, ) - logger.info("✓ TextMemory initialization completed") - # 7. Create NaiveMemCube naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=None, + pref_mem=None, # Simplified: no preference memory act_mem=None, para_mem=None, ) - logger.info("✓ NaiveMemCube creation completed") - - logger.info("All components initialized!") return { "llm": llm, @@ -155,30 +147,18 @@ def build_minimal_components(): } -def example_1_factory_initialization(): - """ - Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) - - Advantages: - - Separation of configuration and code - - Easy to manage and modify - - Support loading from configuration files - """ - logger.info("\n" + "=" * 60) - logger.info("Example 1: Initialize using Factory pattern") - logger.info("=" * 60 + "\n") - - # Build necessary components +def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]: + # Build necessary components with simplified setup components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - # Create configuration Factory + # Create configuration Factory with simplified config agent_config_factory = MemAgentConfigFactory( backend="deep_search", config={ - "agent_name": "MyDeepSearchAgent", - "description": "Intelligent agent for deep search", + "agent_name": "SimplifiedDeepSearchAgent", + "description": "Simplified intelligent agent for deep search", "max_iterations": 3, # Maximum number of iterations "timeout": 60, # Timeout in seconds }, @@ -198,151 +178,13 @@ def example_1_factory_initialization(): return deep_search_agent, components -def example_3_usage(deep_search_agent, components): - """ - Example 3: Using DeepSearchMemAgent for search - - Demonstrates how to: - 1. Add memories to the system - 2. Use Agent for deep search - 3. Get comprehensive answers - """ - logger.info("\n" + "=" * 60) - logger.info("Example 3: Using DeepSearchMemAgent") - logger.info("=" * 60 + "\n") - - naive_mem_cube = components["naive_mem_cube"] - text_mem = naive_mem_cube.text_mem - - # Simulate adding some memories - logger.info("1. Adding test memories...") - test_memories = [ - { - "user_name": "test_user", - "messages": [ - "Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence." - ], - "source": "manual", - }, - { - "user_name": "test_user", - "messages": [ - "Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance." - ], - "source": "manual", - }, - { - "user_name": "test_user", - "messages": [ - "Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks." - ], - "source": "manual", - }, - ] - - for memory in test_memories: - try: - text_mem.add( - user_name=memory["user_name"], - messages=memory["messages"], - source=memory.get("source", "manual"), - ) - logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") - except Exception as e: - logger.warning(f" ✗ Failed to add memory: {e}") - - # Use Agent for search - logger.info("\n2. Executing deep search...") - query = "What is the relationship between artificial intelligence and machine learning?" - logger.info(f" Query: {query}") - - try: - response = deep_search_agent.run( - query=query, - user_id="test_user", - history=["Hello", "I want to learn about artificial intelligence"], - ) - - logger.info("\n3. Search results:") - logger.info("-" * 60) - logger.info(response) - logger.info("-" * 60) - - except Exception as e: - logger.error(f"Error during search: {e}") - import traceback - - traceback.print_exc() - - -def example_4_minimal_initialization(): - """ - Example 4: Minimal initialization (using default configuration) - - Use cases: - - Quick testing - - Prototype development - - No need for custom configuration - """ - logger.info("\n" + "=" * 60) - logger.info("Example 4: Minimal initialization (default configuration)") - logger.info("=" * 60 + "\n") - - # Build necessary components - components = build_minimal_components() - llm = components["llm"] - naive_mem_cube = components["naive_mem_cube"] - - # Direct initialization using default configuration - deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - # config parameter omitted, will use default configuration - ) - - logger.info("✓ DeepSearchMemAgent created successfully (using default configuration)") - logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") - logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - - return deep_search_agent, components - - def main(): - """Main function: Run all examples""" - logger.info("DeepSearch Agent Usage Examples") - logger.info("=" * 60) - - # Check environment variables - required_env_vars = ["OPENAI_API_KEY"] - missing_vars = [var for var in required_env_vars if not os.getenv(var)] - - if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") - logger.error("Please set the following environment variables:") - logger.error(" - OPENAI_API_KEY: OpenAI API key") - logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") - logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") - return - - try: - # Run Example 1: Factory pattern - agent_factory, components_factory = example_1_factory_initialization() - - # Run Example 4: Minimal initialization - agent_minimal, components_minimal = example_4_minimal_initialization() - - # Run Example 3: Actual usage (using agent created with factory method) - example_3_usage(agent_factory, components_factory) - - logger.info("\n" + "=" * 60) - logger.info("All examples completed!") - logger.info("=" * 60) - - except Exception as e: - logger.error(f"Error running examples: {e}") - import traceback - - traceback.print_exc() + agent_factory, components_factory = factory_initialization() + results = agent_factory.run( + "Caroline met up with friends, family, and mentors in early July 2023.", + user_id="locomo_exp_user_0_speaker_b_ct-1118", + ) + print(results) if __name__ == "__main__": diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 6c6fc32a9..5a070c6ad 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -6,6 +6,7 @@ """ import json +import re from typing import TYPE_CHECKING, Any @@ -14,6 +15,7 @@ from memos.log import get_logger from memos.mem_agent.base import BaseMemAgent from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory from memos.templates.mem_agent_prompts import ( FINAL_GENERATION_PROMPT, QUERY_REWRITE_PROMPT, @@ -25,91 +27,96 @@ from memos.types import MessageList +class JSONResponseParser: + """Elegant JSON response parser for LLM outputs""" + + @staticmethod + def parse(response: str) -> dict[str, Any]: + """Parse JSON response from LLM output with fallback strategies""" + # Clean response text by removing code block markers + cleaned = re.sub(r"^```(?:json)?\s*\n?|```\s*$", "", response.strip(), flags=re.IGNORECASE) + + # Try parsing with multiple strategies + for text in [cleaned, re.search(r"\{.*\}", cleaned, re.DOTALL)]: + if not text: + continue + try: + return json.loads(text if isinstance(text, str) else text.group()) + except json.JSONDecodeError: + continue + + raise ValueError(f"Cannot parse JSON response: {response[:100]}...") + + logger = get_logger(__name__) class QueryRewriter(BaseMemAgent): - """ - Specialized agent for rewriting queries based on conversation history. - Corresponds to the "LLM subAgent (Rewrite...)" in the architecture diagram. - """ + """Specialized agent for rewriting queries based on conversation history""" def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): self.llm = llm self.name = name def run(self, query: str, history: list[str] | None = None) -> str: - """ - Rewrite the query to be standalone and more searchable. - - Args: - query: Original user query - history: List of previous conversation messages - - Returns: - Rewritten query string - """ - if history is None: - history = [] - - history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages - - prompt = QUERY_REWRITE_PROMPT.format( - history=history_str if history_str else "No previous conversation", query=query - ) - - messages: MessageList = [{"role": "user", "content": prompt}] + """Rewrite query to be standalone and more searchable""" + history = history or [] + history_context = self._format_history(history) + prompt = QUERY_REWRITE_PROMPT.format(history=history_context, query=query) + messages = [{"role": "user", "content": prompt}] try: response = self.llm.generate(messages) logger.info(f"[{self.name}] Rewritten query: {response.strip()}") return response.strip() except Exception as e: - logger.error(f"[{self.name}] Error rewriting query: {e}") - return query # Fallback to original query + logger.error(f"[{self.name}] Query rewrite failed: {e}") + return query + + def _format_history(self, history: list[str]) -> str: + """Format conversation history for prompt context""" + if not history: + return "No previous conversation" + return "\n".join(f"- {msg}" for msg in history[-5:]) class ReflectionAgent: - """ - Specialized agent for analyzing information sufficiency. - Corresponds to the decision diamond in the architecture diagram. - """ + """Specialized agent for analyzing information sufficiency""" def __init__(self, llm: BaseLLM, name: str = "Reflector"): self.llm = llm self.name = name def run(self, query: str, context: list[str]) -> dict[str, Any]: - """ - Analyze whether retrieved context is sufficient to answer the query. + """Analyze whether retrieved context is sufficient to answer the query""" + context_summary = self._format_context(context) + prompt = REFLECTION_PROMPT.format(query=query, context=context_summary) - Args: - query: User query - context: List of retrieved context strings + try: + response = self.llm.generate([{"role": "user", "content": prompt}]) + logger.info(f"[{self.name}] Reflection response: {response}") - Returns: - Dictionary with status, reasoning, and missing entities - """ - context_str = "\n".join( - [f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]] - ) # Limit context size + result = JSONResponseParser.parse(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result}") + return result - prompt = REFLECTION_PROMPT.format(query=query, context=context_str) - messages: MessageList = [{"role": "user", "content": prompt}] + except Exception as e: + logger.error(f"[{self.name}] Reflection analysis failed: {e}") + return self._fallback_response() - try: - response = self.llm.generate(messages) - result = json.loads(response.strip()) - logger.info(f"[{self.name}] Reflection result: {result.get('status', 'unknown')}") - return result - except (json.JSONDecodeError, Exception) as e: - logger.error(f"[{self.name}] Error in reflection analysis: {e}") - # Fallback response - return { - "status": "sufficient", - "reasoning": "Unable to analyze, proceeding with available information", - "missing_entities": [], - } + def _format_context(self, context: list[str]) -> str: + """Format context strings for analysis with length limits""" + return "\n".join( + f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10] + ) + + def _fallback_response(self) -> dict[str, Any]: + """Return safe fallback when reflection fails""" + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [], + } class DeepSearchMemAgent(BaseMemAgent): @@ -123,7 +130,7 @@ class DeepSearchMemAgent(BaseMemAgent): def __init__( self, llm: BaseLLM, - memory_retriever: Any | None = None, + memory_retriever: TreeTextMemory | None = None, config: DeepSearchAgentConfig | None = None, ): """ @@ -142,12 +149,7 @@ def __init__( self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") self.memory_retriever = memory_retriever - def _set_memory_retriever(self, retriever) -> None: - """Set the memory retrieval interface.""" - self.memory_retriever = retriever - logger.info("Memory retriever interface set") - - def run(self, query: str, **kwargs) -> str: + def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: """ Main execution method implementing the deep search pipeline. @@ -162,6 +164,7 @@ def run(self, query: str, **kwargs) -> str: history = kwargs.get("history", []) user_id = kwargs.get("user_id") + generated_answer = kwargs.get("generated_answer") # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) @@ -175,7 +178,7 @@ def run(self, query: str, **kwargs) -> str: logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") search_results = self._perform_memory_search( - current_query, keywords=search_keywords, user_id=user_id + current_query, keywords=search_keywords, user_id=user_id, history=history ) if search_results: @@ -198,7 +201,8 @@ def run(self, query: str, **kwargs) -> str: elif status == "missing_info": missing_entities = reflection_result.get("missing_entities", []) logger.info(f"Missing information: {missing_entities}") - if missing_entities: + current_query = reflection_result.get("new_search_query") + if not current_query: refined_query = self._refine_query_for_missing_info( current_query, missing_entities ) @@ -211,22 +215,63 @@ def run(self, query: str, **kwargs) -> str: else: break - # Step 3: Generate final answer - final_answer = self._generate_final_answer( - original_query=query, - search_results=accumulated_memories, - context=accumulated_context, - missing_info="", - ) + if not generated_answer: + return self._remove_duplicate_memories(accumulated_memories) + else: + return self._generate_final_answer( + query, accumulated_memories, accumulated_context, "", history + ) + + def _remove_duplicate_memories( + self, memories: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Remove duplicate memories based on memory content. + + Args: + memories: List of TextualMemoryItem objects to deduplicate - logger.info("Deep search pipeline completed") - return final_answer + Returns: + List of unique TextualMemoryItem objects (first occurrence kept) + """ + seen = set() + return [ + memory + for memory in memories + if (content := getattr(memory, "memory", "").strip()) + and content not in seen + and not seen.add(content) + ] + + def _generate_final_answer( + self, + original_query: str, + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", + history: list[str] | None = None, + sources: list[str] | None = None, + ) -> str: + """ + Generate the final answer. + """ + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified", + ) + messages: MessageList = [{"role": "user", "content": prompt}] + response = self.llm.generate(messages) + return response.strip() def _perform_memory_search( self, query: str, keywords: list[str] | None = None, user_id: str | None = None, + history: list[str] | None = None, top_k: int = 10, ) -> list[TextualMemoryItem]: """ @@ -254,7 +299,11 @@ def _perform_memory_search( # Assuming the retriever has a search method similar to TreeTextMemory results = self.memory_retriever.search( - query=search_query, top_k=top_k, mode="fast", user_name=user_id + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id, + info={"history": history}, ) return results if isinstance(results, list) else [] @@ -316,7 +365,6 @@ def _generate_final_answer( context=context_str if context_str else "No specific context retrieved", missing_info=missing_info if missing_info else "None identified", ) - messages: MessageList = [{"role": "user", "content": prompt}] try: diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 17e3c0ad9..477cd2409 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -25,14 +25,15 @@ {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", - "missing_entities": ["entity1", "entity2"] // Only if status is "missing_info", - + "missing_entities": ["entity1", "entity2"], + "new_search_query": "new search query", }} Status definitions: - "sufficient": Context fully answers the query - "missing_info": Key information is missing (e.g., specific dates, locations, details) - "needs_raw": Content is relevant but too summarized/vague, need original sources +- "new_search_query": New search query to retrieve more information Response:""" From c21fc586a723477f3f2a2465e895228b4c1b8767 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 22 Nov 2025 18:29:39 +0800 Subject: [PATCH 08/17] feat: update test code --- tests/configs/test_embedder.py | 4 ++-- tests/embedders/test_universal_api.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/configs/test_embedder.py b/tests/configs/test_embedder.py index 8201f9bd8..10572f33e 100644 --- a/tests/configs/test_embedder.py +++ b/tests/configs/test_embedder.py @@ -17,7 +17,7 @@ def test_base_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims"], + optional_fields=["embedding_dims", "headers_extra"], ) check_config_instantiation_valid( @@ -36,7 +36,7 @@ def test_ollama_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims", "api_base"], + optional_fields=["embedding_dims", "headers_extra", "api_base"], ) check_config_instantiation_valid( diff --git a/tests/embedders/test_universal_api.py b/tests/embedders/test_universal_api.py index e4ebb7019..1419b2849 100644 --- a/tests/embedders/test_universal_api.py +++ b/tests/embedders/test_universal_api.py @@ -30,6 +30,7 @@ def test_embed_single_text(self, mock_openai_client): mock_openai_client.assert_called_once_with( api_key="fake-api-key", base_url="https://api.openai.com/v1", + default_headers=None, ) # Assert embeddings.create called with correct params From 0f62af8173c8c1614e582c77b7ce572f44415e75 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 24 Nov 2025 14:30:28 +0800 Subject: [PATCH 09/17] fix: remove dup config --- src/memos/api/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d738c878a..c62cd3b08 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -409,7 +409,6 @@ def get_embedder_config() -> dict[str, Any]: "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), - "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), }, } else: # ollama From dac3394cf5f79e4b08481b661645786a6d73e716 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 10:55:16 +0800 Subject: [PATCH 10/17] feat: dock search pipeline --- src/memos/api/handlers/base_handler.py | 5 +++++ src/memos/api/handlers/component_init.py | 6 ++++++ src/memos/api/handlers/search_handler.py | 10 +++++++--- src/memos/mem_agent/deepsearch_agent.py | 7 +++---- src/memos/multi_mem_cube/single_cube.py | 8 ++++++-- src/memos/templates/mem_agent_prompts.py | 4 +++- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index a686ac8f9..7a47f05e3 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -161,6 +161,11 @@ def mos_server(self): """Get MOS server instance.""" return self.deps.mos_server + @property + def deepsearch_agent(self): + """Get deepsearch agent instance.""" + return self.deps.deepsearch_agent + def _validate_dependencies(self, *required_deps: str) -> None: """ Validate that required dependencies are available. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 3ef1d529d..7b34fcfae 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -307,6 +308,10 @@ def init_server() -> dict[str, Any]: online_bot = get_online_bot_function() if dingding_enabled else None logger.info("DingDing bot is enabled") + deepsearch_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=tree_mem, + ) # Return all components as a dictionary for easy access and extension return { "graph_db": graph_db, @@ -330,4 +335,5 @@ def init_server() -> dict[str, Any]: "text_mem": text_mem, "pref_mem": pref_mem, "online_bot": online_bot, + "deepsearch_agent": deepsearch_agent, } diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 8a2c21aad..b454568f8 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -31,7 +31,9 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher") + self._validate_dependencies( + "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" + ) def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -52,10 +54,10 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse results = cube_view.search_memories(search_req) - self.logger.info(f"[AddHandler] Final add results count={len(results)}") + self.logger.info(f"[SearchHandler] Final search results count={len(results)}") return SearchResponse( - message="Memory searched successfully", + message="Search completed successfully", data=results, ) @@ -87,6 +89,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) else: single_views = [ @@ -97,6 +100,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) for cube_id in cube_ids ] diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5a070c6ad..16b600980 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from memos.types import MessageList +logger = get_logger(__name__) + class JSONResponseParser: """Elegant JSON response parser for LLM outputs""" @@ -48,9 +50,6 @@ def parse(response: str) -> dict[str, Any]: raise ValueError(f"Cannot parse JSON response: {response[:100]}...") -logger = get_logger(__name__) - - class QueryRewriter(BaseMemAgent): """Specialized agent for rewriting queries based on conversation history""" @@ -141,7 +140,7 @@ def __init__( memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) config: Configuration for deep search behavior """ - self.config = config or DeepSearchAgentConfig() + self.config = config or DeepSearchAgentConfig(agent_name="DeepSearchMemAgent") self.max_iterations = self.config.max_iterations self.timeout = self.config.timeout self.llm: BaseLLM = llm diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2055615d2..d2fde36a3 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,6 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any + deepsearch_agent: Any def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ @@ -247,8 +248,11 @@ def _fast_search( def _deep_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: - logger.error("waiting to be implemented") - return [] + deepsearch_results = self.deepsearch_agent.run( + search_req.query, user_id=user_context.mem_cube_id + ) + formatted_memories = [format_memory_item(data) for data in deepsearch_results] + return formatted_memories def _fine_search( self, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 477cd2409..eb624ef89 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -22,12 +22,14 @@ {context} Analyze the context and determine the next step. Return your response in JSON format with the following structure: -{{ + ```json + {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", "missing_entities": ["entity1", "entity2"], "new_search_query": "new search query", }} +``` Status definitions: - "sufficient": Context fully answers the query From 9489d54b0a71ba8db05b7d13158fcaad6135d50f Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 11:17:52 +0800 Subject: [PATCH 11/17] fix: code test --- tests/api/test_server_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 853a271f6..2aa96257b 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -48,6 +48,7 @@ def mock_init_server(): "pref_mem": None, "online_bot": None, "chat_llms": Mock(), + "deepsearch_agent": Mock(), } with patch("memos.api.handlers.init_server", return_value=mock_components): From e43e5db62bf9fab56ed432a6657412a868e58763 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 14:59:14 +0800 Subject: [PATCH 12/17] feat: add test scripts --- src/memos/mem_agent/deepsearch_agent.py | 4 +- tests/mem_agent/test_deepsearch_agent.py | 233 +++++++++++++++++++++++ 2 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 tests/mem_agent/test_deepsearch_agent.py diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 16b600980..5e51aec44 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -218,7 +218,7 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: return self._remove_duplicate_memories(accumulated_memories) else: return self._generate_final_answer( - query, accumulated_memories, accumulated_context, "", history + query, accumulated_memories, accumulated_context, history ) def _remove_duplicate_memories( @@ -247,9 +247,9 @@ def _generate_final_answer( original_query: str, search_results: list[TextualMemoryItem], context: list[str], - missing_info: str = "", history: list[str] | None = None, sources: list[str] | None = None, + missing_info: str | None = None, ) -> str: """ Generate the final answer. diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py new file mode 100644 index 000000000..707d80283 --- /dev/null +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -0,0 +1,233 @@ +"""Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" + +import json +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import ( + DeepSearchMemAgent, + JSONResponseParser, +) +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata + + +class TestJSONResponseParser: + """Test JSONResponseParser class.""" + + def test_parse_clean_json(self): + """Test parsing clean JSON response.""" + response = '{"status": "sufficient", "reasoning": "test"}' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_json_with_code_blocks(self): + """Test parsing JSON wrapped in code blocks.""" + response = '```json\n{"status": "sufficient", "reasoning": "test"}\n```' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_invalid_json_raises_error(self): + """Test that invalid JSON raises ValueError.""" + with pytest.raises(ValueError, match="Cannot parse JSON response"): + JSONResponseParser.parse("This is not JSON at all") + + +class TestDeepSearchMemAgent: + """Test DeepSearchMemAgent core functionality.""" + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + mock = MagicMock() + mock.generate.return_value = "Generated answer" + return mock + + @pytest.fixture + def mock_memory_retriever(self): + """Create a mock memory retriever.""" + mock = MagicMock() + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a programming language", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python was created by Guido van Rossum", + metadata=TextualMemoryMetadata(type="fact"), + ), + ] + mock.search.return_value = memory_items + return mock + + @pytest.fixture + def config(self): + """Create DeepSearchAgentConfig.""" + return DeepSearchAgentConfig(agent_name="TestDeepSearch", max_iterations=3, timeout=30) + + @pytest.fixture + def agent(self, mock_llm, mock_memory_retriever, config): + """Create DeepSearchMemAgent instance.""" + agent = DeepSearchMemAgent( + llm=mock_llm, memory_retriever=mock_memory_retriever, config=config + ) + # Mock the sub-agents to avoid complex interactions + agent.query_rewriter.run = MagicMock(return_value="Rewritten query") + agent.reflector.run = MagicMock( + return_value={ + "status": "sufficient", + "reasoning": "Enough info", + "missing_entities": [], + } + ) + return agent + + def test_init_with_config(self, mock_llm, mock_memory_retriever, config): + """Test DeepSearchMemAgent initialization with config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + assert agent.llm == mock_llm + assert agent.memory_retriever == mock_memory_retriever + assert agent.config == config + assert agent.max_iterations == 3 + assert agent.timeout == 30 + + def test_init_without_config(self, mock_llm, mock_memory_retriever): + """Test DeepSearchMemAgent initialization without config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever) + assert isinstance(agent.config, DeepSearchAgentConfig) + assert agent.config.agent_name == "DeepSearchMemAgent" + + def test_run_no_llm_raises_error(self, config): + """Test that running without LLM raises RuntimeError.""" + agent = DeepSearchMemAgent(llm=None, config=config) + with pytest.raises(RuntimeError, match="LLM not initialized"): + agent.run("test query") + + def test_run_returns_memories_when_no_generated_answer(self, agent, mock_memory_retriever): + """Test run returns memories when generated_answer is not requested.""" + result = agent.run("What is Python?", generated_answer=False) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, TextualMemoryItem) for item in result) + agent.query_rewriter.run.assert_called_once() + + def test_run_returns_answer_when_generated_answer(self, agent, mock_llm): + """Test run returns generated answer when requested.""" + result = agent.run("What is Python?", generated_answer=True) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_run_with_user_id(self, agent, mock_memory_retriever): + """Test run with user_id.""" + agent.run("What is Python?", user_id="user123", generated_answer=False) + + # Check that user_id was passed to search + call_kwargs = mock_memory_retriever.search.call_args[1] + assert call_kwargs.get("user_name") == "user123" + + def test_run_no_search_results(self, agent, mock_memory_retriever): + """Test behavior when search returns no results.""" + mock_memory_retriever.search.return_value = [] + + result = agent.run("What is Python?", generated_answer=False) + + assert result == [] + + def test_remove_duplicate_memories(self, agent): + """Test removing duplicate memories.""" + mem_id1 = str(uuid.uuid4()) + mem_id2 = str(uuid.uuid4()) + mem_id3 = str(uuid.uuid4()) + + memories = [ + TextualMemoryItem( + id=mem_id1, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + TextualMemoryItem( + id=mem_id2, + memory="Different content", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=mem_id3, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + ] + + result = agent._remove_duplicate_memories(memories) + + assert len(result) == 2 + assert result[0].id == mem_id1 + assert result[1].id == mem_id2 + + def test_generate_final_answer(self, agent, mock_llm): + """Test final answer generation.""" + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a language", + metadata=TextualMemoryMetadata(type="fact"), + ) + ] + context = ["Python is a programming language"] + + result = agent._generate_final_answer("What is Python?", memory_items, context) + + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_generate_final_answer_with_missing_info(self, agent, mock_llm): + """Test final answer generation with missing info.""" + result = agent._generate_final_answer( + "What is Python?", [], [], missing_info="Version details not found" + ) + + assert result == "Generated answer" + call_args = mock_llm.generate.call_args[0][0] + assert "Version details not found" in call_args[0]["content"] + + def test_generate_final_answer_llm_error(self, agent, mock_llm): + """Test final answer generation handles LLM errors.""" + mock_llm.generate.side_effect = Exception("LLM error") + + result = agent._generate_final_answer("What is Python?", [], []) + + assert "error" in result.lower() + assert "What is Python?" in result + + def test_perform_memory_search_no_retriever(self, mock_llm, config): + """Test memory search when retriever is not configured.""" + agent = DeepSearchMemAgent(mock_llm, memory_retriever=None, config=config) + result = agent._perform_memory_search("test query") + + assert result == [] + + def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config): + """Test full pipeline integration.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + + with patch.object(agent.query_rewriter, "run", return_value="Rewritten query"): + with patch.object( + agent.reflector, + "run", + return_value={ + "status": "sufficient", + "reasoning": "Info is sufficient", + "missing_entities": [], + }, + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called() From ecd45084935131bb7fb02ac09542c91ebae9e773 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 14:59:52 +0800 Subject: [PATCH 13/17] feat: add test --- tests/mem_agent/test_deepsearch_agent.py | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py index 707d80283..a80dd10ea 100644 --- a/tests/mem_agent/test_deepsearch_agent.py +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -1,6 +1,5 @@ """Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" -import json import uuid from unittest.mock import MagicMock, patch @@ -213,8 +212,9 @@ def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config """Test full pipeline integration.""" agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) - with patch.object(agent.query_rewriter, "run", return_value="Rewritten query"): - with patch.object( + with ( + patch.object(agent.query_rewriter, "run", return_value="Rewritten query"), + patch.object( agent.reflector, "run", return_value={ @@ -222,12 +222,13 @@ def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config "reasoning": "Info is sufficient", "missing_entities": [], }, - ): - result = agent.run( - "What is Python?", user_id="user123", history=[], generated_answer=True - ) - - assert isinstance(result, str) - assert result == "Generated answer" - mock_memory_retriever.search.assert_called() - mock_llm.generate.assert_called() + ), + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called() From 6e21032fc87464f4569f99c26e4309247b533d74 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 19:06:23 +0800 Subject: [PATCH 14/17] feat: update need_raw process --- src/memos/mem_agent/deepsearch_agent.py | 21 +++++++++++++++++++-- src/memos/templates/mem_agent_prompts.py | 8 +++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5e51aec44..051ac03d3 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -183,8 +183,6 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: if search_results: context_batch = [self._extract_context_from_memory(mem) for mem in search_results] accumulated_context.extend(context_batch) - accumulated_memories.extend(search_results) - reflection_result = self.reflector.run(current_query, context_batch) status = reflection_result.get("status", "sufficient") reasoning = reflection_result.get("reasoning", "") @@ -193,11 +191,14 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: if status == "sufficient": logger.info("Sufficient information collected") + accumulated_memories.extend(search_results) break elif status == "needs_raw": logger.info("Need original sources, retrieving raw content") + accumulated_memories.extend(self._set_source_from_memory(search_results)) break elif status == "missing_info": + accumulated_memories.extend(search_results) missing_entities = reflection_result.get("missing_entities", []) logger.info(f"Missing information: {missing_entities}") current_query = reflection_result.get("new_search_query") @@ -331,6 +332,22 @@ def _refine_query_for_missing_info(self, query: str, missing_entities: list[str] return refined_query + def _set_source_from_memory( + self, memory_items: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """set source from memory item""" + for memory_item in memory_items: + if not hasattr(memory_item.metadata, "sources"): + continue + chat_sources = [ + f"{source.chat_time} {source.role}: {source.content}" + for source in memory_item.metadata.sources + if hasattr(source, "type") and source.type == "chat" + ] + if chat_sources: + memory_item.memory = "\n".join(chat_sources) + "\n" + return memory_items + def _generate_final_answer( self, original_query: str, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index eb624ef89..d7163e4a8 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -35,7 +35,13 @@ - "sufficient": Context fully answers the query - "missing_info": Key information is missing (e.g., specific dates, locations, details) - "needs_raw": Content is relevant but too summarized/vague, need original sources -- "new_search_query": New search query to retrieve more information + +IMPORTANT for "new_search_query": +- MUST preserve ALL specific entities from the original query (names, dates, times, locations, etc.) +- DO NOT replace specific information with generic terms like "user", "person", "they", etc. +- Keep the exact same subjects, time references, and key details as in the original query +- Only modify the query to focus on the missing information while maintaining all original specifics +- Example: If original query mentions "May 2024", keep "May 2024" in new query, don't change to "that month" Response:""" From 592f637bc3090dd682013f3b74f290de72450050 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 20:02:22 +0800 Subject: [PATCH 15/17] fix: add initter --- src/memos/multi_mem_cube/single_cube.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d2fde36a3..dbc527bb7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,7 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any - deepsearch_agent: Any + deepsearch_agent: Any | None = None def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ From ad997459ffa8d8a1fc30442b8d47f50d338da9c7 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 27 Nov 2025 16:24:58 +0800 Subject: [PATCH 16/17] fix: change agent search func name --- src/memos/multi_mem_cube/single_cube.py | 4 +++- src/memos/types/general_types.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8f4a25a0b..8e37cb92d 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -203,7 +203,7 @@ def _deep_search( formatted_memories = [format_memory_item(data) for data in enhanced_memories] return formatted_memories - def _deep_search( + def _agentic_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: deepsearch_results = self.deepsearch_agent.run( @@ -229,6 +229,8 @@ def _fine_search( """ if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) + elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: + return self._agentic_search(search_req=search_req, user_context=user_context) target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 2b7206c74..f796e682a 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -102,6 +102,7 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" DEEP_SEARCH = "deep_search" + AGENTIC_SEARCH = "agentic_search" # algorithm strategies From 94dba83f4655e45ca671e57eb19731e85f4fe2c6 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 28 Nov 2025 11:10:02 +0800 Subject: [PATCH 17/17] feat: update logs and defined --- src/memos/multi_mem_cube/single_cube.py | 12 ++++++++---- src/memos/reranker/factory.py | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8e37cb92d..e60e8e71e 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -35,14 +35,17 @@ if TYPE_CHECKING: from memos.api.product_models import APIADDRequest, APISearchRequest + from memos.mem_cube.navie import NaiveMemCube + from memos.mem_reader.simple_struct import SimpleStructMemReader + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler @dataclass class SingleCubeView(MemCubeView): cube_id: str - naive_mem_cube: Any - mem_reader: Any - mem_scheduler: Any + naive_mem_cube: NaiveMemCube + mem_reader: SimpleStructMemReader + mem_scheduler: OptimizedScheduler logger: Any searcher: Any deepsearch_agent: Any | None = None @@ -155,7 +158,7 @@ def _search_text( Args: search_req: Search request user_context: User context - search_mode: Search mode (FAST, FINE, or MIXTURE) + search_mode: Search mode (fast, fine, or mixture) Returns: List of formatted memory items @@ -227,6 +230,7 @@ def _fine_search( Returns: List of enhanced search results """ + logger.info(f"Fine strategy: {FINE_STRATEGY}") if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index d2c50ba5e..1440704a6 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -2,6 +2,7 @@ from __future__ import annotations import json + from typing import TYPE_CHECKING, Any # Import singleton decorator