Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion src/memos/api/handlers/scheduler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,181 @@
import time
import traceback

from collections import Counter
from datetime import datetime, timezone
from typing import Any

from fastapi import HTTPException
from fastapi.responses import StreamingResponse

# Imports for new implementation
from memos.api.product_models import StatusResponse, StatusResponseItem
from memos.api.product_models import (
AllStatusResponse,
AllStatusResponseData,
StatusResponse,
StatusResponseItem,
TaskSummary,
)
from memos.log import get_logger
from memos.mem_scheduler.base_scheduler import BaseScheduler
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker


logger = get_logger(__name__)


def handle_scheduler_allstatus(
mem_scheduler: BaseScheduler,
status_tracker: TaskStatusTracker,
) -> AllStatusResponse:
"""
Get aggregated scheduler status metrics (no per-task payload).

Args:
mem_scheduler: The BaseScheduler instance.
status_tracker: The TaskStatusTracker instance.

Returns:
AllStatusResponse with aggregated status data.
"""

def _summarize_tasks(task_details: list[dict[str, Any]]) -> TaskSummary:
"""Aggregate counts by status for the provided task details (tracker data)."""
counter = Counter()
for detail in task_details:
status = detail.get("status")
if status:
counter[status] += 1

total = sum(counter.values())
return TaskSummary(
waiting=counter.get("waiting", 0),
in_progress=counter.get("in_progress", 0),
completed=counter.get("completed", 0),
pending=counter.get("pending", counter.get("waiting", 0)),
failed=counter.get("failed", 0),
cancelled=counter.get("cancelled", 0),
total=total,
)

def _aggregate_counts_from_redis(
tracker: TaskStatusTracker, max_age_seconds: float = 86400
) -> TaskSummary | None:
"""Stream status counts directly from Redis to avoid loading all task payloads."""
redis_client = getattr(tracker, "redis", None)
if not redis_client:
return None

counter = Counter()
now = datetime.now(timezone.utc).timestamp()

# Scan task_meta keys, then hscan each hash in batches
cursor: int | str = 0
while True:
cursor, keys = redis_client.scan(cursor=cursor, match="memos:task_meta:*", count=200)
for key in keys:
h_cursor: int | str = 0
while True:
h_cursor, fields = redis_client.hscan(key, cursor=h_cursor, count=500)
for value in fields.values():
try:
payload = json.loads(
value.decode("utf-8") if isinstance(value, bytes) else value
)
# Skip stale entries to reduce noise and load
ts = payload.get("submitted_at") or payload.get("started_at")
if ts:
try:
ts_dt = datetime.fromisoformat(ts)
ts_seconds = ts_dt.timestamp()
except Exception:
ts_seconds = None
if ts_seconds and (now - ts_seconds) > max_age_seconds:
continue
status = payload.get("status")
if status:
counter[status] += 1
except Exception:
continue
if h_cursor == 0 or h_cursor == "0":
break
if cursor == 0 or cursor == "0":
break

if not counter:
return TaskSummary() # Empty summary if nothing found

total = sum(counter.values())
return TaskSummary(
waiting=counter.get("waiting", 0),
in_progress=counter.get("in_progress", 0),
completed=counter.get("completed", 0),
pending=counter.get("pending", counter.get("waiting", 0)),
failed=counter.get("failed", 0),
cancelled=counter.get("cancelled", 0),
total=total,
)

try:
# Prefer streaming aggregation to avoid pulling all task payloads
all_tasks_summary = _aggregate_counts_from_redis(status_tracker)
if all_tasks_summary is None:
# Fallback: load all details then aggregate
global_tasks = status_tracker.get_all_tasks_global()
all_task_details: list[dict[str, Any]] = []
for _, tasks in global_tasks.items():
all_task_details.extend(tasks.values())
all_tasks_summary = _summarize_tasks(all_task_details)

# Scheduler view: assume tracker contains scheduler tasks; overlay queue monitor for live queue depth
sched_waiting = all_tasks_summary.waiting
sched_in_progress = all_tasks_summary.in_progress
sched_pending = all_tasks_summary.pending
sched_completed = all_tasks_summary.completed
sched_failed = all_tasks_summary.failed
sched_cancelled = all_tasks_summary.cancelled

# If queue monitor is available, prefer its live waiting/in_progress counts
if mem_scheduler.task_schedule_monitor:
queue_status_data = mem_scheduler.task_schedule_monitor.get_tasks_status() or {}
scheduler_waiting = 0
scheduler_in_progress = 0
scheduler_pending = 0
for key, value in queue_status_data.items():
if not key.startswith("scheduler:"):
continue
scheduler_in_progress += int(value.get("running", 0) or 0)
scheduler_pending += int(value.get("pending", value.get("remaining", 0)) or 0)
scheduler_waiting += int(value.get("remaining", 0) or 0)
sched_waiting = scheduler_waiting
sched_in_progress = scheduler_in_progress
sched_pending = scheduler_pending

scheduler_summary = TaskSummary(
waiting=sched_waiting,
in_progress=sched_in_progress,
pending=sched_pending,
completed=sched_completed,
failed=sched_failed,
cancelled=sched_cancelled,
total=sched_waiting
+ sched_in_progress
+ sched_completed
+ sched_failed
+ sched_cancelled,
)

return AllStatusResponse(
data=AllStatusResponseData(
scheduler_summary=scheduler_summary,
all_tasks_summary=all_tasks_summary,
)
)
except Exception as err:
logger.error(f"Failed to get full scheduler status: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail="Failed to get full scheduler status") from err


def handle_scheduler_status(
user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None
) -> StatusResponse:
Expand Down
31 changes: 31 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,34 @@ class StatusResponse(BaseResponse[list[StatusResponseItem]]):
"""Response model for scheduler status operations."""

message: str = "Memory get status successfully"


class TaskSummary(BaseModel):
"""Aggregated counts of tasks by status."""

waiting: int = Field(0, description="Number of tasks waiting to run")
in_progress: int = Field(0, description="Number of tasks currently running")
pending: int = Field(
0, description="Number of tasks fetched by workers but not yet acknowledged"
)
completed: int = Field(0, description="Number of tasks completed")
failed: int = Field(0, description="Number of tasks failed")
cancelled: int = Field(0, description="Number of tasks cancelled")
total: int = Field(0, description="Total number of tasks counted")


class AllStatusResponseData(BaseModel):
"""Aggregated scheduler status metrics."""

scheduler_summary: TaskSummary = Field(
..., description="Aggregated status for scheduler-managed tasks"
)
all_tasks_summary: TaskSummary = Field(
..., description="Aggregated status for all tracked tasks"
)


class AllStatusResponse(BaseResponse[AllStatusResponseData]):
"""Response model for full scheduler status operations."""

message: str = "Scheduler status summary retrieved successfully"
13 changes: 13 additions & 0 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from memos.api.handlers.feedback_handler import FeedbackHandler
from memos.api.handlers.search_handler import SearchHandler
from memos.api.product_models import (
AllStatusResponse,
APIADDRequest,
APIChatCompleteRequest,
APIFeedbackRequest,
Expand Down Expand Up @@ -114,6 +115,18 @@ def add_memories(add_req: APIADDRequest):
# =============================================================================


@router.get( # Changed from post to get
"/scheduler/allstatus",
summary="Get detailed scheduler status",
response_model=AllStatusResponse,
)
def scheduler_allstatus():
"""Get detailed scheduler status including running tasks and queue metrics."""
return handlers.scheduler_handler.handle_scheduler_allstatus(
mem_scheduler=mem_scheduler, status_tracker=status_tracker
)


@router.get( # Changed from post to get
"/scheduler/status", summary="Get scheduler running status", response_model=StatusResponse
)
Expand Down
20 changes: 15 additions & 5 deletions src/memos/mem_scheduler/monitors/task_schedule_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(

@staticmethod
def init_task_status() -> dict:
return {"running": 0, "remaining": 0}
return {"running": 0, "remaining": 0, "pending": 0}

def get_tasks_status(self) -> dict:
if isinstance(self.queue, SchedulerRedisQueue):
Expand Down Expand Up @@ -154,7 +154,9 @@ def _get_local_tasks_status(self) -> dict:
try:
# remaining is the sum of per-stream qsize
qsize_map = self.queue.qsize()
task_status["remaining"] = sum(v for k, v in qsize_map.items() if isinstance(v, int))
remaining_total = sum(v for k, v in qsize_map.items() if isinstance(v, int))
task_status["remaining"] = remaining_total
task_status["pending"] = remaining_total
# running from dispatcher if available
if self.dispatcher and hasattr(self.dispatcher, "get_running_task_count"):
task_status["running"] = int(self.dispatcher.get_running_task_count())
Expand Down Expand Up @@ -200,11 +202,15 @@ async def _collect_async() -> dict:
if group.get("name") == self.queue.consumer_group:
pending = int(group.get("pending", 0))
break
# Remaining = total messages (xlen) - pending for our group
remaining = max(0, int(xlen_val or 0))
total_messages = max(0, int(xlen_val or 0))
remaining = max(0, total_messages - pending)
# running = in-progress (delivered, not yet acked)
local[stream_key]["running"] += pending
# pending = not yet delivered (remaining)
local[stream_key]["pending"] += remaining
local[stream_key]["remaining"] += remaining
local["running"] += pending
local["pending"] += remaining
local["remaining"] += remaining
return local

Expand Down Expand Up @@ -234,10 +240,14 @@ async def _collect_async() -> dict:
for group in groups_info:
if group.get("name") == self.queue.consumer_group:
pending = int(group.get("pending", 0))
remaining = max(0, xlen_val)
remaining = max(0, xlen_val - pending)
# running = in-progress (delivered, not yet acked)
task_status[stream_key]["running"] += pending
# pending = not yet delivered (remaining)
task_status[stream_key]["pending"] += remaining
task_status[stream_key]["remaining"] += remaining
task_status["running"] += pending
task_status["pending"] += remaining
task_status["remaining"] += remaining
break

Expand Down
28 changes: 28 additions & 0 deletions src/memos/mem_scheduler/utils/status_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,31 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) ->
"item_count": len(item_ids),
"item_statuses": item_statuses,
}

def get_all_tasks_global(self) -> dict[str, dict[str, dict]]:
"""
Retrieve all tasks for all users from Redis.

Returns:
dict: {user_id: {task_id: task_data, ...}, ...}
"""
all_users_tasks = {}
cursor: int | str = 0
while True:
cursor, keys = self.redis.scan(cursor=cursor, match="memos:task_meta:*", count=100)
for key in keys:
# key format: memos:task_meta:{user_id}
parts = key.split(":")
if len(parts) < 3:
continue
user_id = parts[2]

tasks = self.redis.hgetall(key)
if tasks:
user_tasks = {tid: json.loads(t_data) for tid, t_data in tasks.items()}
all_users_tasks[user_id] = user_tasks

if cursor == 0 or cursor == "0":
break

return all_users_tasks
5 changes: 2 additions & 3 deletions src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]:
content=feedback_req_str,
timestamp=datetime.utcnow(),
)
self.mem_scheduler.memos_message_queue.submit_messages(
messages=[message_item_feedback]
)
# Use scheduler submission to ensure tracking and metrics
self.mem_scheduler.submit_messages(messages=[message_item_feedback])
self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted FEEDBACK async")
except Exception as e:
self.logger.error(
Expand Down
Loading