From 587e5a7204d415882fbd5c98d41dae45ad8666d8 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 19 Nov 2025 21:02:40 +0800 Subject: [PATCH 01/35] feat: update memos headers --- src/memos/api/config.py | 3 ++- src/memos/configs/embedder.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a276fa63d..3236c6705 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -379,7 +379,7 @@ def get_reranker_config() -> dict[str, Any]: "url": os.getenv("MOS_RERANKER_URL"), "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, - "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), + "headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")), "rerank_source": os.getenv("MOS_RERANK_SOURCE"), "reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"), }, @@ -406,6 +406,7 @@ def get_embedder_config() -> dict[str, Any]: "api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"), "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), + "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), }, } else: # ollama diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index 70095a194..d88b6005e 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig): embedding_dims: int | None = Field( default=None, description="Number of dimensions for the embedding" ) + headers_extra: dict[str, Any] | None = Field( + default=None, + description="Extra headers for the embedding model, only for universal_api backend", + ) class OllamaEmbedderConfig(BaseEmbedderConfig): From 5637c9d16adeff6a0a8439dae9c8f1f452b02ede Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 19 Nov 2025 21:03:07 +0800 Subject: [PATCH 02/35] feat: headers add --- src/memos/embedders/universal_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index fc51cf073..6cbd06443 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -16,7 +16,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig): self.config = config if self.provider == "openai": - self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url) + self.client = OpenAIClient( + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.headers_extra if config.headers_extra else None, + ) elif self.provider == "azure": self.client = AzureClient( azure_endpoint=config.base_url, From 68831c080e9db6633136f0ed7424b8b4bc275ccb Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 20 Nov 2025 16:45:03 +0800 Subject: [PATCH 03/35] feat: update search agent --- examples/deepsearch_example.py | 76 ++++++ src/memos/configs/mem_agent.py | 60 +++++ src/memos/mem_agent/base.py | 17 ++ src/memos/mem_agent/deepsearch_agent.py | 320 +++++++++++++++++++++++ src/memos/mem_agent/factory.py | 21 ++ src/memos/templates/mem_agent_prompts.py | 76 ++++++ 6 files changed, 570 insertions(+) create mode 100644 examples/deepsearch_example.py create mode 100644 src/memos/configs/mem_agent.py create mode 100644 src/memos/mem_agent/base.py create mode 100644 src/memos/mem_agent/deepsearch_agent.py create mode 100644 src/memos/mem_agent/factory.py create mode 100644 src/memos/templates/mem_agent_prompts.py diff --git a/examples/deepsearch_example.py b/examples/deepsearch_example.py new file mode 100644 index 000000000..43bf4c012 --- /dev/null +++ b/examples/deepsearch_example.py @@ -0,0 +1,76 @@ +""" +Example usage of DeepSearchAgent in MemOS. + +This example demonstrates how to initialize and use the DeepSearchAgent +for comprehensive information retrieval and synthesis. +""" + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchAgent +from memos.llms.factory import LLMFactory +from memos.configs.llm import LLMConfigFactory + + +def main(): + """Example usage of DeepSearchAgent.""" + + # 1. Configure the LLM + llm_config = LLMConfigFactory( + backend="openai", # or "ollama", "azure", etc. + config={ + "api_key": "your-api-key-here", + "model_name_or_path": "gpt-4", + "temperature": 0.7, + "max_tokens": 2048 + } + ) + + # 2. Create LLM instance + llm = LLMFactory.from_config(llm_config) + + # 3. Configure DeepSearchAgent + agent_config = DeepSearchAgentConfig( + agent_name="DeepSearchAgent", + description="Advanced deep search agent for comprehensive information retrieval", + max_iterations=3, + timeout=60 + ) + + # 4. Create DeepSearchAgent instance + agent = DeepSearchAgent(agent_config) + + # 5. Initialize the agent with LLM + agent.set_llm(llm) + + # 6. Set up memory retriever (this would typically be injected by the framework) + # agent.set_memory_retriever(your_memory_retriever) + + # 7. Example queries + queries = [ + "What are the latest developments in AI research?", + "Tell me about my recent project meetings and their outcomes", + "What are the key trends in machine learning this year?" + ] + + # 8. Process queries + for query in queries: + print(f"\n{'='*60}") + print(f"Query: {query}") + print(f"{'='*60}") + + try: + # Run the deep search pipeline + response = agent.run( + input=query, + history=["Previous conversation context"], + user_id="example_user" + ) + + print(f"Response: {response}") + + except Exception as e: + print(f"Error processing query: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py new file mode 100644 index 000000000..e58f8e67c --- /dev/null +++ b/src/memos/configs/mem_agent.py @@ -0,0 +1,60 @@ +from typing import Any, ClassVar + +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig + + +class BaseAgentConfig(BaseConfig): + """Base configuration class for agents.""" + + agent_name: str = Field(..., description="Name of the agent") + description: str | None = Field( + default=None, description="Description of the agent" + ) + + +class SimpleAgentConfig(BaseAgentConfig): + """Simple agent configuration class.""" + + max_iterations: int = Field( + default=10, description="Maximum number of iterations for the agent" + ) + timeout: int = Field( + default=30, description="Timeout in seconds for agent execution" + ) + +class DeepSearchAgentConfig(BaseAgentConfig): + """Deep search agent configuration class.""" + + max_iterations: int = Field( + default=3, description="Maximum number of iterations for the agent" + ) + timeout: int = Field( + default=30, description="Timeout in seconds for agent execution" + ) + +class MemAgentConfigFactory(BaseConfig): + """Factory class for creating agent configurations.""" + + backend: str = Field(..., description="Backend for agent") + config: dict[str, Any] = Field(..., description="Configuration for the agent backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "simple": SimpleAgentConfig, + "deep_search": DeepSearchAgentConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "AgentConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self \ No newline at end of file diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py new file mode 100644 index 000000000..3b54a846f --- /dev/null +++ b/src/memos/mem_agent/base.py @@ -0,0 +1,17 @@ +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod +from pydantic import BaseModel +from memos.configs.mem_agent import BaseAgentConfig + +class BaseMemAgent(ABC): + """ + Base class for all agents. + """ + def __init__(self, config: BaseAgentConfig): + """Initialize the BaseMemAgent with the given configuration.""" + + @abstractmethod + def run(self, input: str) -> str: + """ + Run the agent. + """ \ No newline at end of file diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py new file mode 100644 index 000000000..eb67fabba --- /dev/null +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -0,0 +1,320 @@ +""" +Deep Search Agent implementation for MemOS. + +This module implements a sophisticated deep search agent that performs iterative +query refinement and memory retrieval to provide comprehensive answers. +""" + +import json +import logging +from memos.log import get_logger +from typing import Any, Dict, List, Optional, Tuple + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.llms.base import BaseLLM +from memos.mem_agent.base import BaseMemAgent +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageDict, MessageList +from memos.templates.mem_agent_prompts import ( + QUERY_REWRITE_PROMPT, + REFLECTION_PROMPT, + KEYWORD_EXTRACTION_PROMPT, + FINAL_GENERATION_PROMPT + ) + +logger = get_logger(__name__) + + +class QueryRewriter(BaseMemAgent): + """ + Specialized agent for rewriting queries based on conversation history. + Corresponds to the "LLM subAgent (Rewrite...)" in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): + self.llm = llm + self.name = name + + def run(self, query: str, history: list[str]| None = None) -> str: + """ + Rewrite the query to be standalone and more searchable. + + Args: + query: Original user query + history: List of previous conversation messages + + Returns: + Rewritten query string + """ + if history is None: + history = [] + + history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages + + prompt = QUERY_REWRITE_PROMPT.format( + history=history_str if history_str else "No previous conversation", + query=query + ) + + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + logger.info(f"[{self.name}] Rewritten query: {response.strip()}") + return response.strip() + except Exception as e: + logger.error(f"[{self.name}] Error rewriting query: {e}") + return query # Fallback to original query + + +class ReflectionAgent: + """ + Specialized agent for analyzing information sufficiency. + Corresponds to the decision diamond in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, name: str = "Reflector"): + self.llm = llm + self.name = name + + def run(self, query: str, context: List[str]) -> Dict[str, Any]: + """ + Analyze whether retrieved context is sufficient to answer the query. + + Args: + query: User query + context: List of retrieved context strings + + Returns: + Dictionary with status, reasoning, and missing entities + """ + context_str = "\n".join([f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" + for ctx in context[:10]]) # Limit context size + + prompt = REFLECTION_PROMPT.format(query=query, context=context_str) + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + result = json.loads(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result.get('status', 'unknown')}") + return result + except (json.JSONDecodeError, Exception) as e: + logger.error(f"[{self.name}] Error in reflection analysis: {e}") + # Fallback response + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [] + } + + +class DeepSearchMemAgent(BaseMemAgent): + """ + Main orchestrator agent implementing the deep search pipeline. + + This agent coordinates multiple sub-agents to perform iterative query refinement, + memory retrieval, and information synthesis as shown in the architecture diagram. + """ + + def __init__(self, llm: BaseLLM, memory_retriever: BaseMemoryRetriever | None = None): + super().__init__(config) + self.config = config + self.max_iterations = config.max_iterations + self.timeout = config.timeout + self.llm: Optional[BaseLLM] = llm + self.query_rewriter: Optional[QueryRewriteAgent] = QueryRewriter(llm, "QueryRewriter") + self.reflector: Optional[ReflectionAgent] = ReflectionAgent(llm, "Reflector") + self.memory_retriever = memory_retriever + + def _set_llm(self, llm: BaseLLM) -> None: + """Set the LLM and initialize sub-agents.""" + self.llm = llm + self.query_rewriter = QueryRewriteAgent(llm, "QueryRewriter") + self.reflector = ReflectionAgent(llm, "Reflector") + self.keyword_extractor = KeywordExtractionAgent(llm, "KeywordExtractor") + logger.info("LLM and sub-agents initialized") + + def _set_memory_retriever(self, retriever) -> None: + """Set the memory retrieval interface.""" + self.memory_retriever = retriever + logger.info("Memory retriever interface set") + + def run(self, input: str, **kwargs) -> str: + """ + Main execution method implementing the deep search pipeline. + + Args: + input: User query string + Returns: + Comprehensive response string + """ + if not self.llm: + raise RuntimeError("LLM not initialized. Call set_llm() first.") + + query = input + history = kwargs.get("history", []) + user_id = kwargs.get("user_id") + + # Step 1: Query Rewriting + current_query = self.query_rewriter.run(query, history) + + # Step 2: Keyword Extraction and Planning + keyword_analysis = self.keyword_extractor.run(current_query) + search_keywords = keyword_analysis.get("keywords", [current_query]) + + accumulated_context = [] + accumulated_memories = [] + + # Step 3: Iterative Search and Reflection Loop + for iteration in range(self.max_iterations): + search_results = self._perform_memory_search( + current_query, + keywords=search_keywords, + user_id=user_id + ) + + if search_results: + context_batch = [self._extract_context_from_memory(mem) for mem in search_results] + accumulated_context.extend(context_batch) + accumulated_memories.extend(search_results) + + reflection_result = self.reflector.run(current_query, context_batch) + status = reflection_result.get("status", "sufficient") + reasoning = reflection_result.get("reasoning", "") + + logger.info(f"Reflection status: {status} - {reasoning}") + + if status == "sufficient": + logger.info("Sufficient information collected") + break + elif status == "needs_raw": + logger.info("Need original sources, retrieving raw content") + break + elif status == "missing_info": + missing_entities = reflection_result.get("missing_entities", []) + logger.info(f"Missing information: {missing_entities}") + if missing_entities: + refined_query = self._refine_query_for_missing_info( + current_query, missing_entities + ) + current_query = refined_query + logger.info(f"Refined query: {current_query}") + else: + logger.warning(f"No search results for iteration {iteration + 1}") + if iteration == 0: + current_query = query + else: + break + final_answer = self._generate_final_answer( + original_query=query, + search_results=accumulated_memories, + context=accumulated_context, + missing_info=keyword_analysis.get("search_strategy", "") + ) + + logger.info("Deep search pipeline completed") + return final_answer + + def _perform_memory_search( + self, + query: str, + keywords: List[str] = None, + user_id: str = None, + top_k: int = 10 + ) -> List[TextualMemoryItem]: + """ + Perform memory search using the configured retriever. + + Args: + query: Search query + keywords: Additional keywords for search + user_id: User identifier + top_k: Number of results to retrieve + + Returns: + List of retrieved memory items + """ + if not self.memory_retriever: + logger.warning("Memory retriever not configured, returning empty results") + return [] + + try: + # Use the memory retriever interface + # This is a placeholder - actual implementation depends on the retriever interface + search_query = query + if keywords and len(keywords) > 1: + search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords + + # Assuming the retriever has a search method similar to TreeTextMemory + results = self.memory_retriever.search( + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id + ) + + return results if isinstance(results, list) else [] + + except Exception as e: + logger.error(f"Error performing memory search: {e}") + return [] + + def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: + """Extract readable context from a memory item.""" + if hasattr(memory_item, 'memory'): + return str(memory_item.memory) + elif hasattr(memory_item, 'content'): + return str(memory_item.content) + else: + return str(memory_item) + + def _refine_query_for_missing_info(self, query: str, missing_entities: List[str]) -> str: + """Refine the query to search for missing information.""" + if not missing_entities: + return query + + # Simple refinement strategy - append missing entities + entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities + refined_query = f"{query} {entities_str}" + + return refined_query + + def _generate_final_answer( + self, + original_query: str, + search_results: List[TextualMemoryItem], + context: List[str], + missing_info: str = "" + ) -> str: + """ + Generate the final comprehensive answer. + + Args: + original_query: Original user query + search_results: All retrieved memory items + context: Extracted context strings + missing_info: Information about missing data + + Returns: + Final answer string + """ + # Prepare context for the prompt + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context + sources = f"Retrieved {len(search_results)} memory items" if search_results else "No specific sources" + + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified" + ) + + messages: MessageList = [{"role": "user", "content": prompt}] + + try: + response = self.llm.generate(messages) + return response.strip() + except Exception as e: + logger.error(f"Error generating final answer: {e}") + return f"I apologize, but I encountered an error while processing your query: {original_query}. Please try again." diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py new file mode 100644 index 000000000..be656acc0 --- /dev/null +++ b/src/memos/mem_agent/factory.py @@ -0,0 +1,21 @@ +from typing import Any, ClassVar + +from memos.configs.mem_agent import MemAgentConfigFactory +from memos.mem_agent.base import BaseMemAgent +from memos.mem_agent.deepsearch_agent import DeepSearchAgent + + +class MemAgentFactory: + """Factory class for creating MemAgent instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "deep_search": DeepSearchAgent, + } + + @classmethod + def from_config(cls, config_factory: MemAgentConfigFactory) -> BaseMemAgent: + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + mem_agent_class = cls.backend_to_class[backend] + return mem_agent_class(config_factory.config) \ No newline at end of file diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py new file mode 100644 index 000000000..399b2cd75 --- /dev/null +++ b/src/memos/templates/mem_agent_prompts.py @@ -0,0 +1,76 @@ +QUERY_REWRITE_PROMPT = """ +You are a query rewriting specialist. Your task is to rewrite user queries to be more standalone and searchable. + +Given the conversation history and current user query, rewrite the query to: +1. Be self-contained and independent of conversation context +2. Include relevant context from history when necessary +3. Maintain the original intent and scope +4. Use clear, specific terminology + +Conversation History: +{history} + +Current Query: {query} + +Rewritten Query:""" + +REFLECTION_PROMPT = """ +You are an information sufficiency analyst. Evaluate whether the retrieved context is sufficient to answer the user's query. + +Query: {query} +Retrieved Context: +{context} + +Analyze the context and determine the next step. Return your response in JSON format with the following structure: +{{ + "status": "sufficient|missing_info|needs_raw", + "reasoning": "Brief explanation of your decision", + "missing_entities": ["entity1", "entity2"] // Only if status is "missing_info", + +}} + +Status definitions: +- "sufficient": Context fully answers the query +- "missing_info": Key information is missing (e.g., specific dates, locations, details) +- "needs_raw": Content is relevant but too summarized/vague, need original sources + +Response:""" + +KEYWORD_EXTRACTION_PROMPT = """ +Analyze the user query and extract key search terms and identify optimal data sources. + +Query: {query} + +Extract: +1. Key search terms and concepts +2. Important entities (people, places, dates, etc.) +3. Suggested data sources or memory types to search + +Return response in JSON format: +{{ + "keywords": ["keyword1", "keyword2"], + "entities": ["entity1", "entity2"], + "search_strategy": "Brief strategy description" +}} + +Response:""" + + +FINAL_GENERATION_PROMPT = """ +You are a comprehensive information synthesizer. Generate a complete answer based on the retrieved information. + +User Query: {query} +Search Sources: {sources} +Retrieved Information: +{context} + +Missing Information (if any): {missing_info} + +Instructions: +1. Synthesize all relevant information to answer the query comprehensively +2. If information is missing, acknowledge gaps and suggest next steps +3. Maintain accuracy and cite sources when possible +4. Provide a well-structured, coherent response +5. Use natural, conversational tone + +Response:""" \ No newline at end of file From 58c512d8a7086e0e723bf232257067f8c7516876 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:08:50 +0800 Subject: [PATCH 04/35] feat: upadte mem story --- examples/mem_agent/deepsearch_example.py | 350 ++++++++++++++++++ .../deepsearch_with_api_components.py | 223 +++++++++++ src/memos/mem_agent/deepsearch_agent.py | 57 +-- src/memos/mem_agent/factory.py | 28 +- 4 files changed, 628 insertions(+), 30 deletions(-) create mode 100644 examples/mem_agent/deepsearch_example.py create mode 100644 examples/mem_agent/deepsearch_with_api_components.py diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py new file mode 100644 index 000000000..62f26f0b1 --- /dev/null +++ b/examples/mem_agent/deepsearch_example.py @@ -0,0 +1,350 @@ +""" +DeepSearch Agent Usage Examples + +This example demonstrates two ways to initialize DeepSearchMemAgent: +1. Using Factory pattern (recommended) +2. Direct initialization + +DeepSearchMemAgent implements iterative deep search, providing comprehensive answers through: +- Query rewriting: Optimize queries based on conversation history +- Iterative retrieval: Collect information through multiple search rounds +- Reflective analysis: Determine if information is sufficient +- Comprehensive response: Generate complete final answers +""" + +import os +import sys + +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphConfigFactory +from memos.configs.llms import LLMConfigFactory +from memos.configs.mem_agent import DeepSearchAgentConfig, MemAgentConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.textual_memory import TreeTextMemoryConfig +from memos.embedders.base import BaseEmbedder +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.mem_agent.factory import MemAgentFactory +from memos.mem_cube.naive_cube import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.memory_manager.memory_manager import MemoryManager +from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory +from memos.log import get_logger + +logger = get_logger(__name__) + + +def build_minimal_components(): + """ + Build the minimal component set required for DeepSearchMemAgent. + + Only need to initialize: + 1. LLM - Used for query rewriting, reflection, and final answer generation + 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval + + Returns: + dict: Dictionary containing llm and naive_mem_cube + """ + logger.info("Starting to build minimal component set...") + + # 1. Initialize LLM + llm_config = LLMConfigFactory( + backend="openai", + config={ + "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + "temperature": 0.7, + } + ) + llm = LLMFactory.from_config(llm_config) + logger.info("LLM initialization completed") + + # 2. Initialize Embedder (required by text_mem) + embedder_config = EmbedderConfigFactory( + backend="universal_api", + config={ + "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), + "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), + "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), + } + ) + embedder = EmbedderFactory.from_config(embedder_config) + logger.info("Embedder initialization completed") + + # 3. Initialize GraphDB (required by text_mem) + graph_db_config = GraphConfigFactory( + backend="polardb", + config={ + "host": os.getenv("POLAR_DB_HOST", "localhost"), + "port": int(os.getenv("POLAR_DB_PORT", "5432")), + "user": os.getenv("POLAR_DB_USER", "root"), + "password": os.getenv("POLAR_DB_PASSWORD", "123456"), + "db_name": os.getenv("POLAR_DB_DB_NAME", "shared_memos_db"), + "user_name": "memos_default", + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + } + ) + graph_db = GraphStoreFactory.from_config(graph_db_config) + logger.info("✓ GraphDB init Done") + + # 4. Initialize MemReader (required by text_mem) + mem_reader_config = MemReaderConfigFactory( + backend="simple_struct", + config={ + "llm": { + "backend": "openai", + "config": { + "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + } + } + } + ) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + logger.info("✓ MemReader init Done") + + # 5. Initialize MemoryManager + memory_manager = MemoryManager( + graph_db=graph_db, + embedder=embedder, + llm=llm, + memory_size=1000, + is_reorganize=False, + ) + logger.info("✓ MemoryManager init Done") + + # 6. Initialize TextMemory + text_mem_config = TreeTextMemoryConfig( + reorganize=False, + max_depth=3, + ) + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=None, + memory_manager=memory_manager, + config=text_mem_config, + internet_retriever=None, + ) + logger.info("✓ TextMemory initialization completed") + + # 7. Create NaiveMemCube + naive_mem_cube = NaiveMemCube( + text_mem=text_mem, + pref_mem=None, + act_mem=None, + para_mem=None, + ) + logger.info("✓ NaiveMemCube creation completed") + + logger.info("All components initialized!") + + return { + "llm": llm, + "naive_mem_cube": naive_mem_cube, + "embedder": embedder, + "graph_db": graph_db, + "mem_reader": mem_reader, + } + + +def example_1_factory_initialization(): + """ + Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) + + Advantages: + - Separation of configuration and code + - Easy to manage and modify + - Support loading from configuration files + """ + logger.info("\n" + "="*60) + logger.info("Example 1: Initialize using Factory pattern") + logger.info("="*60 + "\n") + + # Build necessary components + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Create configuration Factory + agent_config_factory = MemAgentConfigFactory( + backend="deep_search", + config={ + "agent_name": "MyDeepSearchAgent", + "description": "Intelligent agent for deep search", + "max_iterations": 3, # Maximum number of iterations + "timeout": 60, # Timeout in seconds + } + ) + + # Create Agent using Factory + # Pass text_mem as memory_retriever, it provides search method + deep_search_agent = MemAgentFactory.from_config( + config_factory=agent_config_factory, + llm=llm, + memory_retriever=naive_mem_cube.text_mem + ) + + logger.info(f"✓ DeepSearchMemAgent created successfully") + logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def example_3_usage(deep_search_agent, components): + """ + Example 3: Using DeepSearchMemAgent for search + + Demonstrates how to: + 1. Add memories to the system + 2. Use Agent for deep search + 3. Get comprehensive answers + """ + logger.info("\n" + "="*60) + logger.info("Example 3: Using DeepSearchMemAgent") + logger.info("="*60 + "\n") + + naive_mem_cube = components["naive_mem_cube"] + text_mem = naive_mem_cube.text_mem + + # Simulate adding some memories + logger.info("1. Adding test memories...") + test_memories = [ + { + "user_name": "test_user", + "messages": ["Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence."], + "source": "manual" + }, + { + "user_name": "test_user", + "messages": ["Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance."], + "source": "manual" + }, + { + "user_name": "test_user", + "messages": ["Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks."], + "source": "manual" + } + ] + + for memory in test_memories: + try: + text_mem.add( + user_name=memory["user_name"], + messages=memory["messages"], + source=memory.get("source", "manual") + ) + logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") + except Exception as e: + logger.warning(f" ✗ Failed to add memory: {e}") + + # Use Agent for search + logger.info("\n2. Executing deep search...") + query = "What is the relationship between artificial intelligence and machine learning?" + logger.info(f" Query: {query}") + + try: + response = deep_search_agent.run( + query=query, + user_id="test_user", + history=[ + "Hello", + "I want to learn about artificial intelligence" + ] + ) + + logger.info("\n3. Search results:") + logger.info("-" * 60) + logger.info(response) + logger.info("-" * 60) + + except Exception as e: + logger.error(f"Error during search: {e}") + import traceback + traceback.print_exc() + + +def example_4_minimal_initialization(): + """ + Example 4: Minimal initialization (using default configuration) + + Use cases: + - Quick testing + - Prototype development + - No need for custom configuration + """ + logger.info("\n" + "="*60) + logger.info("Example 4: Minimal initialization (default configuration)") + logger.info("="*60 + "\n") + + # Build necessary components + components = build_minimal_components() + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + + # Direct initialization using default configuration + deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem + # config parameter omitted, will use default configuration + ) + + logger.info(f"✓ DeepSearchMemAgent created successfully (using default configuration)") + logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") + logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") + + return deep_search_agent, components + + +def main(): + """Main function: Run all examples""" + logger.info("DeepSearch Agent Usage Examples") + logger.info("="*60) + + # Check environment variables + required_env_vars = ["OPENAI_API_KEY"] + missing_vars = [var for var in required_env_vars if not os.getenv(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set the following environment variables:") + logger.error(" - OPENAI_API_KEY: OpenAI API key") + logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") + logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") + return + + try: + # Run Example 1: Factory pattern + agent_factory, components_factory = example_1_factory_initialization() + + # Run Example 2: Direct initialization + agent_direct, components_direct = example_2_direct_initialization() + + # Run Example 4: Minimal initialization + agent_minimal, components_minimal = example_4_minimal_initialization() + + # Run Example 3: Actual usage (using agent created with factory method) + example_3_usage(agent_factory, components_factory) + + logger.info("\n" + "="*60) + logger.info("All examples completed!") + logger.info("="*60) + + except Exception as e: + logger.error(f"Error running examples: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mem_agent/deepsearch_with_api_components.py b/examples/mem_agent/deepsearch_with_api_components.py new file mode 100644 index 000000000..48e4df43c --- /dev/null +++ b/examples/mem_agent/deepsearch_with_api_components.py @@ -0,0 +1,223 @@ +""" +DeepSearch Agent 与 API 组件集成示例 + +本示例展示如何直接使用 API 服务器初始化的组件来创建 DeepSearch Agent。 +这种方式可以避免重复初始化,直接复用已有的组件。 + +适用场景: +- 在已有的 API 服务器中添加 DeepSearch 功能 +- 使用统一的组件配置 +- 避免重复初始化开销 +""" + +import os +import sys + +# 确保可以导入 memos 模块 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from memos.api.handlers.component_init import init_server +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent +from memos.log import get_logger + +logger = get_logger(__name__) + + +def create_deepsearch_from_api_components(): + """ + 从 API 服务器组件创建 DeepSearch Agent + + 这个函数演示了如何: + 1. 使用 init_server() 初始化所有组件 + 2. 提取需要的组件(llm 和 naive_mem_cube) + 3. 创建 DeepSearch Agent + """ + logger.info("="*60) + logger.info("使用 API 组件初始化 DeepSearch Agent") + logger.info("="*60 + "\n") + + # 步骤 1: 初始化所有服务器组件 + # 这会创建 llm, naive_mem_cube, scheduler 等所有组件 + logger.info("1. 初始化服务器组件...") + components = init_server() + logger.info(" ✓ 服务器组件初始化完成") + + # 步骤 2: 提取需要的组件 + logger.info("\n2. 提取必需组件...") + llm = components["llm"] + naive_mem_cube = components["naive_mem_cube"] + logger.info(f" ✓ LLM: {type(llm).__name__}") + logger.info(f" ✓ MemCube: {type(naive_mem_cube).__name__}") + logger.info(f" ✓ TextMemory: {type(naive_mem_cube.text_mem).__name__}") + + # 步骤 3: 创建 DeepSearch Agent 配置 + logger.info("\n3. 创建 DeepSearch Agent...") + config = DeepSearchAgentConfig( + agent_name="APIDeepSearchAgent", + description="基于 API 组件的深度搜索代理", + max_iterations=3, + timeout=60, + ) + + # 步骤 4: 初始化 DeepSearch Agent + # memory_retriever 使用 naive_mem_cube.text_mem + # 它提供了 search() 方法用于检索记忆 + deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem, + config=config + ) + + logger.info(f" ✓ DeepSearch Agent 创建成功") + logger.info(f" ✓ Agent 名称: {config.agent_name}") + logger.info(f" ✓ 最大迭代次数: {config.max_iterations}") + + return deep_search_agent, components + + +def demo_usage(deep_search_agent, components): + """ + 演示如何使用 DeepSearch Agent + """ + logger.info("\n" + "="*60) + logger.info("DeepSearch Agent 使用演示") + logger.info("="*60 + "\n") + + naive_mem_cube = components["naive_mem_cube"] + text_mem = naive_mem_cube.text_mem + + # 添加一些测试记忆 + logger.info("1. 添加测试记忆...") + test_data = [ + "MemOS 是一个先进的记忆操作系统,专门为 AI 系统设计。", + "MemOS 支持多种记忆类型:文本记忆、偏好记忆、行为记忆等。", + "DeepSearch Agent 是 MemOS 中的一个重要组件,用于深度搜索和信息检索。", + ] + + for i, content in enumerate(test_data, 1): + try: + text_mem.add( + user_name="demo_user", + messages=[content], + source="demo" + ) + logger.info(f" ✓ 记忆 {i}: {content[:40]}...") + except Exception as e: + logger.warning(f" ✗ 添加记忆失败: {e}") + + # 执行深度搜索 + logger.info("\n2. 执行深度搜索...") + query = "MemOS 支持哪些功能?" + logger.info(f" 查询: {query}") + + try: + response = deep_search_agent.run( + query=query, + user_id="demo_user", + history=[] + ) + + logger.info("\n3. 搜索结果:") + logger.info("-" * 60) + logger.info(response) + logger.info("-" * 60) + + return response + + except Exception as e: + logger.error(f" ✗ 搜索失败: {e}") + import traceback + traceback.print_exc() + return None + + +def integrate_with_api_router(): + """ + 展示如何将 DeepSearch Agent 集成到 API Router 中 + + 这段代码展示了在 server_router.py 中如何添加一个新的端点 + 来使用 DeepSearch Agent。 + """ + logger.info("\n" + "="*60) + logger.info("API Router 集成示例") + logger.info("="*60 + "\n") + + logger.info("在 server_router.py 中添加以下代码:\n") + + integration_code = ''' +# 在 server_router.py 顶部导入 +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent + +# 在初始化 handlers 之后,创建 DeepSearch Agent +deep_search_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=naive_mem_cube.text_mem, + config=DeepSearchAgentConfig( + agent_name="APIDeepSearchAgent", + max_iterations=3, + timeout=60, + ) +) + +# 添加新的 API 端点 +@router.post("/deepsearch", summary="Deep search with memory") +def deep_search( + user_id: str, + query: str, + history: list[str] = None +): + """Execute deep search with iterative memory retrieval.""" + try: + response = deep_search_agent.run( + query=query, + user_id=user_id, + history=history or [] + ) + return {"status": "success", "response": response} + except Exception as e: + return {"status": "error", "message": str(e)} +''' + + print(integration_code) + + +def main(): + """主函数""" + logger.info("DeepSearch Agent 与 API 组件集成示例\n") + + # 检查环境变量 + required_env_vars = ["OPENAI_API_KEY"] + missing_vars = [var for var in required_env_vars if not os.getenv(var)] + + if missing_vars: + logger.error(f"❌ 缺少必要的环境变量: {', '.join(missing_vars)}") + logger.error("\n请设置以下环境变量:") + logger.error(" export OPENAI_API_KEY='your-api-key'") + logger.error(" export OPENAI_BASE_URL='your-base-url' # 可选") + return + + try: + # 创建 DeepSearch Agent + agent, components = create_deepsearch_from_api_components() + + # 演示使用 + demo_usage(agent, components) + + # 展示集成方法 + integrate_with_api_router() + + logger.info("\n" + "="*60) + logger.info("示例运行完成!") + logger.info("="*60) + + except Exception as e: + logger.error(f"\n❌ 运行失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() + diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index eb67fabba..6dd072728 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -117,57 +117,60 @@ class DeepSearchMemAgent(BaseMemAgent): memory retrieval, and information synthesis as shown in the architecture diagram. """ - def __init__(self, llm: BaseLLM, memory_retriever: BaseMemoryRetriever | None = None): - super().__init__(config) - self.config = config - self.max_iterations = config.max_iterations - self.timeout = config.timeout - self.llm: Optional[BaseLLM] = llm - self.query_rewriter: Optional[QueryRewriteAgent] = QueryRewriter(llm, "QueryRewriter") - self.reflector: Optional[ReflectionAgent] = ReflectionAgent(llm, "Reflector") + def __init__( + self, + llm: BaseLLM, + memory_retriever: Any = None, + config: Optional[DeepSearchAgentConfig] = None + ): + """ + Initialize DeepSearchMemAgent. + + Args: + llm: Language model for query rewriting and response generation + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + config: Configuration for deep search behavior + """ + self.config = config or DeepSearchAgentConfig() + self.max_iterations = self.config.max_iterations + self.timeout = self.config.timeout + self.llm: BaseLLM = llm + self.query_rewriter: QueryRewriter = QueryRewriter(llm, "QueryRewriter") + self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") self.memory_retriever = memory_retriever - def _set_llm(self, llm: BaseLLM) -> None: - """Set the LLM and initialize sub-agents.""" - self.llm = llm - self.query_rewriter = QueryRewriteAgent(llm, "QueryRewriter") - self.reflector = ReflectionAgent(llm, "Reflector") - self.keyword_extractor = KeywordExtractionAgent(llm, "KeywordExtractor") - logger.info("LLM and sub-agents initialized") - def _set_memory_retriever(self, retriever) -> None: """Set the memory retrieval interface.""" self.memory_retriever = retriever logger.info("Memory retriever interface set") - def run(self, input: str, **kwargs) -> str: + def run(self, query: str, **kwargs) -> str: """ Main execution method implementing the deep search pipeline. Args: - input: User query string + query: User query string + **kwargs: Additional arguments (history, user_id, etc.) Returns: Comprehensive response string """ if not self.llm: - raise RuntimeError("LLM not initialized. Call set_llm() first.") + raise RuntimeError("LLM not initialized.") - query = input history = kwargs.get("history", []) user_id = kwargs.get("user_id") # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) - # Step 2: Keyword Extraction and Planning - keyword_analysis = self.keyword_extractor.run(current_query) - search_keywords = keyword_analysis.get("keywords", [current_query]) - accumulated_context = [] accumulated_memories = [] + search_keywords = [] # Can be extended with keyword extraction - # Step 3: Iterative Search and Reflection Loop + # Step 2: Iterative Search and Reflection Loop for iteration in range(self.max_iterations): + logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") + search_results = self._perform_memory_search( current_query, keywords=search_keywords, @@ -206,11 +209,13 @@ def run(self, input: str, **kwargs) -> str: current_query = query else: break + + # Step 3: Generate final answer final_answer = self._generate_final_answer( original_query=query, search_results=accumulated_memories, context=accumulated_context, - missing_info=keyword_analysis.get("search_strategy", "") + missing_info="" ) logger.info("Deep search pipeline completed") diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index be656acc0..33e7b6a6b 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -2,20 +2,40 @@ from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent -from memos.mem_agent.deepsearch_agent import DeepSearchAgent +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent class MemAgentFactory: """Factory class for creating MemAgent instances.""" backend_to_class: ClassVar[dict[str, Any]] = { - "deep_search": DeepSearchAgent, + "deep_search": DeepSearchMemAgent, } @classmethod - def from_config(cls, config_factory: MemAgentConfigFactory) -> BaseMemAgent: + def from_config( + cls, + config_factory: MemAgentConfigFactory, + llm: Any, + memory_retriever: Any = None + ) -> BaseMemAgent: + """ + Create a MemAgent instance from configuration. + + Args: + config_factory: Configuration factory for the agent + llm: Language model instance + memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) + + Returns: + Initialized MemAgent instance + """ backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") mem_agent_class = cls.backend_to_class[backend] - return mem_agent_class(config_factory.config) \ No newline at end of file + return mem_agent_class( + llm=llm, + memory_retriever=memory_retriever, + config=config_factory.config + ) \ No newline at end of file From a497d4614c8bc39b2f868a28952e7add5692e74e Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:12:07 +0800 Subject: [PATCH 05/35] feat: update mem scehduler --- examples/deepsearch_example.py | 76 ------ examples/mem_agent/deepsearch_example.py | 156 ++++++------ .../deepsearch_with_api_components.py | 223 ------------------ src/memos/configs/mem_agent.py | 20 +- src/memos/mem_agent/base.py | 7 +- src/memos/mem_agent/deepsearch_agent.py | 149 ++++++------ src/memos/mem_agent/factory.py | 17 +- src/memos/templates/mem_agent_prompts.py | 2 +- 8 files changed, 169 insertions(+), 481 deletions(-) delete mode 100644 examples/deepsearch_example.py delete mode 100644 examples/mem_agent/deepsearch_with_api_components.py diff --git a/examples/deepsearch_example.py b/examples/deepsearch_example.py deleted file mode 100644 index 43bf4c012..000000000 --- a/examples/deepsearch_example.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Example usage of DeepSearchAgent in MemOS. - -This example demonstrates how to initialize and use the DeepSearchAgent -for comprehensive information retrieval and synthesis. -""" - -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchAgent -from memos.llms.factory import LLMFactory -from memos.configs.llm import LLMConfigFactory - - -def main(): - """Example usage of DeepSearchAgent.""" - - # 1. Configure the LLM - llm_config = LLMConfigFactory( - backend="openai", # or "ollama", "azure", etc. - config={ - "api_key": "your-api-key-here", - "model_name_or_path": "gpt-4", - "temperature": 0.7, - "max_tokens": 2048 - } - ) - - # 2. Create LLM instance - llm = LLMFactory.from_config(llm_config) - - # 3. Configure DeepSearchAgent - agent_config = DeepSearchAgentConfig( - agent_name="DeepSearchAgent", - description="Advanced deep search agent for comprehensive information retrieval", - max_iterations=3, - timeout=60 - ) - - # 4. Create DeepSearchAgent instance - agent = DeepSearchAgent(agent_config) - - # 5. Initialize the agent with LLM - agent.set_llm(llm) - - # 6. Set up memory retriever (this would typically be injected by the framework) - # agent.set_memory_retriever(your_memory_retriever) - - # 7. Example queries - queries = [ - "What are the latest developments in AI research?", - "Tell me about my recent project meetings and their outcomes", - "What are the key trends in machine learning this year?" - ] - - # 8. Process queries - for query in queries: - print(f"\n{'='*60}") - print(f"Query: {query}") - print(f"{'='*60}") - - try: - # Run the deep search pipeline - response = agent.run( - input=query, - history=["Previous conversation context"], - user_id="example_user" - ) - - print(f"Response: {response}") - - except Exception as e: - print(f"Error processing query: {e}") - - -if __name__ == "__main__": - main() diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 62f26f0b1..dacfe7f1e 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -13,25 +13,24 @@ """ import os -import sys from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphConfigFactory from memos.configs.llms import LLMConfigFactory -from memos.configs.mem_agent import DeepSearchAgentConfig, MemAgentConfigFactory +from memos.configs.mem_agent import MemAgentConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory from memos.configs.textual_memory import TreeTextMemoryConfig -from memos.embedders.base import BaseEmbedder from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory +from memos.log import get_logger from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.mem_agent.factory import MemAgentFactory from memos.mem_cube.naive_cube import NaiveMemCube from memos.mem_reader.factory import MemReaderFactory from memos.memories.memory_manager.memory_manager import MemoryManager from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory -from memos.log import get_logger + logger = get_logger(__name__) @@ -39,16 +38,16 @@ def build_minimal_components(): """ Build the minimal component set required for DeepSearchMemAgent. - + Only need to initialize: 1. LLM - Used for query rewriting, reflection, and final answer generation 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval - + Returns: dict: Dictionary containing llm and naive_mem_cube """ logger.info("Starting to build minimal component set...") - + # 1. Initialize LLM llm_config = LLMConfigFactory( backend="openai", @@ -57,11 +56,11 @@ def build_minimal_components(): "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), "temperature": 0.7, - } + }, ) llm = LLMFactory.from_config(llm_config) logger.info("LLM initialization completed") - + # 2. Initialize Embedder (required by text_mem) embedder_config = EmbedderConfigFactory( backend="universal_api", @@ -69,11 +68,11 @@ def build_minimal_components(): "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), - } + }, ) embedder = EmbedderFactory.from_config(embedder_config) logger.info("Embedder initialization completed") - + # 3. Initialize GraphDB (required by text_mem) graph_db_config = GraphConfigFactory( backend="polardb", @@ -87,11 +86,11 @@ def build_minimal_components(): "use_multi_db": False, "auto_create": True, "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - } + }, ) graph_db = GraphStoreFactory.from_config(graph_db_config) logger.info("✓ GraphDB init Done") - + # 4. Initialize MemReader (required by text_mem) mem_reader_config = MemReaderConfigFactory( backend="simple_struct", @@ -102,13 +101,13 @@ def build_minimal_components(): "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), - } + }, } - } + }, ) mem_reader = MemReaderFactory.from_config(mem_reader_config) logger.info("✓ MemReader init Done") - + # 5. Initialize MemoryManager memory_manager = MemoryManager( graph_db=graph_db, @@ -118,7 +117,7 @@ def build_minimal_components(): is_reorganize=False, ) logger.info("✓ MemoryManager init Done") - + # 6. Initialize TextMemory text_mem_config = TreeTextMemoryConfig( reorganize=False, @@ -135,7 +134,7 @@ def build_minimal_components(): internet_retriever=None, ) logger.info("✓ TextMemory initialization completed") - + # 7. Create NaiveMemCube naive_mem_cube = NaiveMemCube( text_mem=text_mem, @@ -144,9 +143,9 @@ def build_minimal_components(): para_mem=None, ) logger.info("✓ NaiveMemCube creation completed") - + logger.info("All components initialized!") - + return { "llm": llm, "naive_mem_cube": naive_mem_cube, @@ -159,21 +158,21 @@ def build_minimal_components(): def example_1_factory_initialization(): """ Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) - + Advantages: - Separation of configuration and code - Easy to manage and modify - Support loading from configuration files """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 1: Initialize using Factory pattern") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + # Build necessary components components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - + # Create configuration Factory agent_config_factory = MemAgentConfigFactory( backend="deep_search", @@ -182,139 +181,141 @@ def example_1_factory_initialization(): "description": "Intelligent agent for deep search", "max_iterations": 3, # Maximum number of iterations "timeout": 60, # Timeout in seconds - } + }, ) - + # Create Agent using Factory # Pass text_mem as memory_retriever, it provides search method deep_search_agent = MemAgentFactory.from_config( - config_factory=agent_config_factory, - llm=llm, - memory_retriever=naive_mem_cube.text_mem + config_factory=agent_config_factory, llm=llm, memory_retriever=naive_mem_cube.text_mem ) - - logger.info(f"✓ DeepSearchMemAgent created successfully") + + logger.info("✓ DeepSearchMemAgent created successfully") logger.info(f" - Agent name: {deep_search_agent.config.agent_name}") logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - + return deep_search_agent, components def example_3_usage(deep_search_agent, components): """ Example 3: Using DeepSearchMemAgent for search - + Demonstrates how to: 1. Add memories to the system 2. Use Agent for deep search 3. Get comprehensive answers """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 3: Using DeepSearchMemAgent") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + naive_mem_cube = components["naive_mem_cube"] text_mem = naive_mem_cube.text_mem - + # Simulate adding some memories logger.info("1. Adding test memories...") test_memories = [ { "user_name": "test_user", - "messages": ["Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence."], - "source": "manual" + "messages": [ + "Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence." + ], + "source": "manual", }, { - "user_name": "test_user", - "messages": ["Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance."], - "source": "manual" + "user_name": "test_user", + "messages": [ + "Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance." + ], + "source": "manual", }, { "user_name": "test_user", - "messages": ["Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks."], - "source": "manual" - } + "messages": [ + "Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks." + ], + "source": "manual", + }, ] - + for memory in test_memories: try: text_mem.add( user_name=memory["user_name"], messages=memory["messages"], - source=memory.get("source", "manual") + source=memory.get("source", "manual"), ) logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") except Exception as e: logger.warning(f" ✗ Failed to add memory: {e}") - + # Use Agent for search logger.info("\n2. Executing deep search...") query = "What is the relationship between artificial intelligence and machine learning?" logger.info(f" Query: {query}") - + try: response = deep_search_agent.run( query=query, user_id="test_user", - history=[ - "Hello", - "I want to learn about artificial intelligence" - ] + history=["Hello", "I want to learn about artificial intelligence"], ) - + logger.info("\n3. Search results:") logger.info("-" * 60) logger.info(response) logger.info("-" * 60) - + except Exception as e: logger.error(f"Error during search: {e}") import traceback + traceback.print_exc() def example_4_minimal_initialization(): """ Example 4: Minimal initialization (using default configuration) - + Use cases: - Quick testing - Prototype development - No need for custom configuration """ - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Example 4: Minimal initialization (default configuration)") - logger.info("="*60 + "\n") - + logger.info("=" * 60 + "\n") + # Build necessary components components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - + # Direct initialization using default configuration deep_search_agent = DeepSearchMemAgent( llm=llm, - memory_retriever=naive_mem_cube.text_mem + memory_retriever=naive_mem_cube.text_mem, # config parameter omitted, will use default configuration ) - - logger.info(f"✓ DeepSearchMemAgent created successfully (using default configuration)") + + logger.info("✓ DeepSearchMemAgent created successfully (using default configuration)") logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - + return deep_search_agent, components def main(): """Main function: Run all examples""" logger.info("DeepSearch Agent Usage Examples") - logger.info("="*60) - + logger.info("=" * 60) + # Check environment variables required_env_vars = ["OPENAI_API_KEY"] missing_vars = [var for var in required_env_vars if not os.getenv(var)] - + if missing_vars: logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") logger.error("Please set the following environment variables:") @@ -322,29 +323,30 @@ def main(): logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") return - + try: # Run Example 1: Factory pattern agent_factory, components_factory = example_1_factory_initialization() - + # Run Example 2: Direct initialization agent_direct, components_direct = example_2_direct_initialization() - + # Run Example 4: Minimal initialization agent_minimal, components_minimal = example_4_minimal_initialization() - + # Run Example 3: Actual usage (using agent created with factory method) example_3_usage(agent_factory, components_factory) - - logger.info("\n" + "="*60) + + logger.info("\n" + "=" * 60) logger.info("All examples completed!") - logger.info("="*60) - + logger.info("=" * 60) + except Exception as e: logger.error(f"Error running examples: {e}") import traceback + traceback.print_exc() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/mem_agent/deepsearch_with_api_components.py b/examples/mem_agent/deepsearch_with_api_components.py deleted file mode 100644 index 48e4df43c..000000000 --- a/examples/mem_agent/deepsearch_with_api_components.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -DeepSearch Agent 与 API 组件集成示例 - -本示例展示如何直接使用 API 服务器初始化的组件来创建 DeepSearch Agent。 -这种方式可以避免重复初始化,直接复用已有的组件。 - -适用场景: -- 在已有的 API 服务器中添加 DeepSearch 功能 -- 使用统一的组件配置 -- 避免重复初始化开销 -""" - -import os -import sys - -# 确保可以导入 memos 模块 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) - -from memos.api.handlers.component_init import init_server -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent -from memos.log import get_logger - -logger = get_logger(__name__) - - -def create_deepsearch_from_api_components(): - """ - 从 API 服务器组件创建 DeepSearch Agent - - 这个函数演示了如何: - 1. 使用 init_server() 初始化所有组件 - 2. 提取需要的组件(llm 和 naive_mem_cube) - 3. 创建 DeepSearch Agent - """ - logger.info("="*60) - logger.info("使用 API 组件初始化 DeepSearch Agent") - logger.info("="*60 + "\n") - - # 步骤 1: 初始化所有服务器组件 - # 这会创建 llm, naive_mem_cube, scheduler 等所有组件 - logger.info("1. 初始化服务器组件...") - components = init_server() - logger.info(" ✓ 服务器组件初始化完成") - - # 步骤 2: 提取需要的组件 - logger.info("\n2. 提取必需组件...") - llm = components["llm"] - naive_mem_cube = components["naive_mem_cube"] - logger.info(f" ✓ LLM: {type(llm).__name__}") - logger.info(f" ✓ MemCube: {type(naive_mem_cube).__name__}") - logger.info(f" ✓ TextMemory: {type(naive_mem_cube.text_mem).__name__}") - - # 步骤 3: 创建 DeepSearch Agent 配置 - logger.info("\n3. 创建 DeepSearch Agent...") - config = DeepSearchAgentConfig( - agent_name="APIDeepSearchAgent", - description="基于 API 组件的深度搜索代理", - max_iterations=3, - timeout=60, - ) - - # 步骤 4: 初始化 DeepSearch Agent - # memory_retriever 使用 naive_mem_cube.text_mem - # 它提供了 search() 方法用于检索记忆 - deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - config=config - ) - - logger.info(f" ✓ DeepSearch Agent 创建成功") - logger.info(f" ✓ Agent 名称: {config.agent_name}") - logger.info(f" ✓ 最大迭代次数: {config.max_iterations}") - - return deep_search_agent, components - - -def demo_usage(deep_search_agent, components): - """ - 演示如何使用 DeepSearch Agent - """ - logger.info("\n" + "="*60) - logger.info("DeepSearch Agent 使用演示") - logger.info("="*60 + "\n") - - naive_mem_cube = components["naive_mem_cube"] - text_mem = naive_mem_cube.text_mem - - # 添加一些测试记忆 - logger.info("1. 添加测试记忆...") - test_data = [ - "MemOS 是一个先进的记忆操作系统,专门为 AI 系统设计。", - "MemOS 支持多种记忆类型:文本记忆、偏好记忆、行为记忆等。", - "DeepSearch Agent 是 MemOS 中的一个重要组件,用于深度搜索和信息检索。", - ] - - for i, content in enumerate(test_data, 1): - try: - text_mem.add( - user_name="demo_user", - messages=[content], - source="demo" - ) - logger.info(f" ✓ 记忆 {i}: {content[:40]}...") - except Exception as e: - logger.warning(f" ✗ 添加记忆失败: {e}") - - # 执行深度搜索 - logger.info("\n2. 执行深度搜索...") - query = "MemOS 支持哪些功能?" - logger.info(f" 查询: {query}") - - try: - response = deep_search_agent.run( - query=query, - user_id="demo_user", - history=[] - ) - - logger.info("\n3. 搜索结果:") - logger.info("-" * 60) - logger.info(response) - logger.info("-" * 60) - - return response - - except Exception as e: - logger.error(f" ✗ 搜索失败: {e}") - import traceback - traceback.print_exc() - return None - - -def integrate_with_api_router(): - """ - 展示如何将 DeepSearch Agent 集成到 API Router 中 - - 这段代码展示了在 server_router.py 中如何添加一个新的端点 - 来使用 DeepSearch Agent。 - """ - logger.info("\n" + "="*60) - logger.info("API Router 集成示例") - logger.info("="*60 + "\n") - - logger.info("在 server_router.py 中添加以下代码:\n") - - integration_code = ''' -# 在 server_router.py 顶部导入 -from memos.configs.mem_agent import DeepSearchAgentConfig -from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent - -# 在初始化 handlers 之后,创建 DeepSearch Agent -deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - config=DeepSearchAgentConfig( - agent_name="APIDeepSearchAgent", - max_iterations=3, - timeout=60, - ) -) - -# 添加新的 API 端点 -@router.post("/deepsearch", summary="Deep search with memory") -def deep_search( - user_id: str, - query: str, - history: list[str] = None -): - """Execute deep search with iterative memory retrieval.""" - try: - response = deep_search_agent.run( - query=query, - user_id=user_id, - history=history or [] - ) - return {"status": "success", "response": response} - except Exception as e: - return {"status": "error", "message": str(e)} -''' - - print(integration_code) - - -def main(): - """主函数""" - logger.info("DeepSearch Agent 与 API 组件集成示例\n") - - # 检查环境变量 - required_env_vars = ["OPENAI_API_KEY"] - missing_vars = [var for var in required_env_vars if not os.getenv(var)] - - if missing_vars: - logger.error(f"❌ 缺少必要的环境变量: {', '.join(missing_vars)}") - logger.error("\n请设置以下环境变量:") - logger.error(" export OPENAI_API_KEY='your-api-key'") - logger.error(" export OPENAI_BASE_URL='your-base-url' # 可选") - return - - try: - # 创建 DeepSearch Agent - agent, components = create_deepsearch_from_api_components() - - # 演示使用 - demo_usage(agent, components) - - # 展示集成方法 - integrate_with_api_router() - - logger.info("\n" + "="*60) - logger.info("示例运行完成!") - logger.info("="*60) - - except Exception as e: - logger.error(f"\n❌ 运行失败: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() - diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py index e58f8e67c..3abe5d64d 100644 --- a/src/memos/configs/mem_agent.py +++ b/src/memos/configs/mem_agent.py @@ -9,9 +9,7 @@ class BaseAgentConfig(BaseConfig): """Base configuration class for agents.""" agent_name: str = Field(..., description="Name of the agent") - description: str | None = Field( - default=None, description="Description of the agent" - ) + description: str | None = Field(default=None, description="Description of the agent") class SimpleAgentConfig(BaseAgentConfig): @@ -20,19 +18,15 @@ class SimpleAgentConfig(BaseAgentConfig): max_iterations: int = Field( default=10, description="Maximum number of iterations for the agent" ) - timeout: int = Field( - default=30, description="Timeout in seconds for agent execution" - ) + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + class DeepSearchAgentConfig(BaseAgentConfig): """Deep search agent configuration class.""" - max_iterations: int = Field( - default=3, description="Maximum number of iterations for the agent" - ) - timeout: int = Field( - default=30, description="Timeout in seconds for agent execution" - ) + max_iterations: int = Field(default=3, description="Maximum number of iterations for the agent") + timeout: int = Field(default=30, description="Timeout in seconds for agent execution") + class MemAgentConfigFactory(BaseConfig): """Factory class for creating agent configurations.""" @@ -57,4 +51,4 @@ def validate_backend(cls, backend: str) -> str: def create_config(self) -> "AgentConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) - return self \ No newline at end of file + return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py index 3b54a846f..e4746cad6 100644 --- a/src/memos/mem_agent/base.py +++ b/src/memos/mem_agent/base.py @@ -1,12 +1,13 @@ -from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod -from pydantic import BaseModel + from memos.configs.mem_agent import BaseAgentConfig + class BaseMemAgent(ABC): """ Base class for all agents. """ + def __init__(self, config: BaseAgentConfig): """Initialize the BaseMemAgent with the given configuration.""" @@ -14,4 +15,4 @@ def __init__(self, config: BaseAgentConfig): def run(self, input: str) -> str: """ Run the agent. - """ \ No newline at end of file + """ diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 6dd072728..70212921a 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -6,21 +6,21 @@ """ import json -import logging -from memos.log import get_logger -from typing import Any, Dict, List, Optional, Tuple + +from typing import Optional, Any from memos.configs.mem_agent import DeepSearchAgentConfig from memos.llms.base import BaseLLM +from memos.log import get_logger from memos.mem_agent.base import BaseMemAgent from memos.memories.textual.item import TextualMemoryItem -from memos.types import MessageDict, MessageList from memos.templates.mem_agent_prompts import ( + FINAL_GENERATION_PROMPT, QUERY_REWRITE_PROMPT, REFLECTION_PROMPT, - KEYWORD_EXTRACTION_PROMPT, - FINAL_GENERATION_PROMPT - ) +) +from memos.types import MessageList + logger = get_logger(__name__) @@ -35,29 +35,28 @@ def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): self.llm = llm self.name = name - def run(self, query: str, history: list[str]| None = None) -> str: + def run(self, query: str, history: list[str] | None = None) -> str: """ Rewrite the query to be standalone and more searchable. - + Args: query: Original user query history: List of previous conversation messages - + Returns: Rewritten query string """ if history is None: history = [] - + history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages - + prompt = QUERY_REWRITE_PROMPT.format( - history=history_str if history_str else "No previous conversation", - query=query + history=history_str if history_str else "No previous conversation", query=query ) - + messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) logger.info(f"[{self.name}] Rewritten query: {response.strip()}") @@ -77,23 +76,24 @@ def __init__(self, llm: BaseLLM, name: str = "Reflector"): self.llm = llm self.name = name - def run(self, query: str, context: List[str]) -> Dict[str, Any]: + def run(self, query: str, context: list[str]) -> dict[str, Any]: """ Analyze whether retrieved context is sufficient to answer the query. - + Args: query: User query context: List of retrieved context strings - + Returns: Dictionary with status, reasoning, and missing entities """ - context_str = "\n".join([f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" - for ctx in context[:10]]) # Limit context size - + context_str = "\n".join( + [f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]] + ) # Limit context size + prompt = REFLECTION_PROMPT.format(query=query, context=context_str) messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) result = json.loads(response.strip()) @@ -105,27 +105,27 @@ def run(self, query: str, context: List[str]) -> Dict[str, Any]: return { "status": "sufficient", "reasoning": "Unable to analyze, proceeding with available information", - "missing_entities": [] + "missing_entities": [], } class DeepSearchMemAgent(BaseMemAgent): """ Main orchestrator agent implementing the deep search pipeline. - + This agent coordinates multiple sub-agents to perform iterative query refinement, memory retrieval, and information synthesis as shown in the architecture diagram. """ def __init__( - self, - llm: BaseLLM, - memory_retriever: Any = None, - config: Optional[DeepSearchAgentConfig] = None + self, + llm: BaseLLM, + memory_retriever: Optional[Any] = None, + config: DeepSearchAgentConfig | None = None, ): """ Initialize DeepSearchMemAgent. - + Args: llm: Language model for query rewriting and response generation memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) @@ -147,7 +147,7 @@ def _set_memory_retriever(self, retriever) -> None: def run(self, query: str, **kwargs) -> str: """ Main execution method implementing the deep search pipeline. - + Args: query: User query string **kwargs: Additional arguments (history, user_id, etc.) @@ -156,38 +156,36 @@ def run(self, query: str, **kwargs) -> str: """ if not self.llm: raise RuntimeError("LLM not initialized.") - + history = kwargs.get("history", []) user_id = kwargs.get("user_id") - + # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) - + accumulated_context = [] accumulated_memories = [] search_keywords = [] # Can be extended with keyword extraction - + # Step 2: Iterative Search and Reflection Loop for iteration in range(self.max_iterations): logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") - + search_results = self._perform_memory_search( - current_query, - keywords=search_keywords, - user_id=user_id + current_query, keywords=search_keywords, user_id=user_id ) - + if search_results: context_batch = [self._extract_context_from_memory(mem) for mem in search_results] accumulated_context.extend(context_batch) accumulated_memories.extend(search_results) - + reflection_result = self.reflector.run(current_query, context_batch) status = reflection_result.get("status", "sufficient") reasoning = reflection_result.get("reasoning", "") - + logger.info(f"Reflection status: {status} - {reasoning}") - + if status == "sufficient": logger.info("Sufficient information collected") break @@ -209,114 +207,111 @@ def run(self, query: str, **kwargs) -> str: current_query = query else: break - + # Step 3: Generate final answer final_answer = self._generate_final_answer( original_query=query, search_results=accumulated_memories, context=accumulated_context, - missing_info="" + missing_info="", ) - + logger.info("Deep search pipeline completed") return final_answer def _perform_memory_search( - self, - query: str, - keywords: List[str] = None, - user_id: str = None, - top_k: int = 10 - ) -> List[TextualMemoryItem]: + self, query: str, keywords: Optional[list[str]] = None, user_id: Optional[str] = None, top_k: int = 10 + ) -> list[TextualMemoryItem]: """ Perform memory search using the configured retriever. - + Args: query: Search query keywords: Additional keywords for search user_id: User identifier top_k: Number of results to retrieve - + Returns: List of retrieved memory items """ if not self.memory_retriever: logger.warning("Memory retriever not configured, returning empty results") return [] - + try: # Use the memory retriever interface # This is a placeholder - actual implementation depends on the retriever interface search_query = query if keywords and len(keywords) > 1: search_query = f"{query} {' '.join(keywords[:3])}" # Combine with top keywords - + # Assuming the retriever has a search method similar to TreeTextMemory results = self.memory_retriever.search( - query=search_query, - top_k=top_k, - mode="fast", - user_name=user_id + query=search_query, top_k=top_k, mode="fast", user_name=user_id ) - + return results if isinstance(results, list) else [] - + except Exception as e: logger.error(f"Error performing memory search: {e}") return [] def _extract_context_from_memory(self, memory_item: TextualMemoryItem) -> str: """Extract readable context from a memory item.""" - if hasattr(memory_item, 'memory'): + if hasattr(memory_item, "memory"): return str(memory_item.memory) - elif hasattr(memory_item, 'content'): + elif hasattr(memory_item, "content"): return str(memory_item.content) else: return str(memory_item) - def _refine_query_for_missing_info(self, query: str, missing_entities: List[str]) -> str: + def _refine_query_for_missing_info(self, query: str, missing_entities: list[str]) -> str: """Refine the query to search for missing information.""" if not missing_entities: return query - + # Simple refinement strategy - append missing entities entities_str = " ".join(missing_entities[:3]) # Limit to top 3 entities refined_query = f"{query} {entities_str}" - + return refined_query def _generate_final_answer( self, original_query: str, - search_results: List[TextualMemoryItem], - context: List[str], - missing_info: str = "" + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", ) -> str: """ Generate the final comprehensive answer. - + Args: original_query: Original user query search_results: All retrieved memory items context: Extracted context strings missing_info: Information about missing data - + Returns: Final answer string """ # Prepare context for the prompt context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) # Limit context - sources = f"Retrieved {len(search_results)} memory items" if search_results else "No specific sources" - + sources = ( + f"Retrieved {len(search_results)} memory items" + if search_results + else "No specific sources" + ) + prompt = FINAL_GENERATION_PROMPT.format( query=original_query, sources=sources, context=context_str if context_str else "No specific context retrieved", - missing_info=missing_info if missing_info else "None identified" + missing_info=missing_info if missing_info else "None identified", ) - + messages: MessageList = [{"role": "user", "content": prompt}] - + try: response = self.llm.generate(messages) return response.strip() diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index 33e7b6a6b..415d4916e 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar +from typing import Optional, Any, ClassVar from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent @@ -14,19 +14,16 @@ class MemAgentFactory: @classmethod def from_config( - cls, - config_factory: MemAgentConfigFactory, - llm: Any, - memory_retriever: Any = None + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Optional[Any] = None ) -> BaseMemAgent: """ Create a MemAgent instance from configuration. - + Args: config_factory: Configuration factory for the agent llm: Language model instance memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) - + Returns: Initialized MemAgent instance """ @@ -35,7 +32,5 @@ def from_config( raise ValueError(f"Invalid backend: {backend}") mem_agent_class = cls.backend_to_class[backend] return mem_agent_class( - llm=llm, - memory_retriever=memory_retriever, - config=config_factory.config - ) \ No newline at end of file + llm=llm, memory_retriever=memory_retriever, config=config_factory.config + ) diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 399b2cd75..17e3c0ad9 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -73,4 +73,4 @@ 4. Provide a well-structured, coherent response 5. Use natural, conversational tone -Response:""" \ No newline at end of file +Response:""" From bd72e9bb804d84f1a12e4c817bae494f0359d600 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 21 Nov 2025 15:33:53 +0800 Subject: [PATCH 06/35] feat: update deepsearch mem code --- examples/mem_agent/deepsearch_example.py | 3 --- src/memos/configs/mem_agent.py | 2 +- src/memos/mem_agent/base.py | 1 + src/memos/mem_agent/deepsearch_agent.py | 15 +++++++++++---- src/memos/mem_agent/factory.py | 4 ++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index dacfe7f1e..d0842b6ec 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -328,9 +328,6 @@ def main(): # Run Example 1: Factory pattern agent_factory, components_factory = example_1_factory_initialization() - # Run Example 2: Direct initialization - agent_direct, components_direct = example_2_direct_initialization() - # Run Example 4: Minimal initialization agent_minimal, components_minimal = example_4_minimal_initialization() diff --git a/src/memos/configs/mem_agent.py b/src/memos/configs/mem_agent.py index 3abe5d64d..7cb623899 100644 --- a/src/memos/configs/mem_agent.py +++ b/src/memos/configs/mem_agent.py @@ -48,7 +48,7 @@ def validate_backend(cls, backend: str) -> str: return backend @model_validator(mode="after") - def create_config(self) -> "AgentConfigFactory": + def create_config(self) -> "MemAgentConfigFactory": config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self diff --git a/src/memos/mem_agent/base.py b/src/memos/mem_agent/base.py index e4746cad6..daa5f075b 100644 --- a/src/memos/mem_agent/base.py +++ b/src/memos/mem_agent/base.py @@ -10,6 +10,7 @@ class BaseMemAgent(ABC): def __init__(self, config: BaseAgentConfig): """Initialize the BaseMemAgent with the given configuration.""" + self.config = config @abstractmethod def run(self, input: str) -> str: diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 70212921a..6c6fc32a9 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -7,7 +7,7 @@ import json -from typing import Optional, Any +from typing import TYPE_CHECKING, Any from memos.configs.mem_agent import DeepSearchAgentConfig from memos.llms.base import BaseLLM @@ -19,7 +19,10 @@ QUERY_REWRITE_PROMPT, REFLECTION_PROMPT, ) -from memos.types import MessageList + + +if TYPE_CHECKING: + from memos.types import MessageList logger = get_logger(__name__) @@ -120,7 +123,7 @@ class DeepSearchMemAgent(BaseMemAgent): def __init__( self, llm: BaseLLM, - memory_retriever: Optional[Any] = None, + memory_retriever: Any | None = None, config: DeepSearchAgentConfig | None = None, ): """ @@ -220,7 +223,11 @@ def run(self, query: str, **kwargs) -> str: return final_answer def _perform_memory_search( - self, query: str, keywords: Optional[list[str]] = None, user_id: Optional[str] = None, top_k: int = 10 + self, + query: str, + keywords: list[str] | None = None, + user_id: str | None = None, + top_k: int = 10, ) -> list[TextualMemoryItem]: """ Perform memory search using the configured retriever. diff --git a/src/memos/mem_agent/factory.py b/src/memos/mem_agent/factory.py index 415d4916e..09537bd8a 100644 --- a/src/memos/mem_agent/factory.py +++ b/src/memos/mem_agent/factory.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, ClassVar +from typing import Any, ClassVar from memos.configs.mem_agent import MemAgentConfigFactory from memos.mem_agent.base import BaseMemAgent @@ -14,7 +14,7 @@ class MemAgentFactory: @classmethod def from_config( - cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Optional[Any] = None + cls, config_factory: MemAgentConfigFactory, llm: Any, memory_retriever: Any | None = None ) -> BaseMemAgent: """ Create a MemAgent instance from configuration. From f332ef22c3a2a29e2c1820de3916e45d70e4d33a Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 22 Nov 2025 18:13:04 +0800 Subject: [PATCH 07/35] feat: update deepsearch agent --- examples/mem_agent/deepsearch_example.py | 358 +++++++---------------- src/memos/mem_agent/deepsearch_agent.py | 200 ++++++++----- src/memos/templates/mem_agent_prompts.py | 5 +- 3 files changed, 227 insertions(+), 336 deletions(-) diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index d0842b6ec..6a9405456 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -1,35 +1,36 @@ """ -DeepSearch Agent Usage Examples +DeepSearch Agent Usage Examples - Simplified Version -This example demonstrates two ways to initialize DeepSearchMemAgent: -1. Using Factory pattern (recommended) -2. Direct initialization - -DeepSearchMemAgent implements iterative deep search, providing comprehensive answers through: -- Query rewriting: Optimize queries based on conversation history -- Iterative retrieval: Collect information through multiple search rounds -- Reflective analysis: Determine if information is sufficient -- Comprehensive response: Generate complete final answers +This example demonstrates simplified initialization of DeepSearchMemAgent without +external config builders, using APIConfig methods directly. """ import os +from typing import Any + +from memos.api.config import APIConfig from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.graph_db import GraphConfigFactory -from memos.configs.llms import LLMConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_agent import MemAgentConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory -from memos.configs.textual_memory import TreeTextMemoryConfig +from memos.configs.reranker import RerankerConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory from memos.log import get_logger from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.mem_agent.factory import MemAgentFactory -from memos.mem_cube.naive_cube import NaiveMemCube +from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.factory import MemReaderFactory -from memos.memories.memory_manager.memory_manager import MemoryManager -from memos.memories.textual.simple_tree_memory import SimpleTreeTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory logger = get_logger(__name__) @@ -37,114 +38,105 @@ def build_minimal_components(): """ - Build the minimal component set required for DeepSearchMemAgent. - - Only need to initialize: - 1. LLM - Used for query rewriting, reflection, and final answer generation - 2. NaiveMemCube - Provides text_mem.search interface for memory retrieval + Build minimal components for DeepSearchMemAgent with simplified configuration. - Returns: - dict: Dictionary containing llm and naive_mem_cube + This function creates all necessary components using APIConfig methods, + similar to config_builders.py but inline for easier customization. """ - logger.info("Starting to build minimal component set...") + logger.info("Initializing simplified MemOS components...") - # 1. Initialize LLM - llm_config = LLMConfigFactory( - backend="openai", - config={ - "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), - "api_key": os.getenv("OPENAI_API_KEY"), - "base_url": os.getenv("OPENAI_BASE_URL"), - "temperature": 0.7, - }, + # Build component configurations using APIConfig methods (like config_builders.py) + + # Graph DB configuration - using APIConfig.get_nebular_config() + graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() + graph_db_backend_map = { + "polardb": APIConfig.get_polardb_config(), + } + graph_db_config = GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } ) - llm = LLMFactory.from_config(llm_config) - logger.info("LLM initialization completed") - # 2. Initialize Embedder (required by text_mem) - embedder_config = EmbedderConfigFactory( - backend="universal_api", - config={ - "model_name": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-small"), - "api_key": os.getenv("MOS_EMBEDDER_API_KEY"), - "base_url": os.getenv("MOS_EMBEDDER_API_BASE"), - }, + # LLM configuration - using APIConfig.get_openai_config() + llm_config = LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } ) - embedder = EmbedderFactory.from_config(embedder_config) - logger.info("Embedder initialization completed") - # 3. Initialize GraphDB (required by text_mem) - graph_db_config = GraphConfigFactory( - backend="polardb", - config={ - "host": os.getenv("POLAR_DB_HOST", "localhost"), - "port": int(os.getenv("POLAR_DB_PORT", "5432")), - "user": os.getenv("POLAR_DB_USER", "root"), - "password": os.getenv("POLAR_DB_PASSWORD", "123456"), - "db_name": os.getenv("POLAR_DB_DB_NAME", "shared_memos_db"), - "user_name": "memos_default", - "use_multi_db": False, - "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), - }, + # Embedder configuration - using APIConfig.get_embedder_config() + embedder_config = EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + # Memory reader configuration - using APIConfig.get_product_default_config() + mem_reader_config = MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] ) - graph_db = GraphStoreFactory.from_config(graph_db_config) - logger.info("✓ GraphDB init Done") - # 4. Initialize MemReader (required by text_mem) - mem_reader_config = MemReaderConfigFactory( - backend="simple_struct", - config={ - "llm": { - "backend": "openai", - "config": { - "model_name": os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini"), - "api_key": os.getenv("OPENAI_API_KEY"), - "base_url": os.getenv("OPENAI_BASE_URL"), - }, - } - }, + # Reranker configuration - using APIConfig.get_reranker_config() + reranker_config = RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + # Internet retriever configuration - using APIConfig.get_internet_config() + internet_retriever_config = InternetRetrieverConfigFactory.model_validate( + APIConfig.get_internet_config() ) - mem_reader = MemReaderFactory.from_config(mem_reader_config) - logger.info("✓ MemReader init Done") - # 5. Initialize MemoryManager - memory_manager = MemoryManager( - graph_db=graph_db, - embedder=embedder, - llm=llm, - memory_size=1000, - is_reorganize=False, + logger.debug("Component configurations built successfully") + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder ) - logger.info("✓ MemoryManager init Done") - # 6. Initialize TextMemory - text_mem_config = TreeTextMemoryConfig( - reorganize=False, - max_depth=3, + logger.debug("Core components instantiated") + + # Get default cube configuration like component_init.py + default_cube_config = APIConfig.get_default_cube_config() + + # Get default memory size from cube config (like component_init.py) + def get_memory_size_from_config(cube_config): + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + memory_size = get_memory_size_from_config(default_cube_config) + is_reorganize = getattr(default_cube_config.text_mem.config, "reorganize", False) + + # Initialize memory manager with config from APIConfig + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=memory_size, + is_reorganize=is_reorganize, ) + text_memory_config = default_cube_config.text_mem.config text_mem = SimpleTreeTextMemory( llm=llm, embedder=embedder, mem_reader=mem_reader, graph_db=graph_db, - reranker=None, + reranker=reranker, memory_manager=memory_manager, - config=text_mem_config, - internet_retriever=None, + config=text_memory_config, + internet_retriever=internet_retriever, ) - logger.info("✓ TextMemory initialization completed") - # 7. Create NaiveMemCube naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=None, + pref_mem=None, # Simplified: no preference memory act_mem=None, para_mem=None, ) - logger.info("✓ NaiveMemCube creation completed") - - logger.info("All components initialized!") return { "llm": llm, @@ -155,30 +147,18 @@ def build_minimal_components(): } -def example_1_factory_initialization(): - """ - Example 1: Initialize DeepSearchMemAgent using Factory pattern (recommended) - - Advantages: - - Separation of configuration and code - - Easy to manage and modify - - Support loading from configuration files - """ - logger.info("\n" + "=" * 60) - logger.info("Example 1: Initialize using Factory pattern") - logger.info("=" * 60 + "\n") - - # Build necessary components +def factory_initialization() -> tuple[DeepSearchMemAgent, dict[str, Any]]: + # Build necessary components with simplified setup components = build_minimal_components() llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] - # Create configuration Factory + # Create configuration Factory with simplified config agent_config_factory = MemAgentConfigFactory( backend="deep_search", config={ - "agent_name": "MyDeepSearchAgent", - "description": "Intelligent agent for deep search", + "agent_name": "SimplifiedDeepSearchAgent", + "description": "Simplified intelligent agent for deep search", "max_iterations": 3, # Maximum number of iterations "timeout": 60, # Timeout in seconds }, @@ -198,151 +178,13 @@ def example_1_factory_initialization(): return deep_search_agent, components -def example_3_usage(deep_search_agent, components): - """ - Example 3: Using DeepSearchMemAgent for search - - Demonstrates how to: - 1. Add memories to the system - 2. Use Agent for deep search - 3. Get comprehensive answers - """ - logger.info("\n" + "=" * 60) - logger.info("Example 3: Using DeepSearchMemAgent") - logger.info("=" * 60 + "\n") - - naive_mem_cube = components["naive_mem_cube"] - text_mem = naive_mem_cube.text_mem - - # Simulate adding some memories - logger.info("1. Adding test memories...") - test_memories = [ - { - "user_name": "test_user", - "messages": [ - "Artificial Intelligence is a branch of computer science dedicated to creating systems capable of performing tasks that typically require human intelligence." - ], - "source": "manual", - }, - { - "user_name": "test_user", - "messages": [ - "Machine Learning is a subfield of artificial intelligence that enables computers to learn from data and improve performance." - ], - "source": "manual", - }, - { - "user_name": "test_user", - "messages": [ - "Deep Learning is a branch of machine learning that uses multi-layer neural networks to handle complex pattern recognition tasks." - ], - "source": "manual", - }, - ] - - for memory in test_memories: - try: - text_mem.add( - user_name=memory["user_name"], - messages=memory["messages"], - source=memory.get("source", "manual"), - ) - logger.info(f" ✓ Memory added: {memory['messages'][0][:30]}...") - except Exception as e: - logger.warning(f" ✗ Failed to add memory: {e}") - - # Use Agent for search - logger.info("\n2. Executing deep search...") - query = "What is the relationship between artificial intelligence and machine learning?" - logger.info(f" Query: {query}") - - try: - response = deep_search_agent.run( - query=query, - user_id="test_user", - history=["Hello", "I want to learn about artificial intelligence"], - ) - - logger.info("\n3. Search results:") - logger.info("-" * 60) - logger.info(response) - logger.info("-" * 60) - - except Exception as e: - logger.error(f"Error during search: {e}") - import traceback - - traceback.print_exc() - - -def example_4_minimal_initialization(): - """ - Example 4: Minimal initialization (using default configuration) - - Use cases: - - Quick testing - - Prototype development - - No need for custom configuration - """ - logger.info("\n" + "=" * 60) - logger.info("Example 4: Minimal initialization (default configuration)") - logger.info("=" * 60 + "\n") - - # Build necessary components - components = build_minimal_components() - llm = components["llm"] - naive_mem_cube = components["naive_mem_cube"] - - # Direct initialization using default configuration - deep_search_agent = DeepSearchMemAgent( - llm=llm, - memory_retriever=naive_mem_cube.text_mem, - # config parameter omitted, will use default configuration - ) - - logger.info("✓ DeepSearchMemAgent created successfully (using default configuration)") - logger.info(f" - Max iterations: {deep_search_agent.max_iterations}") - logger.info(f" - Timeout: {deep_search_agent.timeout} seconds") - - return deep_search_agent, components - - def main(): - """Main function: Run all examples""" - logger.info("DeepSearch Agent Usage Examples") - logger.info("=" * 60) - - # Check environment variables - required_env_vars = ["OPENAI_API_KEY"] - missing_vars = [var for var in required_env_vars if not os.getenv(var)] - - if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") - logger.error("Please set the following environment variables:") - logger.error(" - OPENAI_API_KEY: OpenAI API key") - logger.error(" - OPENAI_BASE_URL (optional): OpenAI API base URL") - logger.error(" - NEBULA_HOST (optional): NebulaGraph host address") - return - - try: - # Run Example 1: Factory pattern - agent_factory, components_factory = example_1_factory_initialization() - - # Run Example 4: Minimal initialization - agent_minimal, components_minimal = example_4_minimal_initialization() - - # Run Example 3: Actual usage (using agent created with factory method) - example_3_usage(agent_factory, components_factory) - - logger.info("\n" + "=" * 60) - logger.info("All examples completed!") - logger.info("=" * 60) - - except Exception as e: - logger.error(f"Error running examples: {e}") - import traceback - - traceback.print_exc() + agent_factory, components_factory = factory_initialization() + results = agent_factory.run( + "Caroline met up with friends, family, and mentors in early July 2023.", + user_id="locomo_exp_user_0_speaker_b_ct-1118", + ) + print(results) if __name__ == "__main__": diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 6c6fc32a9..5a070c6ad 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -6,6 +6,7 @@ """ import json +import re from typing import TYPE_CHECKING, Any @@ -14,6 +15,7 @@ from memos.log import get_logger from memos.mem_agent.base import BaseMemAgent from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree import TreeTextMemory from memos.templates.mem_agent_prompts import ( FINAL_GENERATION_PROMPT, QUERY_REWRITE_PROMPT, @@ -25,91 +27,96 @@ from memos.types import MessageList +class JSONResponseParser: + """Elegant JSON response parser for LLM outputs""" + + @staticmethod + def parse(response: str) -> dict[str, Any]: + """Parse JSON response from LLM output with fallback strategies""" + # Clean response text by removing code block markers + cleaned = re.sub(r"^```(?:json)?\s*\n?|```\s*$", "", response.strip(), flags=re.IGNORECASE) + + # Try parsing with multiple strategies + for text in [cleaned, re.search(r"\{.*\}", cleaned, re.DOTALL)]: + if not text: + continue + try: + return json.loads(text if isinstance(text, str) else text.group()) + except json.JSONDecodeError: + continue + + raise ValueError(f"Cannot parse JSON response: {response[:100]}...") + + logger = get_logger(__name__) class QueryRewriter(BaseMemAgent): - """ - Specialized agent for rewriting queries based on conversation history. - Corresponds to the "LLM subAgent (Rewrite...)" in the architecture diagram. - """ + """Specialized agent for rewriting queries based on conversation history""" def __init__(self, llm: BaseLLM, name: str = "QueryRewriter"): self.llm = llm self.name = name def run(self, query: str, history: list[str] | None = None) -> str: - """ - Rewrite the query to be standalone and more searchable. - - Args: - query: Original user query - history: List of previous conversation messages - - Returns: - Rewritten query string - """ - if history is None: - history = [] - - history_str = "\n".join([f"- {msg}" for msg in history[-5:]]) # Last 5 messages - - prompt = QUERY_REWRITE_PROMPT.format( - history=history_str if history_str else "No previous conversation", query=query - ) - - messages: MessageList = [{"role": "user", "content": prompt}] + """Rewrite query to be standalone and more searchable""" + history = history or [] + history_context = self._format_history(history) + prompt = QUERY_REWRITE_PROMPT.format(history=history_context, query=query) + messages = [{"role": "user", "content": prompt}] try: response = self.llm.generate(messages) logger.info(f"[{self.name}] Rewritten query: {response.strip()}") return response.strip() except Exception as e: - logger.error(f"[{self.name}] Error rewriting query: {e}") - return query # Fallback to original query + logger.error(f"[{self.name}] Query rewrite failed: {e}") + return query + + def _format_history(self, history: list[str]) -> str: + """Format conversation history for prompt context""" + if not history: + return "No previous conversation" + return "\n".join(f"- {msg}" for msg in history[-5:]) class ReflectionAgent: - """ - Specialized agent for analyzing information sufficiency. - Corresponds to the decision diamond in the architecture diagram. - """ + """Specialized agent for analyzing information sufficiency""" def __init__(self, llm: BaseLLM, name: str = "Reflector"): self.llm = llm self.name = name def run(self, query: str, context: list[str]) -> dict[str, Any]: - """ - Analyze whether retrieved context is sufficient to answer the query. + """Analyze whether retrieved context is sufficient to answer the query""" + context_summary = self._format_context(context) + prompt = REFLECTION_PROMPT.format(query=query, context=context_summary) - Args: - query: User query - context: List of retrieved context strings + try: + response = self.llm.generate([{"role": "user", "content": prompt}]) + logger.info(f"[{self.name}] Reflection response: {response}") - Returns: - Dictionary with status, reasoning, and missing entities - """ - context_str = "\n".join( - [f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10]] - ) # Limit context size + result = JSONResponseParser.parse(response.strip()) + logger.info(f"[{self.name}] Reflection result: {result}") + return result - prompt = REFLECTION_PROMPT.format(query=query, context=context_str) - messages: MessageList = [{"role": "user", "content": prompt}] + except Exception as e: + logger.error(f"[{self.name}] Reflection analysis failed: {e}") + return self._fallback_response() - try: - response = self.llm.generate(messages) - result = json.loads(response.strip()) - logger.info(f"[{self.name}] Reflection result: {result.get('status', 'unknown')}") - return result - except (json.JSONDecodeError, Exception) as e: - logger.error(f"[{self.name}] Error in reflection analysis: {e}") - # Fallback response - return { - "status": "sufficient", - "reasoning": "Unable to analyze, proceeding with available information", - "missing_entities": [], - } + def _format_context(self, context: list[str]) -> str: + """Format context strings for analysis with length limits""" + return "\n".join( + f"- {ctx[:200]}..." if len(ctx) > 200 else f"- {ctx}" for ctx in context[:10] + ) + + def _fallback_response(self) -> dict[str, Any]: + """Return safe fallback when reflection fails""" + return { + "status": "sufficient", + "reasoning": "Unable to analyze, proceeding with available information", + "missing_entities": [], + } class DeepSearchMemAgent(BaseMemAgent): @@ -123,7 +130,7 @@ class DeepSearchMemAgent(BaseMemAgent): def __init__( self, llm: BaseLLM, - memory_retriever: Any | None = None, + memory_retriever: TreeTextMemory | None = None, config: DeepSearchAgentConfig | None = None, ): """ @@ -142,12 +149,7 @@ def __init__( self.reflector: ReflectionAgent = ReflectionAgent(llm, "Reflector") self.memory_retriever = memory_retriever - def _set_memory_retriever(self, retriever) -> None: - """Set the memory retrieval interface.""" - self.memory_retriever = retriever - logger.info("Memory retriever interface set") - - def run(self, query: str, **kwargs) -> str: + def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: """ Main execution method implementing the deep search pipeline. @@ -162,6 +164,7 @@ def run(self, query: str, **kwargs) -> str: history = kwargs.get("history", []) user_id = kwargs.get("user_id") + generated_answer = kwargs.get("generated_answer") # Step 1: Query Rewriting current_query = self.query_rewriter.run(query, history) @@ -175,7 +178,7 @@ def run(self, query: str, **kwargs) -> str: logger.info(f"Starting iteration {iteration + 1}/{self.max_iterations}") search_results = self._perform_memory_search( - current_query, keywords=search_keywords, user_id=user_id + current_query, keywords=search_keywords, user_id=user_id, history=history ) if search_results: @@ -198,7 +201,8 @@ def run(self, query: str, **kwargs) -> str: elif status == "missing_info": missing_entities = reflection_result.get("missing_entities", []) logger.info(f"Missing information: {missing_entities}") - if missing_entities: + current_query = reflection_result.get("new_search_query") + if not current_query: refined_query = self._refine_query_for_missing_info( current_query, missing_entities ) @@ -211,22 +215,63 @@ def run(self, query: str, **kwargs) -> str: else: break - # Step 3: Generate final answer - final_answer = self._generate_final_answer( - original_query=query, - search_results=accumulated_memories, - context=accumulated_context, - missing_info="", - ) + if not generated_answer: + return self._remove_duplicate_memories(accumulated_memories) + else: + return self._generate_final_answer( + query, accumulated_memories, accumulated_context, "", history + ) + + def _remove_duplicate_memories( + self, memories: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """ + Remove duplicate memories based on memory content. + + Args: + memories: List of TextualMemoryItem objects to deduplicate - logger.info("Deep search pipeline completed") - return final_answer + Returns: + List of unique TextualMemoryItem objects (first occurrence kept) + """ + seen = set() + return [ + memory + for memory in memories + if (content := getattr(memory, "memory", "").strip()) + and content not in seen + and not seen.add(content) + ] + + def _generate_final_answer( + self, + original_query: str, + search_results: list[TextualMemoryItem], + context: list[str], + missing_info: str = "", + history: list[str] | None = None, + sources: list[str] | None = None, + ) -> str: + """ + Generate the final answer. + """ + context_str = "\n".join([f"- {ctx}" for ctx in context[:20]]) + prompt = FINAL_GENERATION_PROMPT.format( + query=original_query, + sources=sources, + context=context_str if context_str else "No specific context retrieved", + missing_info=missing_info if missing_info else "None identified", + ) + messages: MessageList = [{"role": "user", "content": prompt}] + response = self.llm.generate(messages) + return response.strip() def _perform_memory_search( self, query: str, keywords: list[str] | None = None, user_id: str | None = None, + history: list[str] | None = None, top_k: int = 10, ) -> list[TextualMemoryItem]: """ @@ -254,7 +299,11 @@ def _perform_memory_search( # Assuming the retriever has a search method similar to TreeTextMemory results = self.memory_retriever.search( - query=search_query, top_k=top_k, mode="fast", user_name=user_id + query=search_query, + top_k=top_k, + mode="fast", + user_name=user_id, + info={"history": history}, ) return results if isinstance(results, list) else [] @@ -316,7 +365,6 @@ def _generate_final_answer( context=context_str if context_str else "No specific context retrieved", missing_info=missing_info if missing_info else "None identified", ) - messages: MessageList = [{"role": "user", "content": prompt}] try: diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 17e3c0ad9..477cd2409 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -25,14 +25,15 @@ {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", - "missing_entities": ["entity1", "entity2"] // Only if status is "missing_info", - + "missing_entities": ["entity1", "entity2"], + "new_search_query": "new search query", }} Status definitions: - "sufficient": Context fully answers the query - "missing_info": Key information is missing (e.g., specific dates, locations, details) - "needs_raw": Content is relevant but too summarized/vague, need original sources +- "new_search_query": New search query to retrieve more information Response:""" From c21fc586a723477f3f2a2465e895228b4c1b8767 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sat, 22 Nov 2025 18:29:39 +0800 Subject: [PATCH 08/35] feat: update test code --- tests/configs/test_embedder.py | 4 ++-- tests/embedders/test_universal_api.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/configs/test_embedder.py b/tests/configs/test_embedder.py index 8201f9bd8..10572f33e 100644 --- a/tests/configs/test_embedder.py +++ b/tests/configs/test_embedder.py @@ -17,7 +17,7 @@ def test_base_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims"], + optional_fields=["embedding_dims", "headers_extra"], ) check_config_instantiation_valid( @@ -36,7 +36,7 @@ def test_ollama_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims", "api_base"], + optional_fields=["embedding_dims", "headers_extra", "api_base"], ) check_config_instantiation_valid( diff --git a/tests/embedders/test_universal_api.py b/tests/embedders/test_universal_api.py index e4ebb7019..1419b2849 100644 --- a/tests/embedders/test_universal_api.py +++ b/tests/embedders/test_universal_api.py @@ -30,6 +30,7 @@ def test_embed_single_text(self, mock_openai_client): mock_openai_client.assert_called_once_with( api_key="fake-api-key", base_url="https://api.openai.com/v1", + default_headers=None, ) # Assert embeddings.create called with correct params From 0f62af8173c8c1614e582c77b7ce572f44415e75 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 24 Nov 2025 14:30:28 +0800 Subject: [PATCH 09/35] fix: remove dup config --- src/memos/api/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d738c878a..c62cd3b08 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -409,7 +409,6 @@ def get_embedder_config() -> dict[str, Any]: "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), - "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), }, } else: # ollama From dac3394cf5f79e4b08481b661645786a6d73e716 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 10:55:16 +0800 Subject: [PATCH 10/35] feat: dock search pipeline --- src/memos/api/handlers/base_handler.py | 5 +++++ src/memos/api/handlers/component_init.py | 6 ++++++ src/memos/api/handlers/search_handler.py | 10 +++++++--- src/memos/mem_agent/deepsearch_agent.py | 7 +++---- src/memos/multi_mem_cube/single_cube.py | 8 ++++++-- src/memos/templates/mem_agent_prompts.py | 4 +++- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index a686ac8f9..7a47f05e3 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -161,6 +161,11 @@ def mos_server(self): """Get MOS server instance.""" return self.deps.mos_server + @property + def deepsearch_agent(self): + """Get deepsearch agent instance.""" + return self.deps.deepsearch_agent + def _validate_dependencies(self, *required_deps: str) -> None: """ Validate that required dependencies are available. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 3ef1d529d..7b34fcfae 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory +from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -307,6 +308,10 @@ def init_server() -> dict[str, Any]: online_bot = get_online_bot_function() if dingding_enabled else None logger.info("DingDing bot is enabled") + deepsearch_agent = DeepSearchMemAgent( + llm=llm, + memory_retriever=tree_mem, + ) # Return all components as a dictionary for easy access and extension return { "graph_db": graph_db, @@ -330,4 +335,5 @@ def init_server() -> dict[str, Any]: "text_mem": text_mem, "pref_mem": pref_mem, "online_bot": online_bot, + "deepsearch_agent": deepsearch_agent, } diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 8a2c21aad..b454568f8 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -31,7 +31,9 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher") + self._validate_dependencies( + "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" + ) def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -52,10 +54,10 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse results = cube_view.search_memories(search_req) - self.logger.info(f"[AddHandler] Final add results count={len(results)}") + self.logger.info(f"[SearchHandler] Final search results count={len(results)}") return SearchResponse( - message="Memory searched successfully", + message="Search completed successfully", data=results, ) @@ -87,6 +89,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) else: single_views = [ @@ -97,6 +100,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_scheduler=self.mem_scheduler, logger=self.logger, searcher=self.searcher, + deepsearch_agent=self.deepsearch_agent, ) for cube_id in cube_ids ] diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5a070c6ad..16b600980 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from memos.types import MessageList +logger = get_logger(__name__) + class JSONResponseParser: """Elegant JSON response parser for LLM outputs""" @@ -48,9 +50,6 @@ def parse(response: str) -> dict[str, Any]: raise ValueError(f"Cannot parse JSON response: {response[:100]}...") -logger = get_logger(__name__) - - class QueryRewriter(BaseMemAgent): """Specialized agent for rewriting queries based on conversation history""" @@ -141,7 +140,7 @@ def __init__( memory_retriever: Memory retrieval interface (e.g., naive_mem_cube.text_mem) config: Configuration for deep search behavior """ - self.config = config or DeepSearchAgentConfig() + self.config = config or DeepSearchAgentConfig(agent_name="DeepSearchMemAgent") self.max_iterations = self.config.max_iterations self.timeout = self.config.timeout self.llm: BaseLLM = llm diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2055615d2..d2fde36a3 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,6 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any + deepsearch_agent: Any def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ @@ -247,8 +248,11 @@ def _fast_search( def _deep_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: - logger.error("waiting to be implemented") - return [] + deepsearch_results = self.deepsearch_agent.run( + search_req.query, user_id=user_context.mem_cube_id + ) + formatted_memories = [format_memory_item(data) for data in deepsearch_results] + return formatted_memories def _fine_search( self, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index 477cd2409..eb624ef89 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -22,12 +22,14 @@ {context} Analyze the context and determine the next step. Return your response in JSON format with the following structure: -{{ + ```json + {{ "status": "sufficient|missing_info|needs_raw", "reasoning": "Brief explanation of your decision", "missing_entities": ["entity1", "entity2"], "new_search_query": "new search query", }} +``` Status definitions: - "sufficient": Context fully answers the query From 9489d54b0a71ba8db05b7d13158fcaad6135d50f Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 11:17:52 +0800 Subject: [PATCH 11/35] fix: code test --- tests/api/test_server_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 853a271f6..2aa96257b 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -48,6 +48,7 @@ def mock_init_server(): "pref_mem": None, "online_bot": None, "chat_llms": Mock(), + "deepsearch_agent": Mock(), } with patch("memos.api.handlers.init_server", return_value=mock_components): From e43e5db62bf9fab56ed432a6657412a868e58763 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 14:59:14 +0800 Subject: [PATCH 12/35] feat: add test scripts --- src/memos/mem_agent/deepsearch_agent.py | 4 +- tests/mem_agent/test_deepsearch_agent.py | 233 +++++++++++++++++++++++ 2 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 tests/mem_agent/test_deepsearch_agent.py diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 16b600980..5e51aec44 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -218,7 +218,7 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: return self._remove_duplicate_memories(accumulated_memories) else: return self._generate_final_answer( - query, accumulated_memories, accumulated_context, "", history + query, accumulated_memories, accumulated_context, history ) def _remove_duplicate_memories( @@ -247,9 +247,9 @@ def _generate_final_answer( original_query: str, search_results: list[TextualMemoryItem], context: list[str], - missing_info: str = "", history: list[str] | None = None, sources: list[str] | None = None, + missing_info: str | None = None, ) -> str: """ Generate the final answer. diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py new file mode 100644 index 000000000..707d80283 --- /dev/null +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -0,0 +1,233 @@ +"""Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" + +import json +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.mem_agent import DeepSearchAgentConfig +from memos.mem_agent.deepsearch_agent import ( + DeepSearchMemAgent, + JSONResponseParser, +) +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata + + +class TestJSONResponseParser: + """Test JSONResponseParser class.""" + + def test_parse_clean_json(self): + """Test parsing clean JSON response.""" + response = '{"status": "sufficient", "reasoning": "test"}' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_json_with_code_blocks(self): + """Test parsing JSON wrapped in code blocks.""" + response = '```json\n{"status": "sufficient", "reasoning": "test"}\n```' + result = JSONResponseParser.parse(response) + assert result == {"status": "sufficient", "reasoning": "test"} + + def test_parse_invalid_json_raises_error(self): + """Test that invalid JSON raises ValueError.""" + with pytest.raises(ValueError, match="Cannot parse JSON response"): + JSONResponseParser.parse("This is not JSON at all") + + +class TestDeepSearchMemAgent: + """Test DeepSearchMemAgent core functionality.""" + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + mock = MagicMock() + mock.generate.return_value = "Generated answer" + return mock + + @pytest.fixture + def mock_memory_retriever(self): + """Create a mock memory retriever.""" + mock = MagicMock() + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a programming language", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python was created by Guido van Rossum", + metadata=TextualMemoryMetadata(type="fact"), + ), + ] + mock.search.return_value = memory_items + return mock + + @pytest.fixture + def config(self): + """Create DeepSearchAgentConfig.""" + return DeepSearchAgentConfig(agent_name="TestDeepSearch", max_iterations=3, timeout=30) + + @pytest.fixture + def agent(self, mock_llm, mock_memory_retriever, config): + """Create DeepSearchMemAgent instance.""" + agent = DeepSearchMemAgent( + llm=mock_llm, memory_retriever=mock_memory_retriever, config=config + ) + # Mock the sub-agents to avoid complex interactions + agent.query_rewriter.run = MagicMock(return_value="Rewritten query") + agent.reflector.run = MagicMock( + return_value={ + "status": "sufficient", + "reasoning": "Enough info", + "missing_entities": [], + } + ) + return agent + + def test_init_with_config(self, mock_llm, mock_memory_retriever, config): + """Test DeepSearchMemAgent initialization with config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + assert agent.llm == mock_llm + assert agent.memory_retriever == mock_memory_retriever + assert agent.config == config + assert agent.max_iterations == 3 + assert agent.timeout == 30 + + def test_init_without_config(self, mock_llm, mock_memory_retriever): + """Test DeepSearchMemAgent initialization without config.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever) + assert isinstance(agent.config, DeepSearchAgentConfig) + assert agent.config.agent_name == "DeepSearchMemAgent" + + def test_run_no_llm_raises_error(self, config): + """Test that running without LLM raises RuntimeError.""" + agent = DeepSearchMemAgent(llm=None, config=config) + with pytest.raises(RuntimeError, match="LLM not initialized"): + agent.run("test query") + + def test_run_returns_memories_when_no_generated_answer(self, agent, mock_memory_retriever): + """Test run returns memories when generated_answer is not requested.""" + result = agent.run("What is Python?", generated_answer=False) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, TextualMemoryItem) for item in result) + agent.query_rewriter.run.assert_called_once() + + def test_run_returns_answer_when_generated_answer(self, agent, mock_llm): + """Test run returns generated answer when requested.""" + result = agent.run("What is Python?", generated_answer=True) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_run_with_user_id(self, agent, mock_memory_retriever): + """Test run with user_id.""" + agent.run("What is Python?", user_id="user123", generated_answer=False) + + # Check that user_id was passed to search + call_kwargs = mock_memory_retriever.search.call_args[1] + assert call_kwargs.get("user_name") == "user123" + + def test_run_no_search_results(self, agent, mock_memory_retriever): + """Test behavior when search returns no results.""" + mock_memory_retriever.search.return_value = [] + + result = agent.run("What is Python?", generated_answer=False) + + assert result == [] + + def test_remove_duplicate_memories(self, agent): + """Test removing duplicate memories.""" + mem_id1 = str(uuid.uuid4()) + mem_id2 = str(uuid.uuid4()) + mem_id3 = str(uuid.uuid4()) + + memories = [ + TextualMemoryItem( + id=mem_id1, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + TextualMemoryItem( + id=mem_id2, + memory="Different content", + metadata=TextualMemoryMetadata(type="fact"), + ), + TextualMemoryItem( + id=mem_id3, memory="Same content", metadata=TextualMemoryMetadata(type="fact") + ), + ] + + result = agent._remove_duplicate_memories(memories) + + assert len(result) == 2 + assert result[0].id == mem_id1 + assert result[1].id == mem_id2 + + def test_generate_final_answer(self, agent, mock_llm): + """Test final answer generation.""" + memory_items = [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory="Python is a language", + metadata=TextualMemoryMetadata(type="fact"), + ) + ] + context = ["Python is a programming language"] + + result = agent._generate_final_answer("What is Python?", memory_items, context) + + assert result == "Generated answer" + mock_llm.generate.assert_called_once() + + def test_generate_final_answer_with_missing_info(self, agent, mock_llm): + """Test final answer generation with missing info.""" + result = agent._generate_final_answer( + "What is Python?", [], [], missing_info="Version details not found" + ) + + assert result == "Generated answer" + call_args = mock_llm.generate.call_args[0][0] + assert "Version details not found" in call_args[0]["content"] + + def test_generate_final_answer_llm_error(self, agent, mock_llm): + """Test final answer generation handles LLM errors.""" + mock_llm.generate.side_effect = Exception("LLM error") + + result = agent._generate_final_answer("What is Python?", [], []) + + assert "error" in result.lower() + assert "What is Python?" in result + + def test_perform_memory_search_no_retriever(self, mock_llm, config): + """Test memory search when retriever is not configured.""" + agent = DeepSearchMemAgent(mock_llm, memory_retriever=None, config=config) + result = agent._perform_memory_search("test query") + + assert result == [] + + def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config): + """Test full pipeline integration.""" + agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) + + with patch.object(agent.query_rewriter, "run", return_value="Rewritten query"): + with patch.object( + agent.reflector, + "run", + return_value={ + "status": "sufficient", + "reasoning": "Info is sufficient", + "missing_entities": [], + }, + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called() From ecd45084935131bb7fb02ac09542c91ebae9e773 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 14:59:52 +0800 Subject: [PATCH 13/35] feat: add test --- tests/mem_agent/test_deepsearch_agent.py | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/mem_agent/test_deepsearch_agent.py b/tests/mem_agent/test_deepsearch_agent.py index 707d80283..a80dd10ea 100644 --- a/tests/mem_agent/test_deepsearch_agent.py +++ b/tests/mem_agent/test_deepsearch_agent.py @@ -1,6 +1,5 @@ """Simplified unit tests for DeepSearchAgent - focusing on core functionality.""" -import json import uuid from unittest.mock import MagicMock, patch @@ -213,8 +212,9 @@ def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config """Test full pipeline integration.""" agent = DeepSearchMemAgent(mock_llm, mock_memory_retriever, config) - with patch.object(agent.query_rewriter, "run", return_value="Rewritten query"): - with patch.object( + with ( + patch.object(agent.query_rewriter, "run", return_value="Rewritten query"), + patch.object( agent.reflector, "run", return_value={ @@ -222,12 +222,13 @@ def test_integration_full_pipeline(self, mock_llm, mock_memory_retriever, config "reasoning": "Info is sufficient", "missing_entities": [], }, - ): - result = agent.run( - "What is Python?", user_id="user123", history=[], generated_answer=True - ) - - assert isinstance(result, str) - assert result == "Generated answer" - mock_memory_retriever.search.assert_called() - mock_llm.generate.assert_called() + ), + ): + result = agent.run( + "What is Python?", user_id="user123", history=[], generated_answer=True + ) + + assert isinstance(result, str) + assert result == "Generated answer" + mock_memory_retriever.search.assert_called() + mock_llm.generate.assert_called() From 6e21032fc87464f4569f99c26e4309247b533d74 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 19:06:23 +0800 Subject: [PATCH 14/35] feat: update need_raw process --- src/memos/mem_agent/deepsearch_agent.py | 21 +++++++++++++++++++-- src/memos/templates/mem_agent_prompts.py | 8 +++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/memos/mem_agent/deepsearch_agent.py b/src/memos/mem_agent/deepsearch_agent.py index 5e51aec44..051ac03d3 100644 --- a/src/memos/mem_agent/deepsearch_agent.py +++ b/src/memos/mem_agent/deepsearch_agent.py @@ -183,8 +183,6 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: if search_results: context_batch = [self._extract_context_from_memory(mem) for mem in search_results] accumulated_context.extend(context_batch) - accumulated_memories.extend(search_results) - reflection_result = self.reflector.run(current_query, context_batch) status = reflection_result.get("status", "sufficient") reasoning = reflection_result.get("reasoning", "") @@ -193,11 +191,14 @@ def run(self, query: str, **kwargs) -> str | list[TextualMemoryItem]: if status == "sufficient": logger.info("Sufficient information collected") + accumulated_memories.extend(search_results) break elif status == "needs_raw": logger.info("Need original sources, retrieving raw content") + accumulated_memories.extend(self._set_source_from_memory(search_results)) break elif status == "missing_info": + accumulated_memories.extend(search_results) missing_entities = reflection_result.get("missing_entities", []) logger.info(f"Missing information: {missing_entities}") current_query = reflection_result.get("new_search_query") @@ -331,6 +332,22 @@ def _refine_query_for_missing_info(self, query: str, missing_entities: list[str] return refined_query + def _set_source_from_memory( + self, memory_items: list[TextualMemoryItem] + ) -> list[TextualMemoryItem]: + """set source from memory item""" + for memory_item in memory_items: + if not hasattr(memory_item.metadata, "sources"): + continue + chat_sources = [ + f"{source.chat_time} {source.role}: {source.content}" + for source in memory_item.metadata.sources + if hasattr(source, "type") and source.type == "chat" + ] + if chat_sources: + memory_item.memory = "\n".join(chat_sources) + "\n" + return memory_items + def _generate_final_answer( self, original_query: str, diff --git a/src/memos/templates/mem_agent_prompts.py b/src/memos/templates/mem_agent_prompts.py index eb624ef89..d7163e4a8 100644 --- a/src/memos/templates/mem_agent_prompts.py +++ b/src/memos/templates/mem_agent_prompts.py @@ -35,7 +35,13 @@ - "sufficient": Context fully answers the query - "missing_info": Key information is missing (e.g., specific dates, locations, details) - "needs_raw": Content is relevant but too summarized/vague, need original sources -- "new_search_query": New search query to retrieve more information + +IMPORTANT for "new_search_query": +- MUST preserve ALL specific entities from the original query (names, dates, times, locations, etc.) +- DO NOT replace specific information with generic terms like "user", "person", "they", etc. +- Keep the exact same subjects, time references, and key details as in the original query +- Only modify the query to focus on the missing information while maintaining all original specifics +- Example: If original query mentions "May 2024", keep "May 2024" in new query, don't change to "that month" Response:""" From 592f637bc3090dd682013f3b74f290de72450050 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 25 Nov 2025 20:02:22 +0800 Subject: [PATCH 15/35] fix: add initter --- src/memos/multi_mem_cube/single_cube.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d2fde36a3..dbc527bb7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -42,7 +42,7 @@ class SingleCubeView(MemCubeView): mem_scheduler: Any logger: Any searcher: Any - deepsearch_agent: Any + deepsearch_agent: Any | None = None def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ From ad997459ffa8d8a1fc30442b8d47f50d338da9c7 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 27 Nov 2025 16:24:58 +0800 Subject: [PATCH 16/35] fix: change agent search func name --- src/memos/multi_mem_cube/single_cube.py | 4 +++- src/memos/types/general_types.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8f4a25a0b..8e37cb92d 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -203,7 +203,7 @@ def _deep_search( formatted_memories = [format_memory_item(data) for data in enhanced_memories] return formatted_memories - def _deep_search( + def _agentic_search( self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int ) -> list: deepsearch_results = self.deepsearch_agent.run( @@ -229,6 +229,8 @@ def _fine_search( """ if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) + elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: + return self._agentic_search(search_req=search_req, user_context=user_context) target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 2b7206c74..f796e682a 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -102,6 +102,7 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" DEEP_SEARCH = "deep_search" + AGENTIC_SEARCH = "agentic_search" # algorithm strategies From 94dba83f4655e45ca671e57eb19731e85f4fe2c6 Mon Sep 17 00:00:00 2001 From: fridayL Date: Fri, 28 Nov 2025 11:10:02 +0800 Subject: [PATCH 17/35] feat: update logs and defined --- src/memos/multi_mem_cube/single_cube.py | 12 ++++++++---- src/memos/reranker/factory.py | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8e37cb92d..e60e8e71e 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -35,14 +35,17 @@ if TYPE_CHECKING: from memos.api.product_models import APIADDRequest, APISearchRequest + from memos.mem_cube.navie import NaiveMemCube + from memos.mem_reader.simple_struct import SimpleStructMemReader + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler @dataclass class SingleCubeView(MemCubeView): cube_id: str - naive_mem_cube: Any - mem_reader: Any - mem_scheduler: Any + naive_mem_cube: NaiveMemCube + mem_reader: SimpleStructMemReader + mem_scheduler: OptimizedScheduler logger: Any searcher: Any deepsearch_agent: Any | None = None @@ -155,7 +158,7 @@ def _search_text( Args: search_req: Search request user_context: User context - search_mode: Search mode (FAST, FINE, or MIXTURE) + search_mode: Search mode (fast, fine, or mixture) Returns: List of formatted memory items @@ -227,6 +230,7 @@ def _fine_search( Returns: List of enhanced search results """ + logger.info(f"Fine strategy: {FINE_STRATEGY}") if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index d2c50ba5e..1440704a6 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -2,6 +2,7 @@ from __future__ import annotations import json + from typing import TYPE_CHECKING, Any # Import singleton decorator From 953872e46cd9ef4da5fcfadb3cbb97bbb73d49a5 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 1 Dec 2025 15:10:41 +0800 Subject: [PATCH 18/35] feat: update full text mem search --- src/memos/graph_dbs/polardb.py | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index bfde8c80c..7207f05da 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1450,6 +1450,131 @@ def get_subgraph( def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + + @timed + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + """ + Full-text search functionality using PostgreSQL's full-text search capabilities. + + Args: + query_text: query text + top_k: maximum number of results to return + scope: memory type filter (memory_type) + status: status filter, defaults to "activated" + threshold: similarity threshold filter + search_filter: additional property filter conditions + user_name: username filter + knowledgebase_ids: knowledgebase ids filter + filter: filter conditions with 'and' or 'or' logic for search results. + tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 + tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + **kwargs: other parameters (e.g. cube_name) + + Returns: + list[dict]: result list containing id and score + """ + # Build WHERE clause dynamically, same as search_by_embedding + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + + # Add OR condition if we have any user_name conditions + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + # Add fulltext search condition + # Convert query_text to OR query format: "word1 | word2 | word3" + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text, + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY rank DESC + LIMIT {top_k}; + """ + + params = [tsquery_string, tsquery_string] + + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[2] # rank score + + id_val = str(oldid) + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + + return output[:top_k] + finally: + self._return_connection(conn) @timed def search_by_embedding( From 2591c107812d8b48c6b80b262365ca8e2e675fce Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 1 Dec 2025 17:46:24 +0800 Subject: [PATCH 19/35] feat: cp plugin to dev --- src/memos/api/handlers/component_init.py | 3 + src/memos/api/product_models.py | 6 ++ src/memos/graph_dbs/polardb.py | 3 +- src/memos/memories/textual/simple_tree.py | 3 + src/memos/memories/textual/tree.py | 5 ++ .../retrieve/advanced_searcher.py | 3 + .../tree_text_memory/retrieve/recall.py | 19 +++++ .../retrieve/retrieve_utils.py | 27 ++++++ .../tree_text_memory/retrieve/searcher.py | 83 ++++++++++++++++--- src/memos/multi_mem_cube/single_cube.py | 4 +- 10 files changed, 141 insertions(+), 15 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 706269b52..574f2ae17 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -41,6 +41,7 @@ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer if TYPE_CHECKING: @@ -196,6 +197,7 @@ def init_server() -> dict[str, Any]: logger.debug("Memory manager initialized") + tokenizer = FastTokenizer() # Initialize text memory text_mem = SimpleTreeTextMemory( llm=llm, @@ -206,6 +208,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, config=default_cube_config.text_mem.config, internet_retriever=internet_retriever, + tokenizer=tokenizer, ) logger.debug("Text memory initialized") diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 4f445e9ab..cc76e6751 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -388,6 +388,12 @@ class APISearchRequest(BaseRequest): description="(Internal) Operation definitions for multi-cube read permissions.", ) + # ==== Source for plugin ==== + source: str | None = Field( + None, + description="Source of the search query [plugin will router diff search]", + ) + @model_validator(mode="after") def _convert_deprecated_fields(self) -> "APISearchRequest": """ diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 7207f05da..90bea8ea3 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1450,7 +1450,7 @@ def get_subgraph( def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError - + @timed def search_by_fulltext( self, @@ -1529,7 +1529,6 @@ def search_by_fulltext( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) where_clauses.extend(filter_conditions) diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 05e62e3ee..c67271f76 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -9,6 +9,7 @@ from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.reranker.base import BaseReranker @@ -35,6 +36,7 @@ def __init__( config: TreeTextMemoryConfig, internet_retriever: None = None, is_reorganize: bool = False, + tokenizer: FastTokenizer | None = None, ): """Initialize memory with the given configuration.""" self.config: TreeTextMemoryConfig = config @@ -51,6 +53,7 @@ def __init__( if self.search_strategy and self.search_strategy.get("bm25", False) else None ) + self.tokenizer = tokenizer self.reranker = reranker self.memory_manager: MemoryManager = memory_manager # Create internet retriever if configured diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 2a109bf71..ad2bcd9c4 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -91,6 +91,7 @@ def __init__(self, config: TreeTextMemoryConfig): ) else: logger.info("No internet retriever configured") + self.tokenizer = None def add( self, @@ -165,6 +166,7 @@ def search( search_priority: dict | None = None, search_filter: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. User query -> TaskGoalParser -> MemoryPathResolver -> @@ -197,6 +199,7 @@ def search( internet_retriever=None, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) else: searcher = Searcher( @@ -208,6 +211,7 @@ def search( internet_retriever=self.internet_retriever, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) return searcher.search( query, @@ -218,6 +222,7 @@ def search( search_filter, search_priority, user_name=user_name, + plugin=kwargs.get("plugin", False), ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 22cd44b8c..9c892d8b8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -10,6 +10,7 @@ from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, parse_structured_output, ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -33,6 +34,7 @@ def __init__( search_strategy: dict | None = None, manual_close_internet: bool = True, process_llm: Any | None = None, + tokenizer: FastTokenizer | None = None, ): super().__init__( dispatcher_llm=dispatcher_llm, @@ -43,6 +45,7 @@ def __init__( internet_retriever=internet_retriever, search_strategy=search_strategy, manual_close_internet=manual_close_internet, + tokenizer=tokenizer, ) self.stage_retrieve_top = 3 diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 7fa8a87be..7ac274a62 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -148,6 +148,25 @@ def retrieve_from_cube( return list(combined.values()) + def retrieve_from_mixed( + self, + top_k: int, + memory_scope: str | None = None, + query_embedding: list[list[float]] | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + ) -> list[TextualMemoryItem]: + """Retrieve from mixed and memory""" + vector_results = self._vector_recall( + query_embedding or [], + memory_scope, + top_k, + search_filter=search_filter, + user_name=user_name, + ) # Merge and deduplicate by ID + combined = {item.id: item for item in vector_results} + return list(combined.values()) + def _graph_recall( self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs ) -> list[TextualMemoryItem]: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 0720d1fca..9e1e6c240 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Any +import numpy as np + from memos.dependency import require_python_package from memos.log import get_logger @@ -463,3 +465,28 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]: memory["metadata"]["memory"] = memory["memory"] return memory + + +def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8): + assert len(sentences) == len(similarity_matrix) + + num_sentence = len(sentences) + selected_sentences = [] + selected_indices = [] + for i in range(num_sentence): + can_add = True + for j in selected_indices: + if similarity_matrix[i][j] > bar: + can_add = False + break + if can_add: + selected_sentences.append(i) + selected_indices.append(i) + return selected_sentences, selected_indices + + +def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]: + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + x_normalized = embeddings / norms + similarity_matrix = np.dot(x_normalized, x_normalized.T) + return similarity_matrix diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 976be6a54..f428bf5c0 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -8,7 +8,10 @@ from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + cosine_similarity_matrix, detect_lang, + find_best_unrelated_subgroup, parse_json_result, ) from memos.reranker.base import BaseReranker @@ -43,6 +46,7 @@ def __init__( internet_retriever: None = None, search_strategy: dict | None = None, manual_close_internet: bool = True, + tokenizer: FastTokenizer | None = None, ): self.graph_store = graph_store self.embedder = embedder @@ -58,6 +62,7 @@ def __init__( self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self.manual_close_internet = manual_close_internet + self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed @@ -104,9 +109,10 @@ def post_retrieve( top_k: int, user_name: str | None = None, info=None, + plugin=False, ): deduped = self._deduplicate_results(retrieved_results) - final_results = self._sort_and_trim(deduped, top_k) + final_results = self._sort_and_trim(deduped, top_k, plugin) self._update_usage_history(final_results, info, user_name) return final_results @@ -121,6 +127,7 @@ def search( search_filter: dict | None = None, search_priority: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -149,22 +156,29 @@ def search( else: logger.debug(f"[SEARCH] Received info dict: {info}") - retrieved_results = self.retrieve( - query=query, - top_k=top_k, - info=info, - mode=mode, - memory_type=memory_type, - search_filter=search_filter, - search_priority=search_priority, - user_name=user_name, - ) + if kwargs.get("plugin"): + logger.info(f"[SEARCH] Retrieve from plugin: {query}") + retrieved_results = self._retrieve_simple( + query=query, top_k=top_k, search_filter=search_filter, user_name=user_name + ) + else: + retrieved_results = self.retrieve( + query=query, + top_k=top_k, + info=info, + mode=mode, + memory_type=memory_type, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + ) final_results = self.post_retrieve( retrieved_results=retrieved_results, top_k=top_k, user_name=user_name, info=None, + plugin=kwargs.get("plugin", False), ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -484,6 +498,49 @@ def _retrieve_from_internet( parsed_goal=parsed_goal, ) + @timed + def _retrieve_simple( + self, + query: str, + top_k: int, + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + """Retrieve from by keywords and embedding""" + query_words = [] + if self.tokenizer: + query_words = self.tokenizer.tokenize_mixed(query) + else: + query_words = query.strip().split() + query_words = list(set(query_words))[: top_k * 3] + query_words = [query, *query_words] + logger.info(f"[SIMPLESEARCH] Query words: {query_words}") + query_embeddings = self.embedder.embed(query_words) + + items = self.graph_retriever.retrieve_from_mixed( + top_k=top_k * 2, + memory_scope=None, + query_embedding=query_embeddings, + search_filter=search_filter, + user_name=user_name, + ) + logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") + documents = [getattr(item, "memory", "") for item in items] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) + return self.reranker.rerank( + query=query, + query_embedding=query_embeddings[0], + graph_results=selected_items, + top_k=top_k, + ) + @timed def _deduplicate_results(self, results): """Deduplicate results by memory text""" @@ -494,12 +551,14 @@ def _deduplicate_results(self, results): return list(deduped.values()) @timed - def _sort_and_trim(self, results, top_k): + def _sort_and_trim(self, results, top_k, plugin=False): """Sort results by score and trim to top_k""" sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: + if plugin and round(score, 2) == 0.00: + continue meta_data = item.metadata.model_dump() meta_data["relativity"] = score final_items.append( diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index e346bdf1f..880646939 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -363,7 +363,8 @@ def _fast_search( target_session_id = search_req.session_id or "default_session" search_priority = {"session_id": search_req.session_id} if search_req.session_id else None search_filter = search_req.filter or None - print(f"type of text_mem: {type(self.naive_mem_cube.text_mem)}") + plugin = bool(search_req.source is not None and search_req.source == "plugin") + search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, @@ -377,6 +378,7 @@ def _fast_search( "session_id": target_session_id, "chat_history": search_req.chat_history, }, + plugin=plugin, ) formatted_memories = [format_memory_item(data) for data in search_results] From 383eaaa790a0d625698f03c2462b098978433ce7 Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 1 Dec 2025 20:43:38 +0800 Subject: [PATCH 20/35] feat: add one recall for fulltext retrieval --- src/memos/graph_dbs/polardb.py | 2 +- .../tree_text_memory/retrieve/recall.py | 68 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index e731ef138..da88604b6 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1553,7 +1553,7 @@ def search_by_fulltext( """ params = [tsquery_string, tsquery_string] - + logger.info(f"[search_by_fulltext] query: {query}, params: {params}") conn = self._get_connection() try: with conn.cursor() as cursor: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 7ac274a62..5a6ff0233 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -101,13 +101,27 @@ def retrieve( user_name=user_name, search_filter=id_filter, ) + if use_fast_graph: + future_fulltext = executor.submit( + self._fulltext_recall, + query_words=parsed_goal.keys or [], + memory_scope=memory_scope, + top_k=top_k, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + ) graph_results = future_graph.result() vector_results = future_vector.result() bm25_results = future_bm25.result() if self.use_bm25 else [] + fulltext_results = future_fulltext.result() # Merge and deduplicate by ID - combined = {item.id: item for item in graph_results + vector_results + bm25_results} + combined = { + item.id: item + for item in graph_results + vector_results + bm25_results + fulltext_results + } return list(combined.values()) @@ -404,3 +418,55 @@ def _bm25_recall( ) return [TextualMemoryItem.from_dict(n) for n in bm25_results] + + def _fulltext_recall( + self, + query_words: list[str], + memory_scope: str, + top_k: int = 20, + max_num: int = 5, + status: str = "activated", + cube_name: str | None = None, + search_filter: dict | None = None, + search_priority: dict | None = None, + user_name: str | None = None, + ): + """Perform fulltext-based retrieval. + Args: + query_words: list of query words + memory_scope: memory scope + top_k: top k results + max_num: max number of query words + status: status + cube_name: cube name + search_filter: search filter + search_priority: search priority + user_name: user name + Returns: + list of TextualMemoryItem + """ + if not query_words: + return [] + logger.info(f"[FULLTEXT] query_words: {query_words}") + all_hits = self.graph_store.search_by_fulltext( + query_words=query_words, + top_k=top_k, + status=status, + scope=memory_scope, + cube_name=cube_name, + search_filter=search_priority, + filter=search_filter, + user_name=user_name, + ) + if not all_hits: + return [] + + # merge and deduplicate + unique_ids = {r["id"] for r in all_hits if r.get("id")} + node_dicts = ( + self.graph_store.get_nodes( + list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name + ) + or [] + ) + return [TextualMemoryItem.from_dict(n) for n in node_dicts] From 502e15e5216f3f08c927f724560f57ee41e2e604 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 2 Dec 2025 09:57:00 +0800 Subject: [PATCH 21/35] fix: set default for fulltext search --- src/memos/memories/textual/tree_text_memory/retrieve/recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 5a6ff0233..5dfbde704 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -115,7 +115,7 @@ def retrieve( graph_results = future_graph.result() vector_results = future_vector.result() bm25_results = future_bm25.result() if self.use_bm25 else [] - fulltext_results = future_fulltext.result() + fulltext_results = future_fulltext.result() if use_fast_graph else [] # Merge and deduplicate by ID combined = { From 861e489fa82b688f9eb4507a061b5fb59019bd64 Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 2 Dec 2025 16:47:12 +0800 Subject: [PATCH 22/35] feat: add langchain chunk --- examples/mem_chunk/markdown_chunk.py | 33 ++ poetry.lock | 324 +++--------------- pyproject.toml | 3 +- src/memos/chunkers/factory.py | 2 + src/memos/chunkers/markdown_chunker.py | 53 +++ src/memos/configs/chunker.py | 14 + .../tree_text_memory/retrieve/searcher.py | 5 +- 7 files changed, 164 insertions(+), 270 deletions(-) create mode 100644 examples/mem_chunk/markdown_chunk.py create mode 100644 src/memos/chunkers/markdown_chunker.py diff --git a/examples/mem_chunk/markdown_chunk.py b/examples/mem_chunk/markdown_chunk.py new file mode 100644 index 000000000..ce7d2b9ae --- /dev/null +++ b/examples/mem_chunk/markdown_chunk.py @@ -0,0 +1,33 @@ +from memos.chunkers import ChunkerFactory +from memos.configs.chunker import ChunkerConfigFactory + + +config = ChunkerConfigFactory.model_validate( + { + "backend": "markdown", + "config": { + "chunk_size": 1000, + "chunk_overlap": 100, + "recursive": True, + }, + } +) + +chunker = ChunkerFactory.from_config(config) + +text = """ +# Header 1 +This is the first sentence. This is the second sentence. +And here's a third one with some additional context. + +# Header 2 +This is the fourth sentence. This is the fifth sentence. +And here's a sixth one with some additional context. + +# Header 3 +This is the seventh sentence. This is the eighth sentence. +And here's a ninth one with some additional context. +""" +chunks = chunker.chunk(text) +for chunk in chunks: + print("doc:", chunk) diff --git a/poetry.lock b/poetry.lock index 40d0f6210..940697b1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -24,32 +24,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[[package]] -name = "anthropic" -version = "0.57.1" -description = "The official Python library for the anthropic API" -optional = false -python-versions = ">=3.8" -groups = ["eval"] -files = [ - {file = "anthropic-0.57.1-py3-none-any.whl", hash = "sha256:33afc1f395af207d07ff1bffc0a3d1caac53c371793792569c5d2f09283ea306"}, - {file = "anthropic-0.57.1.tar.gz", hash = "sha256:7815dd92245a70d21f65f356f33fc80c5072eada87fb49437767ea2918b2c4b0"}, -] - -[package.dependencies] -anyio = ">=3.5.0,<5" -distro = ">=1.7.0,<2" -httpx = ">=0.25.0,<1" -jiter = ">=0.4.0,<1" -pydantic = ">=1.9.0,<3" -sniffio = "*" -typing-extensions = ">=4.10,<5" - -[package.extras] -aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"] -bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] -vertex = ["google-auth[requests] (>=2,<3)"] - [[package]] name = "anyio" version = "4.9.0" @@ -73,19 +47,6 @@ doc = ["Sphinx (>=8.2,<9.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "blockbuster (>=1.5.23)", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""] trio = ["trio (>=0.26.1)"] -[[package]] -name = "async-timeout" -version = "4.0.3" -description = "Timeout context manager for asyncio programs" -optional = false -python-versions = ">=3.7" -groups = ["main", "eval"] -files = [ - {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, - {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, -] -markers = {main = "(extra == \"mem-scheduler\" or extra == \"all\") and python_version == \"3.10\"", eval = "python_version == \"3.10\""} - [[package]] name = "async-timeout" version = "5.0.1" @@ -93,7 +54,7 @@ description = "Timeout context manager for asyncio programs" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "(extra == \"mem-scheduler\" or extra == \"all\") and python_full_version < \"3.11.3\" and python_version == \"3.11\"" +markers = "(python_version == \"3.10\" or python_version == \"3.11\") and python_full_version < \"3.11.3\" and (extra == \"mem-scheduler\" or extra == \"all\")" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -293,7 +254,7 @@ files = [ {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, ] -markers = {main = "platform_python_implementation != \"PyPy\"", eval = "platform_python_implementation == \"PyPy\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or platform_python_implementation != \"PyPy\"", eval = "platform_python_implementation == \"PyPy\""} [package.dependencies] pycparser = "*" @@ -823,24 +784,6 @@ files = [ [package.dependencies] python-dotenv = "*" -[[package]] -name = "dydantic" -version = "0.0.8" -description = "Dynamically generate pydantic models from JSON schema." -optional = false -python-versions = "<4.0,>=3.9" -groups = ["eval"] -files = [ - {file = "dydantic-0.0.8-py3-none-any.whl", hash = "sha256:cd0a991f523bd8632699872f1c0c4278415dd04783e36adec5428defa0afb721"}, - {file = "dydantic-0.0.8.tar.gz", hash = "sha256:14a31d4cdfce314ce3e69e8f8c7c46cbc26ce3ce4485de0832260386c612942f"}, -] - -[package.dependencies] -pydantic = ">=2,<3" - -[package.extras] -email = ["email-validator (>=2.1,<3.0)"] - [[package]] name = "email-validator" version = "2.2.0" @@ -1137,7 +1080,7 @@ description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.9" groups = ["main", "eval"] -markers = "(platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") and python_version < \"3.14\"" +markers = "(python_version == \"3.10\" or python_version == \"3.11\" or python_version == \"3.12\" or python_version == \"3.13\") and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")" files = [ {file = "greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be"}, {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac"}, @@ -1701,11 +1644,12 @@ version = "1.33" description = "Apply JSON-Patches (RFC 6902)" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" -groups = ["eval"] +groups = ["main", "eval"] files = [ {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, ] +markers = {main = "extra == \"mem-reader\" or extra == \"all\""} [package.dependencies] jsonpointer = ">=1.9" @@ -1716,11 +1660,12 @@ version = "3.0.0" description = "Identify specific nodes in a JSON document (RFC 6901)" optional = false python-versions = ">=3.7" -groups = ["eval"] +groups = ["main", "eval"] files = [ {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, {file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"}, ] +markers = {main = "extra == \"mem-reader\" or extra == \"all\""} [[package]] name = "jsonschema" @@ -1849,116 +1794,43 @@ files = [ {file = "kiwisolver-1.4.8.tar.gz", hash = "sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e"}, ] -[[package]] -name = "langchain" -version = "0.3.26" -description = "Building applications with LLMs through composability" -optional = false -python-versions = ">=3.9" -groups = ["eval"] -files = [ - {file = "langchain-0.3.26-py3-none-any.whl", hash = "sha256:361bb2e61371024a8c473da9f9c55f4ee50f269c5ab43afdb2b1309cb7ac36cf"}, - {file = "langchain-0.3.26.tar.gz", hash = "sha256:8ff034ee0556d3e45eff1f1e96d0d745ced57858414dba7171c8ebdbeb5580c9"}, -] - -[package.dependencies] -async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} -langchain-core = ">=0.3.66,<1.0.0" -langchain-text-splitters = ">=0.3.8,<1.0.0" -langsmith = ">=0.1.17" -pydantic = ">=2.7.4,<3.0.0" -PyYAML = ">=5.3" -requests = ">=2,<3" -SQLAlchemy = ">=1.4,<3" - -[package.extras] -anthropic = ["langchain-anthropic"] -aws = ["langchain-aws"] -azure-ai = ["langchain-azure-ai"] -cohere = ["langchain-cohere"] -community = ["langchain-community"] -deepseek = ["langchain-deepseek"] -fireworks = ["langchain-fireworks"] -google-genai = ["langchain-google-genai"] -google-vertexai = ["langchain-google-vertexai"] -groq = ["langchain-groq"] -huggingface = ["langchain-huggingface"] -mistralai = ["langchain-mistralai"] -ollama = ["langchain-ollama"] -openai = ["langchain-openai"] -perplexity = ["langchain-perplexity"] -together = ["langchain-together"] -xai = ["langchain-xai"] - -[[package]] -name = "langchain-anthropic" -version = "0.3.17" -description = "An integration package connecting AnthropicMessages and LangChain" -optional = false -python-versions = ">=3.9" -groups = ["eval"] -files = [ - {file = "langchain_anthropic-0.3.17-py3-none-any.whl", hash = "sha256:6df784615b93aab0336fbd6a50ca2bd16a704ef01c9488c36a4fa7aad2faf2d6"}, - {file = "langchain_anthropic-0.3.17.tar.gz", hash = "sha256:f2c2a0382ed7992204d790ff8538448f5243f4dbb1e798256ef790c9a69033e4"}, -] - -[package.dependencies] -anthropic = ">=0.57.0,<1" -langchain-core = ">=0.3.68,<1.0.0" -pydantic = ">=2.7.4,<3.0.0" - [[package]] name = "langchain-core" -version = "0.3.69" +version = "1.1.0" description = "Building applications with LLMs through composability" optional = false -python-versions = ">=3.9" -groups = ["eval"] +python-versions = "<4.0.0,>=3.10.0" +groups = ["main", "eval"] files = [ - {file = "langchain_core-0.3.69-py3-none-any.whl", hash = "sha256:383e9cb4919f7ef4b24bf8552ef42e4323c064924fea88b28dd5d7ddb740d3b8"}, - {file = "langchain_core-0.3.69.tar.gz", hash = "sha256:c132961117cc7f0227a4c58dd3e209674a6dd5b7e74abc61a0df93b0d736e283"}, + {file = "langchain_core-1.1.0-py3-none-any.whl", hash = "sha256:2c9f27dadc6d21ed4aa46506a37a56e6a7e2d2f9141922dc5c251ba921822ee6"}, + {file = "langchain_core-1.1.0.tar.gz", hash = "sha256:2b76a82d427922c8bc51c08404af4fc2a29e9f161dfe2297cb05091e810201e7"}, ] +markers = {main = "extra == \"mem-reader\" or extra == \"all\""} [package.dependencies] -jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.3.45" -packaging = ">=23.2" -pydantic = ">=2.7.4" -PyYAML = ">=5.3" +jsonpatch = ">=1.33.0,<2.0.0" +langsmith = ">=0.3.45,<1.0.0" +packaging = ">=23.2.0,<26.0.0" +pydantic = ">=2.7.4,<3.0.0" +pyyaml = ">=5.3.0,<7.0.0" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" -typing-extensions = ">=4.7" - -[[package]] -name = "langchain-openai" -version = "0.3.28" -description = "An integration package connecting OpenAI and LangChain" -optional = false -python-versions = ">=3.9" -groups = ["eval"] -files = [ - {file = "langchain_openai-0.3.28-py3-none-any.whl", hash = "sha256:4cd6d80a5b2ae471a168017bc01b2e0f01548328d83532400a001623624ede67"}, - {file = "langchain_openai-0.3.28.tar.gz", hash = "sha256:6c669548dbdea325c034ae5ef699710e2abd054c7354fdb3ef7bf909dc739d9e"}, -] - -[package.dependencies] -langchain-core = ">=0.3.68,<1.0.0" -openai = ">=1.86.0,<2.0.0" -tiktoken = ">=0.7,<1" +typing-extensions = ">=4.7.0,<5.0.0" [[package]] name = "langchain-text-splitters" -version = "0.3.8" +version = "1.0.0" description = "LangChain text splitting utilities" -optional = false -python-versions = "<4.0,>=3.9" -groups = ["eval"] +optional = true +python-versions = "<4.0.0,>=3.10.0" +groups = ["main"] +markers = "extra == \"mem-reader\" or extra == \"all\"" files = [ - {file = "langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02"}, - {file = "langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e"}, + {file = "langchain_text_splitters-1.0.0-py3-none-any.whl", hash = "sha256:f00c8219d3468f2c5bd951b708b6a7dd9bc3c62d0cfb83124c377f7170f33b2e"}, + {file = "langchain_text_splitters-1.0.0.tar.gz", hash = "sha256:d8580a20ad7ed10b432feb273e5758b2cc0902d094919629cec0e1ad691a6744"}, ] [package.dependencies] -langchain-core = ">=0.3.51,<1.0.0" +langchain-core = ">=1.0.0,<2.0.0" [[package]] name = "langgraph" @@ -2028,39 +1900,18 @@ files = [ httpx = ">=0.25.2" orjson = ">=3.10.1" -[[package]] -name = "langmem" -version = "0.0.27" -description = "Prebuilt utilities for memory management and retrieval." -optional = false -python-versions = ">=3.10" -groups = ["eval"] -files = [ - {file = "langmem-0.0.27-py3-none-any.whl", hash = "sha256:25e9f06ad7c420442cf4b62caff6f805b124dfb2e2cc9cacc464d7a455fbafda"}, - {file = "langmem-0.0.27.tar.gz", hash = "sha256:729c1eb77c4cd8d9f2285f908a68a1e622ef01f074eeeb8cbbc7343f296efc53"}, -] - -[package.dependencies] -langchain = ">=0.3.15" -langchain-anthropic = ">=0.3.3" -langchain-core = ">=0.3.46" -langchain-openai = ">=0.3.1" -langgraph = ">=0.3.23" -langgraph-checkpoint = ">=2.0.12" -langsmith = ">=0.3.8" -trustcall = ">=0.0.39" - [[package]] name = "langsmith" version = "0.4.7" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.9" -groups = ["eval"] +groups = ["main", "eval"] files = [ {file = "langsmith-0.4.7-py3-none-any.whl", hash = "sha256:de91f1abdd65da369996f8eedb5201f442110c9c3bde5babc6f5300f07da65df"}, {file = "langsmith-0.4.7.tar.gz", hash = "sha256:3864cf29295c2565c578e93d1533f5b39e2b4af616545ace30f069635a319890"}, ] +markers = {main = "extra == \"mem-reader\" or extra == \"all\""} [package.dependencies] httpx = ">=0.23.0,<1" @@ -2772,7 +2623,7 @@ files = [ {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"}, {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cuda-cupti-cu12" @@ -2788,7 +2639,7 @@ files = [ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"}, {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cuda-nvrtc-cu12" @@ -2802,7 +2653,7 @@ files = [ {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"}, {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cuda-runtime-cu12" @@ -2818,7 +2669,7 @@ files = [ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"}, {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cudnn-cu12" @@ -2832,7 +2683,7 @@ files = [ {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"}, {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-cublas-cu12 = "*" @@ -2851,7 +2702,7 @@ files = [ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"}, {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-nvjitlink-cu12 = "*" @@ -2867,7 +2718,7 @@ files = [ {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"}, {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-curand-cu12" @@ -2883,7 +2734,7 @@ files = [ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"}, {file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cusolver-cu12" @@ -2899,7 +2750,7 @@ files = [ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"}, {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-cublas-cu12 = "*" @@ -2920,7 +2771,7 @@ files = [ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"}, {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-nvjitlink-cu12 = "*" @@ -2937,7 +2788,7 @@ files = [ {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"}, {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-nccl-cu12" @@ -2950,7 +2801,7 @@ files = [ {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"}, {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-nvjitlink-cu12" @@ -2964,7 +2815,7 @@ files = [ {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"}, {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-nvtx-cu12" @@ -2980,7 +2831,7 @@ files = [ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"}, {file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "ollama" @@ -3637,7 +3488,7 @@ files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] -markers = {main = "platform_python_implementation != \"PyPy\"", eval = "platform_python_implementation == \"PyPy\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or platform_python_implementation != \"PyPy\"", eval = "platform_python_implementation == \"PyPy\""} [[package]] name = "pydantic" @@ -4068,7 +3919,7 @@ files = [ {file = "pywin32-311-cp39-cp39-win_amd64.whl", hash = "sha256:e0c4cfb0621281fe40387df582097fd796e80430597cb9944f0ae70447bacd91"}, {file = "pywin32-311-cp39-cp39-win_arm64.whl", hash = "sha256:62ea666235135fee79bb154e695f3ff67370afefd71bd7fea7512fc70ef31e3d"}, ] -markers = {main = "platform_system == \"Windows\" and extra == \"all\" or sys_platform == \"win32\"", eval = "platform_system == \"Windows\""} +markers = {main = "extra == \"all\" and platform_system == \"Windows\" or sys_platform == \"win32\"", eval = "platform_system == \"Windows\""} [[package]] name = "pyyaml" @@ -4352,11 +4203,12 @@ version = "1.0.0" description = "A utility belt for advanced users of python-requests" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -groups = ["eval"] +groups = ["main", "eval"] files = [ {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, ] +markers = {main = "extra == \"mem-reader\" or extra == \"all\""} [package.dependencies] requests = ">=2.0.1,<3.0.0" @@ -5065,7 +4917,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {main = "extra == \"all\" or extra == \"pref-mem\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5307,54 +5159,6 @@ files = [ {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"}, ] -[[package]] -name = "tiktoken" -version = "0.9.0" -description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" -optional = false -python-versions = ">=3.9" -groups = ["eval"] -files = [ - {file = "tiktoken-0.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:586c16358138b96ea804c034b8acf3f5d3f0258bd2bc3b0227af4af5d622e382"}, - {file = "tiktoken-0.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d9c59ccc528c6c5dd51820b3474402f69d9a9e1d656226848ad68a8d5b2e5108"}, - {file = "tiktoken-0.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0968d5beeafbca2a72c595e8385a1a1f8af58feaebb02b227229b69ca5357fd"}, - {file = "tiktoken-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92a5fb085a6a3b7350b8fc838baf493317ca0e17bd95e8642f95fc69ecfed1de"}, - {file = "tiktoken-0.9.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:15a2752dea63d93b0332fb0ddb05dd909371ededa145fe6a3242f46724fa7990"}, - {file = "tiktoken-0.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:26113fec3bd7a352e4b33dbaf1bd8948de2507e30bd95a44e2b1156647bc01b4"}, - {file = "tiktoken-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f32cc56168eac4851109e9b5d327637f15fd662aa30dd79f964b7c39fbadd26e"}, - {file = "tiktoken-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45556bc41241e5294063508caf901bf92ba52d8ef9222023f83d2483a3055348"}, - {file = "tiktoken-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03935988a91d6d3216e2ec7c645afbb3d870b37bcb67ada1943ec48678e7ee33"}, - {file = "tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3d80aad8d2c6b9238fc1a5524542087c52b860b10cbf952429ffb714bc1136"}, - {file = "tiktoken-0.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b2a21133be05dc116b1d0372af051cd2c6aa1d2188250c9b553f9fa49301b336"}, - {file = "tiktoken-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:11a20e67fdf58b0e2dea7b8654a288e481bb4fc0289d3ad21291f8d0849915fb"}, - {file = "tiktoken-0.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e88f121c1c22b726649ce67c089b90ddda8b9662545a8aeb03cfef15967ddd03"}, - {file = "tiktoken-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a6600660f2f72369acb13a57fb3e212434ed38b045fd8cc6cdd74947b4b5d210"}, - {file = "tiktoken-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e811743b5dfa74f4b227927ed86cbc57cad4df859cb3b643be797914e41794"}, - {file = "tiktoken-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99376e1370d59bcf6935c933cb9ba64adc29033b7e73f5f7569f3aad86552b22"}, - {file = "tiktoken-0.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:badb947c32739fb6ddde173e14885fb3de4d32ab9d8c591cbd013c22b4c31dd2"}, - {file = "tiktoken-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a62d7a25225bafed786a524c1b9f0910a1128f4232615bf3f8257a73aaa3b16"}, - {file = "tiktoken-0.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2b0e8e05a26eda1249e824156d537015480af7ae222ccb798e5234ae0285dbdb"}, - {file = "tiktoken-0.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:27d457f096f87685195eea0165a1807fae87b97b2161fe8c9b1df5bd74ca6f63"}, - {file = "tiktoken-0.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cf8ded49cddf825390e36dd1ad35cd49589e8161fdcb52aa25f0583e90a3e01"}, - {file = "tiktoken-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc156cb314119a8bb9748257a2eaebd5cc0753b6cb491d26694ed42fc7cb3139"}, - {file = "tiktoken-0.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cd69372e8c9dd761f0ab873112aba55a0e3e506332dd9f7522ca466e817b1b7a"}, - {file = "tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95"}, - {file = "tiktoken-0.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c6386ca815e7d96ef5b4ac61e0048cd32ca5a92d5781255e13b31381d28667dc"}, - {file = "tiktoken-0.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75f6d5db5bc2c6274b674ceab1615c1778e6416b14705827d19b40e6355f03e0"}, - {file = "tiktoken-0.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e15b16f61e6f4625a57a36496d28dd182a8a60ec20a534c5343ba3cafa156ac7"}, - {file = "tiktoken-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebcec91babf21297022882344c3f7d9eed855931466c3311b1ad6b64befb3df"}, - {file = "tiktoken-0.9.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e5fd49e7799579240f03913447c0cdfa1129625ebd5ac440787afc4345990427"}, - {file = "tiktoken-0.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:26242ca9dc8b58e875ff4ca078b9a94d2f0813e6a535dcd2205df5d49d927cc7"}, - {file = "tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d"}, -] - -[package.dependencies] -regex = ">=2022.1.18" -requests = ">=2.26.0" - -[package.extras] -blobfile = ["blobfile (>=2)"] - [[package]] name = "tokenizers" version = "0.21.2" @@ -5604,7 +5408,7 @@ files = [ {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"}, {file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] setuptools = ">=40.8.0" @@ -5614,23 +5418,6 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "isort", "llnl-hatchet", "numpy", "pytest", "pytest-forked", "pytest-xdist", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] -[[package]] -name = "trustcall" -version = "0.0.39" -description = "Tenacious & trustworthy tool calling built on LangGraph." -optional = false -python-versions = "<4.0,>=3.10" -groups = ["eval"] -files = [ - {file = "trustcall-0.0.39-py3-none-any.whl", hash = "sha256:d7da42e0bba816c0539b2936dfed90ffb3ea8d789e548e73865d416f8ac4ee64"}, - {file = "trustcall-0.0.39.tar.gz", hash = "sha256:ec315818224501b9537ce6b7618dbc21be41210c6e8f2e239169a5a00912cd6e"}, -] - -[package.dependencies] -dydantic = ">=0.0.8,<1.0.0" -jsonpatch = ">=1.33,<2.0" -langgraph = ">=0.2.25" - [[package]] name = "typer" version = "0.16.0" @@ -5830,7 +5617,7 @@ description = "Fast implementation of asyncio event loop on top of libuv" optional = false python-versions = ">=3.8.0" groups = ["main"] -markers = "platform_python_implementation != \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"cygwin\"" +markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"" files = [ {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"}, {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"}, @@ -6303,7 +6090,7 @@ version = "0.23.0" description = "Zstandard bindings for Python" optional = false python-versions = ">=3.8" -groups = ["eval"] +groups = ["main", "eval"] files = [ {file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"}, {file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"}, @@ -6403,6 +6190,7 @@ files = [ {file = "zstandard-0.23.0-cp39-cp39-win_amd64.whl", hash = "sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58"}, {file = "zstandard-0.23.0.tar.gz", hash = "sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09"}, ] +markers = {main = "extra == \"mem-reader\" or extra == \"all\""} [package.dependencies] cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\""} @@ -6411,8 +6199,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["cachetools", "chonkie", "datasketch", "jieba", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] -mem-reader = ["chonkie", "markitdown"] +all = ["cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +mem-reader = ["chonkie", "langchain-text-splitters", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] pref-mem = ["datasketch", "pymilvus"] @@ -6421,4 +6209,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "95e737a53fed62215bcb523c162e19ed67ffc745e27fa081bc3da5e356eba086" +content-hash = "1eae4dc9df321c2e5157497c7ce6fb2b1248cb1d4cf7d57e3d38710be977e07b" diff --git a/pyproject.toml b/pyproject.toml index 9a8db2694..265a5ae5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ mem-user = [ mem-reader = [ "chonkie (>=1.0.7,<2.0.0)", # Sentence chunking library "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", # Markdown parser for various file formats + "langchain-text-splitters (>=1.0.0,<2.0.0)", # markdown chunk for langchain ] # PreferenceTextMemory @@ -105,6 +106,7 @@ all = [ "pika (>=1.3.2,<2.0.0)", "pymysql (>=1.1.0,<2.0.0)", "chonkie (>=1.0.7,<2.0.0)", + "langchain-text-splitters (>=1.0.0,<2.0.0)", "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", "pymilvus (>=2.6.1,<3.0.0)", "datasketch (>=1.6.5,<2.0.0)", @@ -174,7 +176,6 @@ bert-score = "^0.3.13" scipy = "^1.10.1" python-dotenv = "^1.1.1" langgraph = "^0.5.1" -langmem = "^0.0.27" [tool.poetry.group.mem-user.dependencies] diff --git a/src/memos/chunkers/factory.py b/src/memos/chunkers/factory.py index 95b306aae..47c8fc71b 100644 --- a/src/memos/chunkers/factory.py +++ b/src/memos/chunkers/factory.py @@ -3,6 +3,7 @@ from memos.configs.chunker import ChunkerConfigFactory from .base import BaseChunker +from .markdown_chunker import MarkdownChunker from .sentence_chunker import SentenceChunker @@ -11,6 +12,7 @@ class ChunkerFactory: backend_to_class: ClassVar[dict[str, Any]] = { "sentence": SentenceChunker, + "markdown": MarkdownChunker, } @classmethod diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py new file mode 100644 index 000000000..477e96b8d --- /dev/null +++ b/src/memos/chunkers/markdown_chunker.py @@ -0,0 +1,53 @@ +from memos.configs.chunker import MarkdownChunkerConfig +from memos.dependency import require_python_package +from memos.log import get_logger + +from .base import BaseChunker, Chunk + + +logger = get_logger(__name__) + + +class MarkdownChunker(BaseChunker): + """Markdown-based text chunker.""" + + @require_python_package( + import_name="langchain_text_splitters", + install_command="pip install langchain_text_splitters==1.0.0", + install_link="https://github.com/langchain-ai/langchain-text-splitters", + ) + def __init__(self, config: MarkdownChunkerConfig): + from langchain_text_splitters import ( + MarkdownHeaderTextSplitter, + RecursiveCharacterTextSplitter, + ) + + self.config = config + self.chunker = MarkdownHeaderTextSplitter( + headers_to_split_on=config.headers_to_split_on, + strip_headers=config.strip_headers, + ) + self.chunker_recursive = None + logger.info(f"Initialized MarkdownHeaderTextSplitter with config: {config}") + if config.recursive: + self.chunker_recursive = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + ) + + def chunk(self, text: str) -> list[str] | list[Chunk]: + """Chunk the given text into smaller chunks based on sentences.""" + md_header_splits = self.chunker.split_text(text) + chunks = [] + if self.chunker_recursive: + md_header_splits = self.chunker_recursive.split_documents(md_header_splits) + for doc in md_header_splits: + try: + chunk = " ".join(list(doc.metadata.values())) + "\n" + doc.page_content + chunks.append(chunk) + except Exception as e: + logger.warning(f"warning chunking document: {e}") + chunks.append(doc.page_content) + + logger.debug(f"Generated {len(chunks)} chunks from input text") + return chunks diff --git a/src/memos/configs/chunker.py b/src/memos/configs/chunker.py index cb4f0e06d..c2af012f0 100644 --- a/src/memos/configs/chunker.py +++ b/src/memos/configs/chunker.py @@ -20,6 +20,19 @@ class SentenceChunkerConfig(BaseChunkerConfig): """Configuration for sentence-based text chunker.""" +class MarkdownChunkerConfig(BaseChunkerConfig): + """Configuration for markdown-based text chunker.""" + + headers_to_split_on: list[tuple[str, str]] = Field( + default=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")], + description="Headers to split on", + ) + strip_headers: bool = Field(default=True, description="Strip headers from the text") + recursive: bool = Field( + default=False, description="Whether to use recursive character text splitter" + ) + + class ChunkerConfigFactory(BaseConfig): """Factory class for creating chunker configurations.""" @@ -28,6 +41,7 @@ class ChunkerConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "sentence": SentenceChunkerConfig, + "markdown": MarkdownChunkerConfig, } @field_validator("backend") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index f428bf5c0..830b915c1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -507,7 +507,10 @@ def _retrieve_simple( user_name: str | None = None, **kwargs, ): - """Retrieve from by keywords and embedding""" + """ + Retrieve from by keywords and embedding, this func is hotfix for sources=plugin mode + will merge with fulltext retrieval in the future + """ query_words = [] if self.tokenizer: query_words = self.tokenizer.tokenize_mixed(query) From b8a953ad909dba5c1d9ca30b39e95d53741cbe6c Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 3 Dec 2025 10:56:29 +0800 Subject: [PATCH 23/35] feat: fix playground for query --- poetry.lock | 42 +++++++++---------- src/memos/api/handlers/chat_handler.py | 58 ++++++++++++++++---------- src/memos/api/product_models.py | 2 +- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/poetry.lock b/poetry.lock index 940697b1c..c6c82cdbb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -54,7 +54,7 @@ description = "Timeout context manager for asyncio programs" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "(python_version == \"3.10\" or python_version == \"3.11\") and python_full_version < \"3.11.3\" and (extra == \"mem-scheduler\" or extra == \"all\")" +markers = "python_full_version < \"3.11.3\" and (extra == \"mem-scheduler\" or extra == \"all\")" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -1080,7 +1080,7 @@ description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.9" groups = ["main", "eval"] -markers = "(python_version == \"3.10\" or python_version == \"3.11\" or python_version == \"3.12\" or python_version == \"3.13\") and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")" +markers = "(platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") and python_version < \"3.14\"" files = [ {file = "greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be"}, {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac"}, @@ -2623,7 +2623,7 @@ files = [ {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"}, {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-cuda-cupti-cu12" @@ -2639,7 +2639,7 @@ files = [ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"}, {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-cuda-nvrtc-cu12" @@ -2653,7 +2653,7 @@ files = [ {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"}, {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-cuda-runtime-cu12" @@ -2669,7 +2669,7 @@ files = [ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"}, {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-cudnn-cu12" @@ -2683,7 +2683,7 @@ files = [ {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"}, {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [package.dependencies] nvidia-cublas-cu12 = "*" @@ -2702,7 +2702,7 @@ files = [ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"}, {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [package.dependencies] nvidia-nvjitlink-cu12 = "*" @@ -2718,7 +2718,7 @@ files = [ {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"}, {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-curand-cu12" @@ -2734,7 +2734,7 @@ files = [ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"}, {file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-cusolver-cu12" @@ -2750,7 +2750,7 @@ files = [ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"}, {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [package.dependencies] nvidia-cublas-cu12 = "*" @@ -2771,7 +2771,7 @@ files = [ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"}, {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [package.dependencies] nvidia-nvjitlink-cu12 = "*" @@ -2788,7 +2788,7 @@ files = [ {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"}, {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-nccl-cu12" @@ -2801,7 +2801,7 @@ files = [ {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"}, {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-nvjitlink-cu12" @@ -2815,7 +2815,7 @@ files = [ {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"}, {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "nvidia-nvtx-cu12" @@ -2831,7 +2831,7 @@ files = [ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"}, {file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [[package]] name = "ollama" @@ -3919,7 +3919,7 @@ files = [ {file = "pywin32-311-cp39-cp39-win_amd64.whl", hash = "sha256:e0c4cfb0621281fe40387df582097fd796e80430597cb9944f0ae70447bacd91"}, {file = "pywin32-311-cp39-cp39-win_arm64.whl", hash = "sha256:62ea666235135fee79bb154e695f3ff67370afefd71bd7fea7512fc70ef31e3d"}, ] -markers = {main = "extra == \"all\" and platform_system == \"Windows\" or sys_platform == \"win32\"", eval = "platform_system == \"Windows\""} +markers = {main = "platform_system == \"Windows\" and extra == \"all\" or sys_platform == \"win32\"", eval = "platform_system == \"Windows\""} [[package]] name = "pyyaml" @@ -4917,7 +4917,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"all\" or extra == \"pref-mem\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5408,7 +5408,7 @@ files = [ {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"}, {file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"}, ] -markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} [package.dependencies] setuptools = ">=40.8.0" @@ -5617,7 +5617,7 @@ description = "Fast implementation of asyncio event loop on top of libuv" optional = false python-versions = ">=3.8.0" groups = ["main"] -markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"cygwin\"" files = [ {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"}, {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"}, diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 1054644d2..fe6b600b8 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -381,6 +381,10 @@ def generate_chat_response() -> Generator[str, None, None]: readable_cube_ids = chat_req.readable_cube_ids or ( [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] ) + # Resolve writable cube IDs (for add) + writable_cube_ids = chat_req.writable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) search_req = APISearchRequest( query=chat_req.query, @@ -397,6 +401,15 @@ def generate_chat_response() -> Generator[str, None, None]: ) search_response = self.search_handler.handle_search_memories(search_req) + # for playground, add the query to memory without response + self._start_add_to_memory( + user_id=chat_req.user_id, + writable_cube_ids=writable_cube_ids, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=None, + async_mode="sync", + ) yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" # Use first readable cube ID for scheduler (backward compatibility) @@ -539,11 +552,6 @@ def generate_chat_response() -> Generator[str, None, None]: speed_improvement=speed_improvement, current_messages=current_messages, ) - - # Resolve writable cube IDs (for add) - writable_cube_ids = chat_req.writable_cube_ids or ( - [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] - ) self._start_add_to_memory( user_id=chat_req.user_id, writable_cube_ids=writable_cube_ids, @@ -905,25 +913,29 @@ async def _add_conversation_to_memory( writable_cube_ids: list[str], session_id: str, query: str, - clean_response: str, + clean_response: str | None = None, async_mode: Literal["async", "sync"] = "sync", ) -> None: - add_req = APIADDRequest( - user_id=user_id, - writable_cube_ids=writable_cube_ids, - session_id=session_id, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, + messages = [ + { + "role": "user", + "content": query, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + } + ] + if clean_response: + messages.append( { "role": "assistant", "content": clean_response, "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], + } + ) + add_req = APIADDRequest( + user_id=user_id, + writable_cube_ids=writable_cube_ids, + session_id=session_id, + messages=messages, async_mode=async_mode, ) @@ -1128,7 +1140,7 @@ def _start_add_to_memory( writable_cube_ids: list[str], session_id: str, query: str, - full_response: str, + full_response: str | None = None, async_mode: Literal["async", "sync"] = "sync", ) -> None: def run_async_in_thread(): @@ -1136,7 +1148,9 @@ def run_async_in_thread(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - clean_response, _ = self._extract_references_from_response(full_response) + clean_response = full_response + if full_response: + clean_response, _ = self._extract_references_from_response(full_response) loop.run_until_complete( self._add_conversation_to_memory( user_id=user_id, @@ -1157,7 +1171,9 @@ def run_async_in_thread(): try: asyncio.get_running_loop() - clean_response, _ = self._extract_references_from_response(full_response) + clean_response = full_response + if full_response: + clean_response, _ = self._extract_references_from_response(full_response) task = asyncio.create_task( self._add_conversation_to_memory( user_id=user_id, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 16ae86638..ffe736aa3 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -6,7 +6,7 @@ # Import message types from core types module from memos.log import get_logger -from memos.types import MessageDict, MessageList, MessagesType, PermissionDict, SearchMode +from memos.types import MessageDict, PermissionDict, SearchMode logger = get_logger(__name__) From 78c1582bb2d60559bb425935aaaddbd2d1060433 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 11:52:43 +0800 Subject: [PATCH 24/35] feat: update file content memory extract --- .../read_multi_modal/file_content_parser.py | 246 ++++++++++++++---- .../mem_reader/read_multi_modal/utils.py | 98 ++++++- 2 files changed, 296 insertions(+), 48 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 8a08d6a93..9c22e43ec 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -1,10 +1,14 @@ """Parser for file content parts (RawMessageList).""" +import concurrent.futures import os import tempfile from typing import Any +from tqdm import tqdm + +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM from memos.log import get_logger @@ -13,19 +17,72 @@ TextualMemoryItem, TreeNodeTextualMemoryMetadata, ) +from memos.templates.mem_reader_prompts import ( + CUSTOM_TAGS_INSTRUCTION, + CUSTOM_TAGS_INSTRUCTION_ZH, + SIMPLE_STRUCT_DOC_READER_PROMPT, + SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, +) from memos.types.openai_chat_completion_types import File from .base import BaseMessageParser, _derive_key -from .utils import get_parser, get_text_splitter +from .utils import ( + detect_lang, + get_parser, + get_text_splitter, + parse_json_result, +) logger = get_logger(__name__) +# Prompt dictionary for doc processing (shared by simple_struct and file_content_parser) +DOC_PROMPT_DICT = { + "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, + "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, +} + class FileContentParser(BaseMessageParser): """Parser for file content parts.""" - def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None]: + def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None = None) -> dict: + """ + Call LLM to extract memory from document chunk. + Uses doc prompts from DOC_PROMPT_DICT. + + Args: + chunk_text: Text chunk to extract memory from + custom_tags: Optional list of custom tags for LLM extraction + + Returns: + Parsed JSON response from LLM or empty dict if failed + """ + if not self.llm: + logger.warning("[FileContentParser] LLM not available for fine mode") + return {} + + lang = detect_lang(chunk_text) + template = DOC_PROMPT_DICT["doc"][lang] + prompt = template.replace("{chunk_text}", chunk_text) + + custom_tags_prompt = ( + DOC_PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) + + messages = [{"role": "user", "content": prompt}] + try: + response_text = self.llm.generate(messages) + response_json = parse_json_result(response_text) + except Exception as e: + logger.error(f"[FileContentParser] LLM generation error: {e}") + response_json = {} + return response_json + + def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, bool]: """Download and parse file from URL.""" try: from urllib.parse import urlparse @@ -42,14 +99,14 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None]: filename = os.path.basename(parsed_url.path) or "downloaded_file" if hostname in self.direct_markdown_hostnames: - return response.text, None + return response.text, None, True file_ext = os.path.splitext(filename)[1].lower() if file_ext in [".md", ".markdown", ".txt"]: - return response.text, None + return response.text, None, True with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_ext) as temp_file: temp_file.write(response.content) - return "", temp_file.name + return "", temp_file.name, False except Exception as e: logger.error(f"[FileContentParser] URL processing error: {e}") return f"[File URL download failed: {url_str}]", None @@ -108,7 +165,7 @@ def __init__( else: self.direct_markdown_hostnames = [] - def _split_text(self, text: str) -> list[str]: + def _split_text(self, text: str, is_markdown: bool = False) -> list[str]: """ Split text into chunks using text splitter from utils. @@ -288,6 +345,8 @@ def parse_fast( # Extract info fields info_ = info.copy() + if file_id: + info_.update({"file_id": file_id}) user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") @@ -358,10 +417,19 @@ def parse_fine( """ Parse file content part in fine mode. Fine mode downloads and parses file content, especially for URLs. + Then uses LLM to extract structured memories from each chunk. + Handles various file parameter scenarios: - file_data: URL (http://, https://, or @http://), base64 encoded data, or plain text content - file_id: ID of an uploaded file - filename: name of the file + + Args: + message: File content part to parse + info: Dictionary containing user_id and session_id + **kwargs: Additional parameters including: + - custom_tags: Optional list of custom tags for LLM extraction + - context_items: Optional list of TextualMemoryItem for context """ if not isinstance(message, dict): logger.warning(f"[FileContentParser] Expected dict, got {type(message)}") @@ -378,6 +446,9 @@ def parse_fine( file_id = file_info.get("file_id", "") filename = file_info.get("filename", "") + # Extract custom_tags from kwargs (for LLM extraction) + custom_tags = kwargs.get("custom_tags") + # Use parser from utils parser = self.parser or get_parser() if not parser: @@ -386,6 +457,7 @@ def parse_fine( parsed_text = "" temp_file_path = None + is_markdown = False try: # Priority 1: If file_data is provided, process it @@ -394,7 +466,9 @@ def parse_fine( url_str = file_data[1:] if file_data.startswith("@") else file_data if url_str.startswith(("http://", "https://")): - parsed_text, temp_file_path = self._handle_url(url_str, filename) + parsed_text, temp_file_path, is_markdown = self._handle_url( + url_str, filename + ) if temp_file_path: try: # Use parser from utils @@ -457,28 +531,32 @@ def parse_fine( memory_type = "LongTermMemory" # Split parsed text into chunks - content_chunks = self._split_text(parsed_text) - - # Create memory items for each chunk - memory_items = [] - for chunk_idx, chunk_text in enumerate(content_chunks): - if not chunk_text.strip(): - continue - - memory_item = TextualMemoryItem( - memory=chunk_text, + content_chunks = self._split_text(parsed_text, is_markdown) + + # Filter out empty chunks and create indexed list + valid_chunks = [ + (idx, chunk_text) for idx, chunk_text in enumerate(content_chunks) if chunk_text.strip() + ] + total_chunks = len(content_chunks) + + # Helper function to create memory item (similar to SimpleStructMemReader._make_memory_item) + def _make_memory_item( + value: str, + mem_type: str = memory_type, + tags: list[str] | None = None, + key: str | None = None, + ) -> TextualMemoryItem: + """Construct memory item with common fields.""" + return TextualMemoryItem( + memory=value, metadata=TreeNodeTextualMemoryMetadata( user_id=user_id, session_id=session_id, - memory_type=memory_type, + memory_type=mem_type, status="activated", - tags=[ - "mode:fine", - "multimodal:file", - f"chunk:{chunk_idx + 1}/{len(content_chunks)}", - ], - key=_derive_key(chunk_text), - embedding=self.embedder.embed([chunk_text])[0], + tags=tags or [], + key=key if key is not None else _derive_key(value), + embedding=self.embedder.embed([value])[0], usage=[], sources=[source], background="", @@ -487,28 +565,102 @@ def parse_fine( info=info_, ), ) - memory_items.append(memory_item) - # If no chunks were created, create a placeholder - if not memory_items: - memory_item = TextualMemoryItem( - memory=parsed_text, - metadata=TreeNodeTextualMemoryMetadata( - user_id=user_id, - session_id=session_id, - memory_type=memory_type, - status="activated", - tags=["mode:fine", "multimodal:file"], - key=_derive_key(parsed_text), - embedding=self.embedder.embed([parsed_text])[0], - usage=[], - sources=[source], - background="", - confidence=0.99, - type="fact", - info=info_, - ), + # Helper function to create fallback item for a chunk + def _make_fallback( + chunk_idx: int, chunk_text: str, reason: str = "raw" + ) -> TextualMemoryItem: + """Create fallback memory item with raw chunk text.""" + return _make_memory_item( + value=chunk_text, + tags=[ + "mode:fine", + "multimodal:file", + f"fallback:{reason}", + f"chunk:{chunk_idx + 1}/{total_chunks}", + ], ) - memory_items.append(memory_item) - return memory_items + # Handle empty chunks case + if not valid_chunks: + return [ + _make_memory_item( + value=parsed_text or "[File: empty content]", + tags=["mode:fine", "multimodal:file"], + ) + ] + + # If no LLM available, create memory items directly from chunks + if not self.llm: + return [_make_fallback(idx, text, "no_llm") for idx, text in valid_chunks] + + # Process single chunk with LLM extraction (worker function) + def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: + """Process chunk with LLM, fallback to raw on failure.""" + try: + response_json = self._get_doc_llm_response(chunk_text, custom_tags) + if response_json: + value = response_json.get("value", "").strip() + if value: + tags = response_json.get("tags", []) + tags = tags if isinstance(tags, list) else [] + tags.extend(["mode:fine", "multimodal:file"]) + + llm_mem_type = response_json.get("memory_type", memory_type) + if llm_mem_type not in ["LongTermMemory", "UserMemory"]: + llm_mem_type = memory_type + + return _make_memory_item( + value=value, + mem_type=llm_mem_type, + tags=tags, + key=response_json.get("key"), + ) + except Exception as e: + logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}") + + # Fallback to raw chunk + logger.warning(f"[FileContentParser] Fallback to raw for chunk {chunk_idx}") + return _make_fallback(chunk_idx, chunk_text) + + # Process chunks concurrently with progress bar + memory_items = [] + chunk_map = dict(valid_chunks) + total_chunks = len(valid_chunks) + + logger.info(f"[FileContentParser] Processing {total_chunks} chunks with LLM...") + + with ContextThreadPoolExecutor(max_workers=20) as executor: + futures = { + executor.submit(_process_chunk, idx, text): idx for idx, text in valid_chunks + } + + # Use tqdm for progress bar (similar to simple_struct.py _process_doc_data) + for future in tqdm( + concurrent.futures.as_completed(futures), + total=total_chunks, + desc="[FileContentParser] Processing chunks", + ): + chunk_idx = futures[future] + try: + node = future.result() + if node: + memory_items.append(node) + except Exception as e: + tqdm.write(f"[ERROR] Chunk {chunk_idx} failed: {e}") + logger.error(f"[FileContentParser] Future failed for chunk {chunk_idx}: {e}") + # Create fallback for failed future + if chunk_idx in chunk_map: + memory_items.append( + _make_fallback(chunk_idx, chunk_map[chunk_idx], "error") + ) + + logger.info( + f"[FileContentParser] Completed processing {len(memory_items)}/{total_chunks} chunks" + ) + + return memory_items or [ + _make_memory_item( + value=parsed_text or "[File: empty content]", tags=["mode:fine", "multimodal:file"] + ) + ] diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index 992011765..60ed5544c 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -1,5 +1,6 @@ """Utility functions for message parsing.""" +import json import os import re @@ -43,6 +44,94 @@ re.I, ) + +def detect_lang(text: str) -> str: + """ + Detect language of text (zh or en). + + Args: + text: Text to detect language for + + Returns: + "zh" for Chinese, "en" for English (default) + """ + try: + if not text or not isinstance(text, str): + return "en" + cleaned_text = text + # remove role and timestamp + cleaned_text = re.sub( + r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE + ) + cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) + + # extract chinese characters + chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" + chinese_chars = re.findall(chinese_pattern, cleaned_text) + text_without_special = re.sub(r"[\s\d\W]", "", cleaned_text) + if text_without_special and len(chinese_chars) / len(text_without_special) > 0.3: + return "zh" + return "en" + except Exception: + return "en" + + +def parse_json_result(response_text: str) -> dict: + """ + Parse JSON result from LLM response. + + Handles various formats including: + - JSON wrapped in markdown code blocks + - Raw JSON + - Incomplete JSON (attempts to fix) + + Args: + response_text: Raw response text from LLM + + Returns: + Parsed dictionary or empty dict if parsing fails + """ + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + try: + return json.loads(s) + except json.JSONDecodeError: + pass + logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw: {response_text}") + return {} + + # Default configuration for parser and text splitter DEFAULT_PARSER_CONFIG = { "backend": "markitdown", @@ -114,7 +203,10 @@ def _simple_split_text(text: str, chunk_size: int, chunk_overlap: int) -> list[s from langchain.text_splitter import RecursiveCharacterTextSplitter except ImportError: try: - from langchain_text_splitters import RecursiveCharacterTextSplitter + from langchain_text_splitters import ( + MarkdownHeaderTextSplitter, + RecursiveCharacterTextSplitter, + ) except ImportError: logger.error( "langchain not available. Install with: pip install langchain or pip install langchain-text-splitters" @@ -126,6 +218,10 @@ def _simple_split_text(text: str, chunk_size: int, chunk_overlap: int) -> list[s length_function=len, separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], ) + markdown_text_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")], + strip_headers=False, + ) logger.debug( f"[FileContentParser] Initialized langchain text splitter with chunk_size={DEFAULT_CHUNK_SIZE}, " f"chunk_overlap={DEFAULT_CHUNK_OVERLAP}" From ef236cb436fabd8d640a61bd15c1d5322c0c3e56 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 14:28:13 +0800 Subject: [PATCH 25/35] feat: update code --- src/memos/mem_reader/read_multi_modal/file_content_parser.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 74d62e98a..8202f03a0 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -29,7 +29,6 @@ from .utils import ( detect_lang, get_parser, - get_text_splitter, parse_json_result, ) @@ -504,7 +503,7 @@ def parse_fine( memory_type = "LongTermMemory" # Split parsed text into chunks - content_chunks = self._split_text(parsed_text, is_markdown) + content_chunks = self._split_text(parsed_text) # Filter out empty chunks and create indexed list valid_chunks = [ From 905b1e22dadb8d2c9acdb39274958e7799ea028e Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 14:35:40 +0800 Subject: [PATCH 26/35] feat: update import --- .../read_multi_modal/file_content_parser.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 8202f03a0..4ec4f5279 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -12,6 +12,12 @@ from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM from memos.log import get_logger +from memos.mem_reader.read_multi_modal.base import BaseMessageParser, _derive_key +from memos.mem_reader.read_multi_modal.utils import ( + detect_lang, + get_parser, + parse_json_result, +) from memos.memories.textual.item import ( SourceMessage, TextualMemoryItem, @@ -25,13 +31,6 @@ ) from memos.types.openai_chat_completion_types import File -from .base import BaseMessageParser, _derive_key -from .utils import ( - detect_lang, - get_parser, - parse_json_result, -) - logger = get_logger(__name__) From eb09595c4d8fdafb14446022fcb97910385aa1f9 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 14:54:25 +0800 Subject: [PATCH 27/35] code: reformat suffix --- .../mem_reader/read_multi_modal/utils.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index 60ed5544c..1f7b0e1f8 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -45,37 +45,6 @@ ) -def detect_lang(text: str) -> str: - """ - Detect language of text (zh or en). - - Args: - text: Text to detect language for - - Returns: - "zh" for Chinese, "en" for English (default) - """ - try: - if not text or not isinstance(text, str): - return "en" - cleaned_text = text - # remove role and timestamp - cleaned_text = re.sub( - r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE - ) - cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) - - # extract chinese characters - chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" - chinese_chars = re.findall(chinese_pattern, cleaned_text) - text_without_special = re.sub(r"[\s\d\W]", "", cleaned_text) - if text_without_special and len(chinese_chars) / len(text_without_special) > 0.3: - return "zh" - return "en" - except Exception: - return "en" - - def parse_json_result(response_text: str) -> dict: """ Parse JSON result from LLM response. @@ -433,3 +402,34 @@ def coerce_scene_data(scene_data: SceneDataInput, scene_type: str) -> list[Messa # fallback return [str(scene_data)] + + +def detect_lang(text: str) -> str: + """ + Detect language of text (zh or en). + + Args: + text: Text to detect language for + + Returns: + "zh" for Chinese, "en" for English (default) + """ + try: + if not text or not isinstance(text, str): + return "en" + cleaned_text = text + # remove role and timestamp + cleaned_text = re.sub( + r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE + ) + cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) + + # extract chinese characters + chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" + chinese_chars = re.findall(chinese_pattern, cleaned_text) + text_without_special = re.sub(r"[\s\d\W]", "", cleaned_text) + if text_without_special and len(chinese_chars) / len(text_without_special) > 0.3: + return "zh" + return "en" + except Exception: + return "en" From a93fc0f81336ba159c9e12a458bd3ac028ae4574 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 16:23:25 +0800 Subject: [PATCH 28/35] feat: update file_id --- src/memos/mem_reader/read_multi_modal/file_content_parser.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 4ec4f5279..b5305af9a 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -497,6 +497,8 @@ def parse_fine( info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") + if file_id: + info_["file_id"] = file_id # For file content parts, default to LongTermMemory memory_type = "LongTermMemory" From 035013533b6332c22a7bcd0ee97bde697216bd3a Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 16:46:48 +0800 Subject: [PATCH 29/35] remove langchain-text-splitters==1.0.0 --- docker/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index 21f246599..d3268edae 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -160,4 +160,3 @@ xlrd==2.0.2 xlsxwriter==3.2.5 prometheus-client==0.23.1 pymilvus==2.5.12 -langchain-text-splitters==1.0.0 From ecce0f1d0259488e17572ebf2c3b19c7883d6e5a Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 16:50:50 +0800 Subject: [PATCH 30/35] feat: add reqiuement --- docker/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/requirements.txt b/docker/requirements.txt index d3268edae..21f246599 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -160,3 +160,4 @@ xlrd==2.0.2 xlsxwriter==3.2.5 prometheus-client==0.23.1 pymilvus==2.5.12 +langchain-text-splitters==1.0.0 From 4f53e293fcdcf03d64a81ebb7fb59d77eb8171ce Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 19:07:45 +0800 Subject: [PATCH 31/35] feat: make test --- src/memos/chunkers/charactertext_chunker.py | 41 ++++++ src/memos/chunkers/markdown_chunker.py | 23 ++-- src/memos/chunkers/simple_chunker.py | 50 ++++++++ src/memos/mem_reader/read_multi_modal/base.py | 4 +- .../read_multi_modal/file_content_parser.py | 2 +- .../mem_reader/read_multi_modal/utils.py | 119 ++++-------------- 6 files changed, 131 insertions(+), 108 deletions(-) create mode 100644 src/memos/chunkers/charactertext_chunker.py create mode 100644 src/memos/chunkers/simple_chunker.py diff --git a/src/memos/chunkers/charactertext_chunker.py b/src/memos/chunkers/charactertext_chunker.py new file mode 100644 index 000000000..15c0958ba --- /dev/null +++ b/src/memos/chunkers/charactertext_chunker.py @@ -0,0 +1,41 @@ +from memos.configs.chunker import MarkdownChunkerConfig +from memos.dependency import require_python_package +from memos.log import get_logger + +from .base import BaseChunker, Chunk + + +logger = get_logger(__name__) + + +class CharacterTextChunker(BaseChunker): + """Character-based text chunker.""" + + @require_python_package( + import_name="langchain_text_splitters", + install_command="pip install langchain_text_splitters==1.0.0", + install_link="https://github.com/langchain-ai/langchain-text-splitters", + ) + def __init__( + self, + config: MarkdownChunkerConfig | None = None, + chunk_size: int = 1000, + chunk_overlap: int = 200, + ): + from langchain_text_splitters import ( + RecursiveCharacterTextSplitter, + ) + + self.config = config + self.chunker = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size if config else chunk_size, + chunk_overlap=config.chunk_overlap if config else chunk_overlap, + length_function=len, + separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], + ) + + def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: + """Chunk the given text into smaller chunks based on sentences.""" + chunks = self.chunker.split_text(text) + logger.debug(f"Generated {len(chunks)} chunks from input text") + return chunks diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py index 477e96b8d..de375a4dc 100644 --- a/src/memos/chunkers/markdown_chunker.py +++ b/src/memos/chunkers/markdown_chunker.py @@ -16,7 +16,13 @@ class MarkdownChunker(BaseChunker): install_command="pip install langchain_text_splitters==1.0.0", install_link="https://github.com/langchain-ai/langchain-text-splitters", ) - def __init__(self, config: MarkdownChunkerConfig): + def __init__( + self, + config: MarkdownChunkerConfig | None = None, + chunk_size: int = 1000, + chunk_overlap: int = 200, + recursive: bool = False, + ): from langchain_text_splitters import ( MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter, @@ -24,18 +30,21 @@ def __init__(self, config: MarkdownChunkerConfig): self.config = config self.chunker = MarkdownHeaderTextSplitter( - headers_to_split_on=config.headers_to_split_on, - strip_headers=config.strip_headers, + headers_to_split_on=config.headers_to_split_on + if config + else [("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")], + strip_headers=config.strip_headers if config else False, ) self.chunker_recursive = None logger.info(f"Initialized MarkdownHeaderTextSplitter with config: {config}") - if config.recursive: + if (config and config.recursive) or recursive: self.chunker_recursive = RecursiveCharacterTextSplitter( - chunk_size=config.chunk_size, - chunk_overlap=config.chunk_overlap, + chunk_size=config.chunk_size if config else chunk_size, + chunk_overlap=config.chunk_overlap if config else chunk_overlap, + length_function=len, ) - def chunk(self, text: str) -> list[str] | list[Chunk]: + def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" md_header_splits = self.chunker.split_text(text) chunks = [] diff --git a/src/memos/chunkers/simple_chunker.py b/src/memos/chunkers/simple_chunker.py new file mode 100644 index 000000000..5175e8e06 --- /dev/null +++ b/src/memos/chunkers/simple_chunker.py @@ -0,0 +1,50 @@ +class SimpleTextSplitter: + """Simple text splitter wrapper.""" + + def __init__(self, chunk_size: int, chunk_overlap: int): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def chunk(self, text: str, **kwargs) -> list[str]: + return _simple_split_text(text, self.chunk_size, self.chunk_overlap) + + def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> list[str]: + """ + Simple text splitter as fallback when langchain is not available. + + Args: + text: Text to split + chunk_size: Maximum size of chunks + chunk_overlap: Overlap between chunks + + Returns: + List of text chunks + """ + if not text or len(text) <= chunk_size: + return [text] if text.strip() else [] + + chunks = [] + start = 0 + text_len = len(text) + + while start < text_len: + # Calculate end position + end = min(start + chunk_size, text_len) + + # If not the last chunk, try to break at a good position + if end < text_len: + # Try to break at newline, sentence end, or space + for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: + last_sep = text.rfind(separator, start, end) + if last_sep != -1: + end = last_sep + len(separator) + break + + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + + # Move start position with overlap + start = max(start + 1, end - chunk_overlap) + + return chunks diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py index 123eb22bc..a1d4d6ed6 100644 --- a/src/memos/mem_reader/read_multi_modal/base.py +++ b/src/memos/mem_reader/read_multi_modal/base.py @@ -226,7 +226,7 @@ def parse( else: raise ValueError(f"Unknown mode: {mode}. Must be 'fast' or 'fine'") - def _split_text(self, text: str) -> list[str]: + def _split_text(self, text: str, is_markdown: bool = False) -> list[str]: """ Split text into chunks using text splitter from utils. @@ -245,7 +245,7 @@ def _split_text(self, text: str) -> list[str]: return [text] if text.strip() else [] try: - chunks = splitter.split_text(text) + chunks = splitter.chunk(text, is_markdown) logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks") return chunks except Exception as e: diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index b5305af9a..39d9f1e8e 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -504,7 +504,7 @@ def parse_fine( memory_type = "LongTermMemory" # Split parsed text into chunks - content_chunks = self._split_text(parsed_text) + content_chunks = self._split_text(parsed_text, is_markdown) # Filter out empty chunks and create indexed list valid_chunks = [ diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index 0c887a9f2..137312af4 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -111,48 +111,6 @@ def _cheap_close(t: str) -> str: DEFAULT_CHUNK_OVERLAP = int(os.getenv("FILE_PARSER_CHUNK_OVERLAP", "200")) -def _simple_split_text(text: str, chunk_size: int, chunk_overlap: int) -> list[str]: - """ - Simple text splitter as fallback when langchain is not available. - - Args: - text: Text to split - chunk_size: Maximum size of chunks - chunk_overlap: Overlap between chunks - - Returns: - List of text chunks - """ - if not text or len(text) <= chunk_size: - return [text] if text.strip() else [] - - chunks = [] - start = 0 - text_len = len(text) - - while start < text_len: - # Calculate end position - end = min(start + chunk_size, text_len) - - # If not the last chunk, try to break at a good position - if end < text_len: - # Try to break at newline, sentence end, or space - for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: - last_sep = text.rfind(separator, start, end) - if last_sep != -1: - end = last_sep + len(separator) - break - - chunk = text[start:end].strip() - if chunk: - chunks.append(chunk) - - # Move start position with overlap - start = max(start + 1, end - chunk_overlap) - - return chunks - - # Initialize parser instance file_parser = None try: @@ -163,51 +121,27 @@ def _simple_split_text(text: str, chunk_size: int, chunk_overlap: int) -> list[s logger.error(f"[FileContentParser] Failed to create parser: {e}") file_parser = None -# Initialize text splitter instance -text_splitter = None -_use_simple_splitter = False +markdown_text_splitter = None try: - try: - from langchain.text_splitter import RecursiveCharacterTextSplitter - except ImportError: - try: - from langchain_text_splitters import ( - MarkdownHeaderTextSplitter, - RecursiveCharacterTextSplitter, - ) - except ImportError: - logger.error( - "langchain not available. Install with: pip install langchain or pip install langchain-text-splitters" - ) - - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=DEFAULT_CHUNK_SIZE, - chunk_overlap=DEFAULT_CHUNK_OVERLAP, - length_function=len, - separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], - ) - markdown_text_splitter = MarkdownHeaderTextSplitter( - headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")], - strip_headers=False, - ) - logger.debug( - f"[FileContentParser] Initialized langchain text splitter with chunk_size={DEFAULT_CHUNK_SIZE}, " - f"chunk_overlap={DEFAULT_CHUNK_OVERLAP}" + from memos.chunkers.charactertext_chunker import CharacterTextChunker + from memos.chunkers.markdown_chunker import MarkdownChunker + + markdown_text_splitter = MarkdownChunker( + chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP, recursive=True ) -except ImportError as e: - logger.warning( - f"[FileContentParser] langchain not available, using simple text splitter as fallback: {e}. " - "Install with: pip install langchain or pip install langchain-text-splitters" + text_splitter = CharacterTextChunker( + chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP ) - text_splitter = None - _use_simple_splitter = True + logger.info("[FileContentParser] Initialized text splitter instances by lancga") except Exception as e: - logger.error( - f"[FileContentParser] Failed to initialize text splitter: {e}, using simple splitter as fallback" + logger.warning( + f"[FileContentParser] Failed to create text splitter: {e} will use simple splitter fallback" ) + from memos.chunkers.simple_chunker import SimpleTextSplitter + + markdown_text_splitter = None text_splitter = None - _use_simple_splitter = True def get_parser() -> Any: @@ -220,7 +154,9 @@ def get_parser() -> Any: return file_parser -def get_text_splitter(chunk_size: int | None = None, chunk_overlap: int | None = None) -> Any: +def get_text_splitter( + chunk_size: int | None = None, chunk_overlap: int | None = None, is_markdown: bool = False +) -> Any: """ Get text splitter instance or a callable that uses simple splitter. @@ -231,28 +167,15 @@ def get_text_splitter(chunk_size: int | None = None, chunk_overlap: int | None = Returns: Text splitter instance (RecursiveCharacterTextSplitter) or a callable wrapper for simple splitter """ - if text_splitter is not None: + if is_markdown and markdown_text_splitter is not None: + return markdown_text_splitter + elif text_splitter is not None: return text_splitter - - # Return a callable wrapper that uses simple splitter - if _use_simple_splitter: + else: actual_chunk_size = chunk_size or DEFAULT_CHUNK_SIZE actual_chunk_overlap = chunk_overlap or DEFAULT_CHUNK_OVERLAP - - class SimpleTextSplitter: - """Simple text splitter wrapper.""" - - def __init__(self, chunk_size: int, chunk_overlap: int): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - def split_text(self, text: str) -> list[str]: - return _simple_split_text(text, self.chunk_size, self.chunk_overlap) - return SimpleTextSplitter(actual_chunk_size, actual_chunk_overlap) - return None - def extract_role(message: dict[str, Any]) -> str: """Extract role from message.""" From 8e3063a307869ab8ede1f3e57ba1acbafbdd057a Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 20:42:04 +0800 Subject: [PATCH 32/35] feat: fix markdown --- src/memos/mem_reader/read_multi_modal/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py index a1d4d6ed6..a3992a1f1 100644 --- a/src/memos/mem_reader/read_multi_modal/base.py +++ b/src/memos/mem_reader/read_multi_modal/base.py @@ -245,7 +245,7 @@ def _split_text(self, text: str, is_markdown: bool = False) -> list[str]: return [text] if text.strip() else [] try: - chunks = splitter.chunk(text, is_markdown) + chunks = splitter.chunk(text) logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks") return chunks except Exception as e: From b022b043865109152909eb652854ec1fbcb87369 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 4 Dec 2025 20:56:44 +0800 Subject: [PATCH 33/35] feat: fix simple chunker --- src/memos/chunkers/simple_chunker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/chunkers/simple_chunker.py b/src/memos/chunkers/simple_chunker.py index 5175e8e06..cc0dc40d0 100644 --- a/src/memos/chunkers/simple_chunker.py +++ b/src/memos/chunkers/simple_chunker.py @@ -6,7 +6,7 @@ def __init__(self, chunk_size: int, chunk_overlap: int): self.chunk_overlap = chunk_overlap def chunk(self, text: str, **kwargs) -> list[str]: - return _simple_split_text(text, self.chunk_size, self.chunk_overlap) + return self._simple_split_text(text, self.chunk_size, self.chunk_overlap) def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> list[str]: """ From 5847e073384aba9b33895a90c84db57b806989e9 Mon Sep 17 00:00:00 2001 From: fridayL Date: Sun, 7 Dec 2025 17:01:51 +0800 Subject: [PATCH 34/35] feat: add file sources --- .../mem_reader/read_multi_modal/file_content_parser.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 67de3020d..9efb58263 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -167,6 +167,7 @@ def create_source( self, message: File, info: dict[str, Any], + chunk_content: str | None = None, ) -> SourceMessage: """Create SourceMessage from file content part.""" if isinstance(message, dict): @@ -174,7 +175,7 @@ def create_source( return SourceMessage( type="file", doc_path=file_info.get("filename") or file_info.get("file_id", ""), - content=file_info.get("file_data", ""), + content=chunk_content if chunk_content else file_info.get("file_data", ""), original_part=message, ) return SourceMessage(type="file", doc_path=str(message)) @@ -490,9 +491,6 @@ def parse_fine( f"[FileContentParser] Failed to delete temp file {temp_file_path}: {e}" ) - # Create source - source = self.create_source(message, info) - # Extract info fields if not info: info = {} @@ -520,8 +518,10 @@ def _make_memory_item( mem_type: str = memory_type, tags: list[str] | None = None, key: str | None = None, + chunk_content: str | None = None, ) -> TextualMemoryItem: """Construct memory item with common fields.""" + source = self.create_source(message, info, chunk_content) return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( @@ -591,6 +591,7 @@ def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: mem_type=llm_mem_type, tags=tags, key=response_json.get("key"), + chunk_content=chunk_text, ) except Exception as e: logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}") From 59c1515bb3d069768ba9fad3ce8b5784fb587a1c Mon Sep 17 00:00:00 2001 From: fridayL Date: Sun, 7 Dec 2025 19:10:14 +0800 Subject: [PATCH 35/35] feat: add concat doc source --- .../reranker/strategies/concat_docsource.py | 105 ++++++++++++++++++ src/memos/reranker/strategies/factory.py | 2 + 2 files changed, 107 insertions(+) create mode 100644 src/memos/reranker/strategies/concat_docsource.py diff --git a/src/memos/reranker/strategies/concat_docsource.py b/src/memos/reranker/strategies/concat_docsource.py new file mode 100644 index 000000000..0fb471218 --- /dev/null +++ b/src/memos/reranker/strategies/concat_docsource.py @@ -0,0 +1,105 @@ +# memos/reranker/strategies/single_turn.py +from __future__ import annotations + +import re + +from typing import Any + +from .base import BaseRerankerStrategy +from .dialogue_common import DialogueRankingTracker + + +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + + +class ConcatDocSourceStrategy(BaseRerankerStrategy): + """ + Concat background strategy. + + This strategy processes dialogue pairs by concatenating background and + user and assistant messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + """ + + """ + Concat background strategy. + + This strategy processes dialogue pairs by concatenating background and + user and assistant messages into single strings for ranking. Each dialogue pair becomes a + separate document for ranking. + """ + + def prepare_documents( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + """ + Prepare documents based on single turn concatenation strategy. + + Args: + query: The search query + graph_results: List of graph results + top_k: Maximum number of items to return + + Returns: + tuple[DialogueRankingTracker, dict[str, Any], list[str]]: + - Tracker: DialogueRankingTracker instance + - original_items: Dict mapping memory_id to original TextualMemoryItem + - documents: List of text documents ready for ranking + """ + + original_items = {} + tracker = DialogueRankingTracker() + documents = [] + for item in graph_results: + memory = getattr(item, "memory", None) + if isinstance(memory, str): + memory = _TAG1.sub("", memory) + + chunk_text = "" + if hasattr(item, "metadata") and hasattr(item.metadata, "sources"): + sources = getattr(item.metadata, "sources", []) + for source in sources: + if source.type == "file": + chunk_text += source.content + if chunk_text: + documents.append(f"{memory}\n\n[Sources]:\n{chunk_text}") + else: + documents.append(memory) + return tracker, original_items, documents + + def reconstruct_items( + self, + ranked_indices: list[int], + scores: list[float], + tracker: DialogueRankingTracker, + original_items: dict[str, Any], + top_k: int, + **kwargs, + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + tracker: DialogueRankingTracker instance + original_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + graph_results = kwargs.get("graph_results") + documents = kwargs.get("documents") + reconstructed_items = [] + for idx in ranked_indices: + item = graph_results[idx] + item.memory = f"{documents[idx]}" + reconstructed_items.append((item, scores[idx])) + + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] diff --git a/src/memos/reranker/strategies/factory.py b/src/memos/reranker/strategies/factory.py index d93cbd65a..c8a8f2256 100644 --- a/src/memos/reranker/strategies/factory.py +++ b/src/memos/reranker/strategies/factory.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from .concat_background import ConcatBackgroundStrategy +from .concat_docsource import ConcatDocSourceStrategy from .single_turn import SingleTurnStrategy from .singleturn_outmem import SingleTurnOutMemStrategy @@ -19,6 +20,7 @@ class RerankerStrategyFactory: "single_turn": SingleTurnStrategy, "concat_background": ConcatBackgroundStrategy, "singleturn_outmem": SingleTurnOutMemStrategy, + "concat_docsource": ConcatDocSourceStrategy, } @classmethod