diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py
index a28f3bdce..0e7dcdb44 100644
--- a/src/memos/configs/mem_scheduler.py
+++ b/src/memos/configs/mem_scheduler.py
@@ -16,6 +16,7 @@
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
+ DEFAULT_MONITOR_INTERVAL_SECONDS,
DEFAULT_MULTI_TASK_RUNNING_TIMEOUT,
DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
@@ -68,6 +69,10 @@ class BaseSchedulerConfig(BaseConfig):
default=DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
description="Maximum size of internal message queue when not using Redis",
)
+ monitor_interval_seconds: int = Field(
+ default=DEFAULT_MONITOR_INTERVAL_SECONDS,
+ description="Interval in seconds for monitoring memory usage",
+ )
multi_task_running_timeout: int = Field(
default=DEFAULT_MULTI_TASK_RUNNING_TIMEOUT,
description="Default timeout for multi-task running operations in seconds",
diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
index ae5ae5d47..2a8d641ad 100644
--- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
+++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
@@ -164,11 +164,33 @@ def update_working_memory_for_eval(
new_candidates = []
for item in missing_evidences:
logger.info(f"missing_evidences: {item}")
- results: list[TextualMemoryItem] = self.retriever.search(
+
+ # Determine search mode from method
+ from memos.mem_scheduler.schemas.general_schemas import (
+ TreeTextMemory_FINE_SEARCH_METHOD,
+ TreeTextMemory_SEARCH_METHOD,
+ )
+ from memos.types.general_types import SearchMode
+
+ # Convert search_method to SearchMode
+ if self.search_method == TreeTextMemory_FINE_SEARCH_METHOD:
+ mode = SearchMode.FINE
+ elif self.search_method == TreeTextMemory_SEARCH_METHOD:
+ mode = SearchMode.FAST
+ else:
+ # Fallback to FAST mode for unknown methods
+ logger.warning(
+ f"Unknown search_method '{self.search_method}', falling back to SearchMode.FAST"
+ )
+ mode = SearchMode.FAST
+
+ # Use unified search service
+ results: list[TextualMemoryItem] = self.search_service.search(
query=item,
+ user_id=user_id,
mem_cube=mem_cube,
top_k=k_per_evidence,
- method=self.search_method,
+ mode=mode,
)
logger.info(
f"search results for {missing_evidences}: {[one.memory for one in results]}"
@@ -226,7 +248,7 @@ def evaluate_memory_answer_ability(
try:
# Extract JSON response
- from memos.mem_scheduler.utils.misc_utils import extract_json_obj
+ from memos.utils import extract_json_obj
result = extract_json_obj(response)
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 5ab524128..5dd1e0eb3 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -25,9 +25,12 @@
from memos.mem_cube.general import GeneralMemCube
from memos.mem_feedback.simple_feedback import SimpleMemFeedback
from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components
-from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
-from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
+from memos.mem_scheduler.memory_manage_modules.activation_memory_manager import (
+ ActivationMemoryManager,
+)
+from memos.mem_scheduler.memory_manage_modules.post_processor import MemoryPostProcessor
+from memos.mem_scheduler.memory_manage_modules.search_service import SchedulerSearchService
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.monitors.task_schedule_monitor import TaskScheduleMonitor
@@ -37,7 +40,6 @@
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
- DEFAULT_MAX_WEB_LOG_QUEUE_SIZE,
DEFAULT_STARTUP_MODE,
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
@@ -46,17 +48,10 @@
TreeTextMemory_SEARCH_METHOD,
)
from memos.mem_scheduler.schemas.message_schemas import (
- ScheduleLogForWebItem,
ScheduleMessageItem,
)
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
from memos.mem_scheduler.schemas.task_schemas import (
- ADD_TASK_LABEL,
- ANSWER_TASK_LABEL,
- MEM_ARCHIVE_TASK_LABEL,
- MEM_ORGANIZE_TASK_LABEL,
- MEM_UPDATE_TASK_LABEL,
- QUERY_TASK_LABEL,
TaskPriorityLevel,
)
from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
@@ -72,12 +67,10 @@
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
-from memos.memories.activation.kv import KVCacheMemory
-from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
+from memos.mem_scheduler.webservice_modules.web_log_service import WebLogSchedulerModule
from memos.memories.textual.naive import NaiveTextMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
-from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
from memos.types.general_types import (
MemCubeID,
UserID,
@@ -93,7 +86,9 @@
logger = get_logger(__name__)
-class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule):
+class BaseScheduler(
+ RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule, WebLogSchedulerModule
+):
"""Base class for all mem_scheduler."""
def __init__(self, config: BaseSchedulerConfig):
@@ -122,12 +117,8 @@ def __init__(self, config: BaseSchedulerConfig):
# optional configs
self.disabled_handlers: list | None = self.config.get("disabled_handlers", None)
- self.max_web_log_queue_size = self.config.get(
- "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
- )
- self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue(
- maxsize=self.max_web_log_queue_size
- )
+ self.init_web_log_module(config=self.config)
+
self._consumer_thread = None # Reference to our consumer thread/process
self._consumer_process = None # Reference to our consumer process
self._running = False
@@ -144,14 +135,15 @@ def __init__(self, config: BaseSchedulerConfig):
self.orchestrator = SchedulerOrchestrator()
self.searcher: Searcher | None = None
- self.retriever: SchedulerRetriever | None = None
+ self.search_service: SchedulerSearchService | None = None
+ self.post_processor: MemoryPostProcessor | None = None
+ self.activation_memory_manager: ActivationMemoryManager | None = None
self.db_engine: Engine | None = None
self.monitor: SchedulerGeneralMonitor | None = None
self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None
self.mem_reader = None # Will be set by MOSCore
self._status_tracker: TaskStatusTracker | None = None
self.metrics = metrics
- self._monitor_thread = None
self.memos_message_queue = ScheduleTaskQueue(
use_redis_queue=self.use_redis_queue,
maxsize=self.max_internal_message_queue_size,
@@ -175,6 +167,7 @@ def __init__(self, config: BaseSchedulerConfig):
memos_message_queue=self.memos_message_queue.memos_message_queue,
dispatcher=self.dispatcher,
get_status_parallel=self.get_status_parallel,
+ check_interval=self.config.get("monitor_interval_seconds", 15),
)
# other attributes
@@ -213,6 +206,9 @@ def init_mem_cube(
self.searcher = searcher
self.feedback_server = feedback_server
+ # Initialize search service with the searcher
+ self.search_service = SchedulerSearchService(searcher=self.searcher)
+
def initialize_modules(
self,
chat_llm: BaseLLM,
@@ -241,7 +237,21 @@ def initialize_modules(
)
self.db_engine = self.monitor.db_engine
self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config)
- self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)
+
+ # Initialize search service (will be updated with searcher when mem_cube is initialized)
+ self.search_service = SchedulerSearchService(searcher=self.searcher)
+
+ # Initialize post-processor for memory enhancement and filtering
+ self.post_processor = MemoryPostProcessor(
+ process_llm=self.process_llm, config=self.config
+ )
+
+ self.activation_memory_manager = ActivationMemoryManager(
+ act_mem_dump_path=self.act_mem_dump_path,
+ monitor=self.monitor,
+ log_func_callback=self._submit_web_logs,
+ log_activation_memory_update_func=self.log_activation_memory_update,
+ )
if mem_reader:
self.mem_reader = mem_reader
@@ -352,7 +362,7 @@ def feedback_server(self, value: SimpleMemFeedback) -> None:
def mem_cube(self, value: BaseMemCube) -> None:
"""The memory cube associated with this MemChat."""
self.current_mem_cube = value
- self.retriever.mem_cube = value
+ # No need to set mem_cube on retriever anymore (it's passed per-search now)
@property
def mem_cubes(self) -> dict[str, BaseMemCube]:
@@ -478,7 +488,7 @@ def replace_working_memory(
original_memory = filtered_original_memory
memories_with_new_order, rerank_success_flag = (
- self.retriever.process_and_rerank_memories(
+ self.post_processor.process_and_rerank_memories(
queries=query_history,
original_memory=original_memory,
new_memory=new_memory,
@@ -488,7 +498,7 @@ def replace_working_memory(
# Filter completely unrelated memories according to query_history
logger.info(f"Filtering memories based on query history: {len(query_history)} queries")
- filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories(
+ filtered_memories, filter_success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history,
memories=memories_with_new_order,
)
@@ -586,74 +596,16 @@ def update_activation_memory(
Update activation memory by extracting KVCacheItems from new_memory (list of str),
add them to a KVCacheMemory instance, and dump to disk.
"""
- if len(new_memories) == 0:
- logger.error("update_activation_memory: new_memory is empty.")
- return
- if isinstance(new_memories[0], TextualMemoryItem):
- new_text_memories = [mem.memory for mem in new_memories]
- elif isinstance(new_memories[0], str):
- new_text_memories = new_memories
- else:
- logger.error("Not Implemented.")
- return
-
- try:
- if isinstance(mem_cube.act_mem, VLLMKVCacheMemory):
- act_mem: VLLMKVCacheMemory = mem_cube.act_mem
- elif isinstance(mem_cube.act_mem, KVCacheMemory):
- act_mem: KVCacheMemory = mem_cube.act_mem
- else:
- logger.error("Not Implemented.")
- return
-
- new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
- memory_text="".join(
- [
- f"{i + 1}. {sentence.strip()}\n"
- for i, sentence in enumerate(new_text_memories)
- if sentence.strip() # Skip empty strings
- ]
- )
- )
-
- # huggingface or vllm kv cache
- original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all()
- original_text_memories = []
- if len(original_cache_items) > 0:
- pre_cache_item: VLLMKVCacheItem = original_cache_items[-1]
- original_text_memories = pre_cache_item.records.text_memories
- original_composed_text_memory = pre_cache_item.records.composed_text_memory
- if original_composed_text_memory == new_text_memory:
- logger.warning(
- "Skipping memory update - new composition matches existing cache: %s",
- new_text_memory[:50] + "..."
- if len(new_text_memory) > 50
- else new_text_memory,
- )
- return
- act_mem.delete_all()
-
- cache_item = act_mem.extract(new_text_memory)
- cache_item.records.text_memories = new_text_memories
- cache_item.records.timestamp = get_utc_now()
-
- act_mem.add([cache_item])
- act_mem.dump(self.act_mem_dump_path)
-
- self.log_activation_memory_update(
- original_text_memories=original_text_memories,
- new_text_memories=new_text_memories,
+ if self.activation_memory_manager:
+ self.activation_memory_manager.update_activation_memory(
+ new_memories=new_memories,
label=label,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
- log_func_callback=self._submit_web_logs,
)
-
- except Exception as e:
- logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True)
- # Re-raise the exception if it's critical for the operation
- # For now, we'll continue execution but this should be reviewed
+ else:
+ logger.warning("Activation memory manager not initialized")
def update_activation_memory_periodically(
self,
@@ -663,73 +615,16 @@ def update_activation_memory_periodically(
mem_cube_id: MemCubeID | str,
mem_cube: GeneralMemCube,
):
- try:
- if (
- self.monitor.last_activation_mem_update_time == datetime.min
- or self.monitor.timed_trigger(
- last_time=self.monitor.last_activation_mem_update_time,
- interval_seconds=interval_seconds,
- )
- ):
- logger.info(
- f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}"
- )
-
- if (
- user_id not in self.monitor.working_memory_monitors
- or mem_cube_id not in self.monitor.working_memory_monitors[user_id]
- or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories)
- == 0
- ):
- logger.warning(
- "No memories found in working_memory_monitors, activation memory update is skipped"
- )
- return
-
- self.monitor.update_activation_memory_monitors(
- user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
- )
-
- # Sync with database to get latest activation memories
- activation_db_manager = self.monitor.activation_memory_monitors[user_id][
- mem_cube_id
- ]
- activation_db_manager.sync_with_orm()
- new_activation_memories = [
- m.memory_text for m in activation_db_manager.obj.memories
- ]
-
- logger.info(
- f"Collected {len(new_activation_memories)} new memory entries for processing"
- )
- # Print the content of each new activation memory
- for i, memory in enumerate(new_activation_memories[:5], 1):
- logger.info(
- f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}"
- )
-
- self.update_activation_memory(
- new_memories=new_activation_memories,
- label=label,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- )
-
- self.monitor.last_activation_mem_update_time = get_utc_now()
-
- logger.debug(
- f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
- )
-
- else:
- logger.info(
- f"Skipping update - {interval_seconds} second interval not yet reached. "
- f"Last update time is {self.monitor.last_activation_mem_update_time} and now is "
- f"{get_utc_now()}"
- )
- except Exception as e:
- logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True)
+ if self.activation_memory_manager:
+ self.activation_memory_manager.update_activation_memory_periodically(
+ interval_seconds=interval_seconds,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+ else:
+ logger.warning("Activation memory manager not initialized")
def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
"""Submit messages for processing, with priority-aware dispatch.
@@ -856,103 +751,6 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
if queued_msgs:
self.memos_message_queue.submit_messages(messages=queued_msgs)
- def _submit_web_logs(
- self,
- messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem],
- additional_log_info: str | None = None,
- ) -> None:
- """Submit log messages to the web log queue and optionally to RabbitMQ.
-
- Args:
- messages: Single log message or list of log messages
- """
- if isinstance(messages, ScheduleLogForWebItem):
- messages = [messages] # transform single message to list
-
- for message in messages:
- if self.rabbitmq_config is None:
- return
- try:
- # Always call publish; the publisher now caches when offline and flushes after reconnect
- logger.info(
- f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}"
- )
- self.rabbitmq_publish_message(message=message.to_dict())
- logger.info(
- "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched "
- "item_id=%s task_id=%s label=%s",
- message.item_id,
- message.task_id,
- message.label,
- )
- except Exception as e:
- logger.error(
- f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True
- )
-
- logger.debug(
- f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}"
- )
-
- def get_web_log_messages(self) -> list[dict]:
- """
- Retrieve structured log messages from the queue and return JSON-serializable dicts.
- """
- raw_items: list[ScheduleLogForWebItem] = []
- while True:
- try:
- raw_items.append(self._web_log_message_queue.get_nowait())
- except Exception:
- break
-
- def _map_label(label: str) -> str:
- mapping = {
- QUERY_TASK_LABEL: "addMessage",
- ANSWER_TASK_LABEL: "addMessage",
- ADD_TASK_LABEL: "addMemory",
- MEM_UPDATE_TASK_LABEL: "updateMemory",
- MEM_ORGANIZE_TASK_LABEL: "mergeMemory",
- MEM_ARCHIVE_TASK_LABEL: "archiveMemory",
- }
- return mapping.get(label, label)
-
- def _normalize_item(item: ScheduleLogForWebItem) -> dict:
- data = item.to_dict()
- data["label"] = _map_label(data.get("label"))
- memcube_content = getattr(item, "memcube_log_content", None) or []
- metadata = getattr(item, "metadata", None) or []
-
- memcube_name = getattr(item, "memcube_name", None)
- if not memcube_name and hasattr(self, "_map_memcube_name"):
- memcube_name = self._map_memcube_name(item.mem_cube_id)
- data["memcube_name"] = memcube_name
-
- memory_len = getattr(item, "memory_len", None)
- if memory_len is None:
- if data["label"] == "mergeMemory":
- memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"])
- elif memcube_content:
- memory_len = len(memcube_content)
- else:
- memory_len = 1 if item.log_content else 0
-
- data["memcube_log_content"] = memcube_content
- data["memory_len"] = memory_len
-
- def _with_memory_time(meta: dict) -> dict:
- enriched = dict(meta)
- if "memory_time" not in enriched:
- enriched["memory_time"] = enriched.get("updated_at") or enriched.get(
- "update_at"
- )
- return enriched
-
- data["metadata"] = [_with_memory_time(m) for m in metadata]
- data["log_title"] = ""
- return data
-
- return [_normalize_item(it) for it in raw_items]
-
def _message_consumer(self) -> None:
"""
Continuously checks the queue for messages and dispatches them.
@@ -1041,39 +839,6 @@ def _message_consumer(self) -> None:
logger.error(f"Unexpected error in message consumer: {e!s}", exc_info=True)
time.sleep(self._consume_interval) # Prevent tight error loops
- def _monitor_loop(self):
- while self._running:
- try:
- q_sizes = self.memos_message_queue.qsize()
-
- if not isinstance(q_sizes, dict):
- continue
-
- for stream_key, queue_length in q_sizes.items():
- # Skip aggregate keys like 'total_size'
- if stream_key == "total_size":
- continue
-
- # Key format: ...:{user_id}:{mem_cube_id}:{task_label}
- # We want to extract user_id, which is the 3rd component from the end.
- parts = stream_key.split(":")
- if len(parts) >= 3:
- user_id = parts[-3]
- self.metrics.update_queue_length(queue_length, user_id)
- else:
- # Fallback for unexpected key formats (e.g. legacy or testing)
- # Try to use the key itself if it looks like a user_id (no colons)
- # or just log a warning?
- # For now, let's assume if it's not total_size and short, it might be a direct user_id key
- # (though that shouldn't happen with current queue implementations)
- if ":" not in stream_key:
- self.metrics.update_queue_length(queue_length, stream_key)
-
- except Exception as e:
- logger.error(f"Error in metrics monitor loop: {e}", exc_info=True)
-
- time.sleep(15) # 每 15 秒采样一次
-
def start(self) -> None:
"""
Start the message consumer thread/process and initialize dispatcher resources.
@@ -1081,6 +846,7 @@ def start(self) -> None:
Initializes and starts:
1. Message consumer thread or process (based on startup_mode)
2. Dispatcher thread pool (if parallel dispatch enabled)
+ 3. Task schedule monitor
"""
# Initialize dispatcher resources
if self.enable_parallel_dispatch:
@@ -1089,16 +855,7 @@ def start(self) -> None:
)
self.start_consumer()
- self.start_background_monitor()
-
- def start_background_monitor(self):
- if self._monitor_thread and self._monitor_thread.is_alive():
- return
- self._monitor_thread = ContextThread(
- target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor"
- )
- self._monitor_thread.start()
- logger.info("Scheduler metrics monitor thread started.")
+ self.task_schedule_monitor.start()
def start_consumer(self) -> None:
"""
@@ -1186,8 +943,7 @@ def stop(self) -> None:
# Stop consumer first
self.stop_consumer()
- if self._monitor_thread:
- self._monitor_thread.join(timeout=2.0)
+ self.task_schedule_monitor.stop()
# Shutdown dispatcher
if self.dispatcher:
diff --git a/src/memos/mem_scheduler/general_modules/scheduler_context.py b/src/memos/mem_scheduler/general_modules/scheduler_context.py
new file mode 100644
index 000000000..712f0ccdc
--- /dev/null
+++ b/src/memos/mem_scheduler/general_modules/scheduler_context.py
@@ -0,0 +1,99 @@
+from typing import TYPE_CHECKING, Any
+
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Engine
+
+ from memos.configs.mem_scheduler import GeneralSchedulerConfig
+ from memos.mem_cube.general import GeneralMemCube
+ from memos.mem_feedback.simple_feedback import SimpleMemFeedback
+ from memos.mem_scheduler.general_scheduler import GeneralScheduler
+ from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
+ from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
+ from memos.mem_scheduler.schemas.message_schemas import (
+ ScheduleLogForWebItem,
+ ScheduleMessageItem,
+ )
+ from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
+
+
+class SchedulerContext:
+ """Context object to hold dependencies for handlers."""
+
+ def __init__(self, scheduler: "GeneralScheduler"):
+ self._scheduler = scheduler
+
+ @property
+ def mem_cube(self) -> "GeneralMemCube":
+ return self._scheduler.mem_cube
+
+ @property
+ def monitor(self) -> "SchedulerGeneralMonitor":
+ return self._scheduler.monitor
+
+ @property
+ def retriever(self) -> "SchedulerRetriever":
+ return self._scheduler.retriever
+
+ @property
+ def config(self) -> "GeneralSchedulerConfig":
+ return self._scheduler.config
+
+ @property
+ def dispatcher(self) -> "SchedulerDispatcher":
+ return self._scheduler.dispatcher
+
+ @property
+ def db_engine(self) -> "Engine | None":
+ return self._scheduler.db_engine
+
+ @property
+ def feedback_server(self) -> "SimpleMemFeedback | None":
+ return self._scheduler.feedback_server
+
+ @property
+ def mem_reader(self) -> Any:
+ return self._scheduler.mem_reader
+
+ # Methods
+ def submit_web_logs(self, messages: list["ScheduleLogForWebItem"]) -> None:
+ self._scheduler._submit_web_logs(messages)
+
+ def create_event_log(self, *args, **kwargs) -> "ScheduleLogForWebItem":
+ return self._scheduler.create_event_log(*args, **kwargs)
+
+ def validate_schedule_messages(self, messages: list["ScheduleMessageItem"], label: str) -> bool:
+ return self._scheduler.validate_schedule_messages(messages, label)
+
+ def submit_messages(self, messages: list["ScheduleMessageItem"]) -> None:
+ self._scheduler.submit_messages(messages)
+
+ def long_memory_update_process(self, *args, **kwargs) -> None:
+ self._scheduler.long_memory_update_process(*args, **kwargs)
+
+ def process_session_turn(self, *args, **kwargs) -> Any:
+ return self._scheduler.process_session_turn(*args, **kwargs)
+
+ def replace_working_memory(self, *args, **kwargs) -> Any:
+ return self._scheduler.replace_working_memory(*args, **kwargs)
+
+ def update_activation_memory_periodically(self, *args, **kwargs) -> None:
+ self._scheduler.update_activation_memory_periodically(*args, **kwargs)
+
+ def map_memcube_name(self, mem_cube_id: str) -> str:
+ return self._scheduler._map_memcube_name(mem_cube_id)
+
+ def log_add_messages(self, *args, **kwargs) -> Any:
+ return self._scheduler.log_add_messages(*args, **kwargs)
+
+ def send_add_log_messages_to_cloud_env(self, *args, **kwargs) -> None:
+ self._scheduler.send_add_log_messages_to_cloud_env(*args, **kwargs)
+
+ def send_add_log_messages_to_local_env(self, *args, **kwargs) -> None:
+ self._scheduler.send_add_log_messages_to_local_env(*args, **kwargs)
+
+ def process_memories_with_reader(self, *args, **kwargs) -> None:
+ self._scheduler._process_memories_with_reader(*args, **kwargs)
+
+ def process_memories_with_reorganize(self, *args, **kwargs) -> None:
+ self._scheduler._process_memories_with_reorganize(*args, **kwargs)
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index d4ac09cc3..c86bd93fc 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -1,13 +1,12 @@
-import concurrent.futures
import contextlib
import json
import traceback
from memos.configs.mem_scheduler import GeneralSchedulerConfig
-from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.base_scheduler import BaseScheduler
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem
from memos.mem_scheduler.schemas.task_schemas import (
@@ -19,23 +18,34 @@
MEM_ORGANIZE_TASK_LABEL,
MEM_READ_TASK_LABEL,
MEM_UPDATE_TASK_LABEL,
- NOT_APPLICABLE_TYPE,
PREF_ADD_TASK_LABEL,
QUERY_TASK_LABEL,
USER_INPUT_TYPE,
)
+from memos.mem_scheduler.task_schedule_modules.handlers.add_handler import AddHandler
+from memos.mem_scheduler.task_schedule_modules.handlers.answer_handler import AnswerHandler
+from memos.mem_scheduler.task_schedule_modules.handlers.mem_feedback_handler import (
+ MemFeedbackHandler,
+)
+from memos.mem_scheduler.task_schedule_modules.handlers.mem_read_handler import MemReadHandler
+from memos.mem_scheduler.task_schedule_modules.handlers.mem_reorganize_handler import (
+ MemReorganizeHandler,
+)
+from memos.mem_scheduler.task_schedule_modules.handlers.memory_update_handler import (
+ MemoryUpdateHandler,
+)
+from memos.mem_scheduler.task_schedule_modules.handlers.pref_add_handler import PrefAddHandler
+from memos.mem_scheduler.task_schedule_modules.handlers.query_handler import QueryHandler
from memos.mem_scheduler.utils.filter_utils import (
is_all_chinese,
is_all_english,
transform_name_to_key,
)
from memos.mem_scheduler.utils.misc_utils import (
- group_messages_by_user_and_mem_cube,
is_cloud_env,
)
from memos.memories.textual.item import TextualMemoryItem
from memos.memories.textual.naive import NaiveTextMemory
-from memos.memories.textual.preference import PreferenceTextMemory
from memos.memories.textual.tree import TreeTextMemory
from memos.types import (
MemCubeID,
@@ -53,16 +63,19 @@ def __init__(self, config: GeneralSchedulerConfig):
self.query_key_words_limit = self.config.get("query_key_words_limit", 20)
+ # Initialize context
+ context = SchedulerContext(self)
+
# register handlers
handlers = {
- QUERY_TASK_LABEL: self._query_message_consumer,
- ANSWER_TASK_LABEL: self._answer_message_consumer,
- MEM_UPDATE_TASK_LABEL: self._memory_update_consumer,
- ADD_TASK_LABEL: self._add_message_consumer,
- MEM_READ_TASK_LABEL: self._mem_read_message_consumer,
- MEM_ORGANIZE_TASK_LABEL: self._mem_reorganize_message_consumer,
- PREF_ADD_TASK_LABEL: self._pref_add_message_consumer,
- MEM_FEEDBACK_TASK_LABEL: self._mem_feedback_message_consumer,
+ QUERY_TASK_LABEL: QueryHandler(context),
+ ANSWER_TASK_LABEL: AnswerHandler(context),
+ MEM_UPDATE_TASK_LABEL: MemoryUpdateHandler(context),
+ ADD_TASK_LABEL: AddHandler(context),
+ MEM_READ_TASK_LABEL: MemReadHandler(context),
+ MEM_ORGANIZE_TASK_LABEL: MemReorganizeHandler(context),
+ PREF_ADD_TASK_LABEL: PrefAddHandler(context),
+ MEM_FEEDBACK_TASK_LABEL: MemFeedbackHandler(context),
}
self.dispatcher.register_handlers(handlers)
@@ -175,160 +188,6 @@ def long_memory_update_process(
mem_cube=self.mem_cube,
)
- def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.")
- # Process the query in a session turn
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=ADD_TASK_LABEL)
- try:
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
-
- # Process each message in the batch
- for msg in batch:
- prepared_add_items, prepared_update_items_with_original = (
- self.log_add_messages(msg=msg)
- )
- logger.info(
- f"prepared_add_items: {prepared_add_items};\n prepared_update_items_with_original: {prepared_update_items_with_original}"
- )
- # Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default
- cloud_env = is_cloud_env()
-
- if cloud_env:
- self.send_add_log_messages_to_cloud_env(
- msg, prepared_add_items, prepared_update_items_with_original
- )
- else:
- self.send_add_log_messages_to_local_env(
- msg, prepared_add_items, prepared_update_items_with_original
- )
-
- except Exception as e:
- logger.error(f"Error: {e}", exc_info=True)
-
- def _memory_update_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.")
-
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL)
-
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
- # Process the whole batch once; no need to iterate per message
- self.long_memory_update_process(
- user_id=user_id, mem_cube_id=mem_cube_id, messages=batch
- )
-
- def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- """
- Process and handle query trigger messages from the queue.
-
- Args:
- messages: List of query messages to process
- """
- logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.")
-
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=QUERY_TASK_LABEL)
-
- mem_update_messages = []
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
-
- for msg in batch:
- try:
- event = self.create_event_log(
- label="addMessage",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=NOT_APPLICABLE_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=[
- {
- "content": f"[User] {msg.content}",
- "ref_id": msg.item_id,
- "role": "user",
- }
- ],
- metadata=[],
- memory_len=1,
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.task_id = msg.task_id
- self._submit_web_logs([event])
- except Exception:
- logger.exception("Failed to record addMessage log for query")
- # Re-submit the message with label changed to mem_update
- update_msg = ScheduleMessageItem(
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- label=MEM_UPDATE_TASK_LABEL,
- content=msg.content,
- session_id=msg.session_id,
- user_name=msg.user_name,
- info=msg.info,
- task_id=msg.task_id,
- )
- mem_update_messages.append(update_msg)
-
- self.submit_messages(messages=mem_update_messages)
-
- def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- """
- Process and handle answer trigger messages from the queue.
-
- Args:
- messages: List of answer messages to process
- """
- logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.")
- grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
-
- self.validate_schedule_messages(messages=messages, label=ANSWER_TASK_LABEL)
-
- for user_id in grouped_messages:
- for mem_cube_id in grouped_messages[user_id]:
- batch = grouped_messages[user_id][mem_cube_id]
- if not batch:
- continue
- try:
- for msg in batch:
- event = self.create_event_log(
- label="addMessage",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=NOT_APPLICABLE_TYPE,
- user_id=msg.user_id,
- mem_cube_id=msg.mem_cube_id,
- mem_cube=self.mem_cube,
- memcube_log_content=[
- {
- "content": f"[Assistant] {msg.content}",
- "ref_id": msg.item_id,
- "role": "assistant",
- }
- ],
- metadata=[],
- memory_len=1,
- memcube_name=self._map_memcube_name(msg.mem_cube_id),
- )
- event.task_id = msg.task_id
- self._submit_web_logs([event])
- except Exception:
- logger.exception("Failed to record addMessage log for answer")
-
def log_add_messages(self, msg: ScheduleMessageItem):
try:
userinput_memory_ids = json.loads(msg.content)
@@ -579,234 +438,6 @@ def send_add_log_messages_to_cloud_env(
event.task_id = msg.task_id
self._submit_web_logs([event])
- def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- try:
- if not messages:
- return
- message = messages[0]
- mem_cube = self.mem_cube
-
- user_id = message.user_id
- mem_cube_id = message.mem_cube_id
- content = message.content
-
- try:
- feedback_data = json.loads(content) if isinstance(content, str) else content
- if not isinstance(feedback_data, dict):
- logger.error(
- f"Failed to decode feedback_data or it is not a dict: {feedback_data}"
- )
- return
- except json.JSONDecodeError:
- logger.error(f"Invalid JSON content for feedback message: {content}", exc_info=True)
- return
-
- task_id = feedback_data.get("task_id") or message.task_id
- feedback_result = self.feedback_server.process_feedback(
- user_id=user_id,
- user_name=mem_cube_id,
- session_id=feedback_data.get("session_id"),
- chat_history=feedback_data.get("history", []),
- retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []),
- feedback_content=feedback_data.get("feedback_content"),
- feedback_time=feedback_data.get("feedback_time"),
- task_id=task_id,
- info=feedback_data.get("info", None),
- )
-
- logger.info(
- f"Successfully processed feedback for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- cloud_env = is_cloud_env()
- if cloud_env:
- record = feedback_result.get("record") if isinstance(feedback_result, dict) else {}
- add_records = record.get("add") if isinstance(record, dict) else []
- update_records = record.get("update") if isinstance(record, dict) else []
-
- def _extract_fields(mem_item):
- mem_id = (
- getattr(mem_item, "id", None)
- if not isinstance(mem_item, dict)
- else mem_item.get("id")
- )
- mem_memory = (
- getattr(mem_item, "memory", None)
- if not isinstance(mem_item, dict)
- else mem_item.get("memory") or mem_item.get("text")
- )
- if mem_memory is None and isinstance(mem_item, dict):
- mem_memory = mem_item.get("text")
- original_content = (
- getattr(mem_item, "origin_memory", None)
- if not isinstance(mem_item, dict)
- else mem_item.get("origin_memory")
- or mem_item.get("old_memory")
- or mem_item.get("original_content")
- )
- source_doc_id = None
- if isinstance(mem_item, dict):
- source_doc_id = mem_item.get("source_doc_id", None)
-
- return mem_id, mem_memory, original_content, source_doc_id
-
- kb_log_content: list[dict] = []
-
- for mem_item in add_records or []:
- mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item)
- if mem_id and mem_memory:
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": "Feedback",
- "operation": "ADD",
- "memory_id": mem_id,
- "content": mem_memory,
- "original_content": None,
- "source_doc_id": source_doc_id,
- }
- )
- else:
- logger.warning(
- "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
- user_id,
- mem_cube_id,
- task_id,
- mem_item,
- stack_info=True,
- )
-
- for mem_item in update_records or []:
- mem_id, mem_memory, original_content, source_doc_id = _extract_fields(mem_item)
- if mem_id and mem_memory:
- kb_log_content.append(
- {
- "log_source": "KNOWLEDGE_BASE_LOG",
- "trigger_source": "Feedback",
- "operation": "UPDATE",
- "memory_id": mem_id,
- "content": mem_memory,
- "original_content": original_content,
- "source_doc_id": source_doc_id,
- }
- )
- else:
- logger.warning(
- "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
- user_id,
- mem_cube_id,
- task_id,
- mem_item,
- stack_info=True,
- )
-
- logger.info(f"[Feedback Scheduler] kb_log_content: {kb_log_content!s}")
- if kb_log_content:
- logger.info(
- "[DIAGNOSTIC] general_scheduler._mem_feedback_message_consumer: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s",
- user_id,
- mem_cube_id,
- task_id,
- len(kb_log_content),
- )
- event = self.create_event_log(
- label="knowledgeBaseUpdate",
- from_memory_type=USER_INPUT_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- memcube_log_content=kb_log_content,
- metadata=None,
- memory_len=len(kb_log_content),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- event.log_content = (
- f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
- )
- event.task_id = task_id
- self._submit_web_logs([event])
- else:
- logger.warning(
- "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s",
- user_id,
- mem_cube_id,
- task_id,
- stack_info=True,
- )
- else:
- logger.info(
- "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)",
- cloud_env,
- )
-
- except Exception as e:
- logger.error(f"Error processing feedbackMemory message: {e}", exc_info=True)
-
- def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(
- f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}"
- )
- logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.")
-
- def process_message(message: ScheduleMessageItem):
- try:
- user_id = message.user_id
- mem_cube_id = message.mem_cube_id
- mem_cube = self.mem_cube
- if mem_cube is None:
- logger.error(
- f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing",
- stack_info=True,
- )
- return
-
- content = message.content
- user_name = message.user_name
- info = message.info or {}
-
- # Parse the memory IDs from content
- mem_ids = json.loads(content) if isinstance(content, str) else content
- if not mem_ids:
- return
-
- logger.info(
- f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}"
- )
-
- # Get the text memory from the mem_cube
- text_mem = mem_cube.text_mem
- if not isinstance(text_mem, TreeTextMemory):
- logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}")
- return
-
- # Use mem_reader to process the memories
- self._process_memories_with_reader(
- mem_ids=mem_ids,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- text_mem=text_mem,
- user_name=user_name,
- custom_tags=info.get("custom_tags", None),
- task_id=message.task_id,
- info=info,
- )
-
- logger.info(
- f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- except Exception as e:
- logger.error(f"Error processing mem_read message: {e}", stack_info=True)
-
- with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor:
- futures = [executor.submit(process_message, msg) for msg in messages]
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- logger.error(f"Thread task failed: {e}", stack_info=True)
-
def _process_memories_with_reader(
self,
mem_ids: list[str],
@@ -1087,185 +718,6 @@ def _process_memories_with_reader(
event.status = "failed"
self._submit_web_logs([event])
- def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.")
-
- def process_message(message: ScheduleMessageItem):
- try:
- user_id = message.user_id
- mem_cube_id = message.mem_cube_id
- mem_cube = self.mem_cube
- if mem_cube is None:
- logger.warning(
- f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing"
- )
- return
- content = message.content
- user_name = message.user_name
-
- # Parse the memory IDs from content
- mem_ids = json.loads(content) if isinstance(content, str) else content
- if not mem_ids:
- return
-
- logger.info(
- f"Processing mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}"
- )
-
- # Get the text memory from the mem_cube
- text_mem = mem_cube.text_mem
- if not isinstance(text_mem, TreeTextMemory):
- logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}")
- return
-
- # Use mem_reader to process the memories
- self._process_memories_with_reorganize(
- mem_ids=mem_ids,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- text_mem=text_mem,
- user_name=user_name,
- )
-
- with contextlib.suppress(Exception):
- mem_items: list[TextualMemoryItem] = []
- for mid in mem_ids:
- with contextlib.suppress(Exception):
- mem_items.append(text_mem.get(mid, user_name=user_name))
- if len(mem_items) > 1:
- keys: list[str] = []
- memcube_content: list[dict] = []
- meta: list[dict] = []
- merged_target_ids: set[str] = set()
- with contextlib.suppress(Exception):
- if hasattr(text_mem, "graph_store"):
- for mid in mem_ids:
- edges = text_mem.graph_store.get_edges(
- mid, type="MERGED_TO", direction="OUT"
- )
- for edge in edges:
- target = (
- edge.get("to") or edge.get("dst") or edge.get("target")
- )
- if target:
- merged_target_ids.add(target)
- for item in mem_items:
- key = getattr(
- getattr(item, "metadata", {}), "key", None
- ) or transform_name_to_key(getattr(item, "memory", ""))
- keys.append(key)
- memcube_content.append(
- {"content": key or "(no key)", "ref_id": item.id, "type": "merged"}
- )
- meta.append(
- {
- "ref_id": item.id,
- "id": item.id,
- "key": key,
- "memory": item.memory,
- "memory_type": item.metadata.memory_type,
- "status": item.metadata.status,
- "confidence": item.metadata.confidence,
- "tags": item.metadata.tags,
- "updated_at": getattr(item.metadata, "updated_at", None)
- or getattr(item.metadata, "update_at", None),
- }
- )
- combined_key = keys[0] if keys else ""
- post_ref_id = None
- post_meta = {
- "ref_id": None,
- "id": None,
- "key": None,
- "memory": None,
- "memory_type": None,
- "status": None,
- "confidence": None,
- "tags": None,
- "updated_at": None,
- }
- if merged_target_ids:
- post_ref_id = next(iter(merged_target_ids))
- with contextlib.suppress(Exception):
- merged_item = text_mem.get(post_ref_id, user_name=user_name)
- combined_key = (
- getattr(getattr(merged_item, "metadata", {}), "key", None)
- or combined_key
- )
- post_meta = {
- "ref_id": post_ref_id,
- "id": post_ref_id,
- "key": getattr(
- getattr(merged_item, "metadata", {}), "key", None
- ),
- "memory": getattr(merged_item, "memory", None),
- "memory_type": getattr(
- getattr(merged_item, "metadata", {}), "memory_type", None
- ),
- "status": getattr(
- getattr(merged_item, "metadata", {}), "status", None
- ),
- "confidence": getattr(
- getattr(merged_item, "metadata", {}), "confidence", None
- ),
- "tags": getattr(
- getattr(merged_item, "metadata", {}), "tags", None
- ),
- "updated_at": getattr(
- getattr(merged_item, "metadata", {}), "updated_at", None
- )
- or getattr(
- getattr(merged_item, "metadata", {}), "update_at", None
- ),
- }
- if not post_ref_id:
- import hashlib
-
- post_ref_id = f"merge-{hashlib.md5(''.join(sorted(mem_ids)).encode()).hexdigest()}"
- post_meta["ref_id"] = post_ref_id
- post_meta["id"] = post_ref_id
- if not post_meta.get("key"):
- post_meta["key"] = combined_key
- if not keys:
- keys = [item.id for item in mem_items]
- memcube_content.append(
- {
- "content": combined_key if combined_key else "(no key)",
- "ref_id": post_ref_id,
- "type": "postMerge",
- }
- )
- meta.append(post_meta)
- event = self.create_event_log(
- label="mergeMemory",
- from_memory_type=LONG_TERM_MEMORY_TYPE,
- to_memory_type=LONG_TERM_MEMORY_TYPE,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- memcube_log_content=memcube_content,
- metadata=meta,
- memory_len=len(keys),
- memcube_name=self._map_memcube_name(mem_cube_id),
- )
- self._submit_web_logs([event])
-
- logger.info(
- f"Successfully processed mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}"
- )
-
- except Exception as e:
- logger.error(f"Error processing mem_reorganize message: {e}", exc_info=True)
-
- with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor:
- futures = [executor.submit(process_message, msg) for msg in messages]
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- logger.error(f"Thread task failed: {e}", exc_info=True)
-
def _process_memories_with_reorganize(
self,
mem_ids: list[str],
@@ -1319,71 +771,6 @@ def _process_memories_with_reorganize(
exc_info=True,
)
- def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
- logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.")
-
- def process_message(message: ScheduleMessageItem):
- try:
- mem_cube = self.mem_cube
- if mem_cube is None:
- logger.warning(
- f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing"
- )
- return
-
- user_id = message.user_id
- session_id = message.session_id
- mem_cube_id = message.mem_cube_id
- content = message.content
- messages_list = json.loads(content)
- info = message.info or {}
-
- logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}")
-
- # Get the preference memory from the mem_cube
- pref_mem = mem_cube.pref_mem
- if pref_mem is None:
- logger.warning(
- f"Preference memory not initialized for mem_cube_id={mem_cube_id}, "
- f"skipping pref_add processing"
- )
- return
- if not isinstance(pref_mem, PreferenceTextMemory):
- logger.error(
- f"Expected PreferenceTextMemory but got {type(pref_mem).__name__} "
- f"for mem_cube_id={mem_cube_id}"
- )
- return
-
- # Use pref_mem.get_memory to process the memories
- pref_memories = pref_mem.get_memory(
- messages_list,
- type="chat",
- info={
- **info,
- "user_id": user_id,
- "session_id": session_id,
- "mem_cube_id": mem_cube_id,
- },
- )
- # Add pref_mem to vector db
- pref_ids = pref_mem.add(pref_memories)
-
- logger.info(
- f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}"
- )
-
- except Exception as e:
- logger.error(f"Error processing pref_add message: {e}", exc_info=True)
-
- with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor:
- futures = [executor.submit(process_message, msg) for msg in messages]
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- logger.error(f"Thread task failed: {e}", exc_info=True)
-
def process_session_turn(
self,
queries: str | list[str],
@@ -1463,28 +850,41 @@ def process_session_turn(
f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}"
)
- search_args = {}
+ # Determine search mode from method
+ from memos.mem_scheduler.schemas.general_schemas import (
+ TreeTextMemory_FINE_SEARCH_METHOD,
+ TreeTextMemory_SEARCH_METHOD,
+ )
+ from memos.types.general_types import SearchMode
+
+ # Convert search_method to SearchMode
+ if self.search_method == TreeTextMemory_FINE_SEARCH_METHOD:
+ mode = SearchMode.FINE
+ elif self.search_method == TreeTextMemory_SEARCH_METHOD:
+ mode = SearchMode.FAST
+ else:
+ # Fallback to FAST mode for unknown methods
+ logger.warning(
+ f"Unknown search_method '{self.search_method}', falling back to SearchMode.FAST"
+ )
+ mode = SearchMode.FAST
+
if isinstance(text_mem_base, NaiveTextMemory):
- # NaiveTextMemory doesn't support complex search args usually, but let's see
- # self.retriever.search calls mem_cube.text_mem.search
- # NaiveTextMemory.search takes query and top_k
- # SchedulerRetriever.search handles method dispatch
- # For NaiveTextMemory, we might need to bypass retriever or extend it
- # But let's try calling naive memory directly if retriever fails or doesn't support it
+ # NaiveTextMemory: Use direct search as fallback
try:
results = text_mem_base.search(query=item, top_k=k_per_evidence)
except Exception as e:
logger.warning(f"NaiveTextMemory search failed: {e}")
results = []
else:
- results: list[TextualMemoryItem] = self.retriever.search(
+ # Use unified search service
+ results: list[TextualMemoryItem] = self.search_service.search(
query=item,
user_id=user_id,
- mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
top_k=k_per_evidence,
- method=self.search_method,
- search_args=search_args,
+ mode=mode,
+ mem_cube_id=mem_cube_id,
)
logger.info(
diff --git a/src/memos/mem_scheduler/memory_manage_modules/__init__.py b/src/memos/mem_scheduler/memory_manage_modules/__init__.py
index 94d704294..edb9f2a42 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/__init__.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/__init__.py
@@ -1,5 +1,10 @@
from .memory_filter import MemoryFilter
-from .retriever import SchedulerRetriever
+from .post_processor import MemoryPostProcessor
+from .search_service import SchedulerSearchService
-__all__ = ["MemoryFilter", "SchedulerRetriever"]
+__all__ = [
+ "MemoryFilter",
+ "MemoryPostProcessor",
+ "SchedulerSearchService",
+]
diff --git a/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py b/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py
new file mode 100644
index 000000000..589d0e421
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/activation_memory_manager.py
@@ -0,0 +1,186 @@
+from collections.abc import Callable
+from datetime import datetime
+
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
+from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.memories.activation.kv import KVCacheMemory
+from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
+from memos.memories.textual.tree import TextualMemoryItem
+from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
+from memos.types.general_types import MemCubeID, UserID
+
+
+logger = get_logger(__name__)
+
+
+class ActivationMemoryManager:
+ def __init__(
+ self,
+ act_mem_dump_path: str,
+ monitor: SchedulerGeneralMonitor,
+ log_func_callback: Callable,
+ log_activation_memory_update_func: Callable,
+ ):
+ self.act_mem_dump_path = act_mem_dump_path
+ self.monitor = monitor
+ self.log_func_callback = log_func_callback
+ self.log_activation_memory_update_func = log_activation_memory_update_func
+
+ def update_activation_memory(
+ self,
+ new_memories: list[str | TextualMemoryItem],
+ label: str,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube: GeneralMemCube,
+ ) -> None:
+ """
+ Update activation memory by extracting KVCacheItems from new_memory (list of str),
+ add them to a KVCacheMemory instance, and dump to disk.
+ """
+ if len(new_memories) == 0:
+ logger.error("update_activation_memory: new_memory is empty.")
+ return
+ if isinstance(new_memories[0], TextualMemoryItem):
+ new_text_memories = [mem.memory for mem in new_memories]
+ elif isinstance(new_memories[0], str):
+ new_text_memories = new_memories
+ else:
+ logger.error("Not Implemented.")
+ return
+
+ try:
+ if isinstance(mem_cube.act_mem, VLLMKVCacheMemory):
+ act_mem: VLLMKVCacheMemory = mem_cube.act_mem
+ elif isinstance(mem_cube.act_mem, KVCacheMemory):
+ act_mem: KVCacheMemory = mem_cube.act_mem
+ else:
+ logger.error("Not Implemented.")
+ return
+
+ new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
+ memory_text="".join(
+ [
+ f"{i + 1}. {sentence.strip()}\n"
+ for i, sentence in enumerate(new_text_memories)
+ if sentence.strip() # Skip empty strings
+ ]
+ )
+ )
+
+ # huggingface or vllm kv cache
+ original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all()
+ original_text_memories = []
+ if len(original_cache_items) > 0:
+ pre_cache_item: VLLMKVCacheItem = original_cache_items[-1]
+ original_text_memories = pre_cache_item.records.text_memories
+ original_composed_text_memory = pre_cache_item.records.composed_text_memory
+ if original_composed_text_memory == new_text_memory:
+ logger.warning(
+ "Skipping memory update - new composition matches existing cache: %s",
+ new_text_memory[:50] + "..."
+ if len(new_text_memory) > 50
+ else new_text_memory,
+ )
+ return
+ act_mem.delete_all()
+
+ cache_item = act_mem.extract(new_text_memory)
+ cache_item.records.text_memories = new_text_memories
+ cache_item.records.timestamp = get_utc_now()
+
+ act_mem.add([cache_item])
+ act_mem.dump(self.act_mem_dump_path)
+
+ self.log_activation_memory_update_func(
+ original_text_memories=original_text_memories,
+ new_text_memories=new_text_memories,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ log_func_callback=self.log_func_callback,
+ )
+
+ except Exception as e:
+ logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True)
+ # Re-raise the exception if it's critical for the operation
+ # For now, we'll continue execution but this should be reviewed
+
+ def update_activation_memory_periodically(
+ self,
+ interval_seconds: int,
+ label: str,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube: GeneralMemCube,
+ ):
+ try:
+ if (
+ self.monitor.last_activation_mem_update_time == datetime.min
+ or self.monitor.timed_trigger(
+ last_time=self.monitor.last_activation_mem_update_time,
+ interval_seconds=interval_seconds,
+ )
+ ):
+ logger.info(
+ f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}"
+ )
+
+ if (
+ user_id not in self.monitor.working_memory_monitors
+ or mem_cube_id not in self.monitor.working_memory_monitors[user_id]
+ or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories)
+ == 0
+ ):
+ logger.warning(
+ "No memories found in working_memory_monitors, activation memory update is skipped"
+ )
+ return
+
+ self.monitor.update_activation_memory_monitors(
+ user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
+ )
+
+ # Sync with database to get latest activation memories
+ activation_db_manager = self.monitor.activation_memory_monitors[user_id][
+ mem_cube_id
+ ]
+ activation_db_manager.sync_with_orm()
+ new_activation_memories = [
+ m.memory_text for m in activation_db_manager.obj.memories
+ ]
+
+ logger.info(
+ f"Collected {len(new_activation_memories)} new memory entries for processing"
+ )
+ # Print the content of each new activation memory
+ for i, memory in enumerate(new_activation_memories[:5], 1):
+ logger.info(
+ f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}"
+ )
+
+ self.update_activation_memory(
+ new_memories=new_activation_memories,
+ label=label,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+
+ self.monitor.last_activation_mem_update_time = get_utc_now()
+
+ logger.debug(
+ f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
+ )
+
+ else:
+ logger.info(
+ f"Skipping update - {interval_seconds} second interval not yet reached. "
+ f"Last update time is {self.monitor.last_activation_mem_update_time} and now is "
+ f"{get_utc_now()}"
+ )
+ except Exception as e:
+ logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
index 25b9a98f3..5a495fad2 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
@@ -2,8 +2,8 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
-from memos.mem_scheduler.utils.misc_utils import extract_json_obj
from memos.memories.textual.tree import TextualMemoryItem
+from memos.utils import extract_json_obj
logger = get_logger(__name__)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/post_processor.py b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py
new file mode 100644
index 000000000..2e1821e1e
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/post_processor.py
@@ -0,0 +1,307 @@
+"""
+Memory Post-Processor - Handles post-retrieval memory filtering and reranking.
+
+This module provides post-processing operations for retrieved memories,
+including filtering and reranking operations specific to the scheduler's needs.
+
+Note: Memory enhancement operations (enhance_memories_with_query, recall_for_missing_memories)
+have been moved to AdvancedSearcher for better architectural separation.
+"""
+
+from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.llms.base import BaseLLM
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
+from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+)
+from memos.mem_scheduler.utils.filter_utils import (
+ filter_too_short_memories,
+ filter_vector_based_similar_memories,
+ transform_name_to_key,
+)
+from memos.memories.textual.item import TextualMemoryItem
+from memos.utils import extract_json_obj
+
+from .memory_filter import MemoryFilter
+
+
+logger = get_logger(__name__)
+
+
+class MemoryPostProcessor(BaseSchedulerModule):
+ """
+ Post-processor for retrieved memories.
+
+ This class handles scheduler-specific post-retrieval operations:
+ - Memory filtering: Remove unrelated or redundant memories
+ - Memory reranking: Reorder memories by relevance
+ - Memory evaluation: Assess memory's ability to answer queries
+
+ Design principles:
+ - Single Responsibility: Only handles filtering/reranking, not enhancement or retrieval
+ - Composable: Can be used independently or chained together
+ - Testable: Each operation can be tested in isolation
+
+ Note: Memory enhancement operations have been moved to AdvancedSearcher.
+
+ Usage:
+ processor = MemoryPostProcessor(process_llm=llm, config=config)
+
+ # Filter out unrelated memories
+ filtered, _ = processor.filter_unrelated_memories(
+ query_history=["What is Python?"],
+ memories=raw_memories
+ )
+
+ # Rerank memories by relevance
+ reranked, _ = processor.process_and_rerank_memories(
+ queries=["What is Python?"],
+ original_memory=filtered,
+ new_memory=[],
+ top_k=10
+ )
+ """
+
+ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
+ """
+ Initialize the post-processor.
+
+ Args:
+ process_llm: LLM instance for enhancement and filtering operations
+ config: Scheduler configuration containing batch sizes and retry settings
+ """
+ super().__init__()
+
+ # Core dependencies
+ self.process_llm = process_llm
+ self.config = config
+ self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
+
+ # Configuration
+ self.filter_similarity_threshold = 0.75
+ self.filter_min_length_threshold = 6
+
+ # NOTE: Config keys still use "scheduler_retriever_*" prefix for backward compatibility
+ # TODO: Consider renaming to "post_processor_*" in future config refactor
+ self.batch_size: int | None = getattr(
+ config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
+ )
+ self.retries: int = getattr(
+ config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
+ )
+
+ def evaluate_memory_answer_ability(
+ self, query: str, memory_texts: list[str], top_k: int | None = None
+ ) -> bool:
+ """
+ Evaluate whether the given memories can answer the query.
+
+ This method uses LLM to assess if the provided memories contain
+ sufficient information to answer the given query.
+
+ Args:
+ query: The query to be answered
+ memory_texts: List of memory text strings
+ top_k: Optional limit on number of memories to consider
+
+ Returns:
+ Boolean indicating whether memories can answer the query
+ """
+ limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
+
+ # Build prompt using the template
+ prompt = self.build_prompt(
+ template_name="memory_answer_ability_evaluation",
+ query=query,
+ memory_list="\n".join([f"- {memory}" for memory in limited_memories])
+ if limited_memories
+ else "No memories available",
+ )
+
+ # Use the process LLM to generate response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ try:
+ result = extract_json_obj(response)
+
+ # Validate response structure
+ if "result" in result:
+ logger.info(
+ f"[Answerability] result={result['result']}; "
+ f"reason={result.get('reason', 'n/a')}; "
+ f"evaluated={len(limited_memories)}"
+ )
+ return result["result"]
+ else:
+ logger.warning(f"[Answerability] invalid LLM JSON structure; payload={result}")
+ return False
+
+ except Exception as e:
+ logger.error(f"[Answerability] parse failed; err={e}; raw={str(response)[:200]}...")
+ return False
+
+ def rerank_memories(
+ self, queries: list[str], original_memories: list[str], top_k: int
+ ) -> tuple[list[str], bool]:
+ """
+ Rerank memories based on relevance to given queries using LLM.
+
+ Args:
+ queries: List of query strings to determine relevance
+ original_memories: List of memory strings to be reranked
+ top_k: Number of top memories to return after reranking
+
+ Returns:
+ Tuple of (reranked_memories, success_flag)
+ - reranked_memories: List of reranked memory strings (length <= top_k)
+ - success_flag: True if reranking succeeded
+
+ Note:
+ If LLM reranking fails, falls back to original order (truncated to top_k)
+ """
+ logger.info(f"Starting memory reranking for {len(original_memories)} memories")
+
+ # Build LLM prompt for memory reranking
+ prompt = self.build_prompt(
+ "memory_reranking",
+ queries=[f"[0] {queries[0]}"],
+ current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
+ )
+ logger.debug(f"Generated reranking prompt: {prompt[:200]}...")
+
+ # Get LLM response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug(f"Received LLM response: {response[:200]}...")
+
+ try:
+ # Parse JSON response
+ response = extract_json_obj(response)
+ new_order = response["new_order"][:top_k]
+ text_memories_with_new_order = [original_memories[idx] for idx in new_order]
+ logger.info(
+ f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items; "
+ f"Ranking reasoning: {response['reasoning']}"
+ )
+ success_flag = True
+ except Exception as e:
+ logger.error(
+ f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ",
+ exc_info=True,
+ )
+ text_memories_with_new_order = original_memories[:top_k]
+ success_flag = False
+
+ return text_memories_with_new_order, success_flag
+
+ def process_and_rerank_memories(
+ self,
+ queries: list[str],
+ original_memory: list[TextualMemoryItem],
+ new_memory: list[TextualMemoryItem],
+ top_k: int = 10,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Process and rerank memory items by combining, filtering, and reranking.
+
+ This is a higher-level method that combines multiple post-processing steps:
+ 1. Merge original and new memories
+ 2. Apply similarity filtering
+ 3. Apply length filtering
+ 4. Remove duplicates
+ 5. Rerank by relevance
+
+ Args:
+ queries: List of query strings to rerank memories against
+ original_memory: List of original TextualMemoryItem objects
+ new_memory: List of new TextualMemoryItem objects to merge
+ top_k: Maximum number of memories to return after reranking
+
+ Returns:
+ Tuple of (reranked_memories, success_flag)
+ - reranked_memories: List of reranked TextualMemoryItem objects
+ - success_flag: True if reranking succeeded
+ """
+ # Combine original and new memories
+ combined_memory = original_memory + new_memory
+
+ # Create mapping from normalized text to memory objects
+ memory_map = {
+ transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
+ }
+
+ # Extract text representations
+ combined_text_memory = [m.memory for m in combined_memory]
+
+ # Apply similarity filter
+ filtered_combined_text_memory = filter_vector_based_similar_memories(
+ text_memories=combined_text_memory,
+ similarity_threshold=self.filter_similarity_threshold,
+ )
+
+ # Apply length filter
+ filtered_combined_text_memory = filter_too_short_memories(
+ text_memories=filtered_combined_text_memory,
+ min_length_threshold=self.filter_min_length_threshold,
+ )
+
+ # Remove duplicates (preserving order)
+ unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
+
+ # Rerank memories
+ text_memories_with_new_order, success_flag = self.rerank_memories(
+ queries=queries,
+ original_memories=unique_memory,
+ top_k=top_k,
+ )
+
+ # Map reranked texts back to memory objects
+ memories_with_new_order = []
+ for text in text_memories_with_new_order:
+ normalized_text = transform_name_to_key(name=text)
+ if normalized_text in memory_map:
+ memories_with_new_order.append(memory_map[normalized_text])
+ else:
+ logger.warning(
+ f"Memory text not found in memory map. text: {text};\n"
+ f"Keys of memory_map: {memory_map.keys()}"
+ )
+
+ return memories_with_new_order, success_flag
+
+ def filter_unrelated_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Filter out memories unrelated to the query history.
+
+ Delegates to MemoryFilter for the actual filtering logic.
+ """
+ return self.memory_filter.filter_unrelated_memories(query_history, memories)
+
+ def filter_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Filter out redundant memories from the list.
+
+ Delegates to MemoryFilter for the actual filtering logic.
+ """
+ return self.memory_filter.filter_redundant_memories(query_history, memories)
+
+ def filter_unrelated_and_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Filter out both unrelated and redundant memories using LLM analysis.
+
+ Delegates to MemoryFilter for the actual filtering logic.
+ """
+ return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
deleted file mode 100644
index f205766f0..000000000
--- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py
+++ /dev/null
@@ -1,547 +0,0 @@
-import time
-
-from concurrent.futures import as_completed
-
-from memos.configs.mem_scheduler import BaseSchedulerConfig
-from memos.context.context import ContextThreadPoolExecutor
-from memos.llms.base import BaseLLM
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
-from memos.mem_scheduler.schemas.general_schemas import (
- DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
- DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
- TreeTextMemory_FINE_SEARCH_METHOD,
- TreeTextMemory_SEARCH_METHOD,
-)
-from memos.mem_scheduler.utils.filter_utils import (
- filter_too_short_memories,
- filter_vector_based_similar_memories,
- transform_name_to_key,
-)
-from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer
-from memos.memories.textual.item import TextualMemoryMetadata
-from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
-from memos.types.general_types import (
- FINE_STRATEGY,
- FineStrategy,
- SearchMode,
-)
-
-# Extract JSON response
-from .memory_filter import MemoryFilter
-
-
-logger = get_logger(__name__)
-
-
-class SchedulerRetriever(BaseSchedulerModule):
- def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
- super().__init__()
-
- # hyper-parameters
- self.filter_similarity_threshold = 0.75
- self.filter_min_length_threshold = 6
- self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
- self.process_llm = process_llm
- self.config = config
-
- # Configure enhancement batching & retries from config with safe defaults
- self.batch_size: int | None = getattr(
- config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
- )
- self.retries: int = getattr(
- config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
- )
-
- def evaluate_memory_answer_ability(
- self, query: str, memory_texts: list[str], top_k: int | None = None
- ) -> bool:
- limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
- # Build prompt using the template
- prompt = self.build_prompt(
- template_name="memory_answer_ability_evaluation",
- query=query,
- memory_list="\n".join([f"- {memory}" for memory in limited_memories])
- if limited_memories
- else "No memories available",
- )
-
- # Use the process LLM to generate response
- response = self.process_llm.generate([{"role": "user", "content": prompt}])
-
- try:
- result = extract_json_obj(response)
-
- # Validate response structure
- if "result" in result:
- logger.info(
- f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}"
- )
- return result["result"]
- else:
- logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}")
- return False
-
- except Exception as e:
- logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...")
- # Fallback: return False if we can't determine answer ability
- return False
-
- # ---------------------- Enhancement helpers ----------------------
- def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str:
- if len(query_history) == 1:
- query_history = query_history[0]
- else:
- query_history = (
- [f"[{i}] {query}" for i, query in enumerate(query_history)]
- if len(query_history) > 1
- else query_history[0]
- )
- # Include numbering for rewrite mode to help LLM reference original memory IDs
- if FINE_STRATEGY == FineStrategy.REWRITE:
- text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)])
- prompt_name = "memory_rewrite_enhancement"
- else:
- text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)])
- prompt_name = "memory_recreate_enhancement"
- return self.build_prompt(
- prompt_name,
- query_history=query_history,
- memories=text_memories,
- )
-
- def _process_enhancement_batch(
- self,
- batch_index: int,
- query_history: list[str],
- memories: list[TextualMemoryItem],
- retries: int,
- ) -> tuple[list[TextualMemoryItem], bool]:
- attempt = 0
- text_memories = [one.memory for one in memories]
-
- prompt = self._build_enhancement_prompt(
- query_history=query_history, batch_texts=text_memories
- )
-
- llm_response = None
- while attempt <= max(0, retries) + 1:
- try:
- llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
- processed_text_memories = extract_list_items_in_answer(llm_response)
- if len(processed_text_memories) > 0:
- # create new
- enhanced_memories = []
- user_id = memories[0].metadata.user_id
- if FINE_STRATEGY == FineStrategy.RECREATE:
- for new_mem in processed_text_memories:
- enhanced_memories.append(
- TextualMemoryItem(
- memory=new_mem,
- metadata=TextualMemoryMetadata(
- user_id=user_id, memory_type="LongTermMemory"
- ), # TODO add memory_type
- )
- )
- elif FINE_STRATEGY == FineStrategy.REWRITE:
- # Parse index from each processed line and rewrite corresponding original memory
- def _parse_index_and_text(s: str) -> tuple[int | None, str]:
- import re
-
- s = (s or "").strip()
- # Preferred: [index] text
- m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s)
- if m:
- return int(m.group(1)), m.group(2).strip()
- # Fallback: index: text or index - text
- m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s)
- if m:
- return int(m.group(1)), m.group(2).strip()
- return None, s
-
- idx_to_original = dict(enumerate(memories))
- for j, item in enumerate(processed_text_memories):
- idx, new_text = _parse_index_and_text(item)
- if idx is not None and idx in idx_to_original:
- orig = idx_to_original[idx]
- else:
- # Fallback: align by order if index missing/invalid
- orig = memories[j] if j < len(memories) else None
- if not orig:
- continue
- enhanced_memories.append(
- TextualMemoryItem(
- id=orig.id,
- memory=new_text,
- metadata=orig.metadata,
- )
- )
- else:
- logger.error(f"Fine search strategy {FINE_STRATEGY} not exists")
-
- logger.info(
- f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}"
- )
- return enhanced_memories, True
- else:
- raise ValueError(
- f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}"
- )
- except Exception as e:
- attempt += 1
- time.sleep(1)
- logger.debug(
- f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}"
- )
- logger.error(
- f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}",
- exc_info=True,
- )
- return memories, False
-
- @staticmethod
- def _split_batches(
- memories: list[TextualMemoryItem], batch_size: int
- ) -> list[tuple[int, int, list[TextualMemoryItem]]]:
- batches: list[tuple[int, int, list[TextualMemoryItem]]] = []
- start = 0
- n = len(memories)
- while start < n:
- end = min(start + batch_size, n)
- batches.append((start, end, memories[start:end]))
- start = end
- return batches
-
- def recall_for_missing_memories(
- self,
- query: str,
- memories: list[str],
- ) -> tuple[str, bool]:
- text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)])
-
- prompt = self.build_prompt(
- template_name="enlarge_recall",
- query=query,
- memories_inline=text_memories,
- )
- llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
-
- json_result: dict = extract_json_obj(llm_response)
-
- logger.info(
- f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}"
- )
-
- hint = json_result.get("hint", "")
- if len(hint) == 0:
- return hint, False
- return hint, json_result.get("trigger_recall", False)
-
- def search(
- self,
- query: str,
- user_id: str,
- mem_cube_id: str,
- mem_cube: GeneralMemCube,
- top_k: int,
- method: str = TreeTextMemory_SEARCH_METHOD,
- search_args: dict | None = None,
- ) -> list[TextualMemoryItem]:
- """Search in text memory with the given query.
-
- Args:
- query: The search query string
- top_k: Number of top results to return
- method: Search method to use
-
- Returns:
- Search results or None if not implemented
- """
- text_mem_base = mem_cube.text_mem
- # Normalize default for mutable argument
- search_args = search_args or {}
- try:
- if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
- assert isinstance(text_mem_base, TreeTextMemory)
- session_id = search_args.get("session_id", "default_session")
- target_session_id = session_id
- search_priority = (
- {"session_id": target_session_id} if "session_id" in search_args else None
- )
- search_filter = search_args.get("filter")
- search_source = search_args.get("source")
- plugin = bool(search_source is not None and search_source == "plugin")
- user_name = search_args.get("user_name", mem_cube_id)
- internet_search = search_args.get("internet_search", False)
- chat_history = search_args.get("chat_history")
- search_tool_memory = search_args.get("search_tool_memory", False)
- tool_mem_top_k = search_args.get("tool_mem_top_k", 6)
- playground_search_goal_parser = search_args.get(
- "playground_search_goal_parser", False
- )
-
- info = search_args.get(
- "info",
- {
- "user_id": user_id,
- "session_id": target_session_id,
- "chat_history": chat_history,
- },
- )
-
- results_long_term = mem_cube.text_mem.search(
- query=query,
- user_name=user_name,
- top_k=top_k,
- mode=SearchMode.FAST,
- manual_close_internet=not internet_search,
- memory_type="LongTermMemory",
- search_filter=search_filter,
- search_priority=search_priority,
- info=info,
- plugin=plugin,
- search_tool_memory=search_tool_memory,
- tool_mem_top_k=tool_mem_top_k,
- playground_search_goal_parser=playground_search_goal_parser,
- )
-
- results_user = mem_cube.text_mem.search(
- query=query,
- user_name=user_name,
- top_k=top_k,
- mode=SearchMode.FAST,
- manual_close_internet=not internet_search,
- memory_type="UserMemory",
- search_filter=search_filter,
- search_priority=search_priority,
- info=info,
- plugin=plugin,
- search_tool_memory=search_tool_memory,
- tool_mem_top_k=tool_mem_top_k,
- playground_search_goal_parser=playground_search_goal_parser,
- )
- results = results_long_term + results_user
- else:
- raise NotImplementedError(str(type(text_mem_base)))
- except Exception as e:
- logger.error(f"Fail to search. The exeption is {e}.", exc_info=True)
- results = []
- return results
-
- def enhance_memories_with_query(
- self,
- query_history: list[str],
- memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- """
- Enhance memories by adding context and making connections to better answer queries.
-
- Args:
- query_history: List of user queries in chronological order
- memories: List of memory items to enhance
-
- Returns:
- Tuple of (enhanced_memories, success_flag)
- """
- if not memories:
- logger.warning("[Enhance] ⚠️ skipped (no memories to process)")
- return memories, True
-
- batch_size = self.batch_size
- retries = self.retries
- num_of_memories = len(memories)
- try:
- # no parallel
- if batch_size is None or num_of_memories <= batch_size:
- # Single batch path with retry
- enhanced_memories, success_flag = self._process_enhancement_batch(
- batch_index=0,
- query_history=query_history,
- memories=memories,
- retries=retries,
- )
-
- all_success = success_flag
- else:
- # parallel running batches
- # Split into batches preserving order
- batches = self._split_batches(memories=memories, batch_size=batch_size)
-
- # Process batches concurrently
- all_success = True
- failed_batches = 0
- with ContextThreadPoolExecutor(max_workers=len(batches)) as executor:
- future_map = {
- executor.submit(
- self._process_enhancement_batch, bi, query_history, texts, retries
- ): (bi, s, e)
- for bi, (s, e, texts) in enumerate(batches)
- }
- enhanced_memories = []
- for fut in as_completed(future_map):
- bi, s, e = future_map[fut]
-
- batch_memories, ok = fut.result()
- enhanced_memories.extend(batch_memories)
- if not ok:
- all_success = False
- failed_batches += 1
- logger.info(
- f"[Enhance] ✅ multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |"
- f" failed_batches={failed_batches} | success={all_success}"
- )
-
- except Exception as e:
- logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True)
- all_success = False
- enhanced_memories = memories
-
- if len(enhanced_memories) == 0:
- enhanced_memories = []
- logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True)
- return enhanced_memories, all_success
-
- def rerank_memories(
- self, queries: list[str], original_memories: list[str], top_k: int
- ) -> (list[str], bool):
- """
- Rerank memories based on relevance to given queries using LLM.
-
- Args:
- queries: List of query strings to determine relevance
- original_memories: List of memory strings to be reranked
- top_k: Number of top memories to return after reranking
-
- Returns:
- List of reranked memory strings (length <= top_k)
-
- Note:
- If LLM reranking fails, falls back to original order (truncated to top_k)
- """
-
- logger.info(f"Starting memory reranking for {len(original_memories)} memories")
-
- # Build LLM prompt for memory reranking
- prompt = self.build_prompt(
- "memory_reranking",
- queries=[f"[0] {queries[0]}"],
- current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)],
- )
- logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars
-
- # Get LLM response
- response = self.process_llm.generate([{"role": "user", "content": prompt}])
- logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars
-
- try:
- # Parse JSON response
- response = extract_json_obj(response)
- new_order = response["new_order"][:top_k]
- text_memories_with_new_order = [original_memories[idx] for idx in new_order]
- logger.info(
- f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;"
- f"Ranking reasoning: {response['reasoning']}"
- )
- success_flag = True
- except Exception as e:
- logger.error(
- f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ",
- exc_info=True,
- )
- text_memories_with_new_order = original_memories[:top_k]
- success_flag = False
- return text_memories_with_new_order, success_flag
-
- def process_and_rerank_memories(
- self,
- queries: list[str],
- original_memory: list[TextualMemoryItem],
- new_memory: list[TextualMemoryItem],
- top_k: int = 10,
- ) -> list[TextualMemoryItem] | None:
- """
- Process and rerank memory items by combining original and new memories,
- applying filters, and then reranking based on relevance to queries.
-
- Args:
- queries: List of query strings to rerank memories against
- original_memory: List of original TextualMemoryItem objects
- new_memory: List of new TextualMemoryItem objects to merge
- top_k: Maximum number of memories to return after reranking
-
- Returns:
- List of reranked TextualMemoryItem objects, or None if processing fails
- """
- # Combine original and new memories into a single list
- combined_memory = original_memory + new_memory
-
- # Create a mapping from normalized text to memory objects
- memory_map = {
- transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory
- }
-
- # Extract normalized text representations from all memory items
- combined_text_memory = [m.memory for m in combined_memory]
-
- # Apply similarity filter to remove overly similar memories
- filtered_combined_text_memory = filter_vector_based_similar_memories(
- text_memories=combined_text_memory,
- similarity_threshold=self.filter_similarity_threshold,
- )
-
- # Apply length filter to remove memories that are too short
- filtered_combined_text_memory = filter_too_short_memories(
- text_memories=filtered_combined_text_memory,
- min_length_threshold=self.filter_min_length_threshold,
- )
-
- # Ensure uniqueness of memory texts using dictionary keys (preserves order)
- unique_memory = list(dict.fromkeys(filtered_combined_text_memory))
-
- # Rerank the filtered memories based on relevance to the queries
- text_memories_with_new_order, success_flag = self.rerank_memories(
- queries=queries,
- original_memories=unique_memory,
- top_k=top_k,
- )
-
- # Map reranked text entries back to their original memory objects
- memories_with_new_order = []
- for text in text_memories_with_new_order:
- normalized_text = transform_name_to_key(name=text)
- if normalized_text in memory_map: # Ensure correct key matching
- memories_with_new_order.append(memory_map[normalized_text])
- else:
- logger.warning(
- f"Memory text not found in memory map. text: {text};\n"
- f"Keys of memory_map: {memory_map.keys()}"
- )
-
- return memories_with_new_order, success_flag
-
- def filter_unrelated_memories(
- self,
- query_history: list[str],
- memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- return self.memory_filter.filter_unrelated_memories(query_history, memories)
-
- def filter_redundant_memories(
- self,
- query_history: list[str],
- memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- return self.memory_filter.filter_redundant_memories(query_history, memories)
-
- def filter_unrelated_and_redundant_memories(
- self,
- query_history: list[str],
- memories: list[TextualMemoryItem],
- ) -> (list[TextualMemoryItem], bool):
- """
- Filter out both unrelated and redundant memories using LLM analysis.
-
- This method delegates to the MemoryFilter class.
- """
- return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_service.py b/src/memos/mem_scheduler/memory_manage_modules/search_service.py
new file mode 100644
index 000000000..43a768ac4
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/search_service.py
@@ -0,0 +1,265 @@
+"""
+Scheduler Search Service - Unified search interface for the scheduler.
+
+This module provides a clean abstraction over the Searcher class,
+adapting it for scheduler-specific use cases while maintaining compatibility.
+"""
+
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.memories.textual.item import TextualMemoryItem
+from memos.memories.textual.tree import TreeTextMemory
+from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
+from memos.types.general_types import SearchMode
+
+
+logger = get_logger(__name__)
+
+
+class SchedulerSearchService:
+ """
+ Unified search service for the scheduler.
+
+ This service provides a clean interface for memory search operations,
+ delegating to the Searcher class while handling scheduler-specific
+ parameter adaptations.
+
+ Design principles:
+ - Single Responsibility: Only handles search coordination
+ - Dependency Injection: Searcher is injected, not created
+ - Fail-safe: Falls back to direct text_mem.search() if Searcher unavailable
+
+ Usage:
+ service = SchedulerSearchService(searcher=searcher)
+ results = service.search(
+ query="user query",
+ user_id="user_123",
+ mem_cube=mem_cube,
+ top_k=10
+ )
+ """
+
+ def __init__(self, searcher: Searcher | None = None):
+ """
+ Initialize the search service.
+
+ Args:
+ searcher: Optional Searcher instance. If None, will fall back to
+ direct mem_cube.text_mem.search() calls.
+ """
+ self.searcher = searcher
+
+ def search(
+ self,
+ query: str,
+ user_id: str,
+ mem_cube: GeneralMemCube,
+ top_k: int,
+ mode: SearchMode = SearchMode.FAST,
+ search_filter: dict | None = None,
+ search_priority: dict | None = None,
+ session_id: str = "default_session",
+ internet_search: bool = False,
+ chat_history: list | None = None,
+ plugin: bool = False,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
+ playground_search_goal_parser: bool = False,
+ mem_cube_id: str | None = None,
+ ) -> list[TextualMemoryItem]:
+ """
+ Search for memories across both LongTermMemory and UserMemory.
+
+ This method provides a unified interface for memory search, automatically
+ handling the search across different memory types and merging results.
+
+ Args:
+ query: The search query string
+ user_id: User identifier
+ mem_cube: Memory cube instance containing text memory
+ top_k: Number of top results to return per memory type
+ mode: Search mode (FAST or FINE)
+ search_filter: Optional metadata filters for search results
+ search_priority: Optional metadata priority for search results
+ session_id: Session identifier for session-scoped search
+ internet_search: Whether to enable internet search
+ chat_history: Chat history for context
+ plugin: Whether this is a plugin-initiated search
+ search_tool_memory: Whether to search tool memory
+ tool_mem_top_k: Top-k for tool memory search
+ playground_search_goal_parser: Whether to use playground goal parser
+ mem_cube_id: Memory cube identifier (defaults to user_id if not provided)
+
+ Returns:
+ List of TextualMemoryItem objects sorted by relevance
+
+ Raises:
+ Exception: Propagates exceptions from underlying search implementations
+ """
+ mem_cube_id = mem_cube_id or user_id
+ user_name = mem_cube_id
+ text_mem_base = mem_cube.text_mem
+
+ # Build info dict for tracking
+ info = {
+ "user_id": user_id,
+ "session_id": session_id,
+ "chat_history": chat_history,
+ }
+
+ try:
+ if self.searcher:
+ # Use injected Searcher (preferred path)
+ results = self._search_with_searcher(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ internet_search=internet_search,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+ logger.info(
+ f"[SchedulerSearchService] Searched via Searcher: "
+ f"query='{query}' results={len(results)}"
+ )
+ else:
+ # Fallback: Direct text_mem.search() call
+ results = self._search_with_text_mem(
+ text_mem_base=text_mem_base,
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ internet_search=internet_search,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+ logger.info(
+ f"[SchedulerSearchService] Searched via text_mem (fallback): "
+ f"query='{query}' results={len(results)}"
+ )
+
+ return results
+
+ except Exception as e:
+ logger.error(
+ f"[SchedulerSearchService] Search failed for query='{query}': {e}",
+ exc_info=True,
+ )
+ return []
+
+ def _search_with_searcher(
+ self,
+ query: str,
+ user_name: str,
+ top_k: int,
+ mode: SearchMode,
+ search_filter: dict | None,
+ search_priority: dict | None,
+ info: dict,
+ internet_search: bool,
+ plugin: bool,
+ search_tool_memory: bool,
+ tool_mem_top_k: int,
+ playground_search_goal_parser: bool,
+ ) -> list[TextualMemoryItem]:
+ """
+ Search using the injected Searcher instance.
+
+ IMPORTANT: This method searches "All" memory types in a single call to avoid
+ the bug where calling search() twice (for LongTermMemory and UserMemory separately)
+ would return 2*top_k results due to Searcher.search() applying deduplication and
+ top_k limiting on each call.
+
+ This ensures the final result is properly deduplicated and limited to top_k items.
+ """
+ # Preserve original internet search setting
+ original_manual_close = getattr(self.searcher, "manual_close_internet", None)
+
+ try:
+ # Configure internet search
+ if original_manual_close is not None:
+ self.searcher.manual_close_internet = not internet_search
+
+ # Search all memory types in a single call
+ # This ensures proper deduplication and top_k limiting
+ results = self.searcher.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ memory_type="All", # Search all types together to avoid 2*top_k bug
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ return results
+
+ finally:
+ # Restore original setting
+ if original_manual_close is not None:
+ self.searcher.manual_close_internet = original_manual_close
+
+ def _search_with_text_mem(
+ self,
+ text_mem_base: TreeTextMemory,
+ query: str,
+ user_name: str,
+ top_k: int,
+ mode: SearchMode,
+ search_filter: dict | None,
+ search_priority: dict | None,
+ info: dict,
+ internet_search: bool,
+ plugin: bool,
+ search_tool_memory: bool,
+ tool_mem_top_k: int,
+ playground_search_goal_parser: bool,
+ ) -> list[TextualMemoryItem]:
+ """
+ Fallback: Search using direct text_mem.search() calls.
+
+ This is used when no Searcher instance is available, providing
+ backward compatibility with the original implementation.
+
+ NOTE: TreeTextMemory.search() with memory_type="All" will internally
+ search both LongTermMemory and UserMemory and properly merge results.
+ """
+ assert isinstance(text_mem_base, TreeTextMemory), (
+ f"Fallback search requires TreeTextMemory, got {type(text_mem_base)}"
+ )
+
+ # Search all memory types together (same as Searcher behavior)
+ results = text_mem_base.search(
+ query=query,
+ user_name=user_name,
+ top_k=top_k,
+ mode=mode,
+ manual_close_internet=not internet_search,
+ memory_type="All", # Search all types to match Searcher behavior
+ search_filter=search_filter,
+ search_priority=search_priority,
+ info=info,
+ plugin=plugin,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
+ playground_search_goal_parser=playground_search_goal_parser,
+ )
+
+ return results
diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py
index d75d6ee75..2db8eb785 100644
--- a/src/memos/mem_scheduler/monitors/general_monitor.py
+++ b/src/memos/mem_scheduler/monitors/general_monitor.py
@@ -27,9 +27,9 @@
QueryMonitorQueue,
)
from memos.mem_scheduler.utils.db_utils import get_utc_now
-from memos.mem_scheduler.utils.misc_utils import extract_json_obj
from memos.memories.textual.tree import TreeTextMemory
from memos.types import MemCubeID, UserID
+from memos.utils import extract_json_obj
logger = get_logger(__name__)
diff --git a/src/memos/mem_scheduler/monitors/task_schedule_monitor.py b/src/memos/mem_scheduler/monitors/task_schedule_monitor.py
index 14bed8316..b022ce080 100644
--- a/src/memos/mem_scheduler/monitors/task_schedule_monitor.py
+++ b/src/memos/mem_scheduler/monitors/task_schedule_monitor.py
@@ -1,8 +1,13 @@
from __future__ import annotations
+import time
+
+from memos.context.context import ContextThread
from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import DEFAULT_MONITOR_INTERVAL_SECONDS
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
+from memos.mem_scheduler.utils import metrics
logger = get_logger(__name__)
@@ -22,10 +27,68 @@ def __init__(
memos_message_queue: SchedulerRedisQueue | SchedulerLocalQueue,
dispatcher: object | None = None,
get_status_parallel: bool = False,
+ check_interval: int = DEFAULT_MONITOR_INTERVAL_SECONDS,
) -> None:
self.queue = memos_message_queue
self.dispatcher = dispatcher
self.get_status_parallel = get_status_parallel
+ self.check_interval = check_interval
+
+ # Monitor thread
+ self._monitor_thread: ContextThread | None = None
+ self._running = False
+
+ def start(self) -> None:
+ """Start the background monitor thread."""
+ if self._monitor_thread and self._monitor_thread.is_alive():
+ logger.info("TaskScheduleMonitor is already running")
+ return
+
+ self._running = True
+ self._monitor_thread = ContextThread(
+ target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor"
+ )
+ self._monitor_thread.start()
+ logger.info("TaskScheduleMonitor started")
+
+ def stop(self) -> None:
+ """Stop the background monitor thread."""
+ if not self._running:
+ return
+
+ self._running = False
+ if self._monitor_thread:
+ self._monitor_thread.join(timeout=2.0)
+ self._monitor_thread = None
+ logger.info("TaskScheduleMonitor stopped")
+
+ def _monitor_loop(self) -> None:
+ """Periodically check queue sizes and update metrics."""
+ while self._running:
+ try:
+ q_sizes = self.queue.qsize()
+
+ if isinstance(q_sizes, dict):
+ for stream_key, queue_length in q_sizes.items():
+ # Skip aggregate keys like 'total_size'
+ if stream_key == "total_size":
+ continue
+
+ # Key format: ...:{user_id}:{mem_cube_id}:{task_label}
+ # We want to extract user_id, which is the 3rd component from the end.
+ parts = stream_key.split(":")
+ if len(parts) >= 3:
+ user_id = parts[-3]
+ metrics.update_queue_length(queue_length, user_id)
+ else:
+ # Fallback
+ if ":" not in stream_key:
+ metrics.update_queue_length(queue_length, stream_key)
+
+ except Exception as e:
+ logger.error(f"Error in metrics monitor loop: {e}", exc_info=True)
+
+ time.sleep(self.check_interval)
@staticmethod
def init_task_status() -> dict:
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index 497d19ac6..02dd5a5ae 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -104,27 +104,30 @@ def search_memories(
mem_cube: NaiveMemCube,
mode: SearchMode,
):
- """Fine search memories function copied from server_router to avoid circular import"""
- target_session_id = search_req.session_id
- if not target_session_id:
- target_session_id = "default_session"
+ """
+ Search memories using unified Searcher interface.
+
+ This method now uses self.searcher (via SchedulerSearchService) instead of
+ directly calling mem_cube.text_mem.search(), ensuring consistency with the
+ API search path.
+ """
+ target_session_id = search_req.session_id or "default_session"
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
search_filter = search_req.filter
- # Create MemCube and perform search
- search_results = mem_cube.text_mem.search(
+ # Use unified search service
+ search_results = self.search_service.search(
query=search_req.query,
- user_name=user_context.mem_cube_id,
+ user_id=search_req.user_id,
+ mem_cube=mem_cube,
top_k=search_req.top_k,
mode=mode,
- manual_close_internet=not search_req.internet_search,
search_filter=search_filter,
search_priority=search_priority,
- info={
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- },
+ session_id=target_session_id,
+ internet_search=search_req.internet_search,
+ chat_history=search_req.chat_history,
+ mem_cube_id=user_context.mem_cube_id,
)
return search_results
@@ -322,7 +325,7 @@ def replace_working_memory(
query_history = query_db_manager.obj.get_queries_with_timesort()
memories_with_new_order, rerank_success_flag = (
- self.retriever.process_and_rerank_memories(
+ self.post_processor.process_and_rerank_memories(
queries=query_history,
original_memory=original_memory,
new_memory=new_memory,
@@ -335,7 +338,7 @@ def replace_working_memory(
f"[optimized replace_working_memory] Applying combined unrelated and redundant memory filtering to {len(memories_with_new_order)} memories"
)
filtered_memories, filtering_success_flag = (
- self.retriever.filter_unrelated_and_redundant_memories(
+ self.post_processor.filter_unrelated_and_redundant_memories(
query_history=query_history,
memories=memories_with_new_order,
)
diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py
index 06910ba17..9ee568c7d 100644
--- a/src/memos/mem_scheduler/schemas/general_schemas.py
+++ b/src/memos/mem_scheduler/schemas/general_schemas.py
@@ -17,6 +17,7 @@
DEFAULT_THREAD_POOL_MAX_WORKERS = 50
DEFAULT_CONSUME_INTERVAL_SECONDS = 0.01
DEFAULT_CONSUME_BATCH = 3
+DEFAULT_MONITOR_INTERVAL_SECONDS = 15
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2
DEFAULT_STUCK_THREAD_TOLERANCE = 10
diff --git a/src/memos/mem_scheduler/task_schedule_modules/base_handler.py b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py
new file mode 100644
index 000000000..3a2200d15
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/base_handler.py
@@ -0,0 +1,60 @@
+from abc import ABC, abstractmethod
+from collections.abc import Callable
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
+
+
+logger = get_logger(__name__)
+
+
+class BaseHandler(ABC):
+ def __init__(self, context: SchedulerContext):
+ self.context = context
+ self.expected_task_label = None
+
+ def validate_and_log_messages(self, messages: list[ScheduleMessageItem], label: str) -> None:
+ """
+ Log the assignment of messages to the handler and validate them if a validator is present in the context.
+ """
+ logger.info(f"Messages {messages} assigned to {label} handler.")
+ if self.context.validate_schedule_messages:
+ self.context.validate_schedule_messages(messages, label)
+
+ def handle_exception(self, e: Exception, message: str = "Error processing messages") -> None:
+ """
+ Log an exception with a custom message and stack trace.
+ """
+ logger.error(f"{message}: {e}", exc_info=True)
+
+ def process_grouped_messages(
+ self,
+ messages: list[ScheduleMessageItem],
+ message_handler: Callable[[str, str, list[ScheduleMessageItem]], None],
+ ) -> None:
+ """
+ Group messages and process them in batches.
+ """
+ grouped_messages = group_messages_by_user_and_mem_cube(messages=messages)
+ for user_id, user_batches in grouped_messages.items():
+ for mem_cube_id, batch in user_batches.items():
+ if not batch:
+ continue
+ message_handler(user_id, mem_cube_id, batch)
+
+ @abstractmethod
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ pass
+
+ def __call__(self, messages: list[ScheduleMessageItem]) -> None:
+ """
+ Process the messages.
+ """
+ self.validate_and_log_messages(messages=messages, label=self.expected_task_label)
+
+ self.process_grouped_messages(
+ messages=messages,
+ message_handler=self.batch_handler,
+ )
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
new file mode 100644
index 000000000..64567f2cd
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
@@ -0,0 +1,43 @@
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import ADD_TASK_LABEL
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+from memos.mem_scheduler.utils.misc_utils import is_cloud_env
+
+
+logger = get_logger(__name__)
+
+
+class AddHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = ADD_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ # Process each message in the batch
+ for msg in batch:
+ if self.context.log_add_messages:
+ prepared_add_items, prepared_update_items_with_original = (
+ self.context.log_add_messages(msg=msg)
+ )
+ else:
+ logger.error("log_add_messages not available in context")
+ continue
+
+ logger.info(
+ f"prepared_add_items: {prepared_add_items};\n prepared_update_items_with_original: {prepared_update_items_with_original}"
+ )
+ # Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default
+ cloud_env = is_cloud_env()
+
+ if cloud_env:
+ if self.context.send_add_log_messages_to_cloud_env:
+ self.context.send_add_log_messages_to_cloud_env(
+ msg, prepared_add_items, prepared_update_items_with_original
+ )
+ else:
+ if self.context.send_add_log_messages_to_local_env:
+ self.context.send_add_log_messages_to_local_env(
+ msg, prepared_add_items, prepared_update_items_with_original
+ )
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py
new file mode 100644
index 000000000..db58cce9a
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/answer_handler.py
@@ -0,0 +1,47 @@
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ ANSWER_TASK_LABEL,
+ NOT_APPLICABLE_TYPE,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+
+
+logger = get_logger(__name__)
+
+
+class AnswerHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = ANSWER_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ try:
+ for msg in batch:
+ if self.context.create_event_log and self.context.submit_web_logs:
+ event = self.context.create_event_log(
+ label="addMessage",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=NOT_APPLICABLE_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.context.mem_cube,
+ memcube_log_content=[
+ {
+ "content": f"[Assistant] {msg.content}",
+ "ref_id": msg.item_id,
+ "role": "assistant",
+ }
+ ],
+ metadata=[],
+ memory_len=1,
+ memcube_name=self.context.map_memcube_name(msg.mem_cube_id)
+ if self.context.map_memcube_name
+ else None,
+ )
+ event.task_id = msg.task_id
+ self.context.submit_web_logs([event])
+ except Exception as e:
+ self.handle_exception(e, "Failed to record addMessage log for answer")
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_feedback_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_feedback_handler.py
new file mode 100644
index 000000000..2531ebb15
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_feedback_handler.py
@@ -0,0 +1,197 @@
+import json
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ LONG_TERM_MEMORY_TYPE,
+ MEM_FEEDBACK_TASK_LABEL,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+from memos.mem_scheduler.utils.misc_utils import is_cloud_env
+
+
+logger = get_logger(__name__)
+
+
+class MemFeedbackHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = MEM_FEEDBACK_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ try:
+ for message in batch:
+ mem_cube = self.context.mem_cube
+
+ user_id = message.user_id
+ mem_cube_id = message.mem_cube_id
+ content = message.content
+
+ try:
+ feedback_data = json.loads(content) if isinstance(content, str) else content
+ if not isinstance(feedback_data, dict):
+ logger.error(
+ f"Failed to decode feedback_data or it is not a dict: {feedback_data}"
+ )
+ continue
+ except json.JSONDecodeError:
+ logger.error(
+ f"Invalid JSON content for feedback message: {content}", exc_info=True
+ )
+ continue
+
+ task_id = feedback_data.get("task_id") or message.task_id
+
+ if self.context.feedback_server:
+ feedback_result = self.context.feedback_server.process_feedback(
+ user_id=user_id,
+ user_name=mem_cube_id,
+ session_id=feedback_data.get("session_id"),
+ chat_history=feedback_data.get("history", []),
+ retrieved_memory_ids=feedback_data.get("retrieved_memory_ids", []),
+ feedback_content=feedback_data.get("feedback_content"),
+ feedback_time=feedback_data.get("feedback_time"),
+ task_id=task_id,
+ info=feedback_data.get("info", None),
+ )
+ else:
+ logger.error("feedback_server not available in context")
+ continue
+
+ logger.info(
+ f"Successfully processed feedback for user_id={user_id}, mem_cube_id={mem_cube_id}"
+ )
+
+ cloud_env = is_cloud_env()
+ if cloud_env:
+ record = (
+ feedback_result.get("record") if isinstance(feedback_result, dict) else {}
+ )
+ add_records = record.get("add") if isinstance(record, dict) else []
+ update_records = record.get("update") if isinstance(record, dict) else []
+
+ def _extract_fields(mem_item):
+ mem_id = (
+ getattr(mem_item, "id", None)
+ if not isinstance(mem_item, dict)
+ else mem_item.get("id")
+ )
+ mem_memory = (
+ getattr(mem_item, "memory", None)
+ if not isinstance(mem_item, dict)
+ else mem_item.get("memory") or mem_item.get("text")
+ )
+ if mem_memory is None and isinstance(mem_item, dict):
+ mem_memory = mem_item.get("text")
+ original_content = (
+ getattr(mem_item, "origin_memory", None)
+ if not isinstance(mem_item, dict)
+ else mem_item.get("origin_memory")
+ or mem_item.get("old_memory")
+ or mem_item.get("original_content")
+ )
+ source_doc_id = None
+ if isinstance(mem_item, dict):
+ source_doc_id = mem_item.get("source_doc_id", None)
+
+ return mem_id, mem_memory, original_content, source_doc_id
+
+ kb_log_content: list[dict] = []
+
+ for mem_item in add_records or []:
+ mem_id, mem_memory, _, source_doc_id = _extract_fields(mem_item)
+ if mem_id and mem_memory:
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": "Feedback",
+ "operation": "ADD",
+ "memory_id": mem_id,
+ "content": mem_memory,
+ "original_content": None,
+ "source_doc_id": source_doc_id,
+ }
+ )
+ else:
+ logger.warning(
+ "Skipping malformed feedback add item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ mem_item,
+ stack_info=True,
+ )
+
+ for mem_item in update_records or []:
+ mem_id, mem_memory, original_content, source_doc_id = _extract_fields(
+ mem_item
+ )
+ if mem_id and mem_memory:
+ kb_log_content.append(
+ {
+ "log_source": "KNOWLEDGE_BASE_LOG",
+ "trigger_source": "Feedback",
+ "operation": "UPDATE",
+ "memory_id": mem_id,
+ "content": mem_memory,
+ "original_content": original_content,
+ "source_doc_id": source_doc_id,
+ }
+ )
+ else:
+ logger.warning(
+ "Skipping malformed feedback update item. user_id=%s mem_cube_id=%s task_id=%s item=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ mem_item,
+ stack_info=True,
+ )
+
+ logger.info(f"[Feedback Scheduler] kb_log_content: {kb_log_content!s}")
+ if kb_log_content:
+ if self.context.create_event_log and self.context.submit_web_logs:
+ logger.info(
+ "[DIAGNOSTIC] general_scheduler._mem_feedback_message_consumer: Creating knowledgeBaseUpdate event for feedback. user_id=%s mem_cube_id=%s task_id=%s items=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ len(kb_log_content),
+ )
+ event = self.context.create_event_log(
+ label="knowledgeBaseUpdate",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ memcube_log_content=kb_log_content,
+ metadata=None,
+ memory_len=len(kb_log_content),
+ memcube_name=self.context.map_memcube_name(mem_cube_id)
+ if self.context.map_memcube_name
+ else None,
+ )
+ event.log_content = (
+ f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
+ )
+ event.task_id = task_id
+ self.context.submit_web_logs([event])
+ else:
+ logger.warning(
+ "No valid feedback content generated for web log. user_id=%s mem_cube_id=%s task_id=%s",
+ user_id,
+ mem_cube_id,
+ task_id,
+ stack_info=True,
+ )
+ else:
+ logger.info(
+ "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)",
+ cloud_env,
+ )
+
+ except Exception as e:
+ self.handle_exception(e, "Error processing feedbackMemory message")
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
new file mode 100644
index 000000000..0c523bfc4
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
@@ -0,0 +1,71 @@
+import json
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import MEM_READ_TASK_LABEL
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+from memos.memories.textual.tree import TreeTextMemory
+
+
+logger = get_logger(__name__)
+
+
+class MemReadHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = MEM_READ_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ for message in batch:
+ try:
+ user_id = message.user_id
+ mem_cube_id = message.mem_cube_id
+ mem_cube = self.context.mem_cube
+ if mem_cube is None:
+ logger.error(
+ f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing",
+ stack_info=True,
+ )
+ continue
+
+ content = message.content
+ user_name = message.user_name
+ info = message.info or {}
+
+ # Parse the memory IDs from content
+ mem_ids = json.loads(content) if isinstance(content, str) else content
+ if not mem_ids:
+ continue
+
+ logger.info(
+ f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}"
+ )
+
+ # Get the text memory from the mem_cube
+ text_mem = mem_cube.text_mem
+ if not isinstance(text_mem, TreeTextMemory):
+ logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}")
+ continue
+
+ # Use mem_reader to process the memories
+ if self.context.process_memories_with_reader:
+ self.context.process_memories_with_reader(
+ mem_ids=mem_ids,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ text_mem=text_mem,
+ user_name=user_name,
+ custom_tags=info.get("custom_tags", None),
+ task_id=message.task_id,
+ info=info,
+ )
+ else:
+ logger.error("process_memories_with_reader not available in context")
+
+ logger.info(
+ f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing mem_read message: {e}", stack_info=True)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py
new file mode 100644
index 000000000..5816fa3c5
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_reorganize_handler.py
@@ -0,0 +1,203 @@
+import contextlib
+import hashlib
+import json
+
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ LONG_TERM_MEMORY_TYPE,
+ MEM_ORGANIZE_TASK_LABEL,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
+from memos.memories.textual.tree import TreeTextMemory
+
+
+if TYPE_CHECKING:
+ from memos.memories.textual.item import TextualMemoryItem
+
+
+logger = get_logger(__name__)
+
+
+class MemReorganizeHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = MEM_ORGANIZE_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ for message in batch:
+ try:
+ user_id = message.user_id
+ mem_cube_id = message.mem_cube_id
+ mem_cube = self.context.mem_cube
+ if mem_cube is None:
+ logger.warning(
+ f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing"
+ )
+ continue
+ content = message.content
+ user_name = message.user_name
+
+ # Parse the memory IDs from content
+ mem_ids = json.loads(content) if isinstance(content, str) else content
+ if not mem_ids:
+ continue
+
+ logger.info(
+ f"Processing mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}"
+ )
+
+ # Get the text memory from the mem_cube
+ text_mem = mem_cube.text_mem
+ if not isinstance(text_mem, TreeTextMemory):
+ logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}")
+ continue
+
+ # Use mem_reader to process the memories
+ if self.context.process_memories_with_reorganize:
+ self.context.process_memories_with_reorganize(
+ mem_ids=mem_ids,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ text_mem=text_mem,
+ user_name=user_name,
+ )
+ else:
+ logger.error("process_memories_with_reorganize not available in context")
+
+ with contextlib.suppress(Exception):
+ mem_items: list[TextualMemoryItem] = []
+ for mid in mem_ids:
+ with contextlib.suppress(Exception):
+ mem_items.append(text_mem.get(mid, user_name=user_name))
+ if len(mem_items) > 1:
+ keys: list[str] = []
+ memcube_content: list[dict] = []
+ meta: list[dict] = []
+ merged_target_ids: set[str] = set()
+ with contextlib.suppress(Exception):
+ if hasattr(text_mem, "graph_store"):
+ for mid in mem_ids:
+ edges = text_mem.graph_store.get_edges(
+ mid, type="MERGED_TO", direction="OUT"
+ )
+ for edge in edges:
+ target = (
+ edge.get("to") or edge.get("dst") or edge.get("target")
+ )
+ if target:
+ merged_target_ids.add(target)
+ for item in mem_items:
+ key = getattr(
+ getattr(item, "metadata", {}), "key", None
+ ) or transform_name_to_key(getattr(item, "memory", ""))
+ keys.append(key)
+ memcube_content.append(
+ {"content": key or "(no key)", "ref_id": item.id, "type": "merged"}
+ )
+ meta.append(
+ {
+ "ref_id": item.id,
+ "id": item.id,
+ "key": key,
+ "memory": item.memory,
+ "memory_type": item.metadata.memory_type,
+ "status": item.metadata.status,
+ "confidence": item.metadata.confidence,
+ "tags": item.metadata.tags,
+ "updated_at": getattr(item.metadata, "updated_at", None)
+ or getattr(item.metadata, "update_at", None),
+ }
+ )
+ combined_key = keys[0] if keys else ""
+ post_ref_id = None
+ post_meta = {
+ "ref_id": None,
+ "id": None,
+ "key": None,
+ "memory": None,
+ "memory_type": None,
+ "status": None,
+ "confidence": None,
+ "tags": None,
+ "updated_at": None,
+ }
+ if merged_target_ids:
+ post_ref_id = next(iter(merged_target_ids))
+ with contextlib.suppress(Exception):
+ merged_item = text_mem.get(post_ref_id, user_name=user_name)
+ combined_key = (
+ getattr(getattr(merged_item, "metadata", {}), "key", None)
+ or combined_key
+ )
+ post_meta = {
+ "ref_id": post_ref_id,
+ "id": post_ref_id,
+ "key": getattr(
+ getattr(merged_item, "metadata", {}), "key", None
+ ),
+ "memory": getattr(merged_item, "memory", None),
+ "memory_type": getattr(
+ getattr(merged_item, "metadata", {}), "memory_type", None
+ ),
+ "status": getattr(
+ getattr(merged_item, "metadata", {}), "status", None
+ ),
+ "confidence": getattr(
+ getattr(merged_item, "metadata", {}), "confidence", None
+ ),
+ "tags": getattr(
+ getattr(merged_item, "metadata", {}), "tags", None
+ ),
+ "updated_at": getattr(
+ getattr(merged_item, "metadata", {}), "updated_at", None
+ )
+ or getattr(
+ getattr(merged_item, "metadata", {}), "update_at", None
+ ),
+ }
+ if not post_ref_id:
+ post_ref_id = f"merge-{hashlib.md5(''.join(sorted(mem_ids)).encode()).hexdigest()}"
+ post_meta["ref_id"] = post_ref_id
+ post_meta["id"] = post_ref_id
+ if not post_meta.get("key"):
+ post_meta["key"] = combined_key
+ if not keys:
+ keys = [item.id for item in mem_items]
+ memcube_content.append(
+ {
+ "content": combined_key if combined_key else "(no key)",
+ "ref_id": post_ref_id,
+ "type": "postMerge",
+ }
+ )
+ meta.append(post_meta)
+
+ if self.context.create_event_log and self.context.submit_web_logs:
+ event = self.context.create_event_log(
+ label="mergeMemory",
+ from_memory_type=LONG_TERM_MEMORY_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ memcube_log_content=memcube_content,
+ metadata=meta,
+ memory_len=len(keys),
+ memcube_name=self.context.map_memcube_name(mem_cube_id)
+ if self.context.map_memcube_name
+ else None,
+ )
+ self.context.submit_web_logs([event])
+
+ logger.info(
+ f"Successfully processed mem_reorganize for user_id={user_id}, mem_cube_id={mem_cube_id}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing mem_reorganize message: {e}", exc_info=True)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py
new file mode 100644
index 000000000..354bc70b0
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/memory_update_handler.py
@@ -0,0 +1,23 @@
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import MEM_UPDATE_TASK_LABEL
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+
+
+logger = get_logger(__name__)
+
+
+class MemoryUpdateHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = MEM_UPDATE_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ # Process the whole batch once; no need to iterate per message
+ if self.context.long_memory_update_process:
+ self.context.long_memory_update_process(
+ user_id=user_id, mem_cube_id=mem_cube_id, messages=batch
+ )
+ else:
+ logger.error("long_memory_update_process is not available in context")
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py
new file mode 100644
index 000000000..a9adb9b9b
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/pref_add_handler.py
@@ -0,0 +1,72 @@
+import json
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import PREF_ADD_TASK_LABEL
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+from memos.memories.textual.preference import PreferenceTextMemory
+
+
+logger = get_logger(__name__)
+
+
+class PrefAddHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = PREF_ADD_TASK_LABEL
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ for message in batch:
+ try:
+ mem_cube = self.context.mem_cube
+ if mem_cube is None:
+ logger.warning(
+ f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing"
+ )
+ continue
+
+ user_id = message.user_id
+ session_id = message.session_id
+ mem_cube_id = message.mem_cube_id
+ content = message.content
+ messages_list = json.loads(content)
+ info = message.info or {}
+
+ logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}")
+
+ # Get the preference memory from the mem_cube
+ pref_mem = mem_cube.pref_mem
+ if pref_mem is None:
+ logger.warning(
+ f"Preference memory not initialized for mem_cube_id={mem_cube_id}, "
+ f"skipping pref_add processing"
+ )
+ continue
+ if not isinstance(pref_mem, PreferenceTextMemory):
+ logger.error(
+ f"Expected PreferenceTextMemory but got {type(pref_mem).__name__} "
+ f"for mem_cube_id={mem_cube_id}"
+ )
+ continue
+
+ # Use pref_mem.get_memory to process the memories
+ pref_memories = pref_mem.get_memory(
+ messages_list,
+ type="chat",
+ info={
+ **info,
+ "user_id": user_id,
+ "session_id": session_id,
+ "mem_cube_id": mem_cube_id,
+ },
+ )
+ # Add pref_mem to vector db
+ pref_ids = pref_mem.add(pref_memories)
+
+ logger.info(
+ f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing pref_add message: {e}", exc_info=True)
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py
new file mode 100644
index 000000000..92496d32d
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/query_handler.py
@@ -0,0 +1,82 @@
+import threading
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.scheduler_context import SchedulerContext
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ MEM_UPDATE_TASK_LABEL,
+ NOT_APPLICABLE_TYPE,
+ QUERY_TASK_LABEL,
+ USER_INPUT_TYPE,
+)
+from memos.mem_scheduler.task_schedule_modules.base_handler import BaseHandler
+
+
+logger = get_logger(__name__)
+
+
+class QueryHandler(BaseHandler):
+ def __init__(self, context: SchedulerContext):
+ super().__init__(context)
+ self.expected_task_label = QUERY_TASK_LABEL
+ self._local = threading.local()
+
+ def __call__(self, messages: list[ScheduleMessageItem]) -> None:
+ """
+ Process and handle query trigger messages from the queue.
+
+ Args:
+ messages: List of query messages to process
+ """
+ self._local.mem_update_messages = []
+ try:
+ super().__call__(messages)
+
+ if self.context.submit_messages and self._local.mem_update_messages:
+ self.context.submit_messages(self._local.mem_update_messages)
+ finally:
+ self._local.mem_update_messages = []
+
+ def batch_handler(self, user_id: str, mem_cube_id: str, batch: list[ScheduleMessageItem]):
+ if not hasattr(self._local, "mem_update_messages"):
+ self._local.mem_update_messages = []
+
+ for msg in batch:
+ try:
+ if self.context.create_event_log and self.context.submit_web_logs:
+ event = self.context.create_event_log(
+ label="addMessage",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=NOT_APPLICABLE_TYPE,
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ mem_cube=self.context.mem_cube,
+ memcube_log_content=[
+ {
+ "content": f"[User] {msg.content}",
+ "ref_id": msg.item_id,
+ "role": "user",
+ }
+ ],
+ metadata=[],
+ memory_len=1,
+ memcube_name=self.context.map_memcube_name(msg.mem_cube_id)
+ if self.context.map_memcube_name
+ else None,
+ )
+ event.task_id = msg.task_id
+ self.context.submit_web_logs([event])
+ except Exception as e:
+ self.handle_exception(e, "Failed to record addMessage log for query")
+ # Re-submit the message with label changed to mem_update
+ update_msg = ScheduleMessageItem(
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ label=MEM_UPDATE_TASK_LABEL,
+ content=msg.content,
+ session_id=msg.session_id,
+ user_name=msg.user_name,
+ info=msg.info,
+ task_id=msg.task_id,
+ )
+ self._local.mem_update_messages.append(update_msg)
diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py
index 3ce727b5c..019ecbe70 100644
--- a/src/memos/mem_scheduler/utils/misc_utils.py
+++ b/src/memos/mem_scheduler/utils/misc_utils.py
@@ -1,7 +1,4 @@
-import json
import os
-import re
-import traceback
from collections import defaultdict
from functools import wraps
@@ -52,170 +49,6 @@ def is_cloud_env() -> bool:
)
-def extract_json_obj(text: str):
- """
- Safely extracts JSON from LLM response text with robust error handling.
-
- Args:
- text: Raw text response from LLM that may contain JSON
-
- Returns:
- Parsed JSON data (dict or list)
-
- Raises:
- ValueError: If no valid JSON can be extracted
- """
- if not text:
- raise ValueError("Empty input text")
-
- # Normalize the text
- text = text.strip()
-
- # Remove common code block markers
- patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
- for pattern in patterns_to_remove:
- text = text.replace(pattern, "")
-
- # Try: direct JSON parse first
- try:
- return json.loads(text.strip())
- except json.JSONDecodeError as e:
- logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
-
- # Fallback 1: Extract JSON using regex
- json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]"
- matches = re.findall(json_pattern, text)
- if matches:
- try:
- return json.loads(matches[0])
- except json.JSONDecodeError as e:
- logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
-
- # Fallback 2: Handle malformed JSON (common LLM issues)
- try:
- # Try adding missing quotes around keys
- text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text)
- return json.loads(text)
- except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}")
- logger.error("Full traceback:\n" + traceback.format_exc())
- raise ValueError(text) from e
-
-
-def extract_list_items(text: str, bullet_prefixes: tuple[str, ...] = ("- ",)) -> list[str]:
- """
- Extract bullet list items from LLM output where each item is on a single line
- starting with a given bullet prefix (default: "- ").
-
- This function is designed to be robust to common LLM formatting variations,
- following similar normalization practices as `extract_json_obj`.
-
- Behavior:
- - Strips common code-fence markers (```json, ```python, ``` etc.).
- - Collects all lines that start with any of the provided `bullet_prefixes`.
- - Tolerates the "• " bullet as a loose fallback.
- - Unescapes common sequences like "\\n" and "\\t" within items.
- - If no bullet lines are found, falls back to attempting to parse a JSON array
- (using `extract_json_obj`) and returns its string elements.
-
- Args:
- text: Raw text response from LLM.
- bullet_prefixes: Tuple of accepted bullet line prefixes.
-
- Returns:
- List of extracted items (strings). Returns an empty list if none can be parsed.
- """
- if not text:
- return []
-
- # Normalize the text similar to extract_json_obj
- normalized = text.strip()
- patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
- for pattern in patterns_to_remove:
- normalized = normalized.replace(pattern, "")
- normalized = normalized.replace("\r\n", "\n")
-
- lines = normalized.splitlines()
- items: list[str] = []
- seen: set[str] = set()
-
- for raw in lines:
- line = raw.strip()
- if not line:
- continue
-
- matched = False
- for prefix in bullet_prefixes:
- if line.startswith(prefix):
- content = line[len(prefix) :].strip()
- content = content.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
- if content and content not in seen:
- items.append(content)
- seen.add(content)
- matched = True
- break
-
- if matched:
- continue
-
- if items:
- return items
- else:
- logger.error(f"Fail to parse {text}")
-
- return []
-
-
-def extract_list_items_in_answer(
- text: str, bullet_prefixes: tuple[str, ...] = ("- ",)
-) -> list[str]:
- """
- Extract list items specifically from content enclosed within `...` tags.
-
- - When one or more `...` blocks are present, concatenates their inner
- contents with newlines and parses using `extract_list_items`.
- - When no `` block is found, falls back to parsing the entire input with
- `extract_list_items`.
- - Case-insensitive matching of the `` tag.
-
- Args:
- text: Raw text that may contain `...` blocks.
- bullet_prefixes: Accepted bullet prefixes (default: strictly `"- "`).
-
- Returns:
- List of extracted items (strings), or an empty list when nothing is parseable.
- """
- if not text:
- return []
-
- try:
- normalized = text.strip().replace("\r\n", "\n")
- # Ordered, exact-case matching for blocks: answer -> Answer -> ANSWER
- tag_variants = ["answer", "Answer", "ANSWER"]
- matches: list[str] = []
- for tag in tag_variants:
- matches = re.findall(rf"<{tag}>([\\s\\S]*?){tag}>", normalized)
- if matches:
- break
- # Fallback: case-insensitive matching if none of the exact-case variants matched
- if not matches:
- matches = re.findall(r"([\\s\\S]*?)", normalized, flags=re.IGNORECASE)
-
- if matches:
- combined = "\n".join(m.strip() for m in matches if m is not None)
- return extract_list_items(combined, bullet_prefixes=bullet_prefixes)
-
- # Fallback: parse the whole text if tags are absent
- return extract_list_items(normalized, bullet_prefixes=bullet_prefixes)
- except Exception as e:
- logger.info(f"Failed to extract items within tags: {e!s}", exc_info=True)
- # Final fallback: attempt direct list extraction
- try:
- return extract_list_items(text, bullet_prefixes=bullet_prefixes)
- except Exception:
- return []
-
-
def parse_yaml(yaml_file: str | Path):
yaml_path = Path(yaml_file)
if not yaml_path.is_file():
diff --git a/src/memos/mem_scheduler/webservice_modules/web_log_service.py b/src/memos/mem_scheduler/webservice_modules/web_log_service.py
new file mode 100644
index 000000000..bd1e5c9fb
--- /dev/null
+++ b/src/memos/mem_scheduler/webservice_modules/web_log_service.py
@@ -0,0 +1,146 @@
+from typing import TYPE_CHECKING
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
+from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
+from memos.mem_scheduler.schemas.general_schemas import DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
+from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ ADD_TASK_LABEL,
+ ANSWER_TASK_LABEL,
+ MEM_ARCHIVE_TASK_LABEL,
+ MEM_ORGANIZE_TASK_LABEL,
+ MEM_UPDATE_TASK_LABEL,
+ QUERY_TASK_LABEL,
+)
+
+
+if TYPE_CHECKING:
+ from memos.configs.mem_scheduler import BaseSchedulerConfig
+
+logger = get_logger(__name__)
+
+
+class WebLogSchedulerModule(BaseSchedulerModule):
+ def __init__(self):
+ super().__init__()
+ self._web_log_message_queue: Queue[ScheduleLogForWebItem] | None = None
+ self.max_web_log_queue_size = DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
+
+ def init_web_log_module(self, config: "BaseSchedulerConfig"):
+ """Initialize the web log module with the given configuration."""
+ self.max_web_log_queue_size = config.get(
+ "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
+ )
+ self._web_log_message_queue = Queue(maxsize=self.max_web_log_queue_size)
+
+ def _submit_web_logs(
+ self,
+ messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem],
+ additional_log_info: str | None = None,
+ ) -> None:
+ """Submit log messages to the web log queue and optionally to RabbitMQ.
+
+ Args:
+ messages: Single log message or list of log messages
+ """
+ if self._web_log_message_queue is None:
+ logger.warning("Web log queue is not initialized. Dropping logs.")
+ return
+
+ if isinstance(messages, ScheduleLogForWebItem):
+ messages = [messages] # transform single message to list
+
+ for message in messages:
+ # Check if rabbitmq_config is available (provided by RabbitMQSchedulerModule)
+ if getattr(self, "rabbitmq_config", None) is None:
+ continue
+ try:
+ # Always call publish; the publisher now caches when offline and flushes after reconnect
+ logger.info(
+ f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}"
+ )
+ # Assumes rabbitmq_publish_message is available via mixin
+ if hasattr(self, "rabbitmq_publish_message"):
+ self.rabbitmq_publish_message(message=message.to_dict())
+ else:
+ logger.warning("rabbitmq_publish_message method not found.")
+
+ logger.info(
+ "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched "
+ "item_id=%s task_id=%s label=%s",
+ message.item_id,
+ message.task_id,
+ message.label,
+ )
+ except Exception as e:
+ logger.error(
+ f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True
+ )
+
+ logger.debug(
+ f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}"
+ )
+
+ def get_web_log_messages(self) -> list[dict]:
+ """
+ Retrieve structured log messages from the queue and return JSON-serializable dicts.
+ """
+ if self._web_log_message_queue is None:
+ return []
+
+ raw_items: list[ScheduleLogForWebItem] = []
+ while True:
+ try:
+ raw_items.append(self._web_log_message_queue.get_nowait())
+ except Exception:
+ break
+
+ def _map_label(label: str) -> str:
+ mapping = {
+ QUERY_TASK_LABEL: "addMessage",
+ ANSWER_TASK_LABEL: "addMessage",
+ ADD_TASK_LABEL: "addMemory",
+ MEM_UPDATE_TASK_LABEL: "updateMemory",
+ MEM_ORGANIZE_TASK_LABEL: "mergeMemory",
+ MEM_ARCHIVE_TASK_LABEL: "archiveMemory",
+ }
+ return mapping.get(label, label)
+
+ def _normalize_item(item: ScheduleLogForWebItem) -> dict:
+ data = item.to_dict()
+ data["label"] = _map_label(data.get("label"))
+ memcube_content = getattr(item, "memcube_log_content", None) or []
+ metadata = getattr(item, "metadata", None) or []
+
+ memcube_name = getattr(item, "memcube_name", None)
+ if not memcube_name and hasattr(self, "_map_memcube_name"):
+ # _map_memcube_name is provided by SchedulerLoggerModule
+ memcube_name = self._map_memcube_name(item.mem_cube_id)
+ data["memcube_name"] = memcube_name
+
+ memory_len = getattr(item, "memory_len", None)
+ if memory_len is None:
+ if data["label"] == "mergeMemory":
+ memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"])
+ elif memcube_content:
+ memory_len = len(memcube_content)
+ else:
+ memory_len = 1 if item.log_content else 0
+
+ data["memcube_log_content"] = memcube_content
+ data["memory_len"] = memory_len
+
+ def _with_memory_time(meta: dict) -> dict:
+ enriched = dict(meta)
+ if "memory_time" not in enriched:
+ enriched["memory_time"] = enriched.get("updated_at") or enriched.get(
+ "update_at"
+ )
+ return enriched
+
+ data["metadata"] = [_with_memory_time(m) for m in metadata]
+ data["log_title"] = ""
+ return data
+
+ return [_normalize_item(it) for it in raw_items]
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py
index e58ebcdd1..4df4f97a3 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py
@@ -1,8 +1,10 @@
import copy
import time
+from concurrent.futures import as_completed
from typing import Any
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.factory import Neo4jGraphDB
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
@@ -16,7 +18,8 @@
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.reranker.base import BaseReranker
from memos.templates.advanced_search_prompts import PROMPT_MAPPING
-from memos.types.general_types import SearchMode
+from memos.types.general_types import FINE_STRATEGY, FineStrategy, SearchMode
+from memos.utils import extract_json_obj, extract_list_items_in_answer
logger = get_logger(__name__)
@@ -51,7 +54,9 @@ def __init__(
)
self.stage_retrieve_top = 3
- self.process_llm = process_llm
+ if process_llm is None:
+ logger.warning("[AdvancedSearcher] process_llm is None, falling back to dispatcher_llm")
+ self.process_llm = process_llm or dispatcher_llm
self.thinking_stages = 3
self.max_retry_times = 2
self.deep_search_top_k_bar = 2
@@ -362,3 +367,293 @@ def deep_search(
continue
logger.error("Deep search failed, returning original memories")
return memories
+
+ def enhance_memories_with_query(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ batch_size: int | None = None,
+ retries: int = 2,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Enhance memories by adding context and making connections to queries.
+
+ This method uses LLM to rewrite or recreate memories to better align
+ with the given query history, making them more relevant and contextual.
+
+ Args:
+ query_history: List of user queries in chronological order
+ memories: List of memory items to enhance
+ batch_size: Optional batch size for parallel processing
+ retries: Number of retries for LLM calls
+
+ Returns:
+ Tuple of (enhanced_memories, success_flag)
+ - enhanced_memories: Enhanced memory items
+ - success_flag: True if all batches processed successfully
+ """
+ if not memories:
+ logger.warning("[Enhance] ⚠️ skipped (no memories to process)")
+ return memories, True
+
+ num_of_memories = len(memories)
+
+ try:
+ # Single batch path (no parallelization)
+ if batch_size is None or num_of_memories <= batch_size:
+ enhanced_memories, success_flag = self._process_enhancement_batch(
+ batch_index=0,
+ query_history=query_history,
+ memories=memories,
+ retries=retries,
+ )
+ all_success = success_flag
+ else:
+ # Parallel batch processing
+ batches = self._split_batches(memories=memories, batch_size=batch_size)
+ all_success = True
+ failed_batches = 0
+
+ with ContextThreadPoolExecutor(max_workers=len(batches)) as executor:
+ future_map = {
+ executor.submit(
+ self._process_enhancement_batch, bi, query_history, texts, retries
+ ): (bi, s, e)
+ for bi, (s, e, texts) in enumerate(batches)
+ }
+
+ batch_results = []
+ for fut in as_completed(future_map):
+ bi, s, e = future_map[fut]
+ batch_memories, ok = fut.result()
+ batch_results.append((bi, batch_memories))
+
+ if not ok:
+ all_success = False
+ failed_batches += 1
+
+ # Sort by batch index to preserve original order
+ batch_results.sort(key=lambda x: x[0])
+
+ enhanced_memories = []
+ for _, batch_mem in batch_results:
+ enhanced_memories.extend(batch_mem)
+
+ logger.info(
+ f"[Enhance] ✅ multi-batch done | batches={len(batches)} | "
+ f"enhanced={len(enhanced_memories)} | failed_batches={failed_batches} | "
+ f"success={all_success}"
+ )
+
+ except Exception as e:
+ logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True)
+ all_success = False
+ enhanced_memories = memories
+
+ if len(enhanced_memories) == 0:
+ enhanced_memories = []
+ logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True)
+
+ return enhanced_memories, all_success
+
+ def _process_enhancement_batch(
+ self,
+ batch_index: int,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ retries: int,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ """
+ Process a single batch of memories for enhancement.
+
+ This method handles retry logic and strategy-specific enhancement
+ (REWRITE vs RECREATE).
+ """
+ attempt = 0
+ text_memories = [one.memory for one in memories]
+
+ prompt = self._build_enhancement_prompt(
+ query_history=query_history, batch_texts=text_memories
+ )
+
+ llm_response = None
+ while attempt <= max(0, retries) + 1:
+ try:
+ llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ processed_text_memories = extract_list_items_in_answer(llm_response)
+
+ if len(processed_text_memories) > 0:
+ enhanced_memories = self._create_enhanced_memories(
+ processed_text_memories=processed_text_memories,
+ original_memories=memories,
+ )
+
+ logger.info(
+ f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | "
+ f"batch={batch_index}"
+ )
+ return enhanced_memories, True
+ else:
+ raise ValueError(
+ f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; "
+ f"processed_text_memories: {processed_text_memories}"
+ )
+
+ except Exception as e:
+ attempt += 1
+ time.sleep(1)
+ logger.debug(
+ f"[enhance_memories_with_query][batch={batch_index}] "
+ f"🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}"
+ )
+
+ logger.error(
+ f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}",
+ exc_info=True,
+ )
+ return memories, False
+
+ def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str:
+ """Build the LLM prompt for memory enhancement."""
+ if len(query_history) == 1:
+ query_history_formatted = query_history[0]
+ else:
+ query_history_formatted = (
+ [f"[{i}] {query}" for i, query in enumerate(query_history)]
+ if len(query_history) > 1
+ else query_history[0]
+ )
+
+ # Include numbering for rewrite mode to help LLM reference original memory IDs
+ if FINE_STRATEGY == FineStrategy.REWRITE:
+ text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)])
+ prompt_name = "memory_rewrite_enhancement"
+ else:
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)])
+ prompt_name = "memory_recreate_enhancement"
+
+ return self.build_prompt(
+ prompt_name,
+ query_history=query_history_formatted,
+ memories=text_memories,
+ )
+
+ def _create_enhanced_memories(
+ self,
+ processed_text_memories: list[str],
+ original_memories: list[TextualMemoryItem],
+ ) -> list[TextualMemoryItem]:
+ """
+ Create enhanced memory items based on the processing strategy.
+
+ Supports two strategies:
+ - RECREATE: Create new memory items with enhanced text
+ - REWRITE: Rewrite existing memories while preserving metadata
+ """
+ enhanced_memories = []
+ user_id = original_memories[0].metadata.user_id
+
+ if FINE_STRATEGY == FineStrategy.RECREATE:
+ for new_mem in processed_text_memories:
+ enhanced_memories.append(
+ TextualMemoryItem(
+ memory=new_mem,
+ metadata=TextualMemoryMetadata(
+ user_id=user_id, memory_type="LongTermMemory"
+ ),
+ )
+ )
+
+ elif FINE_STRATEGY == FineStrategy.REWRITE:
+ # Parse index from each processed line and rewrite corresponding original memory
+ def _parse_index_and_text(s: str) -> tuple[int | None, str]:
+ import re
+
+ s = (s or "").strip()
+ # Preferred: [index] text
+ m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s)
+ if m:
+ return int(m.group(1)), m.group(2).strip()
+ # Fallback: index: text or index - text
+ m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s)
+ if m:
+ return int(m.group(1)), m.group(2).strip()
+ return None, s
+
+ idx_to_original = dict(enumerate(original_memories))
+ for j, item in enumerate(processed_text_memories):
+ idx, new_text = _parse_index_and_text(item)
+ if idx is not None and idx in idx_to_original:
+ orig = idx_to_original[idx]
+ else:
+ # Fallback: align by order if index missing/invalid
+ orig = original_memories[j] if j < len(original_memories) else None
+
+ if not orig:
+ continue
+
+ enhanced_memories.append(
+ TextualMemoryItem(
+ id=orig.id,
+ memory=new_text,
+ metadata=orig.metadata,
+ )
+ )
+ else:
+ logger.error(f"Fine search strategy {FINE_STRATEGY} not exists")
+
+ return enhanced_memories
+
+ @staticmethod
+ def _split_batches(
+ memories: list[TextualMemoryItem], batch_size: int
+ ) -> list[tuple[int, int, list[TextualMemoryItem]]]:
+ """Split memories into batches for parallel processing."""
+ batches: list[tuple[int, int, list[TextualMemoryItem]]] = []
+ start = 0
+ n = len(memories)
+ while start < n:
+ end = min(start + batch_size, n)
+ batches.append((start, end, memories[start:end]))
+ start = end
+ return batches
+
+ def recall_for_missing_memories(
+ self,
+ query: str,
+ memories: list[str],
+ ) -> tuple[str, bool]:
+ """
+ Analyze memories and generate hint for additional recall.
+
+ This method uses LLM to determine if the current memories are sufficient
+ or if additional recall is needed, along with a hint for the recall query.
+
+ Args:
+ query: Original user query
+ memories: List of currently retrieved memory texts
+
+ Returns:
+ Tuple of (hint, trigger_recall)
+ - hint: Suggested query for additional recall
+ - trigger_recall: Whether to trigger additional recall
+ """
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)])
+
+ prompt = self.build_prompt(
+ template_name="enlarge_recall",
+ query=query,
+ memories_inline=text_memories,
+ )
+ llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ json_result: dict = extract_json_obj(llm_response)
+
+ logger.info(
+ f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}"
+ )
+
+ hint = json_result.get("hint", "")
+ if len(hint) == 0:
+ return hint, False
+ return hint, json_result.get("trigger_recall", False)
diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py
index 426cf32be..75f25c95c 100644
--- a/src/memos/multi_mem_cube/single_cube.py
+++ b/src/memos/multi_mem_cube/single_cube.py
@@ -16,6 +16,10 @@
from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
from memos.mem_reader.utils import parse_keep_filter_response
+from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import (
ADD_TASK_LABEL,
@@ -270,9 +274,7 @@ def _deep_search(
]
return formatted_memories
- def _agentic_search(
- self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int
- ) -> list:
+ def _agentic_search(self, search_req: APISearchRequest, user_context: UserContext) -> list:
deepsearch_results = self.deepsearch_agent.run(
search_req.query, user_id=user_context.mem_cube_id
)
@@ -337,17 +339,26 @@ def _fine_search(
dedup=search_req.dedup,
)
- # Enhance with query
- enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query(
- query_history=[search_req.query],
- memories=raw_memories,
- )
+ if hasattr(self.searcher, "enhance_memories_with_query"):
+ enhanced_memories, _ = self.searcher.enhance_memories_with_query(
+ query_history=[search_req.query],
+ memories=raw_memories,
+ batch_size=DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ retries=DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+ )
+ else:
+ logger.warning(
+ "Searcher does not support enhance_memories_with_query; skipping enhancement."
+ )
+ enhanced_memories = raw_memories
- if len(enhanced_memories) < len(raw_memories):
+ if len(enhanced_memories) < len(raw_memories) and hasattr(
+ self.searcher, "recall_for_missing_memories"
+ ):
logger.info(
f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more."
)
- missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories(
+ missing_info_hint, trigger = self.searcher.recall_for_missing_memories(
query=search_req.query,
memories=[mem.memory for mem in enhanced_memories],
)
diff --git a/src/memos/templates/advanced_search_prompts.py b/src/memos/templates/advanced_search_prompts.py
index baf2f7536..0257dea02 100644
--- a/src/memos/templates/advanced_search_prompts.py
+++ b/src/memos/templates/advanced_search_prompts.py
@@ -163,38 +163,34 @@
"""
MEMORY_RECREATE_ENHANCEMENT_PROMPT = """
-You are a precise and detail-oriented AI assistant specialized in temporal memory reconstruction, reference resolution, and relevance-aware memory fusion.
+You are a knowledgeable and precise AI assistant.
# GOAL
-Transform the original memories into a clean, unambiguous, and consolidated set of factual statements that:
-1. **Resolve all vague or relative references** (e.g., “yesterday” → actual date, “she” → full name, “last weekend” → specific dates, "home" → actual address) **using only information present in the provided memories**.
-2. **Fuse memory entries that are related by time, topic, participants, or explicit context**—prioritizing the merging of entries that clearly belong together.
-3. **Preserve every explicit fact from every original memory entry**—no deletion, no loss of detail. Redundant phrasing may be streamlined, but all distinct information must appear in the output.
-4. **Return at most {top_k} fused and disambiguated memory segments in , ordered by relevance to the user query** (most relevant first).
-
-# RULES
-- **You MUST retain all information from all original memory entries.** Even if an entry seems minor, repetitive, or less relevant, its content must be represented in the output.
-- **Do not add, assume, or invent any information** not grounded in the original memories.
-- **Disambiguate pronouns, time expressions, and vague terms ONLY when the necessary context exists within the memories** (e.g., if “yesterday” appears in a message dated July 3, resolve it to July 2).
-- **If you cannot resolve a vague reference (e.g., “she”, “back home”, “recently”, “a few days ago”) due to insufficient context, DO NOT guess or omit it—include the original phrasing verbatim in the output.**
-- **Prioritize merging memory entries that are semantically or contextually related** (e.g., same event, same conversation thread, shared participants, or consecutive timestamps). Grouping should reflect natural coherence, not just proximity.
-- **The total number of bullets in must not exceed {top_k}.** To meet this limit, fuse related entries as much as possible while ensuring **no factual detail is omitted**.
-- **Never sacrifice factual completeness for brevity or conciseness.** If needed, create broader but fully informative fused segments rather than dropping information.
-- **Each bullet in must be a self-contained, fluent sentence or clause** that includes all resolved details from the original entries it represents. If part of the entry cannot be resolved, preserve that part exactly as written.
-- **Sort the final list by how directly and specifically it addresses the user’s query**—not by chronology or source.
+Transform raw memories into clean, complete, and fully disambiguated statements that preserve original meaning and explicit details.
+
+# RULES & THINKING STEPS
+1. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”).
+2. Resolve all ambiguities using only memory content. If disambiguation cannot be performed using only the provided memories, retain the original phrasing exactly as written. Never guess, infer, or fabricate missing information:
+ - Pronouns → full name (e.g., “she” → “Caroline”)
+ - Relative time expressions → concrete dates or full context (e.g., “last night” → “on the evening of November 25, 2025”)
+ - Vague references → specific, grounded details (e.g., “the event” → “the LGBTQ+ art workshop in Malmö”)
+ - Incomplete descriptions → full version from memory (e.g., “the activity” → “the abstract painting session at the community center”)
+3. Merge memories that are largely repetitive in content but contain complementary or distinct details. Combine them into a single, cohesive statement that preserves all unique information from each original memory. Do not merge memories that describe different events, even if they share a theme.
+4. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely.
# OUTPUT FORMAT (STRICT)
-Return ONLY the following structure:
+Return ONLY the following block, with **one enhanced memory per line**.
+Each line MUST start with "- " (dash + space).
+Wrap the final output inside:
-- [Fully resolved, fused memory segment most relevant to the query — containing all facts from the original entries it covers; unresolved parts kept verbatim]
-- [Next most relevant resolved and fused segment — again, with no factual loss]
-- [...]
+- enhanced memory 1
+- enhanced memory 2
+...
-
## User Query
-{query}
+{query_history}
## Original Memories
{memories}
diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py
index acbae2281..653815bf0 100644
--- a/src/memos/templates/mem_scheduler_prompts.py
+++ b/src/memos/templates/mem_scheduler_prompts.py
@@ -618,7 +618,6 @@
"memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT,
"memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT,
"memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT,
- "memory_recreate_enhancement": MEMORY_RECREATE_ENHANCEMENT_PROMPT,
"memory_rewrite_enhancement": MEMORY_REWRITE_ENHANCEMENT_PROMPT,
"enlarge_recall": ENLARGE_RECALL_PROMPT_ONE_SENTENCE,
}
diff --git a/src/memos/utils.py b/src/memos/utils.py
index fd6d4eaf9..cc530f0ef 100644
--- a/src/memos/utils.py
+++ b/src/memos/utils.py
@@ -1,4 +1,6 @@
import functools
+import json
+import re
import time
import traceback
@@ -121,3 +123,167 @@ def wrapper(*args, **kwargs):
if func is None:
return decorator
return decorator(func)
+
+
+def extract_json_obj(text: str):
+ """
+ Safely extracts JSON from LLM response text with robust error handling.
+
+ Args:
+ text: Raw text response from LLM that may contain JSON
+
+ Returns:
+ Parsed JSON data (dict or list)
+
+ Raises:
+ ValueError: If no valid JSON can be extracted
+ """
+ if not text:
+ raise ValueError("Empty input text")
+
+ # Normalize the text
+ text = text.strip()
+
+ # Remove common code block markers
+ patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
+ for pattern in patterns_to_remove:
+ text = text.replace(pattern, "")
+
+ # Try: direct JSON parse first
+ try:
+ return json.loads(text.strip())
+ except json.JSONDecodeError as e:
+ logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+
+ # Fallback 1: Extract JSON using regex
+ json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]"
+ matches = re.findall(json_pattern, text)
+ if matches:
+ try:
+ return json.loads(matches[0])
+ except json.JSONDecodeError as e:
+ logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+
+ # Fallback 2: Handle malformed JSON (common LLM issues)
+ try:
+ # Try adding missing quotes around keys
+ text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text)
+ return json.loads(text)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}")
+ logger.error("Full traceback:\n" + traceback.format_exc())
+ raise ValueError(text) from e
+
+
+def extract_list_items(text: str, bullet_prefixes: tuple[str, ...] = ("- ",)) -> list[str]:
+ """
+ Extract bullet list items from LLM output where each item is on a single line
+ starting with a given bullet prefix (default: "- ").
+
+ This function is designed to be robust to common LLM formatting variations,
+ following similar normalization practices as `extract_json_obj`.
+
+ Behavior:
+ - Strips common code-fence markers (```json, ```python, ``` etc.).
+ - Collects all lines that start with any of the provided `bullet_prefixes`.
+ - Tolerates the "• " bullet as a loose fallback.
+ - Unescapes common sequences like "\\n" and "\\t" within items.
+ - If no bullet lines are found, falls back to attempting to parse a JSON array
+ (using `extract_json_obj`) and returns its string elements.
+
+ Args:
+ text: Raw text response from LLM.
+ bullet_prefixes: Tuple of accepted bullet line prefixes.
+
+ Returns:
+ List of extracted items (strings). Returns an empty list if none can be parsed.
+ """
+ if not text:
+ return []
+
+ # Normalize the text similar to extract_json_obj
+ normalized = text.strip()
+ patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
+ for pattern in patterns_to_remove:
+ normalized = normalized.replace(pattern, "")
+ normalized = normalized.replace("\r\n", "\n")
+
+ lines = normalized.splitlines()
+ items: list[str] = []
+ seen: set[str] = set()
+
+ for raw in lines:
+ line = raw.strip()
+ if not line:
+ continue
+
+ matched = False
+ for prefix in bullet_prefixes:
+ if line.startswith(prefix):
+ content = line[len(prefix) :].strip()
+ content = content.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
+ if content and content not in seen:
+ items.append(content)
+ seen.add(content)
+ matched = True
+ break
+
+ if matched:
+ continue
+
+ if items:
+ return items
+ else:
+ logger.error(f"Fail to parse {text}")
+
+ return []
+
+
+def extract_list_items_in_answer(
+ text: str, bullet_prefixes: tuple[str, ...] = ("- ",)
+) -> list[str]:
+ """
+ Extract list items specifically from content enclosed within `...` tags.
+
+ - When one or more `...` blocks are present, concatenates their inner
+ contents with newlines and parses using `extract_list_items`.
+ - When no `` block is found, falls back to parsing the entire input with
+ `extract_list_items`.
+ - Case-insensitive matching of the `` tag.
+
+ Args:
+ text: Raw text that may contain `...` blocks.
+ bullet_prefixes: Accepted bullet prefixes (default: strictly `"- "`).
+
+ Returns:
+ List of extracted items (strings), or an empty list when nothing is parseable.
+ """
+ if not text:
+ return []
+
+ try:
+ normalized = text.strip().replace("\r\n", "\n")
+ # Ordered, exact-case matching for blocks: answer -> Answer -> ANSWER
+ tag_variants = ["answer", "Answer", "ANSWER"]
+ matches: list[str] = []
+ for tag in tag_variants:
+ matches = re.findall(rf"<{tag}>([\s\S]*?){tag}>", normalized)
+ if matches:
+ break
+ # Fallback: case-insensitive matching if none of the exact-case variants matched
+ if not matches:
+ matches = re.findall(r"([\s\S]*?)", normalized, flags=re.IGNORECASE)
+
+ if matches:
+ combined = "\n".join(m.strip() for m in matches if m is not None)
+ return extract_list_items(combined, bullet_prefixes=bullet_prefixes)
+
+ # Fallback: parse the whole text if tags are absent
+ return extract_list_items(normalized, bullet_prefixes=bullet_prefixes)
+ except Exception as e:
+ logger.info(f"Failed to extract items within tags: {e!s}", exc_info=True)
+ # Final fallback: attempt direct list extraction
+ try:
+ return extract_list_items(text, bullet_prefixes=bullet_prefixes)
+ except Exception:
+ return []
diff --git a/tests/mem_scheduler/test_post_processor.py b/tests/mem_scheduler/test_post_processor.py
new file mode 100644
index 000000000..5a6a58ba4
--- /dev/null
+++ b/tests/mem_scheduler/test_post_processor.py
@@ -0,0 +1,215 @@
+"""
+Unit tests for MemoryPostProcessor.
+
+These tests verify the post-processing operations including filtering and reranking.
+
+Note: Memory enhancement tests have been moved to AdvancedSearcher tests.
+"""
+
+from unittest.mock import Mock
+
+import pytest
+
+from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.mem_scheduler.memory_manage_modules.post_processor import MemoryPostProcessor
+from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
+
+
+class TestMemoryPostProcessor:
+ """Test suite for MemoryPostProcessor."""
+
+ @pytest.fixture
+ def mock_llm(self):
+ """Create a mock LLM instance."""
+ llm = Mock()
+ llm.generate = Mock(return_value='{"result": true, "reason": "test"}')
+ return llm
+
+ @pytest.fixture
+ def mock_config(self):
+ """Create a mock config instance."""
+ config = Mock(spec=BaseSchedulerConfig)
+ config.scheduler_retriever_batch_size = 10
+ config.scheduler_retriever_enhance_retries = 2
+ return config
+
+ @pytest.fixture
+ def processor(self, mock_llm, mock_config):
+ """Create a MemoryPostProcessor instance."""
+ return MemoryPostProcessor(process_llm=mock_llm, config=mock_config)
+
+ @pytest.fixture
+ def sample_memories(self):
+ """Create sample memory items for testing."""
+ return [
+ TextualMemoryItem(
+ memory="Python is a programming language",
+ metadata=TextualMemoryMetadata(user_id="user1", memory_type="LongTermMemory"),
+ ),
+ TextualMemoryItem(
+ memory="JavaScript is also a programming language",
+ metadata=TextualMemoryMetadata(user_id="user1", memory_type="LongTermMemory"),
+ ),
+ ]
+
+ def test_init(self, mock_llm, mock_config):
+ """Test initialization of MemoryPostProcessor."""
+ processor = MemoryPostProcessor(process_llm=mock_llm, config=mock_config)
+
+ assert processor.process_llm is mock_llm
+ assert processor.config is mock_config
+ assert processor.filter_similarity_threshold == 0.75
+ assert processor.filter_min_length_threshold == 6
+
+ def test_evaluate_memory_answer_ability_true(self, processor, mock_llm):
+ """Test evaluate_memory_answer_ability when memories can answer query."""
+ mock_llm.generate.return_value = (
+ '{"result": true, "reason": "Memories contain relevant info"}'
+ )
+
+ result = processor.evaluate_memory_answer_ability(
+ query="What is Python?",
+ memory_texts=["Python is a programming language"],
+ )
+
+ assert result is True
+ assert mock_llm.generate.called
+
+ def test_evaluate_memory_answer_ability_false(self, processor, mock_llm):
+ """Test evaluate_memory_answer_ability when memories cannot answer query."""
+ mock_llm.generate.return_value = '{"result": false, "reason": "No relevant info"}'
+
+ result = processor.evaluate_memory_answer_ability(
+ query="What is the capital of France?",
+ memory_texts=["Python is a programming language"],
+ )
+
+ assert result is False
+
+ def test_evaluate_memory_answer_ability_with_top_k(self, processor, mock_llm):
+ """Test evaluate_memory_answer_ability with top_k limit."""
+ mock_llm.generate.return_value = '{"result": true}'
+
+ memories = ["memory 1", "memory 2", "memory 3", "memory 4", "memory 5"]
+ processor.evaluate_memory_answer_ability(
+ query="test query",
+ memory_texts=memories,
+ top_k=3,
+ )
+
+ # Verify only top 3 memories were used
+ call_args = mock_llm.generate.call_args[0][0]
+ prompt_content = call_args[0]["content"]
+ # Should contain only 3 memories
+ assert prompt_content.count("- memory") == 3
+
+ def test_rerank_memories_success(self, processor, mock_llm):
+ """Test successful memory reranking."""
+ original_memories = ["memory A", "memory B", "memory C"]
+ mock_llm.generate.return_value = (
+ '{"new_order": [2, 0, 1], "reasoning": "C is most relevant"}'
+ )
+
+ reranked, success = processor.rerank_memories(
+ queries=["test query"],
+ original_memories=original_memories,
+ top_k=3,
+ )
+
+ assert success is True
+ assert reranked == ["memory C", "memory A", "memory B"]
+
+ def test_rerank_memories_failure_fallback(self, processor, mock_llm):
+ """Test reranking fallback when LLM fails."""
+ original_memories = ["memory A", "memory B", "memory C"]
+ mock_llm.generate.return_value = '{"invalid": "response"}' # Missing new_order
+
+ reranked, success = processor.rerank_memories(
+ queries=["test query"],
+ original_memories=original_memories,
+ top_k=2,
+ )
+
+ assert success is False
+ assert reranked == ["memory A", "memory B"] # Original order, truncated to top_k
+
+ def test_rerank_memories_respects_top_k(self, processor, mock_llm):
+ """Test that reranking respects top_k limit."""
+ original_memories = ["A", "B", "C", "D", "E"]
+ mock_llm.generate.return_value = '{"new_order": [4, 3, 2, 1, 0], "reasoning": "reversed"}'
+
+ reranked, success = processor.rerank_memories(
+ queries=["test"],
+ original_memories=original_memories,
+ top_k=3,
+ )
+
+ assert len(reranked) == 3
+ assert reranked == ["E", "D", "C"]
+
+ def test_process_and_rerank_memories(self, processor, mock_llm, sample_memories):
+ """Test combined processing and reranking of memories."""
+ mock_llm.generate.return_value = '{"new_order": [0, 1], "reasoning": "test"}'
+
+ original = sample_memories[:1]
+ new = sample_memories[1:]
+
+ reranked, success = processor.process_and_rerank_memories(
+ queries=["programming languages"],
+ original_memory=original,
+ new_memory=new,
+ top_k=2,
+ )
+
+ # Should have combined and reranked both memories
+ assert len(reranked) <= 2
+ assert all(isinstance(m, TextualMemoryItem) for m in reranked)
+
+ def test_filter_unrelated_memories_delegation(self, processor, sample_memories):
+ """Test that filter_unrelated_memories delegates to MemoryFilter."""
+ mock_memory_filter = Mock()
+ processor.memory_filter = mock_memory_filter
+ mock_memory_filter.filter_unrelated_memories = Mock(return_value=(sample_memories, True))
+
+ filtered, success = processor.filter_unrelated_memories(
+ query_history=["test"],
+ memories=sample_memories,
+ )
+
+ assert mock_memory_filter.filter_unrelated_memories.called
+ assert filtered == sample_memories
+ assert success is True
+
+ def test_filter_redundant_memories_delegation(self, processor, sample_memories):
+ """Test that filter_redundant_memories delegates to MemoryFilter."""
+ mock_memory_filter = Mock()
+ processor.memory_filter = mock_memory_filter
+ mock_memory_filter.filter_redundant_memories = Mock(
+ return_value=(sample_memories[:1], True)
+ )
+
+ filtered, success = processor.filter_redundant_memories(
+ query_history=["test"],
+ memories=sample_memories,
+ )
+
+ assert mock_memory_filter.filter_redundant_memories.called
+
+ def test_filter_unrelated_and_redundant_memories_delegation(self, processor, sample_memories):
+ """Test combined filtering delegation."""
+ mock_memory_filter = Mock()
+ processor.memory_filter = mock_memory_filter
+ mock_memory_filter.filter_unrelated_and_redundant_memories = Mock(
+ return_value=(sample_memories[:1], True)
+ )
+
+ filtered, success = processor.filter_unrelated_and_redundant_memories(
+ query_history=["test"],
+ memories=sample_memories,
+ )
+
+ assert mock_memory_filter.filter_unrelated_and_redundant_memories.called
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/mem_scheduler/test_retriever.py b/tests/mem_scheduler/test_retriever.py
index 35c8b7f3a..7069d4268 100644
--- a/tests/mem_scheduler/test_retriever.py
+++ b/tests/mem_scheduler/test_retriever.py
@@ -74,7 +74,7 @@ def setUp(self):
self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm)
self.scheduler.mem_cube = self.mem_cube
- self.retriever = self.scheduler.retriever
+ self.post_processor = self.scheduler.post_processor
# Mock logging to verify messages
self.logging_warning_patch = patch("logging.warning")
@@ -176,7 +176,7 @@ def test_filter_unrelated_memories_empty_memories(self):
"""Test filter_unrelated_memories with empty memories list."""
query_history = ["What is the weather like?", "Tell me about Python programming"]
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=[]
)
@@ -192,7 +192,7 @@ def test_filter_unrelated_memories_empty_query_history(self):
TextualMemoryItem(memory="Data science involves statistics"),
]
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=[], memories=memories
)
@@ -221,7 +221,7 @@ def test_filter_unrelated_memories_successful_filtering(self):
# Convert to proper JSON string
self.llm.generate.return_value = json.dumps(mock_llm_response)
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=memories
)
@@ -251,7 +251,7 @@ def test_filter_unrelated_memories_llm_failure_fallback(self):
# Mock LLM to return an invalid response that will trigger error handling
self.llm.generate.return_value = "Invalid response that cannot be parsed"
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=memories
)
@@ -275,7 +275,7 @@ def test_filter_unrelated_memories_invalid_json_response(self):
# Mock LLM to return invalid JSON
self.llm.generate.return_value = "This is not valid JSON"
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=memories
)
@@ -301,7 +301,7 @@ def test_filter_unrelated_memories_invalid_indices(self):
# Convert to proper JSON string
self.llm.generate.return_value = json.dumps(mock_llm_response)
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=memories
)
@@ -327,7 +327,7 @@ def test_filter_unrelated_memories_missing_required_fields(self):
# Convert to proper JSON string
self.llm.generate.return_value = json.dumps(mock_llm_response)
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=memories
)
@@ -353,7 +353,7 @@ def test_filter_unrelated_memories_conservative_filtering(self):
self.llm.generate.return_value = json.dumps(mock_llm_response)
- result, success_flag = self.retriever.filter_unrelated_memories(
+ result, success_flag = self.post_processor.filter_unrelated_memories(
query_history=query_history, memories=memories
)
diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py
index 523d5d108..86abc3fae 100644
--- a/tests/mem_scheduler/test_scheduler.py
+++ b/tests/mem_scheduler/test_scheduler.py
@@ -14,7 +14,7 @@
)
from memos.llms.base import BaseLLM
from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
+from memos.mem_scheduler.memory_manage_modules.search_service import SchedulerSearchService
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
from memos.mem_scheduler.schemas.message_schemas import (
@@ -113,7 +113,7 @@ def test_initialize_modules(self):
"""Test module initialization with proper component assignments."""
self.assertEqual(self.scheduler.chat_llm, self.llm)
self.assertIsInstance(self.scheduler.monitor, SchedulerGeneralMonitor)
- self.assertIsInstance(self.scheduler.retriever, SchedulerRetriever)
+ self.assertIsInstance(self.scheduler.search_service, SchedulerSearchService)
def test_submit_web_logs(self):
"""Test submission of web logs with updated data structure."""
diff --git a/tests/mem_scheduler/test_search_service.py b/tests/mem_scheduler/test_search_service.py
new file mode 100644
index 000000000..49db43ded
--- /dev/null
+++ b/tests/mem_scheduler/test_search_service.py
@@ -0,0 +1,193 @@
+"""
+Unit tests for SchedulerSearchService.
+
+These tests verify that the SchedulerSearchService correctly delegates
+search operations to the Searcher class and provides proper fallback behavior.
+"""
+
+from unittest.mock import Mock
+
+import pytest
+
+from memos.mem_scheduler.memory_manage_modules.search_service import SchedulerSearchService
+from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
+from memos.memories.textual.tree import TreeTextMemory
+from memos.types.general_types import SearchMode
+
+
+class TestSchedulerSearchService:
+ """Test suite for SchedulerSearchService."""
+
+ @pytest.fixture
+ def mock_searcher(self):
+ """Create a mock Searcher instance."""
+ searcher = Mock()
+ searcher.manual_close_internet = True
+ searcher.search = Mock(
+ return_value=[
+ TextualMemoryItem(
+ memory="Test memory 1",
+ metadata=TextualMemoryMetadata(user_id="user1", memory_type="LongTermMemory"),
+ )
+ ]
+ )
+ return searcher
+
+ @pytest.fixture
+ def mock_mem_cube(self):
+ """Create a mock MemCube instance."""
+ mem_cube = Mock()
+ mem_cube.text_mem = Mock(spec=TreeTextMemory)
+ mem_cube.text_mem.search = Mock(
+ return_value=[
+ TextualMemoryItem(
+ memory="Fallback memory",
+ metadata=TextualMemoryMetadata(user_id="user1", memory_type="LongTermMemory"),
+ )
+ ]
+ )
+ return mem_cube
+
+ def test_init_with_searcher(self, mock_searcher):
+ """Test initialization with a Searcher instance."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+ assert service.searcher is mock_searcher
+
+ def test_init_without_searcher(self):
+ """Test initialization without a Searcher instance."""
+ service = SchedulerSearchService(searcher=None)
+ assert service.searcher is None
+
+ def test_search_with_searcher(self, mock_searcher, mock_mem_cube):
+ """Test search operation using Searcher (preferred path)."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ results = service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ mode=SearchMode.FAST,
+ )
+
+ # Verify Searcher.search() was called ONCE with memory_type="All"
+ # (This avoids the 2*top_k bug)
+ assert mock_searcher.search.call_count == 1
+
+ # Verify correct parameters were passed
+ call_args = mock_searcher.search.call_args[1]
+ assert call_args["query"] == "test query"
+ assert call_args["memory_type"] == "All" # Should search all types together
+ assert call_args["top_k"] == 10
+
+ # Verify results were returned
+ assert len(results) >= 1
+
+ def test_search_without_searcher_fallback(self, mock_mem_cube):
+ """Test search operation without Searcher (fallback path)."""
+ service = SchedulerSearchService(searcher=None)
+
+ results = service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ mode=SearchMode.FAST,
+ )
+
+ # Verify text_mem.search() was called once as fallback (with memory_type="All")
+ assert mock_mem_cube.text_mem.search.call_count == 1
+
+ # Verify results were returned
+ assert len(results) >= 1
+
+ def test_search_internet_search_toggle(self, mock_searcher, mock_mem_cube):
+ """Test that internet_search parameter correctly toggles manual_close_internet."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ # Test with internet_search=True
+ service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ internet_search=True,
+ )
+
+ # Verify manual_close_internet was set to False (enable internet search)
+ # Note: This is tested during the call, then restored
+ assert mock_searcher.manual_close_internet # Restored after call
+
+ def test_search_mode_fine(self, mock_searcher, mock_mem_cube):
+ """Test search with FINE mode."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ mode=SearchMode.FINE,
+ )
+
+ # Verify FINE mode was passed
+ call_args = mock_searcher.search.call_args_list[0][1]
+ assert call_args["mode"] == SearchMode.FINE
+
+ def test_search_with_filters(self, mock_searcher, mock_mem_cube):
+ """Test search with search_filter and search_priority."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ search_filter = {"source": "document"}
+ search_priority = {"session_id": "session123"}
+
+ service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ search_filter=search_filter,
+ search_priority=search_priority,
+ )
+
+ # Verify filters were passed
+ call_args = mock_searcher.search.call_args_list[0][1]
+ assert call_args["search_filter"] == search_filter
+ assert call_args["search_priority"] == search_priority
+
+ def test_search_exception_handling(self, mock_mem_cube):
+ """Test that exceptions are caught and empty list is returned."""
+ service = SchedulerSearchService(searcher=None)
+ mock_mem_cube.text_mem.search.side_effect = Exception("Search failed")
+
+ results = service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ )
+
+ # Verify empty list is returned on exception
+ assert results == []
+
+ def test_search_preserves_searcher_state(self, mock_searcher, mock_mem_cube):
+ """Test that the original searcher state is preserved after search."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ original_state = True
+ mock_searcher.manual_close_internet = original_state
+
+ service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=10,
+ internet_search=True, # This should temporarily change the state
+ )
+
+ # Verify original state was restored
+ assert mock_searcher.manual_close_internet == original_state
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/mem_scheduler/test_search_service_bug_fix.py b/tests/mem_scheduler/test_search_service_bug_fix.py
new file mode 100644
index 000000000..3fdc3f542
--- /dev/null
+++ b/tests/mem_scheduler/test_search_service_bug_fix.py
@@ -0,0 +1,144 @@
+"""
+Critical bug fix test: Verify that search returns correct number of results.
+
+This test verifies the fix for the bug where calling Searcher.search() twice
+(once for LongTermMemory, once for UserMemory) would return 2*top_k results
+because each call applies deduplication and top_k limiting independently.
+"""
+
+from unittest.mock import Mock
+
+import pytest
+
+from memos.mem_scheduler.memory_manage_modules.search_service import SchedulerSearchService
+from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
+from memos.types.general_types import SearchMode
+
+
+class TestSearchServiceBugFix:
+ """Test suite for the critical 2*top_k bug fix."""
+
+ @pytest.fixture
+ def mock_searcher(self):
+ """Create a mock Searcher that simulates real behavior."""
+ searcher = Mock()
+ searcher.manual_close_internet = True
+
+ # Simulate Searcher.search() behavior:
+ # Returns exactly top_k results after deduplication
+ def search_side_effect(*args, **kwargs):
+ top_k = kwargs.get("top_k", 10)
+ memory_type = kwargs.get("memory_type", "All")
+
+ # Simulate returning top_k results
+ results = [
+ TextualMemoryItem(
+ memory=f"Memory {i} ({memory_type})",
+ metadata=TextualMemoryMetadata(user_id="user1", memory_type=memory_type),
+ )
+ for i in range(top_k)
+ ]
+ return results
+
+ searcher.search = Mock(side_effect=search_side_effect)
+ return searcher
+
+ @pytest.fixture
+ def mock_mem_cube(self):
+ """Create a mock MemCube."""
+ mem_cube = Mock()
+ mem_cube.text_mem = Mock()
+ return mem_cube
+
+ def test_search_returns_correct_count_not_double(self, mock_searcher, mock_mem_cube):
+ """
+ CRITICAL TEST: Verify search returns top_k results, not 2*top_k.
+
+ This test verifies the fix for the bug where:
+ - OLD (buggy): Called search() twice → returned 2*top_k results
+ - NEW (fixed): Calls search() once with memory_type="All" → returns top_k results
+ """
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ top_k = 10
+ results = service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=top_k,
+ mode=SearchMode.FAST,
+ )
+
+ # CRITICAL ASSERTION: Should return exactly top_k results, not 2*top_k
+ assert len(results) == top_k, (
+ f"Expected exactly {top_k} results, but got {len(results)}. "
+ f"This indicates the 2*top_k bug is NOT fixed!"
+ )
+
+ # Verify search was called only ONCE with memory_type="All"
+ assert mock_searcher.search.call_count == 1, (
+ f"Expected search() to be called once, but was called {mock_searcher.search.call_count} times. "
+ f"Multiple calls would cause the 2*top_k bug!"
+ )
+
+ # Verify the call used memory_type="All"
+ call_kwargs = mock_searcher.search.call_args[1]
+ assert call_kwargs["memory_type"] == "All", (
+ f"Expected memory_type='All', but got '{call_kwargs['memory_type']}'. "
+ f"Separate calls for LongTermMemory and UserMemory would cause the 2*top_k bug!"
+ )
+
+ def test_old_buggy_behavior_would_return_double(self):
+ """
+ Documentation test: Show what the OLD buggy behavior would have been.
+
+ This test documents the bug for future reference.
+ """
+ # Simulate the OLD buggy implementation
+ mock_searcher = Mock()
+
+ def buggy_search(*args, **kwargs):
+ # Each call returns top_k results
+ top_k = kwargs.get("top_k", 10)
+ return [Mock() for _ in range(top_k)]
+
+ mock_searcher.search = Mock(side_effect=buggy_search)
+
+ # OLD buggy code would do:
+ top_k = 10
+ results_long_term = mock_searcher.search(memory_type="LongTermMemory", top_k=top_k)
+ results_user = mock_searcher.search(memory_type="UserMemory", top_k=top_k)
+ buggy_results = results_long_term + results_user
+
+ # This would return 2*top_k results!
+ assert len(buggy_results) == 2 * top_k, (
+ f"OLD buggy behavior: Expected {2 * top_k} results (2*top_k), "
+ f"but got {len(buggy_results)}"
+ )
+
+ # This is the BUG we fixed!
+ print(
+ f"✅ Confirmed: OLD buggy behavior would return {len(buggy_results)} results (2*top_k)"
+ )
+
+ def test_search_with_different_top_k_values(self, mock_searcher, mock_mem_cube):
+ """Test that the fix works correctly with different top_k values."""
+ service = SchedulerSearchService(searcher=mock_searcher)
+
+ for top_k in [1, 5, 10, 20, 50]:
+ results = service.search(
+ query="test query",
+ user_id="user1",
+ mem_cube=mock_mem_cube,
+ top_k=top_k,
+ mode=SearchMode.FAST,
+ )
+
+ # Should always return exactly top_k, never 2*top_k
+ assert len(results) == top_k, (
+ f"For top_k={top_k}, expected {top_k} results, but got {len(results)}"
+ )
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "-s"])