diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 7dc40b276..5720939e0 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Union @@ -49,6 +49,7 @@ from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) +from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -175,6 +176,8 @@ def init_mem_cube( searcher: Searcher | None = None, feedback_server: Searcher | None = None, ): + if mem_cube is None: + logger.error("mem_cube is None, cannot initialize", stack_info=True) self.mem_cube = mem_cube self.text_mem: TreeTextMemory = self.mem_cube.text_mem self.reranker: HTTPBGEReranker = self.text_mem.reranker @@ -258,6 +261,15 @@ def _cleanup_on_init_failure(self): @property def mem_cube(self) -> BaseMemCube: """The memory cube associated with this MemChat.""" + if self.current_mem_cube is None: + logger.error("mem_cube is None when accessed", stack_info=True) + try: + self.components = init_components() + self.current_mem_cube: BaseMemCube = self.components["naive_mem_cube"] + except Exception: + logger.info( + "No environment available to initialize mem cube. Using fallback naive_mem_cube." + ) return self.current_mem_cube @mem_cube.setter @@ -757,7 +769,35 @@ def _message_consumer(self) -> None: messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) if messages: + now = time.time() for msg in messages: + enqueue_ts_obj = getattr(msg, "timestamp", None) + enqueue_epoch = None + if isinstance(enqueue_ts_obj, int | float): + enqueue_epoch = float(enqueue_ts_obj) + elif hasattr(enqueue_ts_obj, "timestamp"): + dt = enqueue_ts_obj + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + enqueue_epoch = dt.timestamp() + + queue_wait_ms = None + if enqueue_epoch is not None: + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 + + msg.dequeue_ts = now + emit_monitor_event( + "dequeue", + msg, + { + "enqueue_ts": to_iso(enqueue_ts_obj), + "dequeue_ts": datetime.fromtimestamp( + now, tz=timezone.utc + ).isoformat(), + "queue_wait_ms": queue_wait_ms, + }, + ) + self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) try: import contextlib diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 954855f90..30cba81b3 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -66,7 +66,12 @@ DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 # task queue -DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.4" +DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.5" exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME", None) if exchange_name is not None: DEFAULT_STREAM_KEY_PREFIX += f":{exchange_name}" + +# pending claim configuration +# Only claim pending messages whose idle time exceeds this threshold. +# Unit: milliseconds. Default: 10 minute. +DEFAULT_PENDING_CLAIM_MIN_IDLE_MS = 600_000 diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 9f39d9888..8b74995d4 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import TypedDict +from memos.context.context import generate_trace_id from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -36,6 +37,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): redis_message_id: str = Field(default="", description="the message get from redis stream") stream_key: str = Field("", description="stream_key for identifying the queue in line") user_id: str = Field(..., description="user id") + trace_id: str = Field(default_factory=generate_trace_id, description="trace id for logging") mem_cube_id: str = Field(..., description="memcube id") session_id: str = Field(default="", description="Session ID for soft-filtering memories") label: str = Field(..., description="Label of the schedule message") @@ -80,6 +82,7 @@ def to_dict(self) -> dict: "item_id": self.item_id, "user_id": self.user_id, "cube_id": self.mem_cube_id, + "trace_id": self.trace_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization "content": self.content, @@ -95,6 +98,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], mem_cube_id=data["cube_id"], + trace_id=data.get("trace_id", generate_trace_id()), label=data["label"], content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index e96657ca7..a2d01df6b 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -4,10 +4,15 @@ from collections import defaultdict from collections.abc import Callable -from datetime import timezone +from datetime import datetime, timezone from typing import Any -from memos.context.context import ContextThreadPoolExecutor +from memos.context.context import ( + ContextThreadPoolExecutor, + RequestContext, + generate_trace_id, + set_request_context, +) from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.general_modules.task_threads import ThreadManager @@ -19,6 +24,7 @@ from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -121,15 +127,26 @@ def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): def wrapped_handler(messages: list[ScheduleMessageItem]): start_time = time.time() + start_iso = datetime.fromtimestamp(start_time, tz=timezone.utc).isoformat() if self.status_tracker: self.status_tracker.task_started( task_id=task_item.item_id, user_id=task_item.user_id ) try: + first_msg = messages[0] + trace_id = getattr(first_msg, "trace_id", None) or generate_trace_id() + # Propagate trace_id and user info to logging context for this handler execution + ctx = RequestContext( + trace_id=trace_id, + user_name=getattr(first_msg, "user_name", None), + user_type=None, + ) + set_request_context(ctx) + # --- mark start: record queuing time(now - enqueue_ts)--- now = time.time() - m = messages[0] # All messages in this batch have same user and type - enq_ts = getattr(m, "timestamp", None) + m = first_msg # All messages in this batch have same user and type + enq_ts = getattr(first_msg, "timestamp", None) # Path 1: epoch seconds (preferred) if isinstance(enq_ts, int | float): @@ -149,17 +166,51 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): wait_sec = max(0.0, now - enq_epoch) self.metrics.observe_task_wait_duration(wait_sec, m.user_id, m.label) + dequeue_ts = getattr(first_msg, "dequeue_ts", None) + start_delay_ms = None + if isinstance(dequeue_ts, int | float): + start_delay_ms = max(0.0, start_time - dequeue_ts) * 1000 + + emit_monitor_event( + "start", + first_msg, + { + "start_ts": start_iso, + "start_delay_ms": start_delay_ms, + "enqueue_ts": to_iso(enq_ts), + "dequeue_ts": to_iso( + datetime.fromtimestamp(dequeue_ts, tz=timezone.utc) + if isinstance(dequeue_ts, int | float) + else None + ), + }, + ) + # Execute the original handler result = handler(messages) # --- mark done --- - duration = time.time() - start_time + finish_time = time.time() + duration = finish_time - start_time self.metrics.observe_task_duration(duration, m.user_id, m.label) if self.status_tracker: self.status_tracker.task_completed( task_id=task_item.item_id, user_id=task_item.user_id ) self.metrics.task_completed(user_id=m.user_id, task_type=m.label) + + emit_monitor_event( + "finish", + first_msg, + { + "status": "ok", + "start_ts": start_iso, + "finish_ts": datetime.fromtimestamp( + finish_time, tz=timezone.utc + ).isoformat(), + "exec_duration_ms": duration * 1000, + }, + ) # Redis ack is handled in finally to cover failure cases # Mark task as completed and remove from tracking @@ -172,11 +223,26 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): except Exception as e: m = messages[0] + finish_time = time.time() self.metrics.task_failed(m.user_id, m.label, type(e).__name__) if self.status_tracker: self.status_tracker.task_failed( task_id=task_item.item_id, user_id=task_item.user_id, error_message=str(e) ) + emit_monitor_event( + "finish", + m, + { + "status": "fail", + "start_ts": start_iso, + "finish_ts": datetime.fromtimestamp( + finish_time, tz=timezone.utc + ).isoformat(), + "exec_duration_ms": (finish_time - start_time) * 1000, + "error_type": type(e).__name__, + "error_msg": str(e), + }, + ) # Mark task as failed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 5c551b23e..703dd1eb8 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -16,7 +16,10 @@ from memos.context.context import ContextThread from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, + DEFAULT_STREAM_KEY_PREFIX, +) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -195,57 +198,7 @@ def _ensure_consumer_group(self, stream_key) -> None: else: logger.error(f"Error creating consumer group: {e}", exc_info=True) - def _get_pending_lock_key(self, stream_key: str) -> str: - """Compose a Redis lock key for pending reads on a specific stream. - - Lock key includes stream prefix and consumer group to avoid collisions - across different deployments/groups. - """ - # Use a stable lock namespace; include group to isolate multiple schedulers - return f"{self.stream_key_prefix}:lock:pending:{self.consumer_group}:{stream_key}" - - def _acquire_pending_lock(self, stream_key: str, ttl_ms: int = 2000) -> str | None: - """Try to acquire a short-lived lock before reading pending messages. - - Returns a unique token if the lock is acquired, otherwise None. - """ - if not self._redis_conn: - return None - token = uuid4().hex - try: - ok = self._redis_conn.set( - self._get_pending_lock_key(stream_key), token, nx=True, px=ttl_ms - ) - if ok: - logger.debug( - f"Acquired pending-read lock for stream '{stream_key}' (ttl_ms={ttl_ms})" - ) - return token - else: - logger.debug(f"Skip pending-read: lock not acquired for stream '{stream_key}'") - return None - except Exception as e: - logger.warning(f"Failed to acquire pending-read lock for '{stream_key}': {e}") - return None - - def _release_pending_lock(self, stream_key: str, token: str) -> None: - """Release the pending-read lock only if owned (token matches).""" - if not self._redis_conn or not token: - return - lock_key = self._get_pending_lock_key(stream_key) - # Compare-and-delete via Lua to ensure we only release our own lock - lua = """ - if redis.call('get', KEYS[1]) == ARGV[1] then - return redis.call('del', KEYS[1]) - else - return 0 - end - """ - try: - self._redis_conn.eval(lua, 1, lock_key, token) - logger.debug(f"Released pending-read lock for stream '{stream_key}'") - except Exception as e: - logger.debug(f"Release lock failed for '{stream_key}': {e}") + # Pending lock methods removed as they are unnecessary with idle-threshold claiming def put( self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None @@ -390,46 +343,44 @@ def get( need_pending_count = need_pending if need_pending > 0 else 0 if need_pending_count: - # Acquire a short-lived lock to avoid multiple processes reading the same pending - # messages concurrently when sharing the same consumer_name. - ttl_ms = 2000 - token = self._acquire_pending_lock(stream_key=stream_key, ttl_ms=ttl_ms) - if token: - try: + # Claim only pending messages whose idle time exceeds configured threshold + try: + # Ensure group exists before claiming + self._ensure_consumer_group(stream_key=stream_key) + # XAUTOCLAIM returns (next_start_id, [(id, fields), ...]) + next_id, claimed = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, + start_id="0-0", + count=need_pending_count, + justid=False, + ) + pending_messages = [(stream_key, claimed)] if claimed else [] + except Exception as read_err: + # Handle missing group/stream by creating and retrying once + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)." + ) try: - pending_messages = self._redis_conn.xreadgroup( - self.consumer_group, - self.consumer_name, - {stream_key: "0"}, # read only this consumer's pending + self._ensure_consumer_group(stream_key=stream_key) + next_id, claimed = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=DEFAULT_PENDING_CLAIM_MIN_IDLE_MS, + start_id="0-0", count=need_pending_count, - block=None, # do not block when checking pending + justid=False, ) - except Exception as read_err: - # Handle missing group/stream by creating and retrying once - err_msg = str(read_err).lower() - if "nogroup" in err_msg or "no such key" in err_msg: - logger.warning( - f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (pending)." - ) - self._ensure_consumer_group(stream_key=stream_key) - try: - pending_messages = self._redis_conn.xreadgroup( - self.consumer_group, - self.consumer_name, - {stream_key: "0"}, - count=need_pending_count, - block=None, - ) - except Exception: - pending_messages = [] - else: - pending_messages = [] - finally: - # Always release the lock - self._release_pending_lock(stream_key=stream_key, token=token) - else: - # If lock not acquired, skip pending read in this round - pending_messages = [] + pending_messages = [(stream_key, claimed)] if claimed else [] + except Exception: + pending_messages = [] + else: + pending_messages = [] # Combine: new first, then pending messages = [] @@ -486,10 +437,8 @@ def qsize(self) -> dict: total_size = 0 try: qsize_stats = {} - # Scan for all stream keys matching the prefix - redis_pattern = f"{self.stream_key_prefix}:*" - for stream_key in self._redis_conn.scan_iter(redis_pattern): - # Get the length of each stream and add to total + # Use filtered stream keys to avoid WRONGTYPE on non-stream keys + for stream_key in self.get_stream_keys(): stream_qsize = self._redis_conn.xlen(stream_key) qsize_stats[stream_key] = stream_qsize total_size += stream_qsize @@ -504,8 +453,12 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: """ List all Redis stream keys that match this queue's prefix. + Only returns actual Redis Stream keys, excluding auxiliary keys + (e.g., any lock or string/hash keys). This avoids WRONGTYPE errors + when issuing stream commands on non-stream keys. + Returns: - A list of stream keys like `"{prefix}:{user_id}:{mem_cube_id}"`. + A list of stream keys like `"{prefix}:{user_id}:{mem_cube_id}:{task_label}"`. """ if not self._redis_conn: return [] @@ -514,7 +467,8 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: stream_key_prefix = self.stream_key_prefix # First, get all keys that might match (using Redis pattern matching) redis_pattern = f"{stream_key_prefix}:*" - raw_keys = self._redis_conn.scan_iter(match=redis_pattern) + raw_keys_iter = self._redis_conn.scan_iter(match=redis_pattern) + raw_keys = list(raw_keys_iter) # Second, filter using Python regex to ensure exact prefix match # Escape special regex characters in the prefix, then add :.* @@ -522,7 +476,6 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: regex_pattern = f"^{escaped_prefix}:" stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)] - logger.debug(f"get stream_keys from redis: {stream_keys}") return stream_keys def size(self) -> int: diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index a1285098e..2fd8716a3 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -5,12 +5,14 @@ the local memos_message_queue functionality in BaseScheduler. """ +from memos.context.context import get_current_trace_id from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso logger = get_logger(__name__) @@ -63,7 +65,12 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt if isinstance(messages, ScheduleMessageItem): messages = [messages] + current_trace_id = get_current_trace_id() + for msg in messages: + if current_trace_id: + # Prefer current request trace_id so logs can be correlated + msg.trace_id = current_trace_id msg.stream_key = self.memos_message_queue.get_stream_key( user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, task_label=msg.label ) @@ -71,6 +78,8 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt if len(messages) < 1: logger.error("Submit empty") elif len(messages) == 1: + enqueue_ts = to_iso(getattr(messages[0], "timestamp", None)) + emit_monitor_event("enqueue", messages[0], {"enqueue_ts": enqueue_ts}) self.memos_message_queue.put(messages[0]) else: user_cube_groups = group_messages_by_user_and_mem_cube(messages) @@ -93,6 +102,8 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) continue + enqueue_ts = to_iso(getattr(message, "timestamp", None)) + emit_monitor_event("enqueue", message, {"enqueue_ts": enqueue_ts}) self.memos_message_queue.put(message) logger.info( f"Submitted message to local queue: {message.label} - {message.content}" diff --git a/src/memos/mem_scheduler/utils/monitor_event_utils.py b/src/memos/mem_scheduler/utils/monitor_event_utils.py new file mode 100644 index 000000000..e3c09fa29 --- /dev/null +++ b/src/memos/mem_scheduler/utils/monitor_event_utils.py @@ -0,0 +1,67 @@ +import json +import os +import socket + +from datetime import datetime, timezone +from typing import Any + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +logger = get_logger(__name__) + + +def _iso_ts_now() -> str: + """Return current UTC timestamp in ISO format with milliseconds.""" + return datetime.now(timezone.utc).isoformat() + + +def to_iso(ts) -> str | None: + """Convert datetime to ISO string; return None if not convertible.""" + if ts is None: + return None + if isinstance(ts, datetime): + dt = ts + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.isoformat() + try: + return datetime.fromtimestamp(float(ts), tz=timezone.utc).isoformat() + except Exception: + return None + + +def emit_monitor_event(event: str, msg: ScheduleMessageItem, extra: dict[str, Any] | None = None): + """ + Emit a structured MONITOR_EVENT log line for SLS consumption. + + This must be fire-and-forget: any exception here should never break the scheduler flow. + """ + try: + payload: dict[str, Any] = { + "event": event, + "ts": _iso_ts_now(), + "label": getattr(msg, "label", None), + "user_id": getattr(msg, "user_id", None), + "mem_cube_id": getattr(msg, "mem_cube_id", None), + "item_id": getattr(msg, "item_id", None), + "task_id": getattr(msg, "task_id", "") or "", + "trace_id": getattr(msg, "trace_id", None), + "stream_key": getattr(msg, "stream_key", None), + "redis_message_id": getattr(msg, "redis_message_id", None), + "monitor_flag": None, + "host": socket.gethostname(), + "env": os.getenv("ENV") or os.getenv("ENVIRONMENT") or "", + } + + info = getattr(msg, "info", None) + if isinstance(info, dict): + payload["monitor_flag"] = info.get("monitor_flag") + + if extra: + payload.update(extra) + + logger.info("MONITOR_EVENT " + json.dumps(payload, ensure_ascii=False)) + except Exception: + logger.debug("Failed to emit MONITOR_EVENT", exc_info=True)