diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 79c28c32c..58765f055 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -16,6 +16,7 @@ ContextThread, RequestContext, get_current_context, + get_current_trace_id, set_request_context, ) from memos.llms.base import BaseLLM @@ -664,10 +665,16 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt if not messages: return + current_trace_id = get_current_trace_id() + immediate_msgs: list[ScheduleMessageItem] = [] queued_msgs: list[ScheduleMessageItem] = [] for msg in messages: + # propagate request trace_id when available so monitor logs align with request logs + if current_trace_id: + msg.trace_id = current_trace_id + # basic metrics and status tracking with suppress(Exception): self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index b32e4588d..ab67c683f 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -210,6 +210,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): finish_time, tz=timezone.utc ).isoformat(), "exec_duration_ms": duration * 1000, + "total_duration_ms": self._calc_total_duration_ms( + finish_time, getattr(first_msg, "timestamp", None) + ), }, ) # Redis ack is handled in finally to cover failure cases @@ -243,6 +246,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): "exec_duration_ms": (finish_time - start_time) * 1000, "error_type": type(e).__name__, "error_msg": str(e), + "total_duration_ms": self._calc_total_duration_ms( + finish_time, getattr(m, "timestamp", None) + ), }, ) # Mark task as failed and remove from tracking @@ -423,6 +429,30 @@ def _handle_future_result(self, future): except Exception as e: logger.error(f"Handler execution failed: {e!s}", exc_info=True) + @staticmethod + def _calc_total_duration_ms(finish_epoch: float, enqueue_ts) -> float | None: + """ + Calculate total duration from enqueue timestamp to finish time in milliseconds. + """ + try: + enq_epoch = None + + if isinstance(enqueue_ts, int | float): + enq_epoch = float(enqueue_ts) + elif hasattr(enqueue_ts, "timestamp"): + dt = enqueue_ts + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enq_epoch = dt.timestamp() + + if enq_epoch is None: + return None + + total_ms = max(0.0, finish_epoch - enq_epoch) * 1000 + return total_ms + except Exception: + return None + def execute_task( self, user_id: str,