From 8e27d612f9d0c4dbc20ac732ff2ad7e433f9996c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 3 Dec 2025 15:50:17 +0800 Subject: [PATCH 1/4] feat: split chunk for pure string --- src/memos/mem_reader/multi_modal_struct.py | 6 +- src/memos/mem_reader/read_multi_modal/base.py | 29 ++++++++++ .../read_multi_modal/file_content_parser.py | 29 +--------- .../read_multi_modal/string_parser.py | 55 +++++++++++-------- 4 files changed, 63 insertions(+), 56 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 94ffb5afc..64a2d685f 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -255,14 +255,12 @@ def _process_multi_modal_data( for msg in scene_data_info: items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs) all_memory_items.extend(items) - fast_memory_items = self._concat_multi_modal_memories(all_memory_items) - else: # Parse as single message - fast_memory_items = self.multi_modal_parser.parse( + all_memory_items = self.multi_modal_parser.parse( scene_data_info, info, mode="fast", **kwargs ) - + fast_memory_items = self._concat_multi_modal_memories(all_memory_items) if mode == "fast": return fast_memory_items else: diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py index e59b6a6bc..123eb22bc 100644 --- a/src/memos/mem_reader/read_multi_modal/base.py +++ b/src/memos/mem_reader/read_multi_modal/base.py @@ -16,6 +16,8 @@ TreeNodeTextualMemoryMetadata, ) +from .utils import get_text_splitter + logger = log.get_logger(__name__) @@ -223,3 +225,30 @@ def parse( return self.parse_fine(message, info, **kwargs) else: raise ValueError(f"Unknown mode: {mode}. Must be 'fast' or 'fine'") + + def _split_text(self, text: str) -> list[str]: + """ + Split text into chunks using text splitter from utils. + + Args: + text: Text to split + + Returns: + List of text chunks + """ + if not text or not text.strip(): + return [] + + splitter = get_text_splitter() + if not splitter: + # If text splitter is not available, return text as single chunk + return [text] if text.strip() else [] + + try: + chunks = splitter.split_text(text) + logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks") + return chunks + except Exception as e: + logger.error(f"[FileContentParser] Error splitting text: {e}") + # Fallback to single chunk + return [text] if text.strip() else [] diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 8a08d6a93..c8ca9a400 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -16,7 +16,7 @@ from memos.types.openai_chat_completion_types import File from .base import BaseMessageParser, _derive_key -from .utils import get_parser, get_text_splitter +from .utils import get_parser logger = get_logger(__name__) @@ -108,33 +108,6 @@ def __init__( else: self.direct_markdown_hostnames = [] - def _split_text(self, text: str) -> list[str]: - """ - Split text into chunks using text splitter from utils. - - Args: - text: Text to split - - Returns: - List of text chunks - """ - if not text or not text.strip(): - return [] - - splitter = get_text_splitter() - if not splitter: - # If text splitter is not available, return text as single chunk - return [text] if text.strip() else [] - - try: - chunks = splitter.split_text(text) - logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks") - return chunks - except Exception as e: - logger.error(f"[FileContentParser] Error splitting text: {e}") - # Fallback to single chunk - return [text] if text.strip() else [] - def create_source( self, message: File, diff --git a/src/memos/mem_reader/read_multi_modal/string_parser.py b/src/memos/mem_reader/read_multi_modal/string_parser.py index 3d0837425..b5a58d68c 100644 --- a/src/memos/mem_reader/read_multi_modal/string_parser.py +++ b/src/memos/mem_reader/read_multi_modal/string_parser.py @@ -83,8 +83,8 @@ def parse_fast( if not content: return [] - # Create source - source = self.create_source(message, info) + # Split parsed text into chunks + content_chunks = self._split_text(content) # Extract info fields info_ = info.copy() @@ -92,30 +92,37 @@ def parse_fast( session_id = info_.pop("session_id", "") # For string messages, default to LongTermMemory - # (since we don't have role information) memory_type = "LongTermMemory" - # Create memory item - memory_item = TextualMemoryItem( - memory=content, - metadata=TreeNodeTextualMemoryMetadata( - user_id=user_id, - session_id=session_id, - memory_type=memory_type, - status="activated", - tags=["mode:fast"], - key=_derive_key(content), - embedding=self.embedder.embed([content])[0], - usage=[], - sources=[source], - background="", - confidence=0.99, - type="fact", - info=info_, - ), - ) - - return [memory_item] + # Create memory items for each chunk + memory_items = [] + for _chunk_idx, chunk_text in enumerate(content_chunks): + if not chunk_text.strip(): + continue + + # Create source + source = self.create_source(chunk_text, info) + + memory_item = TextualMemoryItem( + memory=chunk_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=["mode:fast"], + key=_derive_key(chunk_text), + embedding=self.embedder.embed([chunk_text])[0], + usage=[], + sources=[source], + background="", + confidence=0.99, + type="fact", + info=info_, + ), + ) + memory_items.append(memory_item) + return memory_items def parse_fine( self, From 785b5a4a65d2773392a3304bc1efbd41997843ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 3 Dec 2025 16:20:52 +0800 Subject: [PATCH 2/4] feat: add default trucation in embedder --- src/memos/configs/embedder.py | 4 + src/memos/embedders/ark.py | 3 + src/memos/embedders/base.py | 91 +++++++++++++++++++++ src/memos/embedders/ollama.py | 3 + src/memos/embedders/sentence_transformer.py | 3 + src/memos/embedders/universal_api.py | 3 + 6 files changed, 107 insertions(+) diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index d88b6005e..c2e648247 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig): embedding_dims: int | None = Field( default=None, description="Number of dimensions for the embedding" ) + max_tokens: int | None = Field( + default=8192, + description="Maximum number of tokens per text. Texts exceeding this limit will be automatically truncated. Set to None to disable truncation.", + ) headers_extra: dict[str, Any] | None = Field( default=None, description="Extra headers for the embedding model, only for universal_api backend", diff --git a/src/memos/embedders/ark.py b/src/memos/embedders/ark.py index db6b42bd4..a8b47e200 100644 --- a/src/memos/embedders/ark.py +++ b/src/memos/embedders/ark.py @@ -49,6 +49,9 @@ def embed(self, texts: list[str]) -> list[list[float]]: MultimodalEmbeddingContentPartTextParam, ) + # Truncate texts if max_tokens is configured + texts = self._truncate_texts(texts) + if self.config.multi_modal: texts_input = [ MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts diff --git a/src/memos/embedders/base.py b/src/memos/embedders/base.py index 05c0fd1f3..d573521f6 100644 --- a/src/memos/embedders/base.py +++ b/src/memos/embedders/base.py @@ -1,14 +1,105 @@ +import re + from abc import ABC, abstractmethod from memos.configs.embedder import BaseEmbedderConfig +def _count_tokens_for_embedding(text: str) -> int: + """ + Count tokens in text for embedding truncation. + Uses tiktoken if available, otherwise falls back to heuristic. + + Args: + text: Text to count tokens for. + + Returns: + Number of tokens. + """ + try: + import tiktoken + + try: + enc = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(text or "")) + except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + if not text: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", text) + zh = len(zh_chars) + rest = len(text) - zh + return zh + max(1, rest // 4) + + +def _truncate_text_to_tokens(text: str, max_tokens: int) -> str: + """ + Truncate text to fit within max_tokens limit. + Uses binary search to find the optimal truncation point. + + Args: + text: Text to truncate. + max_tokens: Maximum number of tokens allowed. + + Returns: + Truncated text. + """ + if not text or max_tokens is None or max_tokens <= 0: + return text + + current_tokens = _count_tokens_for_embedding(text) + if current_tokens <= max_tokens: + return text + + # Binary search for the right truncation point + low, high = 0, len(text) + best_text = "" + + while low < high: + mid = (low + high + 1) // 2 # Use +1 to avoid infinite loop + truncated = text[:mid] + tokens = _count_tokens_for_embedding(truncated) + + if tokens <= max_tokens: + best_text = truncated + low = mid + else: + high = mid - 1 + + return best_text if best_text else text[:1] # Fallback to at least one character + + class BaseEmbedder(ABC): """Base class for all Embedding models.""" @abstractmethod def __init__(self, config: BaseEmbedderConfig): """Initialize the embedding model with the given configuration.""" + self.config = config + + def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list)[str]: + """ + Truncate texts to fit within max_tokens limit if configured. + + Args: + texts: List of texts to truncate. + + Returns: + List of truncated texts. + """ + if not hasattr(self, "config") or self.config.max_tokens is None: + return texts + max_tokens = self.config.max_tokens + + truncated = [] + for t in texts: + if len(t) < max_tokens * approx_char_per_token: + truncated.append(t) + else: + truncated.append(_truncate_text_to_tokens(t, max_tokens)) + return truncated @abstractmethod def embed(self, texts: list[str]) -> list[list[float]]: diff --git a/src/memos/embedders/ollama.py b/src/memos/embedders/ollama.py index 2461d629a..dfd8e230d 100644 --- a/src/memos/embedders/ollama.py +++ b/src/memos/embedders/ollama.py @@ -67,6 +67,9 @@ def embed(self, texts: list[str]) -> list[list[float]]: Returns: List of embeddings, each represented as a list of floats. """ + # Truncate texts if max_tokens is configured + texts = self._truncate_texts(texts) + response = self.client.embed( model=self.config.model_name_or_path, input=texts, diff --git a/src/memos/embedders/sentence_transformer.py b/src/memos/embedders/sentence_transformer.py index 1ae818ad6..de086cb49 100644 --- a/src/memos/embedders/sentence_transformer.py +++ b/src/memos/embedders/sentence_transformer.py @@ -42,5 +42,8 @@ def embed(self, texts: list[str]) -> list[list[float]]: Returns: List of embeddings, each represented as a list of floats. """ + # Truncate texts if max_tokens is configured + texts = self._truncate_texts(texts) + embeddings = self.model.encode(texts, convert_to_numpy=True) return embeddings.tolist() diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 79a5d9ea6..e74e50614 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -36,6 +36,9 @@ def __init__(self, config: UniversalAPIEmbedderConfig): log_extra_args={"model_name_or_path": "text-embedding-3-large"}, ) def embed(self, texts: list[str]) -> list[list[float]]: + # Truncate texts if max_tokens is configured + texts = self._truncate_texts(texts) + if self.provider == "openai" or self.provider == "azure": try: response = self.client.embeddings.create( From ee6a0ba2baeb1e2bcffef19d006946281aaf02ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 3 Dec 2025 16:48:42 +0800 Subject: [PATCH 3/4] feat: chunking each item after fast mode --- src/memos/mem_reader/multi_modal_struct.py | 88 +++++++++++++++++++--- 1 file changed, 78 insertions(+), 10 deletions(-) diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 64a2d685f..57774cf3a 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -47,6 +47,61 @@ def __init__(self, config: MultiModalStructMemReaderConfig): direct_markdown_hostnames=direct_markdown_hostnames, ) + def _split_large_memory_item( + self, item: TextualMemoryItem, max_tokens: int + ) -> list[TextualMemoryItem]: + """ + Split a single memory item that exceeds max_tokens into multiple chunks. + + Args: + item: TextualMemoryItem to split + max_tokens: Maximum tokens per chunk + + Returns: + List of TextualMemoryItem chunks + """ + item_text = item.memory or "" + if not item_text: + return [item] + + item_tokens = self._count_tokens(item_text) + if item_tokens <= max_tokens: + return [item] + + # Use chunker to split the text + try: + chunks = self.chunker.chunk(item_text) + split_items = [] + + for chunk in chunks: + # Chunk objects have a 'text' attribute + chunk_text = chunk.text + if not chunk_text or not chunk_text.strip(): + continue + + # Create a new memory item for each chunk, preserving original metadata + split_item = self._make_memory_item( + value=chunk_text, + info={ + "user_id": item.metadata.user_id, + "session_id": item.metadata.session_id, + **(item.metadata.info or {}), + }, + memory_type=item.metadata.memory_type, + tags=item.metadata.tags or [], + key=item.metadata.key, + sources=item.metadata.sources or [], + background=item.metadata.background or "", + ) + split_items.append(split_item) + + return split_items if split_items else [item] + except Exception as e: + logger.warning( + f"[MultiModalStruct] Failed to split large memory item: {e}. Returning original item." + ) + return [item] + def _concat_multi_modal_memories( self, all_memory_items: list[TextualMemoryItem], max_tokens=None, overlap=200 ) -> list[TextualMemoryItem]: @@ -57,35 +112,49 @@ def _concat_multi_modal_memories( 2. Each window has overlap tokens for context continuity 3. Aggregates items within each window into a single memory item 4. Determines memory_type based on roles in each window + 5. Splits single large memory items that exceed max_tokens """ if not all_memory_items: return [] - # If only one item, return as-is (no need to aggregate) - if len(all_memory_items) == 1: - return all_memory_items - max_tokens = max_tokens or self.chat_window_max_tokens + + # Split large memory items before processing + processed_items = [] + for item in all_memory_items: + item_text = item.memory or "" + item_tokens = self._count_tokens(item_text) + if item_tokens > max_tokens: + # Split the large item into multiple chunks + split_items = self._split_large_memory_item(item, max_tokens) + processed_items.extend(split_items) + else: + processed_items.append(item) + + # If only one item after processing, return as-is + if len(processed_items) == 1: + return processed_items + windows = [] buf_items = [] cur_text = "" # Extract info from first item (all items should have same user_id, session_id) - first_item = all_memory_items[0] + first_item = processed_items[0] info = { "user_id": first_item.metadata.user_id, "session_id": first_item.metadata.session_id, **(first_item.metadata.info or {}), } - for _idx, item in enumerate(all_memory_items): + for _idx, item in enumerate(processed_items): item_text = item.memory or "" # Ensure line ends with newline (same format as simple_struct) line = item_text if item_text.endswith("\n") else f"{item_text}\n" # Check if adding this item would exceed max_tokens (same logic as _iter_chat_windows) - # Note: The `and cur_text` condition ensures that single large messages are not truncated. - # If cur_text is empty (new window), even if line exceeds max_tokens, it won't trigger output. + # Note: After splitting large items, each item should be <= max_tokens, + # but we still check to handle edge cases if self._count_tokens(cur_text + line) > max_tokens and cur_text: # Yield current window window = self._build_window_from_items(buf_items, info) @@ -102,8 +171,7 @@ def _concat_multi_modal_memories( # Recalculate cur_text from remaining items cur_text = "".join([it.memory or "" for it in buf_items]) - # Add item to current window (always, even if it exceeds max_tokens) - # This ensures single large messages are not truncated, same as simple_struct + # Add item to current window buf_items.append(item) # Recalculate cur_text from all items in buffer (same as _iter_chat_windows) cur_text = "".join([it.memory or "" for it in buf_items]) From 39bc266766ca91fd4e9bb85a00bd28e91b28a56f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 3 Dec 2025 17:32:15 +0800 Subject: [PATCH 4/4] fix: test --- tests/configs/test_embedder.py | 4 ++-- tests/utils.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/configs/test_embedder.py b/tests/configs/test_embedder.py index 10572f33e..002de2259 100644 --- a/tests/configs/test_embedder.py +++ b/tests/configs/test_embedder.py @@ -17,7 +17,7 @@ def test_base_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims", "headers_extra"], + optional_fields=["embedding_dims", "max_tokens", "headers_extra"], ) check_config_instantiation_valid( @@ -36,7 +36,7 @@ def test_ollama_embedder_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["embedding_dims", "headers_extra", "api_base"], + optional_fields=["embedding_dims", "max_tokens", "headers_extra", "api_base"], ) check_config_instantiation_valid( diff --git a/tests/utils.py b/tests/utils.py index e88d4fbcd..132cd7138 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,6 +33,9 @@ def check_module_base_class(cls: Any) -> None: # Check 3: Verify abstract methods for method_name in all_class_methods: method = getattr(cls, method_name) + # Skip private methods (starting with _) as they are typically helper methods + if method_name.startswith("_") and method_name != "__init__": + continue assert getattr(method, "__isabstractmethod__", False), ( f"The method '{method_name}' in {cls.__name__} should be marked as @abstractmethod" )