diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/task_stop_rerun.py index 4664e0eaa..5bd344651 100644 --- a/examples/mem_scheduler/task_stop_rerun.py +++ b/examples/mem_scheduler/task_stop_rerun.py @@ -28,6 +28,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): try: print(f"writing {file_path}...") file_path.write_text(f"Task {task_id} processed.\n") + sleep(5) except Exception as e: print(f"Failed to write {file_path}: {e}") @@ -57,6 +58,8 @@ def submit_tasks(): TEST_HANDLER_LABEL = "test_handler" mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) +# 10s to restart +mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 10_000 tmp_dir = Path("./tmp") tmp_dir.mkdir(exist_ok=True) @@ -69,10 +72,15 @@ def submit_tasks(): submit_tasks() # 6. Wait until tmp has 100 files or timeout -poll_interval = 0.01 +poll_interval = 1 expected = 100 tmp_dir = Path("tmp") -while mem_scheduler.get_tasks_status()["remaining"] != 0: +tasks_status = mem_scheduler.get_tasks_status() +mem_scheduler.print_tasks_status(tasks_status=tasks_status) +while ( + mem_scheduler.get_tasks_status()["remaining"] != 0 + or mem_scheduler.get_tasks_status()["running"] != 0 +): count = len(list(tmp_dir.glob("*.txt"))) if tmp_dir.exists() else 0 tasks_status = mem_scheduler.get_tasks_status() mem_scheduler.print_tasks_status(tasks_status=tasks_status) diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py index d12a8ace4..e2eefb9d8 100644 --- a/src/memos/api/handlers/scheduler_handler.py +++ b/src/memos/api/handlers/scheduler_handler.py @@ -22,10 +22,13 @@ AllStatusResponseData, StatusResponse, StatusResponseItem, + TaskQueueData, + TaskQueueResponse, TaskSummary, ) from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -243,6 +246,96 @@ def handle_scheduler_status( raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err +def handle_task_queue_status( + user_id: str, mem_scheduler: OptimizedScheduler, task_id: str | None = None +) -> TaskQueueResponse: + try: + queue = getattr(mem_scheduler, "memos_message_queue", None) + if queue is None: + raise HTTPException(status_code=503, detail="Scheduler queue is not available") + + # Only support Redis-backed queue for now; try lazy init if not connected + redis_conn = getattr(queue, "_redis_conn", None) + if redis_conn is None: + try: + if hasattr(queue, "auto_initialize_redis"): + queue.auto_initialize_redis() + redis_conn = getattr(queue, "_redis_conn", None) + if redis_conn and hasattr(queue, "connect"): + queue.connect() + except Exception: + redis_conn = None + + if redis_conn is None: + raise HTTPException(status_code=503, detail="Scheduler queue not connected to Redis") + + stream_keys = queue.get_stream_keys() + # Filter by user_id; stream key format: {prefix}:{user_id}:{mem_cube_id}:{task_label} + user_stream_keys = [sk for sk in stream_keys if f":{user_id}:" in sk] + + if not user_stream_keys: + raise HTTPException( + status_code=404, detail=f"No scheduler streams found for user {user_id}" + ) + + def _parse_user_id_from_stream(stream_key: str) -> str | None: + try: + parts = stream_key.split(":") + if len(parts) < 3: + return None + # prefix may contain multiple segments; user_id is the 2nd segment from the end - 1 + return parts[-3] + except Exception: + return None + + user_ids_present = { + uid for uid in (_parse_user_id_from_stream(sk) for sk in stream_keys) if uid + } + + pending_total = 0 + pending_detail: list[str] = [] + remaining_total = 0 + remaining_detail: list[str] = [] + + consumer_group = getattr(queue, "consumer_group", None) or "scheduler_group" + for sk in user_stream_keys: + try: + pending_info = redis_conn.xpending(sk, consumer_group) + pending_count = pending_info[0] if pending_info else 0 + except Exception: + pending_count = 0 + pending_total += pending_count + pending_detail.append(f"{sk}:{pending_count}") + + try: + remaining_count = redis_conn.xlen(sk) + except Exception: + remaining_count = 0 + remaining_total += remaining_count + remaining_detail.append(f"{sk}:{remaining_count}") + + data = TaskQueueData( + user_id=user_id, + user_name=None, + mem_cube_id=None, + stream_keys=user_stream_keys, + users_count=len(user_ids_present), + pending_tasks_count=pending_total, + remaining_tasks_count=remaining_total, + pending_tasks_detail=pending_detail, + remaining_tasks_detail=remaining_detail, + ) + return TaskQueueResponse(data=data) + except HTTPException: + # Re-raise HTTPException directly to preserve its status code (e.g., 404) + raise + except Exception as err: + logger.error( + f"Failed to get task queue status for user {user_id}: {traceback.format_exc()}" + ) + raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + + def handle_scheduler_wait( user_name: str, status_tracker: TaskStatusTracker, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 1f5a582fc..06cc29729 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -884,6 +884,34 @@ class StatusResponse(BaseResponse[list[StatusResponseItem]]): message: str = "Memory get status successfully" +class TaskQueueData(BaseModel): + """Queue-level metrics for scheduler tasks.""" + + user_id: str = Field(..., description="User ID the query is scoped to") + user_name: str | None = Field(None, description="User name if available") + mem_cube_id: str | None = Field( + None, description="MemCube ID if a single cube is targeted; otherwise None" + ) + stream_keys: list[str] = Field(..., description="Matched Redis stream keys for this user") + users_count: int = Field(..., description="Distinct users currently present in queue streams") + pending_tasks_count: int = Field( + ..., description="Count of pending (delivered, not acked) tasks" + ) + remaining_tasks_count: int = Field(..., description="Count of enqueued tasks (xlen)") + pending_tasks_detail: list[str] = Field( + ..., description="Per-stream pending counts, formatted as '{stream_key}:{count}'" + ) + remaining_tasks_detail: list[str] = Field( + ..., description="Per-stream remaining counts, formatted as '{stream_key}:{count}'" + ) + + +class TaskQueueResponse(BaseResponse[TaskQueueData]): + """Response model for scheduler task queue status.""" + + message: str = "Scheduler task queue status retrieved successfully" + + class TaskSummary(BaseModel): """Aggregated counts of tasks by status.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e8acf2e38..fcb70a64c 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -41,6 +41,7 @@ StatusResponse, SuggestionRequest, SuggestionResponse, + TaskQueueResponse, ) from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler @@ -143,6 +144,20 @@ def scheduler_status( ) +@router.get( # Changed from post to get + "/scheduler/task_queue_status", + summary="Get scheduler task queue status", + response_model=TaskQueueResponse, +) +def scheduler_task_queue_status( + user_id: str = Query(..., description="User ID whose queue status is requested"), +): + """Get scheduler task queue backlog/pending status for a user.""" + return handlers.scheduler_handler.handle_task_queue_status( + user_id=user_id, mem_scheduler=mem_scheduler + ) + + @router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user") def scheduler_wait( user_name: str, diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index a0b72efd1..eaaa71461 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -44,6 +44,7 @@ def parse_datetime(cls, value): class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" + # Allow passing additional fields without raising validation errors model_config = ConfigDict(extra="allow", strict=True) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index add689336..79c28c32c 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -4,6 +4,7 @@ import time from collections.abc import Callable +from contextlib import suppress from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Union @@ -47,6 +48,15 @@ ScheduleMessageItem, ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_ARCHIVE_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, + TaskPriorityLevel, +) from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue @@ -55,6 +65,7 @@ from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule @@ -642,19 +653,115 @@ def update_activation_memory_periodically( logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + """Submit messages for processing, with priority-aware dispatch. + + - LEVEL_1 tasks dispatch immediately to the appropriate handler. + - Lower-priority tasks are enqueued via the configured message queue. + """ if isinstance(messages, ScheduleMessageItem): messages = [messages] - for message in messages: - self.metrics.task_enqueued(user_id=message.user_id, task_type=message.label) + + if not messages: + return + + immediate_msgs: list[ScheduleMessageItem] = [] + queued_msgs: list[ScheduleMessageItem] = [] + + for msg in messages: + # basic metrics and status tracking + with suppress(Exception): + self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) + + # ensure timestamp exists for monitoring + if getattr(msg, "timestamp", None) is None: + msg.timestamp = get_utc_now() + if self.status_tracker: - self.status_tracker.task_submitted( - task_id=message.item_id, - user_id=message.user_id, - task_type=message.label, - mem_cube_id=message.mem_cube_id, - business_task_id=message.task_id, # Pass business task_id if provided + try: + self.status_tracker.task_submitted( + task_id=msg.item_id, + user_id=msg.user_id, + task_type=msg.label, + mem_cube_id=msg.mem_cube_id, + business_task_id=msg.task_id, + ) + except Exception: + logger.warning("status_tracker.task_submitted failed", exc_info=True) + + # honor disabled handlers + if self.disabled_handlers and msg.label in self.disabled_handlers: + logger.info(f"Skipping disabled handler: {msg.label} - {msg.content}") + continue + + # decide priority path + task_priority = self.orchestrator.get_task_priority(task_label=msg.label) + if task_priority == TaskPriorityLevel.LEVEL_1: + immediate_msgs.append(msg) + else: + queued_msgs.append(msg) + + # Dispatch high-priority tasks immediately + if immediate_msgs: + # emit enqueue events for consistency + for m in immediate_msgs: + emit_monitor_event( + "enqueue", m, {"enqueue_ts": to_iso(getattr(m, "timestamp", None))} ) - self.memos_message_queue.submit_messages(messages=messages) + + # simulate dequeue for immediately dispatched messages so monitor logs stay complete + for m in immediate_msgs: + try: + now = time.time() + enqueue_ts_obj = getattr(m, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + object.__setattr__(m, "_dequeue_ts", now) + emit_monitor_event( + "dequeue", + m, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(), + "queue_wait_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label) + except Exception: + logger.debug("Failed to emit dequeue for immediate task", exc_info=True) + + user_cube_groups = group_messages_by_user_and_mem_cube(immediate_msgs) + for user_id, cube_groups in user_cube_groups.items(): + for mem_cube_id, user_cube_msgs in cube_groups.items(): + label_groups: dict[str, list[ScheduleMessageItem]] = {} + for m in user_cube_msgs: + label_groups.setdefault(m.label, []).append(m) + + for label, msgs_by_label in label_groups.items(): + handler = self.dispatcher.handlers.get( + label, self.dispatcher._default_message_handler + ) + self.dispatcher.execute_task( + user_id=user_id, + mem_cube_id=mem_cube_id, + task_label=label, + msgs=msgs_by_label, + handler_call_back=handler, + ) + + # Enqueue lower-priority tasks + if queued_msgs: + self.memos_message_queue.submit_messages(messages=queued_msgs) def _submit_web_logs( self, @@ -706,15 +813,6 @@ def get_web_log_messages(self) -> list[dict]: break def _map_label(label: str) -> str: - from memos.mem_scheduler.schemas.task_schemas import ( - ADD_TASK_LABEL, - ANSWER_TASK_LABEL, - MEM_ARCHIVE_TASK_LABEL, - MEM_ORGANIZE_TASK_LABEL, - MEM_UPDATE_TASK_LABEL, - QUERY_TASK_LABEL, - ) - mapping = { QUERY_TASK_LABEL: "addMessage", ANSWER_TASK_LABEL: "addMessage", diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 080a76389..dc64f5a45 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -2,7 +2,6 @@ import contextlib import json import os -import time import traceback from memos.configs.mem_scheduler import GeneralSchedulerConfig @@ -339,17 +338,9 @@ def log_add_messages(self, msg: ScheduleMessageItem): try: # This mem_item represents the NEW content that was just added/processed mem_item: TextualMemoryItem | None = None - for attempt in range(3): - try: - mem_item = self.current_mem_cube.text_mem.get( - memory_id=memory_id, user_name=msg.mem_cube_id - ) - break - except Exception: - if attempt < 2: - time.sleep(0.5) - else: - raise + mem_item = self.current_mem_cube.text_mem.get( + memory_id=memory_id, user_name=msg.mem_cube_id + ) if mem_item is None: raise ValueError(f"Memory {memory_id} not found after retries") # Check if a memory with the same key already exists (determining if it's an update) diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 8493c596d..f4ad9fe48 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,5 +1,3 @@ -import os - from pathlib import Path @@ -53,14 +51,3 @@ DEFAULT_MAX_QUERY_KEY_WORDS = 1000 DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05] DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 - -# task queue -DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.6" -exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME", None) -if exchange_name is not None: - DEFAULT_STREAM_KEY_PREFIX += f":{exchange_name}" - -# pending claim configuration -# Only claim pending messages whose idle time exceeds this threshold. -# Unit: milliseconds. Default: 10 minute. -DEFAULT_PENDING_CLAIM_MIN_IDLE_MS = 600_000 diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index f82b12d32..a147ebee0 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -1,3 +1,5 @@ +import os + from datetime import datetime from enum import Enum from pathlib import Path @@ -43,6 +45,28 @@ class TaskPriorityLevel(Enum): USER_INPUT_TYPE = "UserInput" NOT_APPLICABLE_TYPE = "NotApplicable" +# pending claim configuration +# Only claim pending messages whose idle time exceeds this threshold. +# Unit: milliseconds. Default: 10 minute. +DEFAULT_PENDING_CLAIM_MIN_IDLE_MS = 600_000 + +# scheduler daemon defaults +# Interval in seconds for periodically releasing stale pending messages +DEFAULT_PENDING_REQUEUE_INTERVAL_SEC = 30.0 + +# Interval in seconds for refreshing cached Redis stream keys +DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC = 30.0 + +# Interval in seconds for batching and cleaning up deletions (xdel) +DEFAULT_DELETE_CLEANUP_INTERVAL_SEC = 30.0 + + +# task queue +DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.7" +exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME", None) +if exchange_name is not None: + DEFAULT_STREAM_KEY_PREFIX += f":{exchange_name}" + # ============== Running Tasks ============== class RunningTaskItem(BaseModel, DictConversionMixin): diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 59afd7b61..b32e4588d 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -20,7 +20,7 @@ DEFAULT_STOP_WAIT, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem, TaskPriorityLevel +from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue @@ -132,9 +132,8 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): start_time = time.time() start_iso = datetime.fromtimestamp(start_time, tz=timezone.utc).isoformat() if self.status_tracker: - self.status_tracker.task_started( - task_id=task_item.item_id, user_id=task_item.user_id - ) + for msg in messages: + self.status_tracker.task_started(task_id=msg.item_id, user_id=msg.user_id) try: first_msg = messages[0] trace_id = getattr(first_msg, "trace_id", None) or generate_trace_id() @@ -197,9 +196,8 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): duration = finish_time - start_time self.metrics.observe_task_duration(duration, m.user_id, m.label) if self.status_tracker: - self.status_tracker.task_completed( - task_id=task_item.item_id, user_id=task_item.user_id - ) + for msg in messages: + self.status_tracker.task_completed(task_id=msg.item_id, user_id=msg.user_id) self.metrics.task_completed(user_id=m.user_id, task_type=m.label) emit_monitor_event( @@ -229,9 +227,10 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): finish_time = time.time() self.metrics.task_failed(m.user_id, m.label, type(e).__name__) if self.status_tracker: - self.status_tracker.task_failed( - task_id=task_item.item_id, user_id=task_item.user_id, error_message=str(e) - ) + for msg in messages: + self.status_tracker.task_failed( + task_id=msg.item_id, user_id=msg.user_id, error_message=str(e) + ) emit_monitor_event( "finish", m, @@ -262,7 +261,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): ): try: for msg in messages: - redis_message_id = getattr(msg, "redis_message_id", "") + redis_message_id = msg.redis_message_id self.memos_message_queue.ack_message( user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, @@ -424,6 +423,54 @@ def _handle_future_result(self, future): except Exception as e: logger.error(f"Handler execution failed: {e!s}", exc_info=True) + def execute_task( + self, + user_id: str, + mem_cube_id: str, + task_label: str, + msgs: list[ScheduleMessageItem], + handler_call_back: Callable[[list[ScheduleMessageItem]], Any], + ): + if isinstance(msgs, ScheduleMessageItem): + msgs = [msgs] + # Create task tracking item for this dispatch + task_item = RunningTaskItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + task_info=f"Processing {len(msgs)} message(s) with label '{task_label}' for user {user_id} and mem_cube {mem_cube_id}", + task_name=f"{task_label}_handler", + messages=msgs, + ) + + # Uniformly register the task before execution + with self._task_lock: + self._running_tasks[task_item.item_id] = task_item + + # Create wrapped handler for task tracking + wrapped_handler = self._create_task_wrapper(handler_call_back, task_item) + + # dispatch to different handler + logger.debug(f"Task started: {task_item.get_execution_info()}") + + # If priority is LEVEL_1, force synchronous execution regardless of thread pool availability + use_thread_pool = self.enable_parallel_dispatch and self.dispatcher_executor is not None + + if use_thread_pool: + # Submit and track the future + future = self.dispatcher_executor.submit(wrapped_handler, msgs) + with self._task_lock: + self._futures.add(future) + future.add_done_callback(self._handle_future_result) + logger.info( + f"Dispatch {len(msgs)} message(s) to {task_label} handler for user {user_id} and mem_cube {mem_cube_id}." + ) + else: + # For synchronous execution, the wrapper will run and remove the task upon completion + logger.info( + f"Execute {len(msgs)} message(s) synchronously for {task_label} for user {user_id} and mem_cube {mem_cube_id}." + ) + wrapped_handler(msgs) + def dispatch(self, msg_list: list[ScheduleMessageItem]): """ Dispatch a list of messages to their respective handlers. @@ -449,51 +496,14 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): # Process each label group within this user/mem_cube combination for label, msgs in label_groups.items(): handler = self.handlers.get(label, self._default_message_handler) - - # Create task tracking item for this dispatch - task_item = RunningTaskItem( + self.execute_task( user_id=user_id, mem_cube_id=mem_cube_id, - task_info=f"Processing {len(msgs)} message(s) with label '{label}' for user {user_id} and mem_cube {mem_cube_id}", - task_name=f"{label}_handler", - messages=msgs, + task_label=label, + msgs=msgs, + handler_call_back=handler, ) - # Uniformly register the task before execution - with self._task_lock: - self._running_tasks[task_item.item_id] = task_item - - # Create wrapped handler for task tracking - wrapped_handler = self._create_task_wrapper(handler, task_item) - - task_priority = self.orchestrator.get_task_priority(task_label=label) - - # dispatch to different handler - logger.debug(f"Task started: {task_item.get_execution_info()}") - - # If priority is LEVEL_1, force synchronous execution regardless of thread pool availability - use_thread_pool = ( - self.enable_parallel_dispatch - and self.dispatcher_executor is not None - and task_priority != TaskPriorityLevel.LEVEL_1 - ) - - if use_thread_pool: - # Submit and track the future - future = self.dispatcher_executor.submit(wrapped_handler, msgs) - with self._task_lock: - self._futures.add(future) - future.add_done_callback(self._handle_future_result) - logger.info( - f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}." - ) - else: - # For synchronous execution, the wrapper will run and remove the task upon completion - logger.info( - f"Execute {len(msgs)} message(s) synchronously for {label} (priority: {task_priority}) for user {user_id} and mem_cube {mem_cube_id}." - ) - wrapped_handler(msgs) - def join(self, timeout: float | None = None) -> bool: """Wait for all dispatched tasks to complete. diff --git a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py index 19da9c7de..d655c6919 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py +++ b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py @@ -19,6 +19,8 @@ from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, ANSWER_TASK_LABEL, + DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, + PREF_ADD_TASK_LABEL, QUERY_TASK_LABEL, TaskPriorityLevel, ) @@ -42,15 +44,22 @@ def __init__(self): ANSWER_TASK_LABEL: TaskPriorityLevel.LEVEL_1, } + # Per-task minimum idle time (ms) before claiming pending messages + # Default fallback handled in `get_task_idle_min`. + self.tasks_min_idle_ms = { + # Preferential add tasks: allow claiming pending sooner (1 minute) + PREF_ADD_TASK_LABEL: 60_000, + } + def get_stream_priorities(self) -> None | dict: return None def get_task_priority(self, task_label: str): - task_priority = TaskPriorityLevel.LEVEL_3 - if task_label in self.tasks_priorities: - task_priority = self.tasks_priorities[task_label] - logger.info(f"get_task_priority: {task_priority}") - return task_priority + return self.tasks_priorities.get(task_label, TaskPriorityLevel.LEVEL_3) + + def get_task_idle_min(self, task_label: str) -> int: + idle_min = self.tasks_min_idle_ms.get(task_label, DEFAULT_PENDING_CLAIM_MIN_IDLE_MS) + return idle_min def get_stream_quotas(self, stream_keys, consume_batch_size) -> dict: stream_priorities = self.get_stream_priorities() diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index fb38a0f44..b937ba8de 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -16,11 +16,11 @@ from memos.context.context import ContextThread from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import ( - DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( DEFAULT_STREAM_KEY_PREFIX, + DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC, ) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -67,7 +67,7 @@ def __init__( # Stream configuration self.stream_key_prefix = stream_key_prefix self.consumer_group = consumer_group - self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" + self.consumer_name = f"{consumer_name}_{uuid4().hex[:8]}" self.max_len = max_len self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages @@ -102,10 +102,92 @@ def __init__( self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator + # Cached stream keys and refresh control + self._stream_keys_cache: list[str] = [] + self._stream_keys_last_refresh: float = 0.0 + self._stream_keys_refresh_interval_sec: float = DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC + self._stream_keys_lock = threading.Lock() + self._stream_keys_refresh_thread: ContextThread | None = None + self._stream_keys_refresh_stop_event = threading.Event() + + # Start background stream keys refresher if connected + if self._is_connected: + # Refresh once synchronously to seed cache at init + try: + self._refresh_stream_keys() + except Exception as e: + logger.debug(f"Initial stream keys refresh failed: {e}") + + # Then start background refresher + self._start_stream_keys_refresh_thread() + def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}" return stream_key + # --- Stream keys refresh background thread --- + def _refresh_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: + """Scan Redis and refresh cached stream keys for the queue prefix.""" + if not self._redis_conn: + return [] + + if stream_key_prefix is None: + stream_key_prefix = self.stream_key_prefix + + try: + redis_pattern = f"{stream_key_prefix}:*" + raw_keys_iter = self._redis_conn.scan_iter(match=redis_pattern) + raw_keys = list(raw_keys_iter) + + escaped_prefix = re.escape(stream_key_prefix) + regex_pattern = f"^{escaped_prefix}:" + stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)] + + if stream_key_prefix == self.stream_key_prefix: + with self._stream_keys_lock: + self._stream_keys_cache = stream_keys + self._stream_keys_last_refresh = time.time() + return stream_keys + except Exception as e: + logger.warning(f"Failed to refresh stream keys: {e}") + return [] + + def _stream_keys_refresh_loop(self) -> None: + """Background loop to periodically refresh Redis stream keys cache.""" + # Seed cache immediately + self._refresh_stream_keys() + logger.debug( + f"Stream keys refresher started with interval={self._stream_keys_refresh_interval_sec}s" + ) + while not self._stream_keys_refresh_stop_event.is_set(): + try: + self._refresh_stream_keys() + except Exception as e: + logger.warning(f"Stream keys refresh iteration failed: {e}") + # Wait with ability to be interrupted + self._stream_keys_refresh_stop_event.wait(self._stream_keys_refresh_interval_sec) + + logger.debug("Stream keys refresher stopped") + + def _start_stream_keys_refresh_thread(self) -> None: + if self._stream_keys_refresh_thread and self._stream_keys_refresh_thread.is_alive(): + return + self._stream_keys_refresh_stop_event.clear() + self._stream_keys_refresh_thread = ContextThread( + target=self._stream_keys_refresh_loop, + name="redis-stream-keys-refresher", + daemon=True, + ) + self._stream_keys_refresh_thread.start() + + def _stop_stream_keys_refresh_thread(self) -> None: + try: + self._stream_keys_refresh_stop_event.set() + if self._stream_keys_refresh_thread and self._stream_keys_refresh_thread.is_alive(): + self._stream_keys_refresh_thread.join(timeout=2.0) + except Exception as e: + logger.debug(f"Stopping stream keys refresh thread encountered: {e}") + def task_broker( self, consume_batch_size: int, @@ -221,6 +303,12 @@ def put( self.seen_streams.add(stream_key) self._ensure_consumer_group(stream_key=stream_key) + # Update stream keys cache with newly observed stream key + with self._stream_keys_lock: + if stream_key not in self._stream_keys_cache: + self._stream_keys_cache.append(stream_key) + self._stream_keys_last_refresh = time.time() + message.stream_key = stream_key # Convert message to dictionary for Redis storage @@ -263,10 +351,8 @@ def ack_message( logger.warning( f"xack failed for stream '{stream_key}', msg_id='{redis_message_id}': {e}" ) - return - - # Optionally delete the message from the stream to keep it clean if self.auto_delete_acked: + # Optionally delete the message from the stream to keep it clean try: self._redis_conn.xdel(stream_key, redis_message_id) logger.info(f"Successfully delete acknowledged message {redis_message_id}") @@ -333,6 +419,7 @@ def get( need_pending = max(0, batch_size - new_count) need_pending_count = need_pending if need_pending > 0 else 0 + task_label = stream_key.rsplit(":", 1)[1] if need_pending_count: # Claim only pending messages whose idle time exceeds configured threshold try: @@ -343,7 +430,8 @@ def get( name=stream_key, groupname=self.consumer_group, consumername=self.consumer_name, - min_idle_time=DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, + # Derive task_label from stream_key suffix: {prefix}:{user_id}:{mem_cube_id}:{task_label} + min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label), start_id="0-0", count=need_pending_count, justid=False, @@ -356,20 +444,19 @@ def get( logger.warning( f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)." ) - try: - self._ensure_consumer_group(stream_key=stream_key) - next_id, claimed = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, - start_id="0-0", - count=need_pending_count, - justid=False, - ) - pending_messages = [(stream_key, claimed)] if claimed else [] - except Exception: - pending_messages = [] + self._ensure_consumer_group(stream_key=stream_key) + next_id, claimed = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min( + task_label=task_label + ), + start_id="0-0", + count=need_pending_count, + justid=False, + ) + pending_messages = [(stream_key, claimed)] if claimed else [] else: pending_messages = [] @@ -381,7 +468,6 @@ def get( messages.extend(pending_messages) result_messages = [] - for _stream, stream_messages in messages: for message_id, fields in stream_messages: try: @@ -392,7 +478,7 @@ def get( result_messages.append(message) except Exception as e: - logger.error(f"Failed to parse message {message_id}: {e}") + logger.error(f"Failed to parse message {message_id}: {e}", stack_info=True) # Always return a list for consistency if not result_messages: @@ -437,37 +523,34 @@ def qsize(self) -> dict: return qsize_stats except Exception as e: - logger.error(f"Failed to get Redis queue size: {e}") + logger.error(f"Failed to get Redis queue size: {e}", stack_info=True) return {} def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: """ - List all Redis stream keys that match this queue's prefix. + Return cached Redis stream keys maintained by background refresher. - Only returns actual Redis Stream keys, excluding auxiliary keys - (e.g., any lock or string/hash keys). This avoids WRONGTYPE errors - when issuing stream commands on non-stream keys. + The cache is updated periodically by a background thread and also + appended immediately on new stream creation via `put`. - Returns: - A list of stream keys like `"{prefix}:{user_id}:{mem_cube_id}:{task_label}"`. + Before returning, validate that all cached keys match the given + `stream_key_prefix` (or the queue's configured prefix if None). + If any key does not match, log an error. """ - if not self._redis_conn: - return [] + effective_prefix = stream_key_prefix or self.stream_key_prefix + with self._stream_keys_lock: + cache_snapshot = list(self._stream_keys_cache) - if stream_key_prefix is None: - stream_key_prefix = self.stream_key_prefix - # First, get all keys that might match (using Redis pattern matching) - redis_pattern = f"{stream_key_prefix}:*" - raw_keys_iter = self._redis_conn.scan_iter(match=redis_pattern) - raw_keys = list(raw_keys_iter) - - # Second, filter using Python regex to ensure exact prefix match - # Escape special regex characters in the prefix, then add :.* - escaped_prefix = re.escape(stream_key_prefix) + # Validate that cached keys conform to the expected prefix + escaped_prefix = re.escape(effective_prefix) regex_pattern = f"^{escaped_prefix}:" - stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)] + for key in cache_snapshot: + if not re.match(regex_pattern, key): + logger.error( + f"[REDIS_QUEUE] Cached stream key '{key}' does not match prefix '{effective_prefix}:'" + ) - return stream_keys + return cache_snapshot def size(self) -> int: """ @@ -578,6 +661,8 @@ def connect(self) -> None: self._redis_conn.ping() self._is_connected = True logger.debug("Redis connection established successfully") + # Start stream keys refresher when connected + self._start_stream_keys_refresh_thread() except Exception as e: logger.error(f"Failed to connect to Redis: {e}") self._is_connected = False @@ -588,6 +673,8 @@ def connect(self) -> None: def disconnect(self) -> None: """Disconnect from Redis and clean up resources.""" self._is_connected = False + # Stop background refresher + self._stop_stream_keys_refresh_thread() if self._is_listening: self.stop_listening() logger.debug("Disconnected from Redis") @@ -604,6 +691,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __del__(self): """Cleanup when object is destroyed.""" + self._stop_stream_keys_refresh_thread() if self._is_connected: self.disconnect()