Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/memos/configs/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig):
embedding_dims: int | None = Field(
default=None, description="Number of dimensions for the embedding"
)
max_tokens: int | None = Field(
default=8192,
description="Maximum number of tokens per text. Texts exceeding this limit will be automatically truncated. Set to None to disable truncation.",
)
headers_extra: dict[str, Any] | None = Field(
default=None,
description="Extra headers for the embedding model, only for universal_api backend",
Expand Down
3 changes: 3 additions & 0 deletions src/memos/embedders/ark.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
MultimodalEmbeddingContentPartTextParam,
)

# Truncate texts if max_tokens is configured
texts = self._truncate_texts(texts)

if self.config.multi_modal:
texts_input = [
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
Expand Down
91 changes: 91 additions & 0 deletions src/memos/embedders/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,105 @@
import re

from abc import ABC, abstractmethod

from memos.configs.embedder import BaseEmbedderConfig


def _count_tokens_for_embedding(text: str) -> int:
"""
Count tokens in text for embedding truncation.
Uses tiktoken if available, otherwise falls back to heuristic.

Args:
text: Text to count tokens for.

Returns:
Number of tokens.
"""
try:
import tiktoken

try:
enc = tiktoken.encoding_for_model("gpt-4o-mini")
except Exception:
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(text or ""))
except Exception:
# Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars
if not text:
return 0
zh_chars = re.findall(r"[\u4e00-\u9fff]", text)
zh = len(zh_chars)
rest = len(text) - zh
return zh + max(1, rest // 4)


def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
"""
Truncate text to fit within max_tokens limit.
Uses binary search to find the optimal truncation point.

Args:
text: Text to truncate.
max_tokens: Maximum number of tokens allowed.

Returns:
Truncated text.
"""
if not text or max_tokens is None or max_tokens <= 0:
return text

current_tokens = _count_tokens_for_embedding(text)
if current_tokens <= max_tokens:
return text

# Binary search for the right truncation point
low, high = 0, len(text)
best_text = ""

while low < high:
mid = (low + high + 1) // 2 # Use +1 to avoid infinite loop
truncated = text[:mid]
tokens = _count_tokens_for_embedding(truncated)

if tokens <= max_tokens:
best_text = truncated
low = mid
else:
high = mid - 1

return best_text if best_text else text[:1] # Fallback to at least one character


class BaseEmbedder(ABC):
"""Base class for all Embedding models."""

@abstractmethod
def __init__(self, config: BaseEmbedderConfig):
"""Initialize the embedding model with the given configuration."""
self.config = config

def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list)[str]:
"""
Truncate texts to fit within max_tokens limit if configured.

Args:
texts: List of texts to truncate.

Returns:
List of truncated texts.
"""
if not hasattr(self, "config") or self.config.max_tokens is None:
return texts
max_tokens = self.config.max_tokens

truncated = []
for t in texts:
if len(t) < max_tokens * approx_char_per_token:
truncated.append(t)
else:
truncated.append(_truncate_text_to_tokens(t, max_tokens))
return truncated

@abstractmethod
def embed(self, texts: list[str]) -> list[list[float]]:
Expand Down
3 changes: 3 additions & 0 deletions src/memos/embedders/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
Returns:
List of embeddings, each represented as a list of floats.
"""
# Truncate texts if max_tokens is configured
texts = self._truncate_texts(texts)

response = self.client.embed(
model=self.config.model_name_or_path,
input=texts,
Expand Down
3 changes: 3 additions & 0 deletions src/memos/embedders/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ def embed(self, texts: list[str]) -> list[list[float]]:
Returns:
List of embeddings, each represented as a list of floats.
"""
# Truncate texts if max_tokens is configured
texts = self._truncate_texts(texts)

embeddings = self.model.encode(texts, convert_to_numpy=True)
return embeddings.tolist()
3 changes: 3 additions & 0 deletions src/memos/embedders/universal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
)
def embed(self, texts: list[str]) -> list[list[float]]:
# Truncate texts if max_tokens is configured
texts = self._truncate_texts(texts)

if self.provider == "openai" or self.provider == "azure":
try:
response = self.client.embeddings.create(
Expand Down
94 changes: 80 additions & 14 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,61 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
direct_markdown_hostnames=direct_markdown_hostnames,
)

def _split_large_memory_item(
self, item: TextualMemoryItem, max_tokens: int
) -> list[TextualMemoryItem]:
"""
Split a single memory item that exceeds max_tokens into multiple chunks.

Args:
item: TextualMemoryItem to split
max_tokens: Maximum tokens per chunk

Returns:
List of TextualMemoryItem chunks
"""
item_text = item.memory or ""
if not item_text:
return [item]

item_tokens = self._count_tokens(item_text)
if item_tokens <= max_tokens:
return [item]

# Use chunker to split the text
try:
chunks = self.chunker.chunk(item_text)
split_items = []

for chunk in chunks:
# Chunk objects have a 'text' attribute
chunk_text = chunk.text
if not chunk_text or not chunk_text.strip():
continue

# Create a new memory item for each chunk, preserving original metadata
split_item = self._make_memory_item(
value=chunk_text,
info={
"user_id": item.metadata.user_id,
"session_id": item.metadata.session_id,
**(item.metadata.info or {}),
},
memory_type=item.metadata.memory_type,
tags=item.metadata.tags or [],
key=item.metadata.key,
sources=item.metadata.sources or [],
background=item.metadata.background or "",
)
split_items.append(split_item)

