diff --git a/src/memos/context/context.py b/src/memos/context/context.py index b5d4c24fe..5c8401732 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -88,13 +88,16 @@ def to_dict(self) -> dict[str, Any]: } -def set_request_context(context: RequestContext) -> None: +def set_request_context(context: RequestContext | None) -> None: """ Set the current request context. This is typically called by the API dependency injection system. """ - _request_context.set(context.to_dict()) + if context: + _request_context.set(context.to_dict()) + else: + _request_context.set(None) def get_current_trace_id() -> str | None: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 610999697..add689336 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -11,7 +11,12 @@ from sqlalchemy.engine import Engine from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig -from memos.context.context import ContextThread +from memos.context.context import ( + ContextThread, + RequestContext, + get_current_context, + set_request_context, +) from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.base import BaseMemCube @@ -775,35 +780,46 @@ def _message_consumer(self) -> None: if messages: now = time.time() for msg in messages: - enqueue_ts_obj = getattr(msg, "timestamp", None) - enqueue_epoch = None - if isinstance(enqueue_ts_obj, int | float): - enqueue_epoch = float(enqueue_ts_obj) - elif hasattr(enqueue_ts_obj, "timestamp"): - dt = enqueue_ts_obj - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - enqueue_epoch = dt.timestamp() - - queue_wait_ms = None - if enqueue_epoch is not None: - queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 - - # Avoid pydantic field enforcement by using object.__setattr__ - object.__setattr__(msg, "_dequeue_ts", now) - emit_monitor_event( - "dequeue", - msg, - { - "enqueue_ts": to_iso(enqueue_ts_obj), - "dequeue_ts": datetime.fromtimestamp( - now, tz=timezone.utc - ).isoformat(), - "queue_wait_ms": queue_wait_ms, - }, - ) - - self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) + prev_context = get_current_context() + try: + # Set context for this message + msg_context = RequestContext( + trace_id=msg.trace_id, + user_name=msg.user_name, + ) + set_request_context(msg_context) + + enqueue_ts_obj = getattr(msg, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + # Avoid pydantic field enforcement by using object.__setattr__ + object.__setattr__(msg, "_dequeue_ts", now) + emit_monitor_event( + "dequeue", + msg, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp( + now, tz=timezone.utc + ).isoformat(), + "queue_wait_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) + finally: + # Restore the prior context of the consumer thread + set_request_context(prev_context) try: import contextlib