Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,35 @@ def __init__(
self._stream_keys_lock = threading.Lock()
self._stream_keys_refresh_thread: ContextThread | None = None
self._stream_keys_refresh_stop_event = threading.Event()
self._initial_scan_max_keys = int(
os.getenv("MEMSCHEDULER_REDIS_INITIAL_SCAN_MAX_KEYS", "1000") or 1000
)
self._initial_scan_time_limit_sec = float(
os.getenv("MEMSCHEDULER_REDIS_INITIAL_SCAN_TIME_LIMIT_SEC", "1.0") or 1.0
)

# Start background stream keys refresher if connected
if self._is_connected:
# Refresh once synchronously to seed cache at init
try:
self._refresh_stream_keys()
self._refresh_stream_keys(
max_keys=self._initial_scan_max_keys,
time_limit_sec=self._initial_scan_time_limit_sec,
)
except Exception as e:
logger.debug(f"Initial stream keys refresh failed: {e}")

# Then start background refresher
self._start_stream_keys_refresh_thread()

def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str:
stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}"
return stream_key

# --- Stream keys refresh background thread ---
def _refresh_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]:
def _refresh_stream_keys(
self,
stream_key_prefix: str | None = None,
max_keys: int | None = None,
time_limit_sec: float | None = None,
) -> list[str]:
"""Scan Redis and refresh cached stream keys for the queue prefix."""
if not self._redis_conn:
return []
Expand All @@ -140,12 +151,29 @@ def _refresh_stream_keys(self, stream_key_prefix: str | None = None) -> list[str

try:
redis_pattern = f"{stream_key_prefix}:*"
raw_keys_iter = self._redis_conn.scan_iter(match=redis_pattern)
raw_keys = list(raw_keys_iter)
collected: list[str] = []
cursor: int | str = 0
start_ts = time.time() if time_limit_sec else None
count_hint = 200
while True:
if (
start_ts is not None
and time_limit_sec is not None
and time.time() - start_ts > time_limit_sec
):
break
cursor, keys = self._redis_conn.scan(
cursor=cursor, match=redis_pattern, count=count_hint
)
collected.extend(keys)
if max_keys is not None and len(collected) >= max_keys:
break
if cursor == 0 or cursor == "0":
break

escaped_prefix = re.escape(stream_key_prefix)
regex_pattern = f"^{escaped_prefix}:"
stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)]
stream_keys = [key for key in collected if re.match(regex_pattern, key)]

if stream_key_prefix == self.stream_key_prefix:
with self._stream_keys_lock:
Expand Down