From 587e5a7204d415882fbd5c98d41dae45ad8666d8 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 19 Nov 2025 21:02:40 +0800 Subject: [PATCH 01/21] feat: update memos headers --- src/memos/api/config.py | 3 ++- src/memos/configs/embedder.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a276fa63d..3236c6705 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -379,7 +379,7 @@ def get_reranker_config() -> dict[str, Any]: "url": os.getenv("MOS_RERANKER_URL"), "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, - "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), + "headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")), "rerank_source": os.getenv("MOS_RERANK_SOURCE"), "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), }, @@ -406,6 +406,7 @@ def get_embedder_config() -> dict[str, Any]: "api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"), "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), + "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), }, } else: # ollama diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index 70095a194..d88b6005e 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig): embedding_dims: int | None = Field( default=None, description="Number of dimensions for the embedding" ) + headers_extra: dict[str, Any] | None = Field( + default=None, + description="Extra headers for the embedding model, only for universal_api backend", + ) class OllamaEmbedderConfig(BaseEmbedderConfig): From 5637c9d16adeff6a0a8439dae9c8f1f452b02ede Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 19 Nov 2025 21:03:07 +0800 Subject: [PATCH 02/21] feat: headers add --- src/memos/embedders/universal_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index fc51cf073..6cbd06443 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -16,7 +16,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig): self.config = config if self.provider == "openai": - self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url) + self.client = OpenAIClient( + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.headers_extra if config.headers_extra else None, + ) elif self.provider == "azure": self.client = AzureClient( azure_endpoint=config.base_url, From 68831c080e9db6633136f0ed7424b8b4bc275ccb Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 20 Nov 2025 16:45:03 +0800 Subject: [PATCH 03/21] feat: update search agent --- examples/deepsearch_example.py | 76 ++++++ src/memos/configs/mem_agent.py | 60 +++++ src/memos/mem_agent/base.py | 17 ++ src/memos/mem_agent/deepsearch_agent.py | 320 +++++++++++++++++++++++ src/memos/mem_agent/factory.py | 21 ++ src/memos/templates/mem_agent_prompts.py | 76 ++++++ 6 files changed, 570 insertions(+) create mode 100644 examples/deepsearch_example.py create mode 100644 src/memos/configs/mem_agent.py create mode 100644 src/memos/mem_agent/base.py create mode 100644 src/memos/mem_agent/deepsearch_agent.py create mode 100644 src/memos/mem_agent/factory.py create mode 100644 src/memos/templates/mem_agent_prompts.py diff --git a/examples/deepsearch_example.py b/examples/deepsearch_example.py new file mode 100644 index 000000000..43bf4c012 --- /dev/null +++ b/examples/deepsearch_example.py @@ -0,0 +1,76 @@ +""" +Example usage of DeepSearchAgent in MemOS. + +This example demonstrates how to initialize and use the DeepSearchAgent +for comprehensive information retrieval and synthesis. +""" + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchAgent +from memos.llms.factory import LLMFactory +from memos.configs.llm import LLMConfigFactory + + +def main(): + """Example usage of DeepSearchAgent.""" + + # 1. Configure the LLM + llm_config = LLMConfigFactory( + backend="openai", # or "ollama", "azure", etc. + config={ + "api_key": "your-api-key-here", + "model_name_or_path": "gpt-4", + "temperature": 0.7, + "max_tokens": 2048 + } + ) + + # 2. Create LLM instance + llm = LLMFactory.from_config(llm_config) + + # 3. Configure DeepSearchAgent + agent_config = DeepSearchAgentConfig( + agent_name="DeepSearchAgent", + description="Advanced deep search agent for comprehensive information retrieval", + max_iterations=3, + timeout=60 + ) + + # 4. Create DeepSearchAgent instance + agent = DeepSearchAgent(agent_config) + + # 5. Initialize the agent with LLM + agent.set_llm(llm) + + # 6. Set up memory retriever (this would typically be injected by the framework) + # agent.set_memory_retriever(your_memory_retriever) + + # 7. Example queries + queries = [ + "What are the latest developments in AI research?", + "Tell me about my recent project meetings and their outcomes", + "What are the key trends in machine learning this year?" + ] + + # 8. Process queries + for query in queries: + print(f"\n{'='*60}") + print(f"Query: {query}") + print(f"{'='*60}") + + try: + # Run the deep search pipeline + response = agent.run( + input=query, + history=["Previous conversation context"], + user_id="example_user" + ) + + print(f"Response: {response}") + + except Exception as e: + print(f"Error processing query: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py new file mode 100644 index 000000000..e58f8e67c --- /dev/null +++ b/src/memos/configs/mem_agent.py @@ -0,0 +1,60 @@ +from typing import Any, ClassVar + +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig + + +class BaseAgentConfig(BaseConfig): + """Base configuration class for agents.""" + + agent_name: str = Field(..., description="Name of the agent") + description: str | None = Field( + default=None, description="Description of the agent" + ) + + +class SimpleAgentConfig(BaseAgentConfig): + """Simple agent configuration class.""" + + max_iterations: int = Field( + default=10, description="Maximum number of iterations for the agent" + ) + timeout: int = Field( + default=30, description="Timeout in seconds for agent execution" + ) + +class DeepSearchAgentConfig(BaseAgentConfig): + """Deep search agent configuration class.""" + + max_iterations: int = Field( + default=3, description="Maximum number of iterations for the agent" + ) + timeout: int = Field( + default=30, description="Timeout in seconds for agent execution" + ) + +class MemAgentConfigFactory(BaseConfig): + """Factory class for creating agent configurations.""" + + backend: str = Field(..., description="Backend for agent") + config: dict[str, Any] = Field(..., description="Configuration for the agent backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "simple": SimpleAgentConfig, + "deep_search": DeepSearchAgentConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "AgentConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self \ No newline at end of file diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py new file mode 100644 index 000000000..3b54a846f --- /dev/null +++ b/src/memos/mem_agent/base.py @@ -0,0 +1,17 @@ +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod +from pydantic import BaseModel +from memos.configs.mem_agent import BaseAgentConfig + +class BaseMemAgent(ABC): + """ + Base class for all agents. + """ + def __init__(self, config: BaseAgentConfig): + """Initialize the BaseMemAgent with the given configuration.""" + + @abstractmethod + def run(self, input: str) -> str: + """ + Run the agent. + """ \ No newline at end of file diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py new file mode 100644 index 000000000..eb67fabba --- /dev/null +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -0,0 +1,320 @@ +""" +Deep Search Agent implementation for MemOS. + +This module implements a sophisticated deep search agent that performs iterative +query refinement and memory retrieval to provide comprehensive answers. +""" + +import json +import logging +from memos.log import get_logger +from typing import Any, Dict, List, Optional, Tuple + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.llms.base import BaseLLM +from memos.mem_agent.base import BaseMemAgent +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageDict, MessageList +from memos.templates.mem_agent_prompts import ( + QUERY_REWRITE_PROMPT, + REFLECTION_PROMPT, + KEYWORD_EXTRACTION_PROMPT, + FINAL_GENERATION_PROMPT + ) + +logger = get_logger(__name__) + + +class QueryRewriter(BaseMemAgent): + """ + Specialized agent for rewriting queries based on conversation history. + Corresponds to the "LLM subAgent (Rewrite...)" in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): + self.llm = llm + self.name = name + + def run(self, query: str, history: list[str]| None = None) -> str: + """ + Rewrite the query to be standalone and more searchable. + + Args: + query: Original user query + history: List of previous conversation messages + + Returns: + Rewritten query string + """ + if history is None: + history = [] + + history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages + + prompt = QUERY_REWRITE_PROMPT.format( + history=history_str if history_str else "No previous conversation", + query=query + ) + + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + logger.info(f"[{self.name}] Rewritten query: {response.strip()}") + return response.strip() + except Exception as e: + logger.error(f"[{self.name}] Error rewriting query: {e}") + return query # Fallback to original query + + +class ReflectionAgent: + """ + Specialized agent for analyzing information sufficiency. + Corresponds to the decision diamond in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, name: str = "Reflector"): + self.llm = llm + self.name = name + + def run(self, query: str, context: List[str]) -> Dict[str, Any]: + """ + Analyze whether retrieved context is sufficient to answer the query. + + Args: + query: User query + context: List of retrieved context strings + + Returns: + Dictionary with status, reasoning, and missing entities + """ + context_str = "\n".join([f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" + for ctx in context[:10]]) # Limit context size + + prompt = REFLECTION_PROMPT.format(query=query, context=context_str) + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + result = json.loads(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result.get('status', 'unknown')}") + return result + except (json.JSONDecodeError, Exception) as e: + logger.error(f"[{self.name}] Error in reflection analysis: {e}") + # Fallback response + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [] + } + + +class DeepSearchMemAgent(BaseMemAgent): + """ + Main orchestrator agent implementing the deep search pipeline. + + This agent coordinates multiple sub-agents to perform iterative query refinement, + memory retrieval, and information synthesis as shown in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, memory_retriever: BaseMemoryRetriever | None = None): + super().__init__(config) + self.config = config + self.max_iterations = config.max_iterations + self.timeout = config.timeout + self.llm: Optional[BaseLLM] = llm + self.query_rewriter: Optional[QueryRewriteAgent] = QueryRewriter(llm, "QueryRewriter") + self.reflector: Optional[ReflectionAgent] = ReflectionAgent(llm, "Reflector") + self.memory_retriever = memory_retriever + + def _set_llm(self, llm: BaseLLM) -> None: + """Set the LLM and initialize sub-agents.""" + self.llm = llm + self.query_rewriter = QueryRewriteAgent(llm, "QueryRewriter") + self.reflector = ReflectionAgent(llm, "Reflector") + self.keyword_extractor = KeywordExtractionAgent(llm, "KeywordExtractor") + logger.info("LLM and sub-agents initialized") + + def _set_memory_retriever(self, retriever) -> None: + """Set the memory retrieval interface.""" + self.memory_retriever = retriever + logger.info("Memory retriever interface set") + + def run(self, input: str, **kwargs) -> str: + """ + Main execution method implementing the deep search pipeline. + + Args: + input: User query string + Returns: + Comprehensive response string + """ + if not self.llm: + raise RuntimeError("LLM not initialized. Call set_llm() first.") + + query = input + history = kwargs.get("history", []) + user_id = kwargs.get("user_id") + + # Step 1: Query Rewriting + current_query = self.query_rewriter.run(query, history) + + # Step 2: Keyword Extraction and Planning + keyword_analysis = self.keyword_extractor.run(current_query) + search_keywords = keyword_analysis.get("keywords", [current_query]) + + accumulated_context = [] + accumulated_memories = [] + + # Step 3: Iterative Search and Reflection Loop + for iteration in range(self.max_iterations): + search_results = self._perform_memory_search( + current_query, + keywords=search_keywords, + user_id=user_id + ) + + if search_results: + context_batch = [self._extract_context_from_memory(mem) for mem in search_results] + accumulated_context.extend(context_batch) + accumulated_memories.extend(search_results) + + reflection_result = self.reflector.run(current_query, context_batch) + status = reflection_result.get("status", "sufficient") + reasoning = reflection_result.get("reasoning", "") + + logger.info(f"Reflection status: {status} - {reasoning}") + + if status == "sufficient": + logger.info("Sufficient information collected") + break + elif status == "needs_raw": + logger.info("Need original sources, retrieving raw content") + break + elif status == "missing_info": + missing_entities = reflection_result.get("missing_entities", []) + logger.info(f"Missing information: {missing_entities}") + if missing_entities: + refined_query = self._refine_query_for_missing_info( + current_query, missing_entities + ) + current_query = refined_query + logger.info(f"Refined query: {current_query}") + else: + logger.warning(f"No search results for iteration {iteration + 1}") + if iteration == 0: + current_query = query + else: + break + final_answer = self._generate_final_answer( + original_query=query, + search_results=accumulated_memories, + context=accumulated_context, + missing_info=keyword_analysis.get("search_strategy", "") + ) + + logger.info("Deep search pipeline completed") + return final_answer + + def _perform_memory_search( + self, + query: str, + keywords: List[str] = None, + user_id: str = None, + top_k: int = 10 + ) -> List[TextualMemoryItem]: + """ + Perform memory search using the configured retriever. + + Args: + query: Search query + keywords: Additional keywords for search + user_id: User identifier + top_k: Number of results to retrieve + + Returns: + List of retrieved memory items + """ + if not self.memory_retriever: + logger.warning("Memory retriever not configured, returning empty results") + return [] + + try: + # Use the memory retriever interface + # This is a placeholder - actual implementation depends on the retriever interface + search_query = query + if keywords and len(keywords) > 1: + search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords + + # Assuming the retriever has a search method similar to TreeTextMemory + results = self.memory_retriever.search( + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id + ) + + return results if isinstance(results, list) else [] + + except Exception as e: + logger.error(f"Error performing memory search: {e}") + return [] + + def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: + """Extract readable context from a memory item.""" + if hasattr(memory_item, 'memory'): + return str(memory_item.memory) + elif hasattr(memory_item, 'content'): + return str(memory_item.content) + else: + return str(memory_item) + + def _refine_query_for_missing_info(self, query: str, missing_entities: List[str]) -> str: + """Refine the query to search for missing information.""" + if not missing_entities: + return query + + # Simple refinement strategy - append missing entities + entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities + refined_query = f"{query} {entities_str}" + + return refined_query + + def _generate_final_answer( + self, + original_query: str, + search_results: List[TextualMemoryItem], + context: List[str], + missing_info: str = "" + ) -> str: + """ + Generate the final comprehensive answer. + + Args: + original_query: Original user query + search_results: All retrieved memory items + context: Extracted context strings + missing_info: Information about missing data + + Returns: + Final answer string + """ + # Prepare context for the prompt + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context + sources = f"Retrieved {len(search_results)} memory items" if search_results else "No specific sources" + + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified" + ) + + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"Error generating final answer: {e}") + return f"I apologize, but I encountered an error while processing your query: {original_query}. Please try again." diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py new file mode 100644 index 000000000..be656acc0 --- /dev/null +++ b/src/memos/mem_agent/factory.py @@ -0,0 +1,21 @@ +from typing import Any, ClassVar + +from memos.configs.mem_agent import MemAgentConfigFactory +from memos.mem_agent.base import BaseMemAgent +from memos.mem_agent.deepsearch_agent import DeepSearchAgent + + +class MemAgentFactory: + """Factory class for creating MemAgent instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "deep_search": DeepSearchAgent, + } + + @classmethod + def from_config(cls, config_factory: MemAgentConfigFactory) -> BaseMemAgent: + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + mem_agent_class = cls.backend_to_class[backend] + return mem_agent_class(config_factory.config) \ No newline at end of file diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py new file mode 100644 index 000000000..399b2cd75 --- /dev/null +++ b/src/memos/templates/mem_agent_prompts.py @@ -0,0 +1,76 @@ +QUERY_REWRITE_PROMPT = """ +You are a query rewriting specialist. Your task is to rewrite user queries to be more standalone and searchable. + +Given the conversation history and current user query, rewrite the query to: +1. Be self-contained and independent of conversation context +2. Include relevant context from history when necessary +3. Maintain the original intent and scope +4. Use clear, specific terminology + +Conversation History: +{history} + +Current Query: {query} + +Rewritten Query:""" + +REFLECTION_PROMPT = """ +You are an information sufficiency analyst. Evaluate whether the retrieved context is sufficient to answer the user's query. + +Query: {query} +Retrieved Context: +{context} + +Analyze the context and determine the next step. Return your response in JSON format with the following structure: +{{ + "status": "sufficient|missing_info|needs_raw", + "reasoning": "Brief explanation of your decision", + "missing_entities": ["entity1", "entity2"] // Only if status is "missing_info", + +}} + +Status definitions: +- "sufficient": Context fully answers the query +- "missing_info": Key information is missing (e.g., specific dates, locations, details) +- "needs_raw": Content is relevant but too summarized/vague, need original sources + +Response:""" + +KEYWORD_EXTRACTION_PROMPT = """ +Analyze the user query and extract key search terms and identify optimal data sources. + +Query: {query} + +Extract: +1. Key search terms and concepts +2. Important entities (people, places, dates, etc.) +3. Suggested data sources or memory types to search + +Return response in JSON format: +{{ + "keywords": ["keyword1", "keyword2"], + "entities": ["entity1", "entity2"], + "search_strategy": "Brief strategy description" +}} + +Response:""" + + +FINAL_GENERATION_PROMPT = """ +You are a comprehensive information synthesizer. Generate a complete answer based on the retrieved information. + +User Query: {query} +Search Sources: {sources} +Retrieved Information: +{context} + +Missing Information (if any): {missing_info} + +Instructions: +1. Synthesize all relevant information to answer the query comprehensively +2. If information is missing, acknowledge gaps and suggest next steps +3. Maintain accuracy and cite sources when possible +4. Provide a well-structured, coherent response +5. Use natural, conversational tone + +Response:""" \ No newline at end of file From 58c512d8a7086e0e723bf232257067f8c7516876 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:08:50 +0800 Subject: [PATCH 04/21] feat: upadte mem story --- examples/mem_agent/deepsearch_example.py | 350 ++++++++++++++++++ .../deepsearch_with_api_components.py | 223 +++++++++++ src/memos/mem_agent/deepsearch_agent.py | 57 +-- src/memos/mem_agent/factory.py | 28 +- 4 files changed, 628 insertions(+), 30 deletions(-) create mode 100644 examples/mem_agent/deepsearch_example.py create mode 100644 examples/mem_agent/deepsearch_with_api_components.py diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py new file mode 100644 index 000000000..62f26f0b1 --- /dev/null +++ b/examples/mem_agent/deepsearch_example.py @@ -0,0 +1,350 @@ +""" +DeepSearch Agent Usage Examples + +This example demonstrates two ways to initialize DeepSearchMemAgent: +1. Using Factory pattern (recommended) +2. Direct initialization + +DeepSearchMemAgent implements iterative deep search, providing comprehensive answers through: +- Query rewriting: Optimize queries based on conversation history +- Iterative retrieval: Collect information through multiple search rounds +- Reflective analysis: Determine if information is sufficient +- Comprehensive response: Generate complete final answers +""" + +import os +import sys + +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphConfigFactory +from memos.configs.llms import LLMConfigFactory +from memos.configs.mem_agent import DeepSearchAgentConfig, MemAgentConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.textual_memory import TreeTextMemoryConfig +from memos.embedders.base import BaseEmbedder +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.mem_agent.factory import MemAgentFactory +from memos.mem_cube.naive_cube import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.memory_manager.memory_manager import MemoryManager +from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory +from memos.log import get_logger + +logger = get_logger(__name__) + + +def build_minimal_components(): + """ + Build the minimal component set required for DeepSearchMemAgent. + + Only need to initialize: + 1. LLM - Used for query rewriting, reflection, and final answer generation + 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval + + Returns: + dict: Dictionary containing llm and naive_mem_cube + """ + logger.info("Starting to build minimal component set...") + + # 1. Initialize LLM + llm_config = LLMConfigFactory( + backend="openai", + config={ + "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + "temperature": 0.7, + } + ) + llm = LLMFactory.from_config(llm_config) + logger.info("LLM initialization completed") + + # 2. Initialize Embedder (required by text_mem) + embedder_config = EmbedderConfigFactory( + backend="universal_api", + config={ + "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), + "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), + "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), + } + ) + embedder = EmbedderFactory.from_config(embedder_config) + logger.info("Embedder initialization completed") + + # 3. Initialize GraphDB (required by text_mem) + graph_db_config = GraphConfigFactory( + backend="polardb", + config={ + "host": os.getenv("POLAR_DB_HOST", "localhost"), + "port": int(os.getenv("POLAR_DB_PORT", "5432")), + "user": os.getenv("POLAR_DB_USER", "root"), + "password": os.getenv("POLAR_DB_PASSWORD", "123456"), + "db_name": os.getenv("POLAR_DB_DB_NAME", "shared_memos_db"), + "user_name": "memos_default", + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + } + ) + graph_db = GraphStoreFactory.from_config(graph_db_config) + logger.info("✓ GraphDB init Done") + + # 4. Initialize MemReader (required by text_mem) + mem_reader_config = MemReaderConfigFactory( + backend="simple_struct", + config={ + "llm": { + "backend": "openai", + "config": { + "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + } + } + } + ) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + logger.info("✓ MemReader init Done") + + # 5. Initialize MemoryManager + memory_manager = MemoryManager( + graph_db=graph_db, + embedder=embedder, + llm=llm, + memory_size=1000, + is_reorganize=False, + ) + logger.info("✓ MemoryManager init Done") + + # 6. Initialize TextMemory + text_mem_config = TreeTextMemoryConfig( + reorganize=False, + max_depth=3, + ) + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=None, + memory_manager=memory_manager, + config=text_mem_config, + internet_retriever=None, + ) + logger.info("✓ TextMemory initialization completed") + + # 7. Create NaiveMemCube + naive_mem_cube = NaiveMemCube( + text_mem=text_mem, + pref_mem=None, + act_mem=None, + para_mem=None, + ) + logger.info("✓ NaiveMemCube creation completed") + + logger.info("All components initialized!") + + return { + "llm": llm, + "naive_mem_cube": naive_mem_cube, + "embedder": embedder, + "graph_db": graph_db, + "mem_reader": mem_reader, + } + + +def example_1_factory_initialization(): + """ + Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) + + Advantages: + - Separation of configuration and code + - Easy to manage and modify + - Support loading from configuration files + """ + logger.info("\n" + "="*60) + logger.info("Example 1: Initialize using Factory pattern") + logger.info("="*60 + "\n") + + # Build necessary components + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Create configuration Factory + agent_config_factory = MemAgentConfigFactory( + backend="deep_search", + config={ + "agent_name": "MyDeepSearchAgent", + "description": "Intelligent agent for deep search", + "max_iterations": 3, # Maximum number of iterations + "timeout": 60, # Timeout in seconds + } + ) + + # Create Agent using Factory + # Pass text_mem as memory_retriever, it provides search method + deep_search_agent = MemAgentFactory.from_config( + config_factory=agent_config_factory, + llm=llm, + memory_retriever=naive_mem_cube.text_mem + ) + + logger.info(f"✓ DeepSearchMemAgent created successfully") + logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def example_3_usage(deep_search_agent, components): + """ + Example 3: Using DeepSearchMemAgent for search + + Demonstrates how to: + 1. Add memories to the system + 2. Use Agent for deep search + 3. Get comprehensive answers + """ + logger.info("\n" + "="*60) + logger.info("Example 3: Using DeepSearchMemAgent") + logger.info("="*60 + "\n") + + naive_mem_cube = components["naive_mem_cube"] + text_mem = naive_mem_cube.text_mem + + # Simulate adding some memories + logger.info("1. Adding test memories...") + test_memories = [ + { + "user_name": "test_user", + "messages": ["Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence."], + "source": "manual" + }, + { + "user_name": "test_user", + "messages": ["Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance."], + "source": "manual" + }, + { + "user_name": "test_user", + "messages": ["Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks."], + "source": "manual" + } + ] + + for memory in test_memories: + try: + text_mem.add( + user_name=memory["user_name"], + messages=memory["messages"], + source=memory.get("source", "manual") + ) + logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") + except Exception as e: + logger.warning(f" ✗ Failed to add memory: {e}") + + # Use Agent for search + logger.info("\n2. Executing deep search...") + query = "What is the relationship between artificial intelligence and machine learning?" + logger.info(f" Query: {query}") + + try: + response = deep_search_agent.run( + query=query, + user_id="test_user", + history=[ + "Hello", + "I want to learn about artificial intelligence" + ] + ) + + logger.info("\n3. Search results:") + logger.info("-" * 60) + logger.info(response) + logger.info("-" * 60) + + except Exception as e: + logger.error(f"Error during search: {e}") + import traceback + traceback.print_exc() + + +def example_4_minimal_initialization(): + """ + Example 4: Minimal initialization (using default configuration) + + Use cases: + - Quick testing + - Prototype development + - No need for custom configuration + """ + logger.info("\n" + "="*60) + logger.info("Example 4: Minimal initialization (default configuration)") + logger.info("="*60 + "\n") + + # Build necessary components + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Direct initialization using default configuration + deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem + # config parameter omitted, will use default configuration + ) + + logger.info(f"✓ DeepSearchMemAgent created successfully (using default configuration)") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def main(): + """Main function: Run all examples""" + logger.info("DeepSearch Agent Usage Examples") + logger.info("="*60) + + # Check environment variables + required_env_vars = ["OPENAI_API_KEY"] + missing_vars = [var for var in required_env_vars if not os.getenv(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set the following environment variables:") + logger.error(" - OPENAI_API_KEY: OpenAI API key") + logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") + logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") + return + + try: + # Run Example 1: Factory pattern + agent_factory, components_factory = example_1_factory_initialization() + + # Run Example 2: Direct initialization + agent_direct, components_direct = example_2_direct_initialization() + + # Run Example 4: Minimal initialization + agent_minimal, components_minimal = example_4_minimal_initialization() + + # Run Example 3: Actual usage (using agent created with factory method) + example_3_usage(agent_factory, components_factory) + + logger.info("\n" + "="*60) + logger.info("All examples completed!") + logger.info("="*60) + + except Exception as e: + logger.error(f"Error running examples: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mem_agent/deepsearch_with_api_components.py b/examples/mem_agent/deepsearch_with_api_components.py new file mode 100644 index 000000000..48e4df43c --- /dev/null +++ b/examples/mem_agent/deepsearch_with_api_components.py @@ -0,0 +1,223 @@ +""" +DeepSearch Agent 与 API 组件集成示例 + +本示例展示如何直接使用 API 服务器初始化的组件来创建 DeepSearch Agent。 +这种方式可以避免重复初始化,直接复用已有的组件。 + +适用场景: +- 在已有的 API 服务器中添加 DeepSearch 功能 +- 使用统一的组件配置 +- 避免重复初始化开销 +""" + +import os +import sys + +# 确保可以导入 memos 模块 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from memos.api.handlers.component_init import init_server +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.log import get_logger + +logger = get_logger(__name__) + + +def create_deepsearch_from_api_components(): + """ + 从 API 服务器组件创建 DeepSearch Agent + + 这个函数演示了如何: + 1. 使用 init_server() 初始化所有组件 + 2. 提取需要的组件(llm 和 naive_mem_cube) + 3. 创建 DeepSearch Agent + """ + logger.info("="*60) + logger.info("使用 API 组件初始化 DeepSearch Agent") + logger.info("="*60 + "\n") + + # 步骤 1: 初始化所有服务器组件 + # 这会创建 llm, naive_mem_cube, scheduler 等所有组件 + logger.info("1. 初始化服务器组件...") + components = init_server() + logger.info(" ✓ 服务器组件初始化完成") + + # 步骤 2: 提取需要的组件 + logger.info("\n2. 提取必需组件...") + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + logger.info(f" ✓ LLM: {type(llm).__name__}") + logger.info(f" ✓ MemCube: {type(naive_mem_cube).__name__}") + logger.info(f" ✓ TextMemory: {type(naive_mem_cube.text_mem).__name__}") + + # 步骤 3: 创建 DeepSearch Agent 配置 + logger.info("\n3. 创建 DeepSearch Agent...") + config = DeepSearchAgentConfig( + agent_name="APIDeepSearchAgent", + description="基于 API 组件的深度搜索代理", + max_iterations=3, + timeout=60, + ) + + # 步骤 4: 初始化 DeepSearch Agent + # memory_retriever 使用 naive_mem_cube.text_mem + # 它提供了 search() 方法用于检索记忆 + deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem, + config=config + ) + + logger.info(f" ✓ DeepSearch Agent 创建成功") + logger.info(f" ✓ Agent 名称: {config.agent_name}") + logger.info(f" ✓ 最大迭代次数: {config.max_iterations}") + + return deep_search_agent, components + + +def demo_usage(deep_search_agent, components): + """ + 演示如何使用 DeepSearch Agent + """ + logger.info("\n" + "="*60) + logger.info("DeepSearch Agent 使用演示") + logger.info("="*60 + "\n") + + naive_mem_cube = components["naive_mem_cube"] + text_mem = naive_mem_cube.text_mem + + # 添加一些测试记忆 + logger.info("1. 添加测试记忆...") + test_data = [ + "MemOS 是一个先进的记忆操作系统,专门为 AI 系统设计。", + "MemOS 支持多种记忆类型:文本记忆、偏好记忆、行为记忆等。", + "DeepSearch Agent 是 MemOS 中的一个重要组件,用于深度搜索和信息检索。", + ] + + for i, content in enumerate(test_data, 1): + try: + text_mem.add( + user_name="demo_user", + messages=[content], + source="demo" + ) + logger.info(f" ✓ 记忆 {i}: {content[:40]}...") + except Exception as e: + logger.warning(f" ✗ 添加记忆失败: {e}") + + # 执行深度搜索 + logger.info("\n2. 执行深度搜索...") + query = "MemOS 支持哪些功能?" + logger.info(f" 查询: {query}") + + try: + response = deep_search_agent.run( + query=query, + user_id="demo_user", + history=[] + ) + + logger.info("\n3. 搜索结果:") + logger.info("-" * 60) + logger.info(response) + logger.info("-" * 60) + + return response + + except Exception as e: + logger.error(f" ✗ 搜索失败: {e}") + import traceback + traceback.print_exc() + return None + + +def integrate_with_api_router(): + """ + 展示如何将 DeepSearch Agent 集成到 API Router 中 + + 这段代码展示了在 server_router.py 中如何添加一个新的端点 + 来使用 DeepSearch Agent。 + """ + logger.info("\n" + "="*60) + logger.info("API Router 集成示例") + logger.info("="*60 + "\n") + + logger.info("在 server_router.py 中添加以下代码:\n") + + integration_code = ''' +# 在 server_router.py 顶部导入 +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent + +# 在初始化 handlers 之后,创建 DeepSearch Agent +deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem, + config=DeepSearchAgentConfig( + agent_name="APIDeepSearchAgent", + max_iterations=3, + timeout=60, + ) +) + +# 添加新的 API 端点 +@router.post("/deepsearch", summary="Deep search with memory") +def deep_search( + user_id: str, + query: str, + history: list[str] = None +): + """Execute deep search with iterative memory retrieval.""" + try: + response = deep_search_agent.run( + query=query, + user_id=user_id, + history=history or [] + ) + return {"status": "success", "response": response} + except Exception as e: + return {"status": "error", "message": str(e)} +''' + + print(integration_code) + + +def main(): + """主函数""" + logger.info("DeepSearch Agent 与 API 组件集成示例\n") + + # 检查环境变量 + required_env_vars = ["OPENAI_API_KEY"] + missing_vars = [var for var in required_env_vars if not os.getenv(var)] + + if missing_vars: + logger.error(f"❌ 缺少必要的环境变量: {', '.join(missing_vars)}") + logger.error("\n请设置以下环境变量:") + logger.error(" export OPENAI_API_KEY='your-api-key'") + logger.error(" export OPENAI_BASE_URL='your-base-url' # 可选") + return + + try: + # 创建 DeepSearch Agent + agent, components = create_deepsearch_from_api_components() + + # 演示使用 + demo_usage(agent, components) + + # 展示集成方法 + integrate_with_api_router() + + logger.info("\n" + "="*60) + logger.info("示例运行完成!") + logger.info("="*60) + + except Exception as e: + logger.error(f"\n❌ 运行失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() + diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index eb67fabba..6dd072728 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -117,57 +117,60 @@ class DeepSearchMemAgent(BaseMemAgent): memory retrieval, and information synthesis as shown in the architecture diagram. """ - def __init__(self, llm: BaseLLM, memory_retriever: BaseMemoryRetriever | None = None): - super().__init__(config) - self.config = config - self.max_iterations = config.max_iterations - self.timeout = config.timeout - self.llm: Optional[BaseLLM] = llm - self.query_rewriter: Optional[QueryRewriteAgent] = QueryRewriter(llm, "QueryRewriter") - self.reflector: Optional[ReflectionAgent] = ReflectionAgent(llm, "Reflector") + def __init__( + self, + llm: BaseLLM, + memory_retriever: Any = None, + config: Optional[DeepSearchAgentConfig] = None + ): + """ + Initialize DeepSearchMemAgent. + + Args: + llm: Language model for query rewriting and response generation + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + config: Configuration for deep search behavior + """ + self.config = config or DeepSearchAgentConfig() + self.max_iterations = self.config.max_iterations + self.timeout = self.config.timeout + self.llm: BaseLLM = llm + self.query_rewriter: QueryRewriter = QueryRewriter(llm, "QueryRewriter") + self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") self.memory_retriever = memory_retriever - def _set_llm(self, llm: BaseLLM) -> None: - """Set the LLM and initialize sub-agents.""" - self.llm = llm - self.query_rewriter = QueryRewriteAgent(llm, "QueryRewriter") - self.reflector = ReflectionAgent(llm, "Reflector") - self.keyword_extractor = KeywordExtractionAgent(llm, "KeywordExtractor") - logger.info("LLM and sub-agents initialized") - def _set_memory_retriever(self, retriever) -> None: """Set the memory retrieval interface.""" self.memory_retriever = retriever logger.info("Memory retriever interface set") - def run(self, input: str, **kwargs) -> str: + def run(self, query: str, **kwargs) -> str: """ Main execution method implementing the deep search pipeline. Args: - input: User query string + query: User query string + **kwargs: Additional arguments (history, user_id, etc.) Returns: Comprehensive response string """ if not self.llm: - raise RuntimeError("LLM not initialized. Call set_llm() first.") + raise RuntimeError("LLM not initialized.") - query = input history = kwargs.get("history", []) user_id = kwargs.get("user_id") # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) - # Step 2: Keyword Extraction and Planning - keyword_analysis = self.keyword_extractor.run(current_query) - search_keywords = keyword_analysis.get("keywords", [current_query]) - accumulated_context = [] accumulated_memories = [] + search_keywords = [] # Can be extended with keyword extraction - # Step 3: Iterative Search and Reflection Loop + # Step 2: Iterative Search and Reflection Loop for iteration in range(self.max_iterations): + logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") + search_results = self._perform_memory_search( current_query, keywords=search_keywords, @@ -206,11 +209,13 @@ def run(self, input: str, **kwargs) -> str: current_query = query else: break + + # Step 3: Generate final answer final_answer = self._generate_final_answer( original_query=query, search_results=accumulated_memories, context=accumulated_context, - missing_info=keyword_analysis.get("search_strategy", "") + missing_info="" ) logger.info("Deep search pipeline completed") diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index be656acc0..33e7b6a6b 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -2,20 +2,40 @@ from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent -from memos.mem_agent.deepsearch_agent import DeepSearchAgent +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent class MemAgentFactory: """Factory class for creating MemAgent instances.""" backend_to_class: ClassVar[dict[str, Any]] = { - "deep_search": DeepSearchAgent, + "deep_search": DeepSearchMemAgent, } @classmethod - def from_config(cls, config_factory: MemAgentConfigFactory) -> BaseMemAgent: + def from_config( + cls, + config_factory: MemAgentConfigFactory, + llm: Any, + memory_retriever: Any = None + ) -> BaseMemAgent: + """ + Create a MemAgent instance from configuration. + + Args: + config_factory: Configuration factory for the agent + llm: Language model instance + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + + Returns: + Initialized MemAgent instance + """ backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") mem_agent_class = cls.backend_to_class[backend] - return mem_agent_class(config_factory.config) \ No newline at end of file + return mem_agent_class( + llm=llm, + memory_retriever=memory_retriever, + config=config_factory.config + ) \ No newline at end of file From a497d4614c8bc39b2f868a28952e7add5692e74e Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:12:07 +0800 Subject: [PATCH 05/21] feat: update mem scehduler --- examples/deepsearch_example.py | 76 ------ examples/mem_agent/deepsearch_example.py | 156 ++++++------ .../deepsearch_with_api_components.py | 223 ------------------ src/memos/configs/mem_agent.py | 20 +- src/memos/mem_agent/base.py | 7 +- src/memos/mem_agent/deepsearch_agent.py | 149 ++++++------ src/memos/mem_agent/factory.py | 17 +- src/memos/templates/mem_agent_prompts.py | 2 +- 8 files changed, 169 insertions(+), 481 deletions(-) delete mode 100644 examples/deepsearch_example.py delete mode 100644 examples/mem_agent/deepsearch_with_api_components.py diff --git a/examples/deepsearch_example.py b/examples/deepsearch_example.py deleted file mode 100644 index 43bf4c012..000000000 --- a/examples/deepsearch_example.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Example usage of DeepSearchAgent in MemOS. - -This example demonstrates how to initialize and use the DeepSearchAgent -for comprehensive information retrieval and synthesis. -""" - -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchAgent -from memos.llms.factory import LLMFactory -from memos.configs.llm import LLMConfigFactory - - -def main(): - """Example usage of DeepSearchAgent.""" - - # 1. Configure the LLM - llm_config = LLMConfigFactory( - backend="openai", # or "ollama", "azure", etc. - config={ - "api_key": "your-api-key-here", - "model_name_or_path": "gpt-4", - "temperature": 0.7, - "max_tokens": 2048 - } - ) - - # 2. Create LLM instance - llm = LLMFactory.from_config(llm_config) - - # 3. Configure DeepSearchAgent - agent_config = DeepSearchAgentConfig( - agent_name="DeepSearchAgent", - description="Advanced deep search agent for comprehensive information retrieval", - max_iterations=3, - timeout=60 - ) - - # 4. Create DeepSearchAgent instance - agent = DeepSearchAgent(agent_config) - - # 5. Initialize the agent with LLM - agent.set_llm(llm) - - # 6. Set up memory retriever (this would typically be injected by the framework) - # agent.set_memory_retriever(your_memory_retriever) - - # 7. Example queries - queries = [ - "What are the latest developments in AI research?", - "Tell me about my recent project meetings and their outcomes", - "What are the key trends in machine learning this year?" - ] - - # 8. Process queries - for query in queries: - print(f"\n{'='*60}") - print(f"Query: {query}") - print(f"{'='*60}") - - try: - # Run the deep search pipeline - response = agent.run( - input=query, - history=["Previous conversation context"], - user_id="example_user" - ) - - print(f"Response: {response}") - - except Exception as e: - print(f"Error processing query: {e}") - - -if __name__ == "__main__": - main() diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 62f26f0b1..dacfe7f1e 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -13,25 +13,24 @@ """ import os -import sys from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphConfigFactory from memos.configs.llms import LLMConfigFactory -from memos.configs.mem_agent import DeepSearchAgentConfig, MemAgentConfigFactory +from memos.configs.mem_agent import MemAgentConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory from memos.configs.textual_memory import TreeTextMemoryConfig -from memos.embedders.base import BaseEmbedder from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory +from memos.log import get_logger from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.mem_agent.factory import MemAgentFactory from memos.mem_cube.naive_cube import NaiveMemCube from memos.mem_reader.factory import MemReaderFactory from memos.memories.memory_manager.memory_manager import MemoryManager from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory -from memos.log import get_logger + logger = get_logger(__name__) @@ -39,16 +38,16 @@ def build_minimal_components(): """ Build the minimal component set required for DeepSearchMemAgent. - + Only need to initialize: 1. LLM - Used for query rewriting, reflection, and final answer generation 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval - + Returns: dict: Dictionary containing llm and naive_mem_cube """ logger.info("Starting to build minimal component set...") - + # 1. Initialize LLM llm_config = LLMConfigFactory( backend="openai", @@ -57,11 +56,11 @@ def build_minimal_components(): "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), "temperature": 0.7, - } + }, ) llm = LLMFactory.from_config(llm_config) logger.info("LLM initialization completed") - + # 2. Initialize Embedder (required by text_mem) embedder_config = EmbedderConfigFactory( backend="universal_api", @@ -69,11 +68,11 @@ def build_minimal_components(): "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), - } + }, ) embedder = EmbedderFactory.from_config(embedder_config) logger.info("Embedder initialization completed") - + # 3. Initialize GraphDB (required by text_mem) graph_db_config = GraphConfigFactory( backend="polardb", @@ -87,11 +86,11 @@ def build_minimal_components(): "use_multi_db": False, "auto_create": True, "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - } + }, ) graph_db = GraphStoreFactory.from_config(graph_db_config) logger.info("✓ GraphDB init Done") - + # 4. Initialize MemReader (required by text_mem) mem_reader_config = MemReaderConfigFactory( backend="simple_struct", @@ -102,13 +101,13 @@ def build_minimal_components(): "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), - } + }, } - } + }, ) mem_reader = MemReaderFactory.from_config(mem_reader_config) logger.info("✓ MemReader init Done") - + # 5. Initialize MemoryManager memory_manager = MemoryManager( graph_db=graph_db, @@ -118,7 +117,7 @@ def build_minimal_components(): is_reorganize=False, ) logger.info("✓ MemoryManager init Done") - + # 6. Initialize TextMemory text_mem_config = TreeTextMemoryConfig( reorganize=False, @@ -135,7 +134,7 @@ def build_minimal_components(): internet_retriever=None, ) logger.info("✓ TextMemory initialization completed") - + # 7. Create NaiveMemCube naive_mem_cube = NaiveMemCube( text_mem=text_mem, @@ -144,9 +143,9 @@ def build_minimal_components(): para_mem=None, ) logger.info("✓ NaiveMemCube creation completed") - + logger.info("All components initialized!") - + return { "llm": llm, "naive_mem_cube": naive_mem_cube, @@ -159,21 +158,21 @@ def build_minimal_components(): def example_1_factory_initialization(): """ Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) - + Advantages: - Separation of configuration and code - Easy to manage and modify - Support loading from configuration files """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 1: Initialize using Factory pattern") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + # Build necessary components components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - + # Create configuration Factory agent_config_factory = MemAgentConfigFactory( backend="deep_search", @@ -182,139 +181,141 @@ def example_1_factory_initialization(): "description": "Intelligent agent for deep search", "max_iterations": 3, # Maximum number of iterations "timeout": 60, # Timeout in seconds - } + }, ) - + # Create Agent using Factory # Pass text_mem as memory_retriever, it provides search method deep_search_agent = MemAgentFactory.from_config( - config_factory=agent_config_factory, - llm=llm, - memory_retriever=naive_mem_cube.text_mem + config_factory=agent_config_factory, llm=llm, memory_retriever=naive_mem_cube.text_mem ) - - logger.info(f"✓ DeepSearchMemAgent created successfully") + + logger.info("✓ DeepSearchMemAgent created successfully") logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - + return deep_search_agent, components def example_3_usage(deep_search_agent, components): """ Example 3: Using DeepSearchMemAgent for search - + Demonstrates how to: 1. Add memories to the system 2. Use Agent for deep search 3. Get comprehensive answers """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 3: Using DeepSearchMemAgent") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + naive_mem_cube = components["naive_mem_cube"] text_mem = naive_mem_cube.text_mem - + # Simulate adding some memories logger.info("1. Adding test memories...") test_memories = [ { "user_name": "test_user", - "messages": ["Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence."], - "source": "manual" + "messages": [ + "Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence." + ], + "source": "manual", }, { - "user_name": "test_user", - "messages": ["Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance."], - "source": "manual" + "user_name": "test_user", + "messages": [ + "Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance." + ], + "source": "manual", }, { "user_name": "test_user", - "messages": ["Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks."], - "source": "manual" - } + "messages": [ + "Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks." + ], + "source": "manual", + }, ] - + for memory in test_memories: try: text_mem.add( user_name=memory["user_name"], messages=memory["messages"], - source=memory.get("source", "manual") + source=memory.get("source", "manual"), ) logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") except Exception as e: logger.warning(f" ✗ Failed to add memory: {e}") - + # Use Agent for search logger.info("\n2. Executing deep search...") query = "What is the relationship between artificial intelligence and machine learning?" logger.info(f" Query: {query}") - + try: response = deep_search_agent.run( query=query, user_id="test_user", - history=[ - "Hello", - "I want to learn about artificial intelligence" - ] + history=["Hello", "I want to learn about artificial intelligence"], ) - + logger.info("\n3. Search results:") logger.info("-" * 60) logger.info(response) logger.info("-" * 60) - + except Exception as e: logger.error(f"Error during search: {e}") import traceback + traceback.print_exc() def example_4_minimal_initialization(): """ Example 4: Minimal initialization (using default configuration) - + Use cases: - Quick testing - Prototype development - No need for custom configuration """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 4: Minimal initialization (default configuration)") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + # Build necessary components components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - + # Direct initialization using default configuration deep_search_agent = DeepSearchMemAgent( llm=llm, - memory_retriever=naive_mem_cube.text_mem + memory_retriever=naive_mem_cube.text_mem, # config parameter omitted, will use default configuration ) - - logger.info(f"✓ DeepSearchMemAgent created successfully (using default configuration)") + + logger.info("✓ DeepSearchMemAgent created successfully (using default configuration)") logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - + return deep_search_agent, components def main(): """Main function: Run all examples""" logger.info("DeepSearch Agent Usage Examples") - logger.info("="*60) - + logger.info("=" * 60) + # Check environment variables required_env_vars = ["OPENAI_API_KEY"] missing_vars = [var for var in required_env_vars if not os.getenv(var)] - + if missing_vars: logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") logger.error("Please set the following environment variables:") @@ -322,29 +323,30 @@ def main(): logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") return - + try: # Run Example 1: Factory pattern agent_factory, components_factory = example_1_factory_initialization() - + # Run Example 2: Direct initialization agent_direct, components_direct = example_2_direct_initialization() - + # Run Example 4: Minimal initialization agent_minimal, components_minimal = example_4_minimal_initialization() - + # Run Example 3: Actual usage (using agent created with factory method) example_3_usage(agent_factory, components_factory) - - logger.info("\n" + "="*60) + + logger.info("\n" + "=" * 60) logger.info("All examples completed!") - logger.info("="*60) - + logger.info("=" * 60) + except Exception as e: logger.error(f"Error running examples: {e}") import traceback + traceback.print_exc() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/mem_agent/deepsearch_with_api_components.py b/examples/mem_agent/deepsearch_with_api_components.py deleted file mode 100644 index 48e4df43c..000000000 --- a/examples/mem_agent/deepsearch_with_api_components.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -DeepSearch Agent 与 API 组件集成示例 - -本示例展示如何直接使用 API 服务器初始化的组件来创建 DeepSearch Agent。 -这种方式可以避免重复初始化,直接复用已有的组件。 - -适用场景: -- 在已有的 API 服务器中添加 DeepSearch 功能 -- 使用统一的组件配置 -- 避免重复初始化开销 -""" - -import os -import sys - -# 确保可以导入 memos 模块 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) - -from memos.api.handlers.component_init import init_server -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent -from memos.log import get_logger - -logger = get_logger(__name__) - - -def create_deepsearch_from_api_components(): - """ - 从 API 服务器组件创建 DeepSearch Agent - - 这个函数演示了如何: - 1. 使用 init_server() 初始化所有组件 - 2. 提取需要的组件(llm 和 naive_mem_cube) - 3. 创建 DeepSearch Agent - """ - logger.info("="*60) - logger.info("使用 API 组件初始化 DeepSearch Agent") - logger.info("="*60 + "\n") - - # 步骤 1: 初始化所有服务器组件 - # 这会创建 llm, naive_mem_cube, scheduler 等所有组件 - logger.info("1. 初始化服务器组件...") - components = init_server() - logger.info(" ✓ 服务器组件初始化完成") - - # 步骤 2: 提取需要的组件 - logger.info("\n2. 提取必需组件...") - llm = components["llm"] - naive_mem_cube = components["naive_mem_cube"] - logger.info(f" ✓ LLM: {type(llm).__name__}") - logger.info(f" ✓ MemCube: {type(naive_mem_cube).__name__}") - logger.info(f" ✓ TextMemory: {type(naive_mem_cube.text_mem).__name__}") - - # 步骤 3: 创建 DeepSearch Agent 配置 - logger.info("\n3. 创建 DeepSearch Agent...") - config = DeepSearchAgentConfig( - agent_name="APIDeepSearchAgent", - description="基于 API 组件的深度搜索代理", - max_iterations=3, - timeout=60, - ) - - # 步骤 4: 初始化 DeepSearch Agent - # memory_retriever 使用 naive_mem_cube.text_mem - # 它提供了 search() 方法用于检索记忆 - deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - config=config - ) - - logger.info(f" ✓ DeepSearch Agent 创建成功") - logger.info(f" ✓ Agent 名称: {config.agent_name}") - logger.info(f" ✓ 最大迭代次数: {config.max_iterations}") - - return deep_search_agent, components - - -def demo_usage(deep_search_agent, components): - """ - 演示如何使用 DeepSearch Agent - """ - logger.info("\n" + "="*60) - logger.info("DeepSearch Agent 使用演示") - logger.info("="*60 + "\n") - - naive_mem_cube = components["naive_mem_cube"] - text_mem = naive_mem_cube.text_mem - - # 添加一些测试记忆 - logger.info("1. 添加测试记忆...") - test_data = [ - "MemOS 是一个先进的记忆操作系统,专门为 AI 系统设计。", - "MemOS 支持多种记忆类型:文本记忆、偏好记忆、行为记忆等。", - "DeepSearch Agent 是 MemOS 中的一个重要组件,用于深度搜索和信息检索。", - ] - - for i, content in enumerate(test_data, 1): - try: - text_mem.add( - user_name="demo_user", - messages=[content], - source="demo" - ) - logger.info(f" ✓ 记忆 {i}: {content[:40]}...") - except Exception as e: - logger.warning(f" ✗ 添加记忆失败: {e}") - - # 执行深度搜索 - logger.info("\n2. 执行深度搜索...") - query = "MemOS 支持哪些功能?" - logger.info(f" 查询: {query}") - - try: - response = deep_search_agent.run( - query=query, - user_id="demo_user", - history=[] - ) - - logger.info("\n3. 搜索结果:") - logger.info("-" * 60) - logger.info(response) - logger.info("-" * 60) - - return response - - except Exception as e: - logger.error(f" ✗ 搜索失败: {e}") - import traceback - traceback.print_exc() - return None - - -def integrate_with_api_router(): - """ - 展示如何将 DeepSearch Agent 集成到 API Router 中 - - 这段代码展示了在 server_router.py 中如何添加一个新的端点 - 来使用 DeepSearch Agent。 - """ - logger.info("\n" + "="*60) - logger.info("API Router 集成示例") - logger.info("="*60 + "\n") - - logger.info("在 server_router.py 中添加以下代码:\n") - - integration_code = ''' -# 在 server_router.py 顶部导入 -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent - -# 在初始化 handlers 之后,创建 DeepSearch Agent -deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - config=DeepSearchAgentConfig( - agent_name="APIDeepSearchAgent", - max_iterations=3, - timeout=60, - ) -) - -# 添加新的 API 端点 -@router.post("/deepsearch", summary="Deep search with memory") -def deep_search( - user_id: str, - query: str, - history: list[str] = None -): - """Execute deep search with iterative memory retrieval.""" - try: - response = deep_search_agent.run( - query=query, - user_id=user_id, - history=history or [] - ) - return {"status": "success", "response": response} - except Exception as e: - return {"status": "error", "message": str(e)} -''' - - print(integration_code) - - -def main(): - """主函数""" - logger.info("DeepSearch Agent 与 API 组件集成示例\n") - - # 检查环境变量 - required_env_vars = ["OPENAI_API_KEY"] - missing_vars = [var for var in required_env_vars if not os.getenv(var)] - - if missing_vars: - logger.error(f"❌ 缺少必要的环境变量: {', '.join(missing_vars)}") - logger.error("\n请设置以下环境变量:") - logger.error(" export OPENAI_API_KEY='your-api-key'") - logger.error(" export OPENAI_BASE_URL='your-base-url' # 可选") - return - - try: - # 创建 DeepSearch Agent - agent, components = create_deepsearch_from_api_components() - - # 演示使用 - demo_usage(agent, components) - - # 展示集成方法 - integrate_with_api_router() - - logger.info("\n" + "="*60) - logger.info("示例运行完成!") - logger.info("="*60) - - except Exception as e: - logger.error(f"\n❌ 运行失败: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() - diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py index e58f8e67c..3abe5d64d 100644 --- a/src/memos/configs/mem_agent.py +++ b/src/memos/configs/mem_agent.py @@ -9,9 +9,7 @@ class BaseAgentConfig(BaseConfig): """Base configuration class for agents.""" agent_name: str = Field(..., description="Name of the agent") - description: str | None = Field( - default=None, description="Description of the agent" - ) + description: str | None = Field(default=None, description="Description of the agent") class SimpleAgentConfig(BaseAgentConfig): @@ -20,19 +18,15 @@ class SimpleAgentConfig(BaseAgentConfig): max_iterations: int = Field( default=10, description="Maximum number of iterations for the agent" ) - timeout: int = Field( - default=30, description="Timeout in seconds for agent execution" - ) + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + class DeepSearchAgentConfig(BaseAgentConfig): """Deep search agent configuration class.""" - max_iterations: int = Field( - default=3, description="Maximum number of iterations for the agent" - ) - timeout: int = Field( - default=30, description="Timeout in seconds for agent execution" - ) + max_iterations: int = Field(default=3, description="Maximum number of iterations for the agent") + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + class MemAgentConfigFactory(BaseConfig): """Factory class for creating agent configurations.""" @@ -57,4 +51,4 @@ def validate_backend(cls, backend: str) -> str: def create_config(self) -> "AgentConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) - return self \ No newline at end of file + return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py index 3b54a846f..e4746cad6 100644 --- a/src/memos/mem_agent/base.py +++ b/src/memos/mem_agent/base.py @@ -1,12 +1,13 @@ -from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod -from pydantic import BaseModel + from memos.configs.mem_agent import BaseAgentConfig + class BaseMemAgent(ABC): """ Base class for all agents. """ + def __init__(self, config: BaseAgentConfig): """Initialize the BaseMemAgent with the given configuration.""" @@ -14,4 +15,4 @@ def __init__(self, config: BaseAgentConfig): def run(self, input: str) -> str: """ Run the agent. - """ \ No newline at end of file + """ diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 6dd072728..70212921a 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -6,21 +6,21 @@ """ import json -import logging -from memos.log import get_logger -from typing import Any, Dict, List, Optional, Tuple + +from typing import Optional, Any from memos.configs.mem_agent import DeepSearchAgentConfig from memos.llms.base import BaseLLM +from memos.log import get_logger from memos.mem_agent.base import BaseMemAgent from memos.memories.textual.item import TextualMemoryItem -from memos.types import MessageDict, MessageList from memos.templates.mem_agent_prompts import ( + FINAL_GENERATION_PROMPT, QUERY_REWRITE_PROMPT, REFLECTION_PROMPT, - KEYWORD_EXTRACTION_PROMPT, - FINAL_GENERATION_PROMPT - ) +) +from memos.types import MessageList + logger = get_logger(__name__) @@ -35,29 +35,28 @@ def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): self.llm = llm self.name = name - def run(self, query: str, history: list[str]| None = None) -> str: + def run(self, query: str, history: list[str] | None = None) -> str: """ Rewrite the query to be standalone and more searchable. - + Args: query: Original user query history: List of previous conversation messages - + Returns: Rewritten query string """ if history is None: history = [] - + history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages - + prompt = QUERY_REWRITE_PROMPT.format( - history=history_str if history_str else "No previous conversation", - query=query + history=history_str if history_str else "No previous conversation", query=query ) - + messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) logger.info(f"[{self.name}] Rewritten query: {response.strip()}") @@ -77,23 +76,24 @@ def __init__(self, llm: BaseLLM, name: str = "Reflector"): self.llm = llm self.name = name - def run(self, query: str, context: List[str]) -> Dict[str, Any]: + def run(self, query: str, context: list[str]) -> dict[str, Any]: """ Analyze whether retrieved context is sufficient to answer the query. - + Args: query: User query context: List of retrieved context strings - + Returns: Dictionary with status, reasoning, and missing entities """ - context_str = "\n".join([f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" - for ctx in context[:10]]) # Limit context size - + context_str = "\n".join( + [f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]] + ) # Limit context size + prompt = REFLECTION_PROMPT.format(query=query, context=context_str) messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) result = json.loads(response.strip()) @@ -105,27 +105,27 @@ def run(self, query: str, context: List[str]) -> Dict[str, Any]: return { "status": "sufficient", "reasoning": "Unable to analyze, proceeding with available information", - "missing_entities": [] + "missing_entities": [], } class DeepSearchMemAgent(BaseMemAgent): """ Main orchestrator agent implementing the deep search pipeline. - + This agent coordinates multiple sub-agents to perform iterative query refinement, memory retrieval, and information synthesis as shown in the architecture diagram. """ def __init__( - self, - llm: BaseLLM, - memory_retriever: Any = None, - config: Optional[DeepSearchAgentConfig] = None + self, + llm: BaseLLM, + memory_retriever: Optional[Any] = None, + config: DeepSearchAgentConfig | None = None, ): """ Initialize DeepSearchMemAgent. - + Args: llm: Language model for query rewriting and response generation memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) @@ -147,7 +147,7 @@ def _set_memory_retriever(self, retriever) -> None: def run(self, query: str, **kwargs) -> str: """ Main execution method implementing the deep search pipeline. - + Args: query: User query string **kwargs: Additional arguments (history, user_id, etc.) @@ -156,38 +156,36 @@ def run(self, query: str, **kwargs) -> str: """ if not self.llm: raise RuntimeError("LLM not initialized.") - + history = kwargs.get("history", []) user_id = kwargs.get("user_id") - + # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) - + accumulated_context = [] accumulated_memories = [] search_keywords = [] # Can be extended with keyword extraction - + # Step 2: Iterative Search and Reflection Loop for iteration in range(self.max_iterations): logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") - + search_results = self._perform_memory_search( - current_query, - keywords=search_keywords, - user_id=user_id + current_query, keywords=search_keywords, user_id=user_id ) - + if search_results: context_batch = [self._extract_context_from_memory(mem) for mem in search_results] accumulated_context.extend(context_batch) accumulated_memories.extend(search_results) - + reflection_result = self.reflector.run(current_query, context_batch) status = reflection_result.get("status", "sufficient") reasoning = reflection_result.get("reasoning", "") - + logger.info(f"Reflection status: {status} - {reasoning}") - + if status == "sufficient": logger.info("Sufficient information collected") break @@ -209,114 +207,111 @@ def run(self, query: str, **kwargs) -> str: current_query = query else: break - + # Step 3: Generate final answer final_answer = self._generate_final_answer( original_query=query, search_results=accumulated_memories, context=accumulated_context, - missing_info="" + missing_info="", ) - + logger.info("Deep search pipeline completed") return final_answer def _perform_memory_search( - self, - query: str, - keywords: List[str] = None, - user_id: str = None, - top_k: int = 10 - ) -> List[TextualMemoryItem]: + self, query: str, keywords: Optional[list[str]] = None, user_id: Optional[str] = None, top_k: int = 10 + ) -> list[TextualMemoryItem]: """ Perform memory search using the configured retriever. - + Args: query: Search query keywords: Additional keywords for search user_id: User identifier top_k: Number of results to retrieve - + Returns: List of retrieved memory items """ if not self.memory_retriever: logger.warning("Memory retriever not configured, returning empty results") return [] - + try: # Use the memory retriever interface # This is a placeholder - actual implementation depends on the retriever interface search_query = query if keywords and len(keywords) > 1: search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords - + # Assuming the retriever has a search method similar to TreeTextMemory results = self.memory_retriever.search( - query=search_query, - top_k=top_k, - mode="fast", - user_name=user_id + query=search_query, top_k=top_k, mode="fast", user_name=user_id ) - + return results if isinstance(results, list) else [] - + except Exception as e: logger.error(f"Error performing memory search: {e}") return [] def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: """Extract readable context from a memory item.""" - if hasattr(memory_item, 'memory'): + if hasattr(memory_item, "memory"): return str(memory_item.memory) - elif hasattr(memory_item, 'content'): + elif hasattr(memory_item, "content"): return str(memory_item.content) else: return str(memory_item) - def _refine_query_for_missing_info(self, query: str, missing_entities: List[str]) -> str: + def _refine_query_for_missing_info(self, query: str, missing_entities: list[str]) -> str: """Refine the query to search for missing information.""" if not missing_entities: return query - + # Simple refinement strategy - append missing entities entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities refined_query = f"{query} {entities_str}" - + return refined_query def _generate_final_answer( self, original_query: str, - search_results: List[TextualMemoryItem], - context: List[str], - missing_info: str = "" + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", ) -> str: """ Generate the final comprehensive answer. - + Args: original_query: Original user query search_results: All retrieved memory items context: Extracted context strings missing_info: Information about missing data - + Returns: Final answer string """ # Prepare context for the prompt context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context - sources = f"Retrieved {len(search_results)} memory items" if search_results else "No specific sources" - + sources = ( + f"Retrieved {len(search_results)} memory items" + if search_results + else "No specific sources" + ) + prompt = FINAL_GENERATION_PROMPT.format( query=original_query, sources=sources, context=context_str if context_str else "No specific context retrieved", - missing_info=missing_info if missing_info else "None identified" + missing_info=missing_info if missing_info else "None identified", ) - + messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) return response.strip() diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index 33e7b6a6b..415d4916e 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar +from typing import Optional, Any, ClassVar from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent @@ -14,19 +14,16 @@ class MemAgentFactory: @classmethod def from_config( - cls, - config_factory: MemAgentConfigFactory, - llm: Any, - memory_retriever: Any = None + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Optional[Any] = None ) -> BaseMemAgent: """ Create a MemAgent instance from configuration. - + Args: config_factory: Configuration factory for the agent llm: Language model instance memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) - + Returns: Initialized MemAgent instance """ @@ -35,7 +32,5 @@ def from_config( raise ValueError(f"Invalid backend: {backend}") mem_agent_class = cls.backend_to_class[backend] return mem_agent_class( - llm=llm, - memory_retriever=memory_retriever, - config=config_factory.config - ) \ No newline at end of file + llm=llm, memory_retriever=memory_retriever, config=config_factory.config + ) diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 399b2cd75..17e3c0ad9 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -73,4 +73,4 @@ 4. Provide a well-structured, coherent response 5. Use natural, conversational tone -Response:""" \ No newline at end of file +Response:""" From bd72e9bb804d84f1a12e4c817bae494f0359d600 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:33:53 +0800 Subject: [PATCH 06/21] feat: update deepsearch mem code --- examples/mem_agent/deepsearch_example.py | 3 --- src/memos/configs/mem_agent.py | 2 +- src/memos/mem_agent/base.py | 1 + src/memos/mem_agent/deepsearch_agent.py | 15 +++++++++++---- src/memos/mem_agent/factory.py | 4 ++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index dacfe7f1e..d0842b6ec 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -328,9 +328,6 @@ def main(): # Run Example 1: Factory pattern agent_factory, components_factory = example_1_factory_initialization() - # Run Example 2: Direct initialization - agent_direct, components_direct = example_2_direct_initialization() - # Run Example 4: Minimal initialization agent_minimal, components_minimal = example_4_minimal_initialization() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py index 3abe5d64d..7cb623899 100644 --- a/src/memos/configs/mem_agent.py +++ b/src/memos/configs/mem_agent.py @@ -48,7 +48,7 @@ def validate_backend(cls, backend: str) -> str: return backend @model_validator(mode="after") - def create_config(self) -> "AgentConfigFactory": + def create_config(self) -> "MemAgentConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py index e4746cad6..daa5f075b 100644 --- a/src/memos/mem_agent/base.py +++ b/src/memos/mem_agent/base.py @@ -10,6 +10,7 @@ class BaseMemAgent(ABC): def __init__(self, config: BaseAgentConfig): """Initialize the BaseMemAgent with the given configuration.""" + self.config = config @abstractmethod def run(self, input: str) -> str: diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 70212921a..6c6fc32a9 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -7,7 +7,7 @@ import json -from typing import Optional, Any +from typing import TYPE_CHECKING, Any from memos.configs.mem_agent import DeepSearchAgentConfig from memos.llms.base import BaseLLM @@ -19,7 +19,10 @@ QUERY_REWRITE_PROMPT, REFLECTION_PROMPT, ) -from memos.types import MessageList + + +if TYPE_CHECKING: + from memos.types import MessageList logger = get_logger(__name__) @@ -120,7 +123,7 @@ class DeepSearchMemAgent(BaseMemAgent): def __init__( self, llm: BaseLLM, - memory_retriever: Optional[Any] = None, + memory_retriever: Any | None = None, config: DeepSearchAgentConfig | None = None, ): """ @@ -220,7 +223,11 @@ def run(self, query: str, **kwargs) -> str: return final_answer def _perform_memory_search( - self, query: str, keywords: Optional[list[str]] = None, user_id: Optional[str] = None, top_k: int = 10 + self, + query: str, + keywords: list[str] | None = None, + user_id: str | None = None, + top_k: int = 10, ) -> list[TextualMemoryItem]: """ Perform memory search using the configured retriever. diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index 415d4916e..09537bd8a 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, ClassVar +from typing import Any, ClassVar from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent @@ -14,7 +14,7 @@ class MemAgentFactory: @classmethod def from_config( - cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Optional[Any] = None + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Any | None = None ) -> BaseMemAgent: """ Create a MemAgent instance from configuration. From f332ef22c3a2a29e2c1820de3916e45d70e4d33a Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 22 Nov 2025 18:13:04 +0800 Subject: [PATCH 07/21] feat: update deepsearch agent --- examples/mem_agent/deepsearch_example.py | 358 +++++++---------------- src/memos/mem_agent/deepsearch_agent.py | 200 ++++++++----- src/memos/templates/mem_agent_prompts.py | 5 +- 3 files changed, 227 insertions(+), 336 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index d0842b6ec..6a9405456 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -1,35 +1,36 @@ """ -DeepSearch Agent Usage Examples +DeepSearch Agent Usage Examples - Simplified Version -This example demonstrates two ways to initialize DeepSearchMemAgent: -1. Using Factory pattern (recommended) -2. Direct initialization - -DeepSearchMemAgent implements iterative deep search, providing comprehensive answers through: -- Query rewriting: Optimize queries based on conversation history -- Iterative retrieval: Collect information through multiple search rounds -- Reflective analysis: Determine if information is sufficient -- Comprehensive response: Generate complete final answers +This example demonstrates simplified initialization of DeepSearchMemAgent without +external config builders, using APIConfig methods directly. """ import os +from typing import Any + +from memos.api.config import APIConfig from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.graph_db import GraphConfigFactory -from memos.configs.llms import LLMConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_agent import MemAgentConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory -from memos.configs.textual_memory import TreeTextMemoryConfig +from memos.configs.reranker import RerankerConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.mem_agent.factory import MemAgentFactory -from memos.mem_cube.naive_cube import NaiveMemCube +from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.factory import MemReaderFactory -from memos.memories.memory_manager.memory_manager import MemoryManager -from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory logger = get_logger(__name__) @@ -37,114 +38,105 @@ def build_minimal_components(): """ - Build the minimal component set required for DeepSearchMemAgent. - - Only need to initialize: - 1. LLM - Used for query rewriting, reflection, and final answer generation - 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval + Build minimal components for DeepSearchMemAgent with simplified configuration. - Returns: - dict: Dictionary containing llm and naive_mem_cube + This function creates all necessary components using APIConfig methods, + similar to config_builders.py but inline for easier customization. """ - logger.info("Starting to build minimal component set...") + logger.info("Initializing simplified MemOS components...") - # 1. Initialize LLM - llm_config = LLMConfigFactory( - backend="openai", - config={ - "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), - "api_key": os.getenv("OPENAI_API_KEY"), - "base_url": os.getenv("OPENAI_BASE_URL"), - "temperature": 0.7, - }, + # Build component configurations using APIConfig methods (like config_builders.py) + + # Graph DB configuration - using APIConfig.get_nebular_config() + graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() + graph_db_backend_map = { + "polardb": APIConfig.get_polardb_config(), + } + graph_db_config = GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } ) - llm = LLMFactory.from_config(llm_config) - logger.info("LLM initialization completed") - # 2. Initialize Embedder (required by text_mem) - embedder_config = EmbedderConfigFactory( - backend="universal_api", - config={ - "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), - "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), - "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), - }, + # LLM configuration - using APIConfig.get_openai_config() + llm_config = LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } ) - embedder = EmbedderFactory.from_config(embedder_config) - logger.info("Embedder initialization completed") - # 3. Initialize GraphDB (required by text_mem) - graph_db_config = GraphConfigFactory( - backend="polardb", - config={ - "host": os.getenv("POLAR_DB_HOST", "localhost"), - "port": int(os.getenv("POLAR_DB_PORT", "5432")), - "user": os.getenv("POLAR_DB_USER", "root"), - "password": os.getenv("POLAR_DB_PASSWORD", "123456"), - "db_name": os.getenv("POLAR_DB_DB_NAME", "shared_memos_db"), - "user_name": "memos_default", - "use_multi_db": False, - "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - }, + # Embedder configuration - using APIConfig.get_embedder_config() + embedder_config = EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + # Memory reader configuration - using APIConfig.get_product_default_config() + mem_reader_config = MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] ) - graph_db = GraphStoreFactory.from_config(graph_db_config) - logger.info("✓ GraphDB init Done") - # 4. Initialize MemReader (required by text_mem) - mem_reader_config = MemReaderConfigFactory( - backend="simple_struct", - config={ - "llm": { - "backend": "openai", - "config": { - "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), - "api_key": os.getenv("OPENAI_API_KEY"), - "base_url": os.getenv("OPENAI_BASE_URL"), - }, - } - }, + # Reranker configuration - using APIConfig.get_reranker_config() + reranker_config = RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + # Internet retriever configuration - using APIConfig.get_internet_config() + internet_retriever_config = InternetRetrieverConfigFactory.model_validate( + APIConfig.get_internet_config() ) - mem_reader = MemReaderFactory.from_config(mem_reader_config) - logger.info("✓ MemReader init Done") - # 5. Initialize MemoryManager - memory_manager = MemoryManager( - graph_db=graph_db, - embedder=embedder, - llm=llm, - memory_size=1000, - is_reorganize=False, + logger.debug("Component configurations built successfully") + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder ) - logger.info("✓ MemoryManager init Done") - # 6. Initialize TextMemory - text_mem_config = TreeTextMemoryConfig( - reorganize=False, - max_depth=3, + logger.debug("Core components instantiated") + + # Get default cube configuration like component_init.py + default_cube_config = APIConfig.get_default_cube_config() + + # Get default memory size from cube config (like component_init.py) + def get_memory_size_from_config(cube_config): + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + memory_size = get_memory_size_from_config(default_cube_config) + is_reorganize = getattr(default_cube_config.text_mem.config, "reorganize", False) + + # Initialize memory manager with config from APIConfig + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=memory_size, + is_reorganize=is_reorganize, ) + text_memory_config = default_cube_config.text_mem.config text_mem = SimpleTreeTextMemory( llm=llm, embedder=embedder, mem_reader=mem_reader, graph_db=graph_db, - reranker=None, + reranker=reranker, memory_manager=memory_manager, - config=text_mem_config, - internet_retriever=None, + config=text_memory_config, + internet_retriever=internet_retriever, ) - logger.info("✓ TextMemory initialization completed") - # 7. Create NaiveMemCube naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=None, + pref_mem=None, # Simplified: no preference memory act_mem=None, para_mem=None, ) - logger.info("✓ NaiveMemCube creation completed") - - logger.info("All components initialized!") return { "llm": llm, @@ -155,30 +147,18 @@ def build_minimal_components(): } -def example_1_factory_initialization(): - """ - Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) - - Advantages: - - Separation of configuration and code - - Easy to manage and modify - - Support loading from configuration files - """ - logger.info("\n" + "=" * 60) - logger.info("Example 1: Initialize using Factory pattern") - logger.info("=" * 60 + "\n") - - # Build necessary components +def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]: + # Build necessary components with simplified setup components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - # Create configuration Factory + # Create configuration Factory with simplified config agent_config_factory = MemAgentConfigFactory( backend="deep_search", config={ - "agent_name": "MyDeepSearchAgent", - "description": "Intelligent agent for deep search", + "agent_name": "SimplifiedDeepSearchAgent", + "description": "Simplified intelligent agent for deep search", "max_iterations": 3, # Maximum number of iterations "timeout": 60, # Timeout in seconds }, @@ -198,151 +178,13 @@ def example_1_factory_initialization(): return deep_search_agent, components -def example_3_usage(deep_search_agent, components): - """ - Example 3: Using DeepSearchMemAgent for search - - Demonstrates how to: - 1. Add memories to the system - 2. Use Agent for deep search - 3. Get comprehensive answers - """ - logger.info("\n" + "=" * 60) - logger.info("Example 3: Using DeepSearchMemAgent") - logger.info("=" * 60 + "\n") - - naive_mem_cube = components["naive_mem_cube"] - text_mem = naive_mem_cube.text_mem - - # Simulate adding some memories - logger.info("1. Adding test memories...") - test_memories = [ - { - "user_name": "test_user", - "messages": [ - "Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence." - ], - "source": "manual", - }, - { - "user_name": "test_user", - "messages": [ - "Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance." - ], - "source": "manual", - }, - { - "user_name": "test_user", - "messages": [ - "Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks." - ], - "source": "manual", - }, - ] - - for memory in test_memories: - try: - text_mem.add( - user_name=memory["user_name"], - messages=memory["messages"], - source=memory.get("source", "manual"), - ) - logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") - except Exception as e: - logger.warning(f" ✗ Failed to add memory: {e}") - - # Use Agent for search - logger.info("\n2. Executing deep search...") - query = "What is the relationship between artificial intelligence and machine learning?" - logger.info(f" Query: {query}") - - try: - response = deep_search_agent.run( - query=query, - user_id="test_user", - history=["Hello", "I want to learn about artificial intelligence"], - ) - - logger.info("\n3. Search results:") - logger.info("-" * 60) - logger.info(response) - logger.info("-" * 60) - - except Exception as e: - logger.error(f"Error during search: {e}") - import traceback - - traceback.print_exc() - - -def example_4_minimal_initialization(): - """ - Example 4: Minimal initialization (using default configuration) - - Use cases: - - Quick testing - - Prototype development - - No need for custom configuration - """ - logger.info("\n" + "=" * 60) - logger.info("Example 4: Minimal initialization (default configuration)") - logger.info("=" * 60 + "\n") - - # Build necessary components - components = build_minimal_components() - llm = components["llm"] - naive_mem_cube = components["naive_mem_cube"] - - # Direct initialization using default configuration - deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - # config parameter omitted, will use default configuration - ) - - logger.info("✓ DeepSearchMemAgent created successfully (using default configuration)") - logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") - logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - - return deep_search_agent, components - - def main(): - """Main function: Run all examples""" - logger.info("DeepSearch Agent Usage Examples") - logger.info("=" * 60) - - # Check environment variables - required_env_vars = ["OPENAI_API_KEY"] - missing_vars = [var for var in required_env_vars if not os.getenv(var)] - - if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") - logger.error("Please set the following environment variables:") - logger.error(" - OPENAI_API_KEY: OpenAI API key") - logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") - logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") - return - - try: - # Run Example 1: Factory pattern - agent_factory, components_factory = example_1_factory_initialization() - - # Run Example 4: Minimal initialization - agent_minimal, components_minimal = example_4_minimal_initialization() - - # Run Example 3: Actual usage (using agent created with factory method) - example_3_usage(agent_factory, components_factory) - - logger.info("\n" + "=" * 60) - logger.info("All examples completed!") - logger.info("=" * 60) - - except Exception as e: - logger.error(f"Error running examples: {e}") - import traceback - - traceback.print_exc() + agent_factory, components_factory = factory_initialization() + results = agent_factory.run( + "Caroline met up with friends, family, and mentors in early July 2023.", + user_id="locomo_exp_user_0_speaker_b_ct-1118", + ) + print(results) if __name__ == "__main__": diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 6c6fc32a9..5a070c6ad 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -6,6 +6,7 @@ """ import json +import re from typing import TYPE_CHECKING, Any @@ -14,6 +15,7 @@ from memos.log import get_logger from memos.mem_agent.base import BaseMemAgent from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory from memos.templates.mem_agent_prompts import ( FINAL_GENERATION_PROMPT, QUERY_REWRITE_PROMPT, @@ -25,91 +27,96 @@ from memos.types import MessageList +class JSONResponseParser: + """Elegant JSON response parser for LLM outputs""" + + @staticmethod + def parse(response: str) -> dict[str, Any]: + """Parse JSON response from LLM output with fallback strategies""" + # Clean response text by removing code block markers + cleaned = re.sub(r"^```(?:json)?\s*\n?|```\s*$", "", response.strip(), flags=re.IGNORECASE) + + # Try parsing with multiple strategies + for text in [cleaned, re.search(r"\{.*\}", cleaned, re.DOTALL)]: + if not text: + continue + try: + return json.loads(text if isinstance(text, str) else text.group()) + except json.JSONDecodeError: + continue + + raise ValueError(f"Cannot parse JSON response: {response[:100]}...") + + logger = get_logger(__name__) class QueryRewriter(BaseMemAgent): - """ - Specialized agent for rewriting queries based on conversation history. - Corresponds to the "LLM subAgent (Rewrite...)" in the architecture diagram. - """ + """Specialized agent for rewriting queries based on conversation history""" def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): self.llm = llm self.name = name def run(self, query: str, history: list[str] | None = None) -> str: - """ - Rewrite the query to be standalone and more searchable. - - Args: - query: Original user query - history: List of previous conversation messages - - Returns: - Rewritten query string - """ - if history is None: - history = [] - - history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages - - prompt = QUERY_REWRITE_PROMPT.format( - history=history_str if history_str else "No previous conversation", query=query - ) - - messages: MessageList = [{"role": "user", "content": prompt}] + """Rewrite query to be standalone and more searchable""" + history = history or [] + history_context = self._format_history(history) + prompt = QUERY_REWRITE_PROMPT.format(history=history_context, query=query) + messages = [{"role": "user", "content": prompt}] try: response = self.llm.generate(messages) logger.info(f"[{self.name}] Rewritten query: {response.strip()}") return response.strip() except Exception as e: - logger.error(f"[{self.name}] Error rewriting query: {e}") - return query # Fallback to original query + logger.error(f"[{self.name}] Query rewrite failed: {e}") + return query + + def _format_history(self, history: list[str]) -> str: + """Format conversation history for prompt context""" + if not history: + return "No previous conversation" + return "\n".join(f"- {msg}" for msg in history[-5:]) class ReflectionAgent: - """ - Specialized agent for analyzing information sufficiency. - Corresponds to the decision diamond in the architecture diagram. - """ + """Specialized agent for analyzing information sufficiency""" def __init__(self, llm: BaseLLM, name: str = "Reflector"): self.llm = llm self.name = name def run(self, query: str, context: list[str]) -> dict[str, Any]: - """ - Analyze whether retrieved context is sufficient to answer the query. + """Analyze whether retrieved context is sufficient to answer the query""" + context_summary = self._format_context(context) + prompt = REFLECTION_PROMPT.format(query=query, context=context_summary) - Args: - query: User query - context: List of retrieved context strings + try: + response = self.llm.generate([{"role": "user", "content": prompt}]) + logger.info(f"[{self.name}] Reflection response: {response}") - Returns: - Dictionary with status, reasoning, and missing entities - """ - context_str = "\n".join( - [f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]] - ) # Limit context size + result = JSONResponseParser.parse(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result}") + return result - prompt = REFLECTION_PROMPT.format(query=query, context=context_str) - messages: MessageList = [{"role": "user", "content": prompt}] + except Exception as e: + logger.error(f"[{self.name}] Reflection analysis failed: {e}") + return self._fallback_response() - try: - response = self.llm.generate(messages) - result = json.loads(response.strip()) - logger.info(f"[{self.name}] Reflection result: {result.get('status', 'unknown')}") - return result - except (json.JSONDecodeError, Exception) as e: - logger.error(f"[{self.name}] Error in reflection analysis: {e}") - # Fallback response - return { - "status": "sufficient", - "reasoning": "Unable to analyze, proceeding with available information", - "missing_entities": [], - } + def _format_context(self, context: list[str]) -> str: + """Format context strings for analysis with length limits""" + return "\n".join( + f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10] + ) + + def _fallback_response(self) -> dict[str, Any]: + """Return safe fallback when reflection fails""" + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [], + } class DeepSearchMemAgent(BaseMemAgent): @@ -123,7 +130,7 @@ class DeepSearchMemAgent(BaseMemAgent): def __init__( self, llm: BaseLLM, - memory_retriever: Any | None = None, + memory_retriever: TreeTextMemory | None = None, config: DeepSearchAgentConfig | None = None, ): """ @@ -142,12 +149,7 @@ def __init__( self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") self.memory_retriever = memory_retriever - def _set_memory_retriever(self, retriever) -> None: - """Set the memory retrieval interface.""" - self.memory_retriever = retriever - logger.info("Memory retriever interface set") - - def run(self, query: str, **kwargs) -> str: + def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: """ Main execution method implementing the deep search pipeline. @@ -162,6 +164,7 @@ def run(self, query: str, **kwargs) -> str: history = kwargs.get("history", []) user_id = kwargs.get("user_id") + generated_answer = kwargs.get("generated_answer") # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) @@ -175,7 +178,7 @@ def run(self, query: str, **kwargs) -> str: logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") search_results = self._perform_memory_search( - current_query, keywords=search_keywords, user_id=user_id + current_query, keywords=search_keywords, user_id=user_id, history=history ) if search_results: @@ -198,7 +201,8 @@ def run(self, query: str, **kwargs) -> str: elif status == "missing_info": missing_entities = reflection_result.get("missing_entities", []) logger.info(f"Missing information: {missing_entities}") - if missing_entities: + current_query = reflection_result.get("new_search_query") + if not current_query: refined_query = self._refine_query_for_missing_info( current_query, missing_entities ) @@ -211,22 +215,63 @@ def run(self, query: str, **kwargs) -> str: else: break - # Step 3: Generate final answer - final_answer = self._generate_final_answer( - original_query=query, - search_results=accumulated_memories, - context=accumulated_context, - missing_info="", - ) + if not generated_answer: + return self._remove_duplicate_memories(accumulated_memories) + else: + return self._generate_final_answer( + query, accumulated_memories, accumulated_context, "", history + ) + + def _remove_duplicate_memories( + self, memories: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Remove duplicate memories based on memory content. + + Args: + memories: List of TextualMemoryItem objects to deduplicate - logger.info("Deep search pipeline completed") - return final_answer + Returns: + List of unique TextualMemoryItem objects (first occurrence kept) + """ + seen = set() + return [ + memory + for memory in memories + if (content := getattr(memory, "memory", "").strip()) + and content not in seen + and not seen.add(content) + ] + + def _generate_final_answer( + self, + original_query: str, + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", + history: list[str] | None = None, + sources: list[str] | None = None, + ) -> str: + """ + Generate the final answer. + """ + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified", + ) + messages: MessageList = [{"role": "user", "content": prompt}] + response = self.llm.generate(messages) + return response.strip() def _perform_memory_search( self, query: str, keywords: list[str] | None = None, user_id: str | None = None, + history: list[str] | None = None, top_k: int = 10, ) -> list[TextualMemoryItem]: """ @@ -254,7 +299,11 @@ def _perform_memory_search( # Assuming the retriever has a search method similar to TreeTextMemory results = self.memory_retriever.search( - query=search_query, top_k=top_k, mode="fast", user_name=user_id + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id, + info={"history": history}, ) return results if isinstance(results, list) else [] @@ -316,7 +365,6 @@ def _generate_final_answer( context=context_str if context_str else "No specific context retrieved", missing_info=missing_info if missing_info else "None identified", ) - messages: MessageList = [{"role": "user", "content": prompt}] try: diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 17e3c0ad9..477cd2409 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -25,14 +25,15 @@ {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", - "missing_entities": ["entity1", "entity2"] // Only if status is "missing_info", - + "missing_entities": ["entity1", "entity2"], + "new_search_query": "new search query", }} Status definitions: - "sufficient": Context fully answers the query - "missing_info": Key information is missing (e.g., specific dates, locations, details) - "needs_raw": Content is relevant but too summarized/vague, need original sources +- "new_search_query": New search query to retrieve more information Response:""" From c21fc586a723477f3f2a2465e895228b4c1b8767 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 22 Nov 2025 18:29:39 +0800 Subject: [PATCH 08/21] feat: update test code --- tests/configs/test_embedder.py | 4 ++-- tests/embedders/test_universal_api.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/configs/test_embedder.py b/tests/configs/test_embedder.py index 8201f9bd8..10572f33e 100644 --- a/tests/configs/test_embedder.py +++ b/tests/configs/test_embedder.py @@ -17,7 +17,7 @@ def test_base_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims"], + optional_fields=["embedding_dims", "headers_extra"], ) check_config_instantiation_valid( @@ -36,7 +36,7 @@ def test_ollama_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims", "api_base"], + optional_fields=["embedding_dims", "headers_extra", "api_base"], ) check_config_instantiation_valid( diff --git a/tests/embedders/test_universal_api.py b/tests/embedders/test_universal_api.py index e4ebb7019..1419b2849 100644 --- a/tests/embedders/test_universal_api.py +++ b/tests/embedders/test_universal_api.py @@ -30,6 +30,7 @@ def test_embed_single_text(self, mock_openai_client): mock_openai_client.assert_called_once_with( api_key="fake-api-key", base_url="https://api.openai.com/v1", + default_headers=None, ) # Assert embeddings.create called with correct params From 0f62af8173c8c1614e582c77b7ce572f44415e75 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 24 Nov 2025 14:30:28 +0800 Subject: [PATCH 09/21] fix: remove dup config --- src/memos/api/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d738c878a..c62cd3b08 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -409,7 +409,6 @@ def get_embedder_config() -> dict[str, Any]: "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), - "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), }, } else: # ollama From dac3394cf5f79e4b08481b661645786a6d73e716 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 10:55:16 +0800 Subject: [PATCH 10/21] feat: dock search pipeline --- src/memos/api/handlers/base_handler.py | 5 +++++ src/memos/api/handlers/component_init.py | 6 ++++++ src/memos/api/handlers/search_handler.py | 10 +++++++--- src/memos/mem_agent/deepsearch_agent.py | 7 +++---- src/memos/multi_mem_cube/single_cube.py | 8 ++++++-- src/memos/templates/mem_agent_prompts.py | 4 +++- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index a686ac8f9..7a47f05e3 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -161,6 +161,11 @@ def mos_server(self): """Get MOS server instance.""" return self.deps.mos_server + @property + def deepsearch_agent(self): + """Get deepsearch agent instance.""" + return self.deps.deepsearch_agent + def _validate_dependencies(self, *required_deps: str) -> None: """ Validate that required dependencies are available. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 3ef1d529d..7b34fcfae 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -307,6 +308,10 @@ def init_server() -> dict[str, Any]: online_bot = get_online_bot_function() if dingding_enabled else None logger.info("DingDing bot is enabled") + deepsearch_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=tree_mem, + ) # Return all components as a dictionary for easy access and extension return { "graph_db": graph_db, @@ -330,4 +335,5 @@ def init_server() -> dict[str, Any]: "text_mem": text_mem, "pref_mem": pref_mem, "online_bot": online_bot, + "deepsearch_agent": deepsearch_agent, } diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 8a2c21aad..b454568f8 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -31,7 +31,9 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher") + self._validate_dependencies( + "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" + ) def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -52,10 +54,10 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse results = cube_view.search_memories(search_req) - self.logger.info(f"[AddHandler] Final add results count={len(results)}") + self.logger.info(f"[SearchHandler] Final search results count={len(results)}") return SearchResponse( - message="Memory searched successfully", + message="Search completed successfully", data=results, ) @@ -87,6 +89,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) else: single_views = [ @@ -97,6 +100,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) for cube_id in cube_ids ] diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5a070c6ad..16b600980 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from memos.types import MessageList +logger = get_logger(__name__) + class JSONResponseParser: """Elegant JSON response parser for LLM outputs""" @@ -48,9 +50,6 @@ def parse(response: str) -> dict[str, Any]: raise ValueError(f"Cannot parse JSON response: {response[:100]}...") -logger = get_logger(__name__) - - class QueryRewriter(BaseMemAgent): """Specialized agent for rewriting queries based on conversation history""" @@ -141,7 +140,7 @@ def __init__( memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) config: Configuration for deep search behavior """ - self.config = config or DeepSearchAgentConfig() + self.config = config or DeepSearchAgentConfig(agent_name="DeepSearchMemAgent") self.max_iterations = self.config.max_iterations self.timeout = self.config.timeout self.llm: BaseLLM = llm diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2055615d2..d2fde36a3 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,6 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any + deepsearch_agent: Any def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ @@ -247,8 +248,11 @@ def _fast_search( def _deep_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: - logger.error("waiting to be implemented") - return [] + deepsearch_results = self.deepsearch_agent.run( + search_req.query, user_id=user_context.mem_cube_id + ) + formatted_memories = [format_memory_item(data) for data in deepsearch_results] + return formatted_memories def _fine_search( self, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 477cd2409..eb624ef89 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -22,12 +22,14 @@ {context} Analyze the context and determine the next step. Return your response in JSON format with the following structure: -{{ + ```json + {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", "missing_entities": ["entity1", "entity2"], "new_search_query": "new search query", }} +``` Status definitions: - "sufficient": Context fully answers the query From 9489d54b0a71ba8db05b7d13158fcaad6135d50f Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 11:17:52 +0800 Subject: [PATCH 11/21] fix: code test --- tests/api/test_server_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 853a271f6..2aa96257b 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -48,6 +48,7 @@ def mock_init_server(): "pref_mem": None, "online_bot": None, "chat_llms": Mock(), + "deepsearch_agent": Mock(), } with patch("memos.api.handlers.init_server", return_value=mock_components): From e43e5db62bf9fab56ed432a6657412a868e58763 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 14:59:14 +0800 Subject: [PATCH 12/21] feat: add test scripts --- src/memos/mem_agent/deepsearch_agent.py | 4 +- tests/mem_agent/test_deepsearch_agent.py | 233 +++++++++++++++++++++++ 2 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 tests/mem_agent/test_deepsearch_agent.py diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 16b600980..5e51aec44 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -218,7 +218,7 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: return self._remove_duplicate_memories(accumulated_memories) else: return self._generate_final_answer( - query, accumulated_memories, accumulated_context, "", history + query, accumulated_memories, accumulated_context, history ) def _remove_duplicate_memories( @@ -247,9 +247,9 @@ def _generate_final_answer( original_query: str, search_results: list[TextualMemoryItem], context: list[str], - missing_info: str = "", history: list[str] | None = None, sources: list[str] | None = None, + missing_info: str | None = None, ) -> str: """ Generate the final answer. diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py new file mode 100644 index 000000000..707d80283 --- /dev/null +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -0,0 +1,233 @@ +"""Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" + +import json +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import ( + DeepSearchMemAgent, + JSONResponseParser, +) +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata + + +class TestJSONResponseParser: + """Test JSONResponseParser class.""" + + def test_parse_clean_json(self): + """Test parsing clean JSON response.""" + response = '{"status": "sufficient", "reasoning": "test"}' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_json_with_code_blocks(self): + """Test parsing JSON wrapped in code blocks.""" + response = '```json\n{"status": "sufficient", "reasoning": "test"}\n```' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_invalid_json_raises_error(self): + """Test that invalid JSON raises ValueError.""" + with pytest.raises(ValueError, match="Cannot parse JSON response"): + JSONResponseParser.parse("This is not JSON at all") + + +class TestDeepSearchMemAgent: + """Test DeepSearchMemAgent core functionality.""" + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + mock = MagicMock() + mock.generate.return_value = "Generated answer" + return mock + + @pytest.fixture + def mock_memory_retriever(self): + """Create a mock memory retriever.""" + mock = MagicMock() + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a programming language", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python was created by Guido van Rossum", + metadata=TextualMemoryMetadata(type="fact"), + ), + ] + mock.search.return_value = memory_items + return mock + + @pytest.fixture + def config(self): + """Create DeepSearchAgentConfig.""" + return DeepSearchAgentConfig(agent_name="TestDeepSearch", max_iterations=3, timeout=30) + + @pytest.fixture + def agent(self, mock_llm, mock_memory_retriever, config): + """Create DeepSearchMemAgent instance.""" + agent = DeepSearchMemAgent( + llm=mock_llm, memory_retriever=mock_memory_retriever, config=config + ) + # Mock the sub-agents to avoid complex interactions + agent.query_rewriter.run = MagicMock(return_value="Rewritten query") + agent.reflector.run = MagicMock( + return_value={ + "status": "sufficient", + "reasoning": "Enough info", + "missing_entities": [], + } + ) + return agent + + def test_init_with_config(self, mock_llm, mock_memory_retriever, config): + """Test DeepSearchMemAgent initialization with config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + assert agent.llm == mock_llm + assert agent.memory_retriever == mock_memory_retriever + assert agent.config == config + assert agent.max_iterations == 3 + assert agent.timeout == 30 + + def test_init_without_config(self, mock_llm, mock_memory_retriever): + """Test DeepSearchMemAgent initialization without config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever) + assert isinstance(agent.config, DeepSearchAgentConfig) + assert agent.config.agent_name == "DeepSearchMemAgent" + + def test_run_no_llm_raises_error(self, config): + """Test that running without LLM raises RuntimeError.""" + agent = DeepSearchMemAgent(llm=None, config=config) + with pytest.raises(RuntimeError, match="LLM not initialized"): + agent.run("test query") + + def test_run_returns_memories_when_no_generated_answer(self, agent, mock_memory_retriever): + """Test run returns memories when generated_answer is not requested.""" + result = agent.run("What is Python?", generated_answer=False) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, TextualMemoryItem) for item in result) + agent.query_rewriter.run.assert_called_once() + + def test_run_returns_answer_when_generated_answer(self, agent, mock_llm): + """Test run returns generated answer when requested.""" + result = agent.run("What is Python?", generated_answer=True) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_run_with_user_id(self, agent, mock_memory_retriever): + """Test run with user_id.""" + agent.run("What is Python?", user_id="user123", generated_answer=False) + + # Check that user_id was passed to search + call_kwargs = mock_memory_retriever.search.call_args[1] + assert call_kwargs.get("user_name") == "user123" + + def test_run_no_search_results(self, agent, mock_memory_retriever): + """Test behavior when search returns no results.""" + mock_memory_retriever.search.return_value = [] + + result = agent.run("What is Python?", generated_answer=False) + + assert result == [] + + def test_remove_duplicate_memories(self, agent): + """Test removing duplicate memories.""" + mem_id1 = str(uuid.uuid4()) + mem_id2 = str(uuid.uuid4()) + mem_id3 = str(uuid.uuid4()) + + memories = [ + TextualMemoryItem( + id=mem_id1, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + TextualMemoryItem( + id=mem_id2, + memory="Different content", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=mem_id3, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + ] + + result = agent._remove_duplicate_memories(memories) + + assert len(result) == 2 + assert result[0].id == mem_id1 + assert result[1].id == mem_id2 + + def test_generate_final_answer(self, agent, mock_llm): + """Test final answer generation.""" + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a language", + metadata=TextualMemoryMetadata(type="fact"), + ) + ] + context = ["Python is a programming language"] + + result = agent._generate_final_answer("What is Python?", memory_items, context) + + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_generate_final_answer_with_missing_info(self, agent, mock_llm): + """Test final answer generation with missing info.""" + result = agent._generate_final_answer( + "What is Python?", [], [], missing_info="Version details not found" + ) + + assert result == "Generated answer" + call_args = mock_llm.generate.call_args[0][0] + assert "Version details not found" in call_args[0]["content"] + + def test_generate_final_answer_llm_error(self, agent, mock_llm): + """Test final answer generation handles LLM errors.""" + mock_llm.generate.side_effect = Exception("LLM error") + + result = agent._generate_final_answer("What is Python?", [], []) + + assert "error" in result.lower() + assert "What is Python?" in result + + def test_perform_memory_search_no_retriever(self, mock_llm, config): + """Test memory search when retriever is not configured.""" + agent = DeepSearchMemAgent(mock_llm, memory_retriever=None, config=config) + result = agent._perform_memory_search("test query") + + assert result == [] + + def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config): + """Test full pipeline integration.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + + with patch.object(agent.query_rewriter, "run", return_value="Rewritten query"): + with patch.object( + agent.reflector, + "run", + return_value={ + "status": "sufficient", + "reasoning": "Info is sufficient", + "missing_entities": [], + }, + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called() From ecd45084935131bb7fb02ac09542c91ebae9e773 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 14:59:52 +0800 Subject: [PATCH 13/21] feat: add test --- tests/mem_agent/test_deepsearch_agent.py | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py index 707d80283..a80dd10ea 100644 --- a/tests/mem_agent/test_deepsearch_agent.py +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -1,6 +1,5 @@ """Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" -import json import uuid from unittest.mock import MagicMock, patch @@ -213,8 +212,9 @@ def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config """Test full pipeline integration.""" agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) - with patch.object(agent.query_rewriter, "run", return_value="Rewritten query"): - with patch.object( + with ( + patch.object(agent.query_rewriter, "run", return_value="Rewritten query"), + patch.object( agent.reflector, "run", return_value={ @@ -222,12 +222,13 @@ def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config "reasoning": "Info is sufficient", "missing_entities": [], }, - ): - result = agent.run( - "What is Python?", user_id="user123", history=[], generated_answer=True - ) - - assert isinstance(result, str) - assert result == "Generated answer" - mock_memory_retriever.search.assert_called() - mock_llm.generate.assert_called() + ), + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called() From 6e21032fc87464f4569f99c26e4309247b533d74 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 19:06:23 +0800 Subject: [PATCH 14/21] feat: update need_raw process --- src/memos/mem_agent/deepsearch_agent.py | 21 +++++++++++++++++++-- src/memos/templates/mem_agent_prompts.py | 8 +++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5e51aec44..051ac03d3 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -183,8 +183,6 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: if search_results: context_batch = [self._extract_context_from_memory(mem) for mem in search_results] accumulated_context.extend(context_batch) - accumulated_memories.extend(search_results) - reflection_result = self.reflector.run(current_query, context_batch) status = reflection_result.get("status", "sufficient") reasoning = reflection_result.get("reasoning", "") @@ -193,11 +191,14 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: if status == "sufficient": logger.info("Sufficient information collected") + accumulated_memories.extend(search_results) break elif status == "needs_raw": logger.info("Need original sources, retrieving raw content") + accumulated_memories.extend(self._set_source_from_memory(search_results)) break elif status == "missing_info": + accumulated_memories.extend(search_results) missing_entities = reflection_result.get("missing_entities", []) logger.info(f"Missing information: {missing_entities}") current_query = reflection_result.get("new_search_query") @@ -331,6 +332,22 @@ def _refine_query_for_missing_info(self, query: str, missing_entities: list[str] return refined_query + def _set_source_from_memory( + self, memory_items: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """set source from memory item""" + for memory_item in memory_items: + if not hasattr(memory_item.metadata, "sources"): + continue + chat_sources = [ + f"{source.chat_time} {source.role}: {source.content}" + for source in memory_item.metadata.sources + if hasattr(source, "type") and source.type == "chat" + ] + if chat_sources: + memory_item.memory = "\n".join(chat_sources) + "\n" + return memory_items + def _generate_final_answer( self, original_query: str, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index eb624ef89..d7163e4a8 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -35,7 +35,13 @@ - "sufficient": Context fully answers the query - "missing_info": Key information is missing (e.g., specific dates, locations, details) - "needs_raw": Content is relevant but too summarized/vague, need original sources -- "new_search_query": New search query to retrieve more information + +IMPORTANT for "new_search_query": +- MUST preserve ALL specific entities from the original query (names, dates, times, locations, etc.) +- DO NOT replace specific information with generic terms like "user", "person", "they", etc. +- Keep the exact same subjects, time references, and key details as in the original query +- Only modify the query to focus on the missing information while maintaining all original specifics +- Example: If original query mentions "May 2024", keep "May 2024" in new query, don't change to "that month" Response:""" From 592f637bc3090dd682013f3b74f290de72450050 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 20:02:22 +0800 Subject: [PATCH 15/21] fix: add initter --- src/memos/multi_mem_cube/single_cube.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d2fde36a3..dbc527bb7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,7 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any - deepsearch_agent: Any + deepsearch_agent: Any | None = None def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ From ad997459ffa8d8a1fc30442b8d47f50d338da9c7 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 27 Nov 2025 16:24:58 +0800 Subject: [PATCH 16/21] fix: change agent search func name --- src/memos/multi_mem_cube/single_cube.py | 4 +++- src/memos/types/general_types.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8f4a25a0b..8e37cb92d 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -203,7 +203,7 @@ def _deep_search( formatted_memories = [format_memory_item(data) for data in enhanced_memories] return formatted_memories - def _deep_search( + def _agentic_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: deepsearch_results = self.deepsearch_agent.run( @@ -229,6 +229,8 @@ def _fine_search( """ if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) + elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: + return self._agentic_search(search_req=search_req, user_context=user_context) target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 2b7206c74..f796e682a 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -102,6 +102,7 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" DEEP_SEARCH = "deep_search" + AGENTIC_SEARCH = "agentic_search" # algorithm strategies From 94dba83f4655e45ca671e57eb19731e85f4fe2c6 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 28 Nov 2025 11:10:02 +0800 Subject: [PATCH 17/21] feat: update logs and defined --- src/memos/multi_mem_cube/single_cube.py | 12 ++++++++---- src/memos/reranker/factory.py | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8e37cb92d..e60e8e71e 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -35,14 +35,17 @@ if TYPE_CHECKING: from memos.api.product_models import APIADDRequest, APISearchRequest + from memos.mem_cube.navie import NaiveMemCube + from memos.mem_reader.simple_struct import SimpleStructMemReader + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler @dataclass class SingleCubeView(MemCubeView): cube_id: str - naive_mem_cube: Any - mem_reader: Any - mem_scheduler: Any + naive_mem_cube: NaiveMemCube + mem_reader: SimpleStructMemReader + mem_scheduler: OptimizedScheduler logger: Any searcher: Any deepsearch_agent: Any | None = None @@ -155,7 +158,7 @@ def _search_text( Args: search_req: Search request user_context: User context - search_mode: Search mode (FAST, FINE, or MIXTURE) + search_mode: Search mode (fast, fine, or mixture) Returns: List of formatted memory items @@ -227,6 +230,7 @@ def _fine_search( Returns: List of enhanced search results """ + logger.info(f"Fine strategy: {FINE_STRATEGY}") if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index d2c50ba5e..1440704a6 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -2,6 +2,7 @@ from __future__ import annotations import json + from typing import TYPE_CHECKING, Any # Import singleton decorator From 953872e46cd9ef4da5fcfadb3cbb97bbb73d49a5 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 1 Dec 2025 15:10:41 +0800 Subject: [PATCH 18/21] feat: update full text mem search --- src/memos/graph_dbs/polardb.py | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index bfde8c80c..7207f05da 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1450,6 +1450,131 @@ def get_subgraph( def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + + @timed + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + """ + Full-text search functionality using PostgreSQL's full-text search capabilities. + + Args: + query_text: query text + top_k: maximum number of results to return + scope: memory type filter (memory_type) + status: status filter, defaults to "activated" + threshold: similarity threshold filter + search_filter: additional property filter conditions + user_name: username filter + knowledgebase_ids: knowledgebase ids filter + filter: filter conditions with 'and' or 'or' logic for search results. + tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 + tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + **kwargs: other parameters (e.g. cube_name) + + Returns: + list[dict]: result list containing id and score + """ + # Build WHERE clause dynamically, same as search_by_embedding + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + + # Add OR condition if we have any user_name conditions + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + # Add fulltext search condition + # Convert query_text to OR query format: "word1 | word2 | word3" + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text, + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY rank DESC + LIMIT {top_k}; + """ + + params = [tsquery_string, tsquery_string] + + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[2] # rank score + + id_val = str(oldid) + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + + return output[:top_k] + finally: + self._return_connection(conn) @timed def search_by_embedding( From 2591c107812d8b48c6b80b262365ca8e2e675fce Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 1 Dec 2025 17:46:24 +0800 Subject: [PATCH 19/21] feat: cp plugin to dev --- src/memos/api/handlers/component_init.py | 3 + src/memos/api/product_models.py | 6 ++ src/memos/graph_dbs/polardb.py | 3 +- src/memos/memories/textual/simple_tree.py | 3 + src/memos/memories/textual/tree.py | 5 ++ .../retrieve/advanced_searcher.py | 3 + .../tree_text_memory/retrieve/recall.py | 19 +++++ .../retrieve/retrieve_utils.py | 27 ++++++ .../tree_text_memory/retrieve/searcher.py | 83 ++++++++++++++++--- src/memos/multi_mem_cube/single_cube.py | 4 +- 10 files changed, 141 insertions(+), 15 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 706269b52..574f2ae17 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -41,6 +41,7 @@ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer if TYPE_CHECKING: @@ -196,6 +197,7 @@ def init_server() -> dict[str, Any]: logger.debug("Memory manager initialized") + tokenizer = FastTokenizer() # Initialize text memory text_mem = SimpleTreeTextMemory( llm=llm, @@ -206,6 +208,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, config=default_cube_config.text_mem.config, internet_retriever=internet_retriever, + tokenizer=tokenizer, ) logger.debug("Text memory initialized") diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 4f445e9ab..cc76e6751 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -388,6 +388,12 @@ class APISearchRequest(BaseRequest): description="(Internal) Operation definitions for multi-cube read permissions.", ) + # ==== Source for plugin ==== + source: str | None = Field( + None, + description="Source of the search query [plugin will router diff search]", + ) + @model_validator(mode="after") def _convert_deprecated_fields(self) -> "APISearchRequest": """ diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 7207f05da..90bea8ea3 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1450,7 +1450,7 @@ def get_subgraph( def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError - + @timed def search_by_fulltext( self, @@ -1529,7 +1529,6 @@ def search_by_fulltext( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) where_clauses.extend(filter_conditions) diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 05e62e3ee..c67271f76 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -9,6 +9,7 @@ from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.reranker.base import BaseReranker @@ -35,6 +36,7 @@ def __init__( config: TreeTextMemoryConfig, internet_retriever: None = None, is_reorganize: bool = False, + tokenizer: FastTokenizer | None = None, ): """Initialize memory with the given configuration.""" self.config: TreeTextMemoryConfig = config @@ -51,6 +53,7 @@ def __init__( if self.search_strategy and self.search_strategy.get("bm25", False) else None ) + self.tokenizer = tokenizer self.reranker = reranker self.memory_manager: MemoryManager = memory_manager # Create internet retriever if configured diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 2a109bf71..ad2bcd9c4 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -91,6 +91,7 @@ def __init__(self, config: TreeTextMemoryConfig): ) else: logger.info("No internet retriever configured") + self.tokenizer = None def add( self, @@ -165,6 +166,7 @@ def search( search_priority: dict | None = None, search_filter: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. User query -> TaskGoalParser -> MemoryPathResolver -> @@ -197,6 +199,7 @@ def search( internet_retriever=None, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) else: searcher = Searcher( @@ -208,6 +211,7 @@ def search( internet_retriever=self.internet_retriever, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) return searcher.search( query, @@ -218,6 +222,7 @@ def search( search_filter, search_priority, user_name=user_name, + plugin=kwargs.get("plugin", False), ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 22cd44b8c..9c892d8b8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -10,6 +10,7 @@ from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, parse_structured_output, ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -33,6 +34,7 @@ def __init__( search_strategy: dict | None = None, manual_close_internet: bool = True, process_llm: Any | None = None, + tokenizer: FastTokenizer | None = None, ): super().__init__( dispatcher_llm=dispatcher_llm, @@ -43,6 +45,7 @@ def __init__( internet_retriever=internet_retriever, search_strategy=search_strategy, manual_close_internet=manual_close_internet, + tokenizer=tokenizer, ) self.stage_retrieve_top = 3 diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 7fa8a87be..7ac274a62 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -148,6 +148,25 @@ def retrieve_from_cube( return list(combined.values()) + def retrieve_from_mixed( + self, + top_k: int, + memory_scope: str | None = None, + query_embedding: list[list[float]] | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + ) -> list[TextualMemoryItem]: + """Retrieve from mixed and memory""" + vector_results = self._vector_recall( + query_embedding or [], + memory_scope, + top_k, + search_filter=search_filter, + user_name=user_name, + ) # Merge and deduplicate by ID + combined = {item.id: item for item in vector_results} + return list(combined.values()) + def _graph_recall( self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs ) -> list[TextualMemoryItem]: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 0720d1fca..9e1e6c240 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Any +import numpy as np + from memos.dependency import require_python_package from memos.log import get_logger @@ -463,3 +465,28 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]: memory["metadata"]["memory"] = memory["memory"] return memory + + +def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8): + assert len(sentences) == len(similarity_matrix) + + num_sentence = len(sentences) + selected_sentences = [] + selected_indices = [] + for i in range(num_sentence): + can_add = True + for j in selected_indices: + if similarity_matrix[i][j] > bar: + can_add = False + break + if can_add: + selected_sentences.append(i) + selected_indices.append(i) + return selected_sentences, selected_indices + + +def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]: + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + x_normalized = embeddings / norms + similarity_matrix = np.dot(x_normalized, x_normalized.T) + return similarity_matrix diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 976be6a54..f428bf5c0 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -8,7 +8,10 @@ from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + cosine_similarity_matrix, detect_lang, + find_best_unrelated_subgroup, parse_json_result, ) from memos.reranker.base import BaseReranker @@ -43,6 +46,7 @@ def __init__( internet_retriever: None = None, search_strategy: dict | None = None, manual_close_internet: bool = True, + tokenizer: FastTokenizer | None = None, ): self.graph_store = graph_store self.embedder = embedder @@ -58,6 +62,7 @@ def __init__( self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self.manual_close_internet = manual_close_internet + self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed @@ -104,9 +109,10 @@ def post_retrieve( top_k: int, user_name: str | None = None, info=None, + plugin=False, ): deduped = self._deduplicate_results(retrieved_results) - final_results = self._sort_and_trim(deduped, top_k) + final_results = self._sort_and_trim(deduped, top_k, plugin) self._update_usage_history(final_results, info, user_name) return final_results @@ -121,6 +127,7 @@ def search( search_filter: dict | None = None, search_priority: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -149,22 +156,29 @@ def search( else: logger.debug(f"[SEARCH] Received info dict: {info}") - retrieved_results = self.retrieve( - query=query, - top_k=top_k, - info=info, - mode=mode, - memory_type=memory_type, - search_filter=search_filter, - search_priority=search_priority, - user_name=user_name, - ) + if kwargs.get("plugin"): + logger.info(f"[SEARCH] Retrieve from plugin: {query}") + retrieved_results = self._retrieve_simple( + query=query, top_k=top_k, search_filter=search_filter, user_name=user_name + ) + else: + retrieved_results = self.retrieve( + query=query, + top_k=top_k, + info=info, + mode=mode, + memory_type=memory_type, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + ) final_results = self.post_retrieve( retrieved_results=retrieved_results, top_k=top_k, user_name=user_name, info=None, + plugin=kwargs.get("plugin", False), ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -484,6 +498,49 @@ def _retrieve_from_internet( parsed_goal=parsed_goal, ) + @timed + def _retrieve_simple( + self, + query: str, + top_k: int, + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + """Retrieve from by keywords and embedding""" + query_words = [] + if self.tokenizer: + query_words = self.tokenizer.tokenize_mixed(query) + else: + query_words = query.strip().split() + query_words = list(set(query_words))[: top_k * 3] + query_words = [query, *query_words] + logger.info(f"[SIMPLESEARCH] Query words: {query_words}") + query_embeddings = self.embedder.embed(query_words) + + items = self.graph_retriever.retrieve_from_mixed( + top_k=top_k * 2, + memory_scope=None, + query_embedding=query_embeddings, + search_filter=search_filter, + user_name=user_name, + ) + logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") + documents = [getattr(item, "memory", "") for item in items] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) + return self.reranker.rerank( + query=query, + query_embedding=query_embeddings[0], + graph_results=selected_items, + top_k=top_k, + ) + @timed def _deduplicate_results(self, results): """Deduplicate results by memory text""" @@ -494,12 +551,14 @@ def _deduplicate_results(self, results): return list(deduped.values()) @timed - def _sort_and_trim(self, results, top_k): + def _sort_and_trim(self, results, top_k, plugin=False): """Sort results by score and trim to top_k""" sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: + if plugin and round(score, 2) == 0.00: + continue meta_data = item.metadata.model_dump() meta_data["relativity"] = score final_items.append( diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index e346bdf1f..880646939 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -363,7 +363,8 @@ def _fast_search( target_session_id = search_req.session_id or "default_session" search_priority = {"session_id": search_req.session_id} if search_req.session_id else None search_filter = search_req.filter or None - print(f"type of text_mem: {type(self.naive_mem_cube.text_mem)}") + plugin = bool(search_req.source is not None and search_req.source == "plugin") + search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, @@ -377,6 +378,7 @@ def _fast_search( "session_id": target_session_id, "chat_history": search_req.chat_history, }, + plugin=plugin, ) formatted_memories = [format_memory_item(data) for data in search_results] From 383eaaa790a0d625698f03c2462b098978433ce7 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 1 Dec 2025 20:43:38 +0800 Subject: [PATCH 20/21] feat: add one recall for fulltext retrieval --- src/memos/graph_dbs/polardb.py | 2 +- .../tree_text_memory/retrieve/recall.py | 68 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index e731ef138..da88604b6 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1553,7 +1553,7 @@ def search_by_fulltext( """ params = [tsquery_string, tsquery_string] - + logger.info(f"[search_by_fulltext] query: {query}, params: {params}") conn = self._get_connection() try: with conn.cursor() as cursor: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 7ac274a62..5a6ff0233 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -101,13 +101,27 @@ def retrieve( user_name=user_name, search_filter=id_filter, ) + if use_fast_graph: + future_fulltext = executor.submit( + self._fulltext_recall, + query_words=parsed_goal.keys or [], + memory_scope=memory_scope, + top_k=top_k, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + ) graph_results = future_graph.result() vector_results = future_vector.result() bm25_results = future_bm25.result() if self.use_bm25 else [] + fulltext_results = future_fulltext.result() # Merge and deduplicate by ID - combined = {item.id: item for item in graph_results + vector_results + bm25_results} + combined = { + item.id: item + for item in graph_results + vector_results + bm25_results + fulltext_results + } return list(combined.values()) @@ -404,3 +418,55 @@ def _bm25_recall( ) return [TextualMemoryItem.from_dict(n) for n in bm25_results] + + def _fulltext_recall( + self, + query_words: list[str], + memory_scope: str, + top_k: int = 20, + max_num: int = 5, + status: str = "activated", + cube_name: str | None = None, + search_filter: dict | None = None, + search_priority: dict | None = None, + user_name: str | None = None, + ): + """Perform fulltext-based retrieval. + Args: + query_words: list of query words + memory_scope: memory scope + top_k: top k results + max_num: max number of query words + status: status + cube_name: cube name + search_filter: search filter + search_priority: search priority + user_name: user name + Returns: + list of TextualMemoryItem + """ + if not query_words: + return [] + logger.info(f"[FULLTEXT] query_words: {query_words}") + all_hits = self.graph_store.search_by_fulltext( + query_words=query_words, + top_k=top_k, + status=status, + scope=memory_scope, + cube_name=cube_name, + search_filter=search_priority, + filter=search_filter, + user_name=user_name, + ) + if not all_hits: + return [] + + # merge and deduplicate + unique_ids = {r["id"] for r in all_hits if r.get("id")} + node_dicts = ( + self.graph_store.get_nodes( + list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name + ) + or [] + ) + return [TextualMemoryItem.from_dict(n) for n in node_dicts] From 502e15e5216f3f08c927f724560f57ee41e2e604 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 2 Dec 2025 09:57:00 +0800 Subject: [PATCH 21/21] fix: set default for fulltext search --- src/memos/memories/textual/tree_text_memory/retrieve/recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 5a6ff0233..5dfbde704 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -115,7 +115,7 @@ def retrieve( graph_results = future_graph.result() vector_results = future_vector.result() bm25_results = future_bm25.result() if self.use_bm25 else [] - fulltext_results = future_fulltext.result() + fulltext_results = future_fulltext.result() if use_fast_graph else [] # Merge and deduplicate by ID combined = {