diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index f5e1aaba0..a0b72efd1 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -44,7 +44,6 @@ def parse_datetime(cls, value): class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" - # Allow passing additional fields without raising validation errors model_config = ConfigDict(extra="allow", strict=True) @@ -61,6 +60,8 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): class StrategyStructMemReaderConfig(BaseMemReaderConfig): """StrategyStruct MemReader configuration class.""" + model_config = ConfigDict(extra="allow", strict=True) + class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 7da013b48..0cb4e1542 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -8,7 +8,7 @@ from memos.configs.mem_reader import MultiModalStructMemReaderConfig from memos.context.context import ContextThreadPoolExecutor from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang -from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.memories.textual.item import TextualMemoryItem from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -248,6 +248,104 @@ def _build_window_from_items( return aggregated_item + def _get_llm_response( + self, + mem_str: str, + custom_tags: list[str] | None = None, + sources: list | None = None, + prompt_type: str = "chat", + ) -> dict: + """ + Override parent method to improve language detection by using actual text content + from sources instead of JSON-structured memory string. + + Args: + mem_str: Memory string (may contain JSON structures) + custom_tags: Optional custom tags + sources: Optional list of SourceMessage objects to extract text content from + prompt_type: Type of prompt to use ("chat" or "doc") + + Returns: + LLM response dictionary + """ + # Try to extract actual text content from sources for better language detection + text_for_lang_detection = mem_str + if sources: + source_texts = [] + for source in sources: + if hasattr(source, "content") and source.content: + source_texts.append(source.content) + elif isinstance(source, dict) and source.get("content"): + source_texts.append(source.get("content")) + + # If we have text content from sources, use it for language detection + if source_texts: + text_for_lang_detection = " ".join(source_texts) + + # Use the extracted text for language detection + lang = detect_lang(text_for_lang_detection) + + # Select prompt template based on prompt_type + if prompt_type == "doc": + template = PROMPT_DICT["doc"][lang] + examples = "" # doc prompts don't have examples + prompt = template.replace("{chunk_text}", mem_str) + else: + template = PROMPT_DICT["chat"][lang] + examples = PROMPT_DICT["chat"][f"{lang}_example"] + prompt = template.replace("${conversation}", mem_str) + + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + + # Replace custom_tags_prompt placeholder (different for doc vs chat) + if prompt_type == "doc": + prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) + else: + prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt) + + if self.config.remove_prompt_example and examples: + prompt = prompt.replace(examples, "") + messages = [{"role": "user", "content": prompt}] + try: + response_text = self.llm.generate(messages) + response_json = self.parse_json_result(response_text) + except Exception as e: + logger.error(f"[LLM] Exception during chat generation: {e}") + response_json = { + "memory list": [ + { + "key": mem_str[:10], + "memory_type": "UserMemory", + "value": mem_str, + "tags": [], + } + ], + "summary": mem_str, + } + return response_json + + def _determine_prompt_type(self, sources: list) -> str: + """ + Determine prompt type based on sources. + """ + if not sources: + return "chat" + prompt_type = "doc" + for source in sources: + source_role = None + if hasattr(source, "role"): + source_role = source.role + elif isinstance(source, dict): + source_role = source.get("role") + if source_role in {"user", "assistant", "system", "tool"}: + prompt_type = "chat" + + return prompt_type + def _process_string_fine( self, fast_memory_items: list[TextualMemoryItem], @@ -270,8 +368,12 @@ def _process_string_fine( sources = fast_item.metadata.sources or [] if not isinstance(sources, list): sources = [sources] + + # Determine prompt type based on sources + prompt_type = self._determine_prompt_type(sources) + try: - resp = self._get_llm_response(mem_str, custom_tags) + resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) except Exception as e: logger.error(f"[MultiModalFine] Error calling LLM: {e}") continue