return split_items if split_items else [item]
except Exception as e:
logger.warning(
f"[MultiModalStruct] Failed to split large memory item: {e}. Returning original item."
)
return [item]

def _concat_multi_modal_memories(
self, all_memory_items: list[TextualMemoryItem], max_tokens=None, overlap=200
) -> list[TextualMemoryItem]:
Expand All @@ -57,35 +112,49 @@ def _concat_multi_modal_memories(
2. Each window has overlap tokens for context continuity
3. Aggregates items within each window into a single memory item
4. Determines memory_type based on roles in each window
5. Splits single large memory items that exceed max_tokens
"""
if not all_memory_items:
return []

# If only one item, return as-is (no need to aggregate)
if len(all_memory_items) == 1:
return all_memory_items

max_tokens = max_tokens or self.chat_window_max_tokens

# Split large memory items before processing
processed_items = []
for item in all_memory_items:
item_text = item.memory or ""
item_tokens = self._count_tokens(item_text)
if item_tokens > max_tokens:
# Split the large item into multiple chunks
split_items = self._split_large_memory_item(item, max_tokens)
processed_items.extend(split_items)
else:
processed_items.append(item)

# If only one item after processing, return as-is
if len(processed_items) == 1:
return processed_items

windows = []
buf_items = []
cur_text = ""

# Extract info from first item (all items should have same user_id, session_id)
first_item = all_memory_items[0]
first_item = processed_items[0]
info = {
"user_id": first_item.metadata.user_id,
"session_id": first_item.metadata.session_id,
**(first_item.metadata.info or {}),
}

for _idx, item in enumerate(all_memory_items):
for _idx, item in enumerate(processed_items):
item_text = item.memory or ""
# Ensure line ends with newline (same format as simple_struct)
line = item_text if item_text.endswith("\n") else f"{item_text}\n"

# Check if adding this item would exceed max_tokens (same logic as _iter_chat_windows)
# Note: The `and cur_text` condition ensures that single large messages are not truncated.
# If cur_text is empty (new window), even if line exceeds max_tokens, it won't trigger output.
# Note: After splitting large items, each item should be <= max_tokens,
# but we still check to handle edge cases
if self._count_tokens(cur_text + line) > max_tokens and cur_text:
# Yield current window
window = self._build_window_from_items(buf_items, info)
Expand All @@ -102,8 +171,7 @@ def _concat_multi_modal_memories(
# Recalculate cur_text from remaining items
cur_text = "".join([it.memory or "" for it in buf_items])

# Add item to current window (always, even if it exceeds max_tokens)
# This ensures single large messages are not truncated, same as simple_struct
# Add item to current window
buf_items.append(item)
# Recalculate cur_text from all items in buffer (same as _iter_chat_windows)
cur_text = "".join([it.memory or "" for it in buf_items])
Expand Down Expand Up @@ -255,14 +323,12 @@ def _process_multi_modal_data(
for msg in scene_data_info:
items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs)
all_memory_items.extend(items)
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)

else:
# Parse as single message
fast_memory_items = self.multi_modal_parser.parse(
all_memory_items = self.multi_modal_parser.parse(
scene_data_info, info, mode="fast", **kwargs
)

fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
if mode == "fast":
return fast_memory_items
else:
Expand Down
29 changes: 29 additions & 0 deletions src/memos/mem_reader/read_multi_modal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
TreeNodeTextualMemoryMetadata,
)

from .utils import get_text_splitter


logger = log.get_logger(__name__)

Expand Down Expand Up @@ -223,3 +225,30 @@ def parse(
return self.parse_fine(message, info, **kwargs)
else:
raise ValueError(f"Unknown mode: {mode}. Must be 'fast' or 'fine'")

def _split_text(self, text: str) -> list[str]:
"""
Split text into chunks using text splitter from utils.

Args:
text: Text to split

Returns:
List of text chunks
"""
if not text or not text.strip():
return []

splitter = get_text_splitter()
if not splitter:
# If text splitter is not available, return text as single chunk
return [text] if text.strip() else []

try:
chunks = splitter.split_text(text)
logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks")
return chunks
except Exception as e:
logger.error(f"[FileContentParser] Error splitting text: {e}")
# Fallback to single chunk
return [text] if text.strip() else []
29 changes: 1 addition & 28 deletions src/memos/mem_reader/read_multi_modal/file_content_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from memos.types.openai_chat_completion_types import File

from .base import BaseMessageParser, _derive_key
from .utils import get_parser, get_text_splitter
from .utils import get_parser


logger = get_logger(__name__)
Expand Down Expand Up @@ -108,33 +108,6 @@ def __init__(
else:
self.direct_markdown_hostnames = []

def _split_text(self, text: str) -> list[str]:
"""
Split text into chunks using text splitter from utils.

Args:
text: Text to split

Returns:
List of text chunks
"""
if not text or not text.strip():
return []

splitter = get_text_splitter()
if not splitter:
# If text splitter is not available, return text as single chunk
return [text] if text.strip() else []

try:
chunks = splitter.split_text(text)
logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks")
return chunks
except Exception as e:
logger.error(f"[FileContentParser] Error splitting text: {e}")
# Fallback to single chunk
return [text] if text.strip() else []

def create_source(
self,
message: File,
Expand Down
Loading
Loading