Skip to content
3 changes: 2 additions & 1 deletion src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def initialize_modules(
if self.dispatcher:
self.dispatcher.status_tracker = self.status_tracker
if self.memos_message_queue:
self.memos_message_queue.status_tracker = self.status_tracker
# Use the setter to propagate to the inner queue (e.g. SchedulerRedisQueue)
self.memos_message_queue.set_status_tracker(self.status_tracker)
# initialize submodules
self.chat_llm = chat_llm
self.process_llm = process_llm
Expand Down
84 changes: 83 additions & 1 deletion src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent
import os
import threading
import time

Expand All @@ -19,7 +20,7 @@
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_STOP_WAIT,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem, ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
Expand Down Expand Up @@ -200,6 +201,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
if self.status_tracker:
for msg in messages:
self.status_tracker.task_completed(task_id=msg.item_id, user_id=msg.user_id)
self._maybe_emit_task_completion(messages)
self.metrics.task_completed(user_id=m.user_id, task_type=m.label)

emit_monitor_event(
Expand Down Expand Up @@ -237,6 +239,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
self.status_tracker.task_failed(
task_id=msg.item_id, user_id=msg.user_id, error_message=str(e)
)
self._maybe_emit_task_completion(messages, error=e)
emit_monitor_event(
"finish",
m,
Expand Down Expand Up @@ -284,6 +287,85 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):

return wrapped_handler

def _maybe_emit_task_completion(
self, messages: list[ScheduleMessageItem], error: Exception | None = None
) -> None:
"""If all item_ids under a business task are completed, emit a single completion log."""
if not self.submit_web_logs or not self.status_tracker:
return

# messages in one batch can belong to different business task_ids; check each
task_ids = {getattr(msg, "task_id", None) for msg in messages}
task_ids.discard(None)
if not task_ids:
return

# Use the first message only for shared fields; mem_cube_id is same within a batch
first = messages[0]
user_id = first.user_id
mem_cube_id = first.mem_cube_id

try:
is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") == "memos-memory-change"
if not is_cloud_env:
return

for task_id in task_ids:
status_data = self.status_tracker.get_task_status_by_business_id(
business_task_id=task_id, user_id=user_id
)
if not status_data:
continue

status = status_data.get("status")

if status == "completed":
# Only emit success log if we didn't just catch an exception locally
# (Although if status is 'completed', local error shouldn't happen theoretically,
# unless status update lags or is inconsistent. We trust status_tracker here.)
event = ScheduleLogForWebItem(
task_id=task_id,
user_id=user_id,
mem_cube_id=mem_cube_id,
label="taskStatus",
from_memory_type="status",
to_memory_type="status",
log_content=f"Task {task_id} completed",
status="completed",
)
self.submit_web_logs(event)

elif status == "failed":
# Construct error message
error_msg = str(error) if error else None
if not error_msg:
# Try to get errors from status_tracker aggregation
errors = status_data.get("errors", [])
if errors:
error_msg = "; ".join(errors)
else:
error_msg = "Unknown error (check system logs)"

event = ScheduleLogForWebItem(
task_id=task_id,
user_id=user_id,
mem_cube_id=mem_cube_id,
label="taskStatus",
from_memory_type="status",
to_memory_type="status",
log_content=f"Task {task_id} failed: {error_msg}",
status="failed",
)
self.submit_web_logs(event)
except Exception:
logger.warning(
"Failed to emit task completion log. user_id=%s mem_cube_id=%s task_ids=%s",
user_id,
mem_cube_id,
list(task_ids),
exc_info=True,
)

def get_running_tasks(
self, filter_func: Callable[[RunningTaskItem], bool] | None = None
) -> dict[str, RunningTaskItem]:
Expand Down
5 changes: 4 additions & 1 deletion src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC,
)
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule


Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
consumer_name: str | None = "scheduler_consumer",
max_len: int | None = None,
auto_delete_acked: bool = True, # Whether to automatically delete acknowledged messages
status_tracker: TaskStatusTracker | None = None,
):
"""
Initialize the Redis queue.
Expand All @@ -62,6 +64,7 @@ def __init__(
max_len: Maximum length of the stream (for memory management)
maxsize: Maximum size of the queue (for Queue compatibility, ignored)
auto_delete_acked: Whether to automatically delete acknowledged messages from stream
status_tracker: TaskStatusTracker instance for tracking task status
"""
super().__init__()
# Stream configuration
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(
self.message_pack_cache = deque()

self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator
self.status_tracker = status_tracker

# Cached stream keys and refresh control
self._stream_keys_cache: list[str] = []
Expand Down Expand Up @@ -354,7 +358,6 @@ def ack_message(
self._redis_conn.xack(stream_key, self.consumer_group, redis_message_id)

if message:
self.status_tracker.task_completed(task_id=message.item_id, user_id=message.user_id)
logger.info(
f"Message {message.item_id} | {message.label} | {message.content} has been acknowledged."
)
Expand Down
14 changes: 14 additions & 0 deletions src/memos/mem_scheduler/task_schedule_modules/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,26 @@ def __init__(
consumer_group="scheduler_group",
consumer_name="scheduler_consumer",
orchestrator=self.orchestrator,
status_tracker=self.status_tracker, # Propagate status_tracker
)
else:
self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize)

self.disabled_handlers = disabled_handlers

def set_status_tracker(self, status_tracker: TaskStatusTracker) -> None:
"""
Set the status tracker for this queue and propagate it to the underlying queue implementation.

This allows the tracker to be injected after initialization (e.g., when Redis connection becomes available).
"""
self.status_tracker = status_tracker
if self.memos_message_queue and hasattr(self.memos_message_queue, "status_tracker"):
# SchedulerRedisQueue has status_tracker attribute (from our previous fix)
# SchedulerLocalQueue can also accept it dynamically if it doesn't use __slots__
self.memos_message_queue.status_tracker = status_tracker
logger.info("Propagated status_tracker to underlying message queue")

def ack_message(
self,
user_id: str,
Expand Down
4 changes: 4 additions & 0 deletions src/memos/mem_scheduler/utils/status_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,14 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) ->
# Get statuses for all items
key = self._get_key(user_id)
item_statuses = []
errors = []
for item_id in item_ids:
item_data_json = self.redis.hget(key, item_id)
if item_data_json:
item_data = json.loads(item_data_json)
item_statuses.append(item_data["status"])
if item_data.get("status") == "failed" and "error" in item_data:
errors.append(item_data["error"])

if not item_statuses:
return None
Expand All @@ -167,6 +170,7 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) ->
"business_task_id": business_task_id,
"item_count": len(item_ids),
"item_statuses": item_statuses,
"errors": errors,
}

def get_all_tasks_global(self) -> dict[str, dict[str, dict]]:
Expand Down