diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py index 697822a77..d12a8ace4 100644 --- a/src/memos/api/handlers/scheduler_handler.py +++ b/src/memos/api/handlers/scheduler_handler.py @@ -9,20 +9,181 @@ import time import traceback +from collections import Counter +from datetime import datetime, timezone from typing import Any from fastapi import HTTPException from fastapi.responses import StreamingResponse # Imports for new implementation -from memos.api.product_models import StatusResponse, StatusResponseItem +from memos.api.product_models import ( + AllStatusResponse, + AllStatusResponseData, + StatusResponse, + StatusResponseItem, + TaskSummary, +) from memos.log import get_logger +from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker logger = get_logger(__name__) +def handle_scheduler_allstatus( + mem_scheduler: BaseScheduler, + status_tracker: TaskStatusTracker, +) -> AllStatusResponse: + """ + Get aggregated scheduler status metrics (no per-task payload). + + Args: + mem_scheduler: The BaseScheduler instance. + status_tracker: The TaskStatusTracker instance. + + Returns: + AllStatusResponse with aggregated status data. + """ + + def _summarize_tasks(task_details: list[dict[str, Any]]) -> TaskSummary: + """Aggregate counts by status for the provided task details (tracker data).""" + counter = Counter() + for detail in task_details: + status = detail.get("status") + if status: + counter[status] += 1 + + total = sum(counter.values()) + return TaskSummary( + waiting=counter.get("waiting", 0), + in_progress=counter.get("in_progress", 0), + completed=counter.get("completed", 0), + pending=counter.get("pending", counter.get("waiting", 0)), + failed=counter.get("failed", 0), + cancelled=counter.get("cancelled", 0), + total=total, + ) + + def _aggregate_counts_from_redis( + tracker: TaskStatusTracker, max_age_seconds: float = 86400 + ) -> TaskSummary | None: + """Stream status counts directly from Redis to avoid loading all task payloads.""" + redis_client = getattr(tracker, "redis", None) + if not redis_client: + return None + + counter = Counter() + now = datetime.now(timezone.utc).timestamp() + + # Scan task_meta keys, then hscan each hash in batches + cursor: int | str = 0 + while True: + cursor, keys = redis_client.scan(cursor=cursor, match="memos:task_meta:*", count=200) + for key in keys: + h_cursor: int | str = 0 + while True: + h_cursor, fields = redis_client.hscan(key, cursor=h_cursor, count=500) + for value in fields.values(): + try: + payload = json.loads( + value.decode("utf-8") if isinstance(value, bytes) else value + ) + # Skip stale entries to reduce noise and load + ts = payload.get("submitted_at") or payload.get("started_at") + if ts: + try: + ts_dt = datetime.fromisoformat(ts) + ts_seconds = ts_dt.timestamp() + except Exception: + ts_seconds = None + if ts_seconds and (now - ts_seconds) > max_age_seconds: + continue + status = payload.get("status") + if status: + counter[status] += 1 + except Exception: + continue + if h_cursor == 0 or h_cursor == "0": + break + if cursor == 0 or cursor == "0": + break + + if not counter: + return TaskSummary() # Empty summary if nothing found + + total = sum(counter.values()) + return TaskSummary( + waiting=counter.get("waiting", 0), + in_progress=counter.get("in_progress", 0), + completed=counter.get("completed", 0), + pending=counter.get("pending", counter.get("waiting", 0)), + failed=counter.get("failed", 0), + cancelled=counter.get("cancelled", 0), + total=total, + ) + + try: + # Prefer streaming aggregation to avoid pulling all task payloads + all_tasks_summary = _aggregate_counts_from_redis(status_tracker) + if all_tasks_summary is None: + # Fallback: load all details then aggregate + global_tasks = status_tracker.get_all_tasks_global() + all_task_details: list[dict[str, Any]] = [] + for _, tasks in global_tasks.items(): + all_task_details.extend(tasks.values()) + all_tasks_summary = _summarize_tasks(all_task_details) + + # Scheduler view: assume tracker contains scheduler tasks; overlay queue monitor for live queue depth + sched_waiting = all_tasks_summary.waiting + sched_in_progress = all_tasks_summary.in_progress + sched_pending = all_tasks_summary.pending + sched_completed = all_tasks_summary.completed + sched_failed = all_tasks_summary.failed + sched_cancelled = all_tasks_summary.cancelled + + # If queue monitor is available, prefer its live waiting/in_progress counts + if mem_scheduler.task_schedule_monitor: + queue_status_data = mem_scheduler.task_schedule_monitor.get_tasks_status() or {} + scheduler_waiting = 0 + scheduler_in_progress = 0 + scheduler_pending = 0 + for key, value in queue_status_data.items(): + if not key.startswith("scheduler:"): + continue + scheduler_in_progress += int(value.get("running", 0) or 0) + scheduler_pending += int(value.get("pending", value.get("remaining", 0)) or 0) + scheduler_waiting += int(value.get("remaining", 0) or 0) + sched_waiting = scheduler_waiting + sched_in_progress = scheduler_in_progress + sched_pending = scheduler_pending + + scheduler_summary = TaskSummary( + waiting=sched_waiting, + in_progress=sched_in_progress, + pending=sched_pending, + completed=sched_completed, + failed=sched_failed, + cancelled=sched_cancelled, + total=sched_waiting + + sched_in_progress + + sched_completed + + sched_failed + + sched_cancelled, + ) + + return AllStatusResponse( + data=AllStatusResponseData( + scheduler_summary=scheduler_summary, + all_tasks_summary=all_tasks_summary, + ) + ) + except Exception as err: + logger.error(f"Failed to get full scheduler status: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail="Failed to get full scheduler status") from err + + def handle_scheduler_status( user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None ) -> StatusResponse: diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 9dfd872b0..e77aee755 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -865,3 +865,34 @@ class StatusResponse(BaseResponse[list[StatusResponseItem]]): """Response model for scheduler status operations.""" message: str = "Memory get status successfully" + + +class TaskSummary(BaseModel): + """Aggregated counts of tasks by status.""" + + waiting: int = Field(0, description="Number of tasks waiting to run") + in_progress: int = Field(0, description="Number of tasks currently running") + pending: int = Field( + 0, description="Number of tasks fetched by workers but not yet acknowledged" + ) + completed: int = Field(0, description="Number of tasks completed") + failed: int = Field(0, description="Number of tasks failed") + cancelled: int = Field(0, description="Number of tasks cancelled") + total: int = Field(0, description="Total number of tasks counted") + + +class AllStatusResponseData(BaseModel): + """Aggregated scheduler status metrics.""" + + scheduler_summary: TaskSummary = Field( + ..., description="Aggregated status for scheduler-managed tasks" + ) + all_tasks_summary: TaskSummary = Field( + ..., description="Aggregated status for all tracked tasks" + ) + + +class AllStatusResponse(BaseResponse[AllStatusResponseData]): + """Response model for full scheduler status operations.""" + + message: str = "Scheduler status summary retrieved successfully" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 5b2107b6c..576cca55e 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -24,6 +24,7 @@ from memos.api.handlers.feedback_handler import FeedbackHandler from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( + AllStatusResponse, APIADDRequest, APIChatCompleteRequest, APIFeedbackRequest, @@ -114,6 +115,18 @@ def add_memories(add_req: APIADDRequest): # ============================================================================= +@router.get( # Changed from post to get + "/scheduler/allstatus", + summary="Get detailed scheduler status", + response_model=AllStatusResponse, +) +def scheduler_allstatus(): + """Get detailed scheduler status including running tasks and queue metrics.""" + return handlers.scheduler_handler.handle_scheduler_allstatus( + mem_scheduler=mem_scheduler, status_tracker=status_tracker + ) + + @router.get( # Changed from post to get "/scheduler/status", summary="Get scheduler running status", response_model=StatusResponse ) diff --git a/src/memos/mem_scheduler/monitors/task_schedule_monitor.py b/src/memos/mem_scheduler/monitors/task_schedule_monitor.py index 82e43d858..14bed8316 100644 --- a/src/memos/mem_scheduler/monitors/task_schedule_monitor.py +++ b/src/memos/mem_scheduler/monitors/task_schedule_monitor.py @@ -29,7 +29,7 @@ def __init__( @staticmethod def init_task_status() -> dict: - return {"running": 0, "remaining": 0} + return {"running": 0, "remaining": 0, "pending": 0} def get_tasks_status(self) -> dict: if isinstance(self.queue, SchedulerRedisQueue): @@ -154,7 +154,9 @@ def _get_local_tasks_status(self) -> dict: try: # remaining is the sum of per-stream qsize qsize_map = self.queue.qsize() - task_status["remaining"] = sum(v for k, v in qsize_map.items() if isinstance(v, int)) + remaining_total = sum(v for k, v in qsize_map.items() if isinstance(v, int)) + task_status["remaining"] = remaining_total + task_status["pending"] = remaining_total # running from dispatcher if available if self.dispatcher and hasattr(self.dispatcher, "get_running_task_count"): task_status["running"] = int(self.dispatcher.get_running_task_count()) @@ -200,11 +202,15 @@ async def _collect_async() -> dict: if group.get("name") == self.queue.consumer_group: pending = int(group.get("pending", 0)) break - # Remaining = total messages (xlen) - pending for our group - remaining = max(0, int(xlen_val or 0)) + total_messages = max(0, int(xlen_val or 0)) + remaining = max(0, total_messages - pending) + # running = in-progress (delivered, not yet acked) local[stream_key]["running"] += pending + # pending = not yet delivered (remaining) + local[stream_key]["pending"] += remaining local[stream_key]["remaining"] += remaining local["running"] += pending + local["pending"] += remaining local["remaining"] += remaining return local @@ -234,10 +240,14 @@ async def _collect_async() -> dict: for group in groups_info: if group.get("name") == self.queue.consumer_group: pending = int(group.get("pending", 0)) - remaining = max(0, xlen_val) + remaining = max(0, xlen_val - pending) + # running = in-progress (delivered, not yet acked) task_status[stream_key]["running"] += pending + # pending = not yet delivered (remaining) + task_status[stream_key]["pending"] += remaining task_status[stream_key]["remaining"] += remaining task_status["running"] += pending + task_status["pending"] += remaining task_status["remaining"] += remaining break diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index 9a8fa53df..f2edc5aea 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -168,3 +168,31 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) -> "item_count": len(item_ids), "item_statuses": item_statuses, } + + def get_all_tasks_global(self) -> dict[str, dict[str, dict]]: + """ + Retrieve all tasks for all users from Redis. + + Returns: + dict: {user_id: {task_id: task_data, ...}, ...} + """ + all_users_tasks = {} + cursor: int | str = 0 + while True: + cursor, keys = self.redis.scan(cursor=cursor, match="memos:task_meta:*", count=100) + for key in keys: + # key format: memos:task_meta:{user_id} + parts = key.split(":") + if len(parts) < 3: + continue + user_id = parts[2] + + tasks = self.redis.hgetall(key) + if tasks: + user_tasks = {tid: json.loads(t_data) for tid, t_data in tasks.items()} + all_users_tasks[user_id] = user_tasks + + if cursor == 0 or cursor == "0": + break + + return all_users_tasks diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 88c0f87c7..081056473 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -157,9 +157,8 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]: content=feedback_req_str, timestamp=datetime.utcnow(), ) - self.mem_scheduler.memos_message_queue.submit_messages( - messages=[message_item_feedback] - ) + # Use scheduler submission to ensure tracking and metrics + self.mem_scheduler.submit_messages(messages=[message_item_feedback]) self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted FEEDBACK async") except Exception as e: self.logger.error(