diff --git a/examples/mem_reader/multimodal_struct_reader.py b/examples/mem_reader/multimodal_struct_reader.py index 20c141828..790b13f85 100644 --- a/examples/mem_reader/multimodal_struct_reader.py +++ b/examples/mem_reader/multimodal_struct_reader.py @@ -327,6 +327,102 @@ def get_info(self) -> dict[str, Any]: ] ], ), + TestCase( + name="oss_text_file", + description="User message with text and file", + scene_data=[ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请阅读这个PDF,总结里面的要点。"}, + { + "type": "file", + "file": { + "file_id": "file_123", + "filename": "report.pdf", + "file_data": "@http://139.196.232.20:9090/graph-test/algorithm/2025_11_13/1763043889_1763043782_PM1%E8%BD%A6%E9%97%B4PMT%E9%9D%B4%E5%8E%8B%E8%BE%B9%E5%8E%8B%E5%8E%8B%E5%8A%9B%E6%97%A0%E6%B3%95%E5%BB%BA%E7%AB%8B%E6%95%85%E9%9A%9C%E6%8A%A5%E5%91%8A20240720.md", + }, + }, + ], + "chat_time": "2025-11-24T10:21:00Z", + "message_id": "mm-file-1", + } + ] + ], + ), + TestCase( + name="pure_data_file", + description="User message with text and file", + scene_data=[ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请阅读这个PDF,总结里面的要点。"}, + { + "type": "file", + "file": { + "file_id": "file_123", + "filename": "report.pdf", + "file_data": "明文记忆是系统与用户对话、操作等交互中动态习得,以及外部提供的、可显式管理的结构化知识形态,通常以文档、提示模板、图结构或用户规则等形式存在。它具备编辑性、可共享性与治理友好性,适合存储需要频繁修改、可审计或多方协同使用的信息。 在 MemOS 中,明文记忆可用于动态生成推理上下文、个性化偏好注入、多代理协作共享等场景,成为连接人类输入与模型认知的关键桥梁。激活记忆是指模型在推理过程中产生的瞬时性认知状态,包括 KV cache、隐藏层激活、注意力权重等中间张量结构。它通常用于维持上下文连续性、对话一致性与行为风格控制。 MemOS 将激活记忆抽象为可调度资源,支持按需唤醒、延迟卸载与结构变换。例如,某些上下文状态可以被压缩为“半结构化记忆片段”用于未来复用,也可以在任务级别转化为参数化模块,支持短期记忆的长期化演进。这一机制为模型行为一致性、风格保持与状态持续性提供了基础。", + }, + }, + ], + "chat_time": "2025-11-24T10:21:00Z", + "message_id": "mm-file-1", + } + ] + ], + ), + TestCase( + name="local_data_file", + description="User message with text and file", + scene_data=[ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请阅读这个PDF,总结里面的要点。"}, + { + "type": "file", + "file": { + "file_id": "file_123", + "filename": "report.pdf", + "file_data": "./my_local_file/report.pdf", + }, + }, + ], + "chat_time": "2025-11-24T10:21:00Z", + "message_id": "mm-file-1", + } + ] + ], + ), + TestCase( + name="internet_file", + description="User message with text and file", + scene_data=[ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请阅读这个PDF,总结里面的要点。"}, + { + "type": "file", + "file": { + "file_id": "file_123", + "filename": "report.pdf", + "file_data": "https://upload.wikimedia.org/wikipedia/commons/c/cb/NLC416-16jh004830-88775_%E7%B4%85%E6%A8%93%E5%A4%A2.pdf", + }, + }, + ], + "chat_time": "2025-11-24T10:21:00Z", + "message_id": "mm-file-1", + } + ] + ], + ), TestCase( name="multimodal_mixed", description="Mixed multimodal message (text + file + image)", @@ -661,6 +757,12 @@ def get_reader_config() -> dict[str, Any]: }, } + # Get direct markdown hostnames from environment variable + direct_markdown_hostnames = None + env_hostnames = os.getenv("FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES", "139.196.232.20") + if env_hostnames: + direct_markdown_hostnames = [h.strip() for h in env_hostnames.split(",") if h.strip()] + return { "llm": llm_config, "embedder": embedder_config, @@ -673,6 +775,7 @@ def get_reader_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, + "direct_markdown_hostnames": direct_markdown_hostnames, } @@ -863,13 +966,13 @@ def main(): parser.add_argument( "--example", type=str, - default="all", + default="oss_text_file", help="Test case name, category name, or 'all' to run all cases (default: all)", ) parser.add_argument( "--mode", choices=["fast", "fine"], - default="fast", + default="fine", help="Processing mode: fast (quick) or fine (with LLM) (default: fast)", ) parser.add_argument( diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 535811c42..af0f0473d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -707,6 +707,13 @@ def get_product_default_config() -> dict[str, Any]: }, }, "chat_chunker": reader_config, + "direct_markdown_hostnames": [ + h.strip() + for h in os.getenv( + "FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES", "139.196.232.20" + ).split(",") + if h.strip() + ], }, }, "enable_textual_memory": True, diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 34693ea68..9b9bee701 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -48,6 +48,12 @@ class SimpleStructMemReaderConfig(BaseMemReaderConfig): class MultiModalStructMemReaderConfig(BaseMemReaderConfig): """MultiModalStruct MemReader configuration class.""" + direct_markdown_hostnames: list[str] | None = Field( + default=None, + description="List of hostnames that should return markdown directly without parsing. " + "If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES environment variable.", + ) + class StrategyStructMemReaderConfig(BaseMemReaderConfig): """StrategyStruct MemReader configuration class.""" diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 5a78208b9..94ffb5afc 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -29,7 +29,13 @@ def __init__(self, config: MultiModalStructMemReaderConfig): """ from memos.configs.mem_reader import SimpleStructMemReaderConfig + # Extract direct_markdown_hostnames before converting to SimpleStructMemReaderConfig + direct_markdown_hostnames = getattr(config, "direct_markdown_hostnames", None) + + # Create config_dict excluding direct_markdown_hostnames for SimpleStructMemReaderConfig config_dict = config.model_dump(exclude_none=True) + config_dict.pop("direct_markdown_hostnames", None) + simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) @@ -38,6 +44,7 @@ def __init__(self, config: MultiModalStructMemReaderConfig): embedder=self.embedder, llm=self.llm, parser=None, + direct_markdown_hostnames=direct_markdown_hostnames, ) def _concat_multi_modal_memories( @@ -271,7 +278,7 @@ def _process_multi_modal_data( sources = fast_item.metadata.sources for source in sources: items = self.multi_modal_parser.process_transfer( - source, context_items=[fast_item], custom_tags=custom_tags + source, context_items=[fast_item], custom_tags=custom_tags, info=info ) fine_memory_items.extend(items) return fine_memory_items diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 12b44eae8..8a08d6a93 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -1,5 +1,8 @@ """Parser for file content parts (RawMessageList).""" +import os +import tempfile + from typing import Any from memos.embedders.base import BaseEmbedder @@ -10,10 +13,10 @@ TextualMemoryItem, TreeNodeTextualMemoryMetadata, ) -from memos.parsers.factory import ParserFactory from memos.types.openai_chat_completion_types import File from .base import BaseMessageParser, _derive_key +from .utils import get_parser, get_text_splitter logger = get_logger(__name__) @@ -22,11 +25,61 @@ class FileContentParser(BaseMessageParser): """Parser for file content parts.""" + def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None]: + """Download and parse file from URL.""" + try: + from urllib.parse import urlparse + + import requests + + parsed_url = urlparse(url_str) + hostname = parsed_url.hostname or "" + + response = requests.get(url_str, timeout=30) + response.raise_for_status() + + if not filename: + filename = os.path.basename(parsed_url.path) or "downloaded_file" + + if hostname in self.direct_markdown_hostnames: + return response.text, None + + file_ext = os.path.splitext(filename)[1].lower() + if file_ext in [".md", ".markdown", ".txt"]: + return response.text, None + with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_ext) as temp_file: + temp_file.write(response.content) + return "", temp_file.name + except Exception as e: + logger.error(f"[FileContentParser] URL processing error: {e}") + return f"[File URL download failed: {url_str}]", None + + def _is_base64(self, data: str) -> bool: + """Quick heuristic to check base64-like string.""" + return data.startswith("data:") or ( + len(data) > 100 + and all( + c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" + for c in data[:100] + ) + ) + + def _handle_base64(self, data: str) -> str: + """Base64 not implemented placeholder.""" + logger.info("[FileContentParser] Base64 content detected but decoding is not implemented.") + return "" + + def _handle_local(self, data: str) -> str: + """Base64 not implemented placeholder.""" + logger.info("[FileContentParser] Local file paths are not supported in fine mode.") + return "" + def __init__( self, embedder: BaseEmbedder, llm: BaseLLM | None = None, parser: Any | None = None, + direct_markdown_hostnames: list[str] | None = None, ): """ Initialize FileContentParser. @@ -35,10 +88,53 @@ def __init__( embedder: Embedder for generating embeddings llm: Optional LLM for fine mode processing parser: Optional parser for parsing file contents + direct_markdown_hostnames: List of hostnames that should return markdown directly + without parsing. If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES + environment variable (comma-separated). """ super().__init__(embedder, llm) self.parser = parser + # Get inner markdown hostnames from config or environment + if direct_markdown_hostnames is not None: + self.direct_markdown_hostnames = direct_markdown_hostnames + else: + env_hostnames = os.getenv("FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES", "") + if env_hostnames: + # Support comma-separated list + self.direct_markdown_hostnames = [ + h.strip() for h in env_hostnames.split(",") if h.strip() + ] + else: + self.direct_markdown_hostnames = [] + + def _split_text(self, text: str) -> list[str]: + """ + Split text into chunks using text splitter from utils. + + Args: + text: Text to split + + Returns: + List of text chunks + """ + if not text or not text.strip(): + return [] + + splitter = get_text_splitter() + if not splitter: + # If text splitter is not available, return text as single chunk + return [text] if text.strip() else [] + + try: + chunks = splitter.split_text(text) + logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks") + return chunks + except Exception as e: + logger.error(f"[FileContentParser] Error splitting text: {e}") + # Fallback to single chunk + return [text] if text.strip() else [] + def create_source( self, message: File, @@ -83,21 +179,10 @@ def _parse_file(self, file_info: dict[str, Any]) -> str: Returns: Parsed text content """ - if not self.parser: - # Try to create a default parser - try: - from memos.configs.parser import ParserConfigFactory - - parser_config = ParserConfigFactory.model_validate( - { - "backend": "markitdown", - "config": {}, - } - ) - self.parser = ParserFactory.from_config(parser_config) - except Exception as e: - logger.warning(f"[FileContentParser] Failed to create parser: {e}") - return "" + parser = self.parser or get_parser() + if not parser: + logger.warning("[FileContentParser] Parser not available") + return "" file_path = file_info.get("path") or file_info.get("file_id", "") filename = file_info.get("filename", "unknown") @@ -107,10 +192,8 @@ def _parse_file(self, file_info: dict[str, Any]) -> str: return f"[File: {filename}]" try: - import os - if os.path.exists(file_path): - parsed_text = self.parser.parse(file_path) + parsed_text = parser.parse(file_path) return parsed_text else: logger.warning(f"[FileContentParser] File not found: {file_path}") @@ -197,6 +280,9 @@ def parse_fast( # Combine content parts content = " ".join(content_parts) + # Split content into chunks + content_chunks = self._split_text(content) + # Create source source = self.create_source(message, info) @@ -209,27 +295,59 @@ def parse_fast( # (since we don't have role information at this level) memory_type = "LongTermMemory" - # Create memory item - memory_item = TextualMemoryItem( - memory=content, - metadata=TreeNodeTextualMemoryMetadata( - user_id=user_id, - session_id=session_id, - memory_type=memory_type, - status="activated", - tags=["mode:fast", "multimodal:file"], - key=_derive_key(content), - embedding=self.embedder.embed([content])[0], - usage=[], - sources=[source], - background="", - confidence=0.99, - type="fact", - info=info_, - ), - ) + # Create memory items for each chunk + memory_items = [] + for chunk_idx, chunk_text in enumerate(content_chunks): + if not chunk_text.strip(): + continue + + memory_item = TextualMemoryItem( + memory=chunk_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=[ + "mode:fast", + "multimodal:file", + f"chunk:{chunk_idx + 1}/{len(content_chunks)}", + ], + key=_derive_key(chunk_text), + embedding=self.embedder.embed([chunk_text])[0], + usage=[], + sources=[source], + background="", + confidence=0.99, + type="fact", + info=info_, + ), + ) + memory_items.append(memory_item) + + # If no chunks were created, create a placeholder + if not memory_items: + memory_item = TextualMemoryItem( + memory=content, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=["mode:fast", "multimodal:file"], + key=_derive_key(content), + embedding=self.embedder.embed([content])[0], + usage=[], + sources=[source], + background="", + confidence=0.99, + type="fact", + info=info_, + ), + ) + memory_items.append(memory_item) - return [memory_item] + return memory_items def parse_fine( self, @@ -237,4 +355,160 @@ def parse_fine( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: - return [] + """ + Parse file content part in fine mode. + Fine mode downloads and parses file content, especially for URLs. + Handles various file parameter scenarios: + - file_data: URL (http://, https://, or @http://), base64 encoded data, or plain text content + - file_id: ID of an uploaded file + - filename: name of the file + """ + if not isinstance(message, dict): + logger.warning(f"[FileContentParser] Expected dict, got {type(message)}") + return [] + + # Extract file information + file_info = message.get("file", {}) + if not isinstance(file_info, dict): + logger.warning(f"[FileContentParser] Expected file dict, got {type(file_info)}") + return [] + + # Extract file parameters (all are optional) + file_data = file_info.get("file_data", "") + file_id = file_info.get("file_id", "") + filename = file_info.get("filename", "") + + # Use parser from utils + parser = self.parser or get_parser() + if not parser: + logger.warning("[FileContentParser] Parser not available") + return [] + + parsed_text = "" + temp_file_path = None + + try: + # Priority 1: If file_data is provided, process it + if file_data: + if isinstance(file_data, str): + url_str = file_data[1:] if file_data.startswith("@") else file_data + + if url_str.startswith(("http://", "https://")): + parsed_text, temp_file_path = self._handle_url(url_str, filename) + if temp_file_path: + try: + # Use parser from utils + if parser: + parsed_text = parser.parse(temp_file_path) + else: + parsed_text = "[File parsing error: Parser not available]" + except Exception as e: + logger.error( + f"[FileContentParser] Error parsing downloaded file: {e}" + ) + parsed_text = f"[File parsing error: {e!s}]" + + elif os.path.exists(file_data): + parsed_text = self._handle_local(file_data) + + elif self._is_base64(file_data): + parsed_text = self._handle_base64(file_data) + + else: + parsed_text = file_data + # Priority 2: If file_id is provided but no file_data, try to use file_id as path + elif file_id: + logger.warning(f"[FileContentParser] File data not provided for file_id: {file_id}") + parsed_text = f"[File ID: {file_id}]: File data not provided" + + # If no content could be parsed, create a placeholder + if not parsed_text: + if filename: + parsed_text = f"[File: {filename}] File data not provided" + else: + parsed_text = "[File: unknown] File data not provided" + + except Exception as e: + logger.error(f"[FileContentParser] Error in parse_fine: {e}") + parsed_text = f"[File parsing error: {e!s}]" + + finally: + # Clean up temporary file + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + logger.debug(f"[FileContentParser] Cleaned up temporary file: {temp_file_path}") + except Exception as e: + logger.warning( + f"[FileContentParser] Failed to delete temp file {temp_file_path}: {e}" + ) + + # Create source + source = self.create_source(message, info) + + # Extract info fields + if not info: + info = {} + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + + # For file content parts, default to LongTermMemory + memory_type = "LongTermMemory" + + # Split parsed text into chunks + content_chunks = self._split_text(parsed_text) + + # Create memory items for each chunk + memory_items = [] + for chunk_idx, chunk_text in enumerate(content_chunks): + if not chunk_text.strip(): + continue + + memory_item = TextualMemoryItem( + memory=chunk_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=[ + "mode:fine", + "multimodal:file", + f"chunk:{chunk_idx + 1}/{len(content_chunks)}", + ], + key=_derive_key(chunk_text), + embedding=self.embedder.embed([chunk_text])[0], + usage=[], + sources=[source], + background="", + confidence=0.99, + type="fact", + info=info_, + ), + ) + memory_items.append(memory_item) + + # If no chunks were created, create a placeholder + if not memory_items: + memory_item = TextualMemoryItem( + memory=parsed_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=["mode:fine", "multimodal:file"], + key=_derive_key(parsed_text), + embedding=self.embedder.embed([parsed_text])[0], + usage=[], + sources=[source], + background="", + confidence=0.99, + type="fact", + info=info_, + ), + ) + memory_items.append(memory_item) + + return memory_items diff --git a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py index 3c60c3143..d00639005 100644 --- a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py +++ b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py @@ -35,6 +35,7 @@ def __init__( embedder: BaseEmbedder, llm: BaseLLM | None = None, parser: Any | None = None, + direct_markdown_hostnames: list[str] | None = None, ): """ Initialize MultiModalParser. @@ -43,6 +44,9 @@ def __init__( embedder: Embedder for generating embeddings llm: Optional LLM for fine mode processing parser: Optional parser for parsing file contents + direct_markdown_hostnames: List of hostnames that should return markdown directly + without parsing. If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES + environment variable (comma-separated). Default: ["139.196.232.20"] """ self.embedder = embedder self.llm = llm @@ -55,7 +59,9 @@ def __init__( self.assistant_parser = AssistantParser(embedder, llm) self.tool_parser = ToolParser(embedder, llm) self.text_content_parser = TextContentParser(embedder, llm) - self.file_content_parser = FileContentParser(embedder, llm, parser) + self.file_content_parser = FileContentParser( + embedder, llm, parser, direct_markdown_hostnames=direct_markdown_hostnames + ) self.image_parser = ImageParser(embedder, llm) self.audio_parser = None # future diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index bb2e77e38..992011765 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -43,6 +43,151 @@ re.I, ) +# Default configuration for parser and text splitter +DEFAULT_PARSER_CONFIG = { + "backend": "markitdown", + "config": {}, +} + +DEFAULT_CHUNK_SIZE = int(os.getenv("FILE_PARSER_CHUNK_SIZE", "1000")) +DEFAULT_CHUNK_OVERLAP = int(os.getenv("FILE_PARSER_CHUNK_OVERLAP", "200")) + + +def _simple_split_text(text: str, chunk_size: int, chunk_overlap: int) -> list[str]: + """ + Simple text splitter as fallback when langchain is not available. + + Args: + text: Text to split + chunk_size: Maximum size of chunks + chunk_overlap: Overlap between chunks + + Returns: + List of text chunks + """ + if not text or len(text) <= chunk_size: + return [text] if text.strip() else [] + + chunks = [] + start = 0 + text_len = len(text) + + while start < text_len: + # Calculate end position + end = min(start + chunk_size, text_len) + + # If not the last chunk, try to break at a good position + if end < text_len: + # Try to break at newline, sentence end, or space + for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: + last_sep = text.rfind(separator, start, end) + if last_sep != -1: + end = last_sep + len(separator) + break + + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + + # Move start position with overlap + start = max(start + 1, end - chunk_overlap) + + return chunks + + +# Initialize parser instance +file_parser = None +try: + parser_config = ParserConfigFactory.model_validate(DEFAULT_PARSER_CONFIG) + file_parser = ParserFactory.from_config(parser_config) + logger.debug("[FileContentParser] Initialized parser instance") +except Exception as e: + logger.error(f"[FileContentParser] Failed to create parser: {e}") + file_parser = None + +# Initialize text splitter instance +text_splitter = None +_use_simple_splitter = False + +try: + try: + from langchain.text_splitter import RecursiveCharacterTextSplitter + except ImportError: + try: + from langchain_text_splitters import RecursiveCharacterTextSplitter + except ImportError: + logger.error( + "langchain not available. Install with: pip install langchain or pip install langchain-text-splitters" + ) + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=DEFAULT_CHUNK_SIZE, + chunk_overlap=DEFAULT_CHUNK_OVERLAP, + length_function=len, + separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], + ) + logger.debug( + f"[FileContentParser] Initialized langchain text splitter with chunk_size={DEFAULT_CHUNK_SIZE}, " + f"chunk_overlap={DEFAULT_CHUNK_OVERLAP}" + ) +except ImportError as e: + logger.warning( + f"[FileContentParser] langchain not available, using simple text splitter as fallback: {e}. " + "Install with: pip install langchain or pip install langchain-text-splitters" + ) + text_splitter = None + _use_simple_splitter = True +except Exception as e: + logger.error( + f"[FileContentParser] Failed to initialize text splitter: {e}, using simple splitter as fallback" + ) + text_splitter = None + _use_simple_splitter = True + + +def get_parser() -> Any: + """ + Get parser instance. + + Returns: + Parser instance (from ParserFactory) or None if not available + """ + return file_parser + + +def get_text_splitter(chunk_size: int | None = None, chunk_overlap: int | None = None) -> Any: + """ + Get text splitter instance or a callable that uses simple splitter. + + Args: + chunk_size: Maximum size of chunks when splitting text (used for simple splitter fallback) + chunk_overlap: Overlap between chunks when splitting text (used for simple splitter fallback) + + Returns: + Text splitter instance (RecursiveCharacterTextSplitter) or a callable wrapper for simple splitter + """ + if text_splitter is not None: + return text_splitter + + # Return a callable wrapper that uses simple splitter + if _use_simple_splitter: + actual_chunk_size = chunk_size or DEFAULT_CHUNK_SIZE + actual_chunk_overlap = chunk_overlap or DEFAULT_CHUNK_OVERLAP + + class SimpleTextSplitter: + """Simple text splitter wrapper.""" + + def __init__(self, chunk_size: int, chunk_overlap: int): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def split_text(self, text: str) -> list[str]: + return _simple_split_text(text, self.chunk_size, self.chunk_overlap) + + return SimpleTextSplitter(actual_chunk_size, actual_chunk_overlap) + + return None + def extract_role(message: dict[str, Any]) -> str: """Extract role from message."""