diff --git a/examples/mem_scheduler/show_redis_status.py b/examples/mem_scheduler/show_redis_status.py new file mode 100644 index 000000000..04e79ca97 --- /dev/null +++ b/examples/mem_scheduler/show_redis_status.py @@ -0,0 +1,67 @@ +import time + +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue + + +queue = mem_scheduler.memos_message_queue.memos_message_queue + + +def fetch_status(queue: SchedulerRedisQueue) -> dict[str, dict[str, int]]: + """Fetch and print per-user Redis queue status using built-in API. + + Returns a dict mapping user_id -> {"pending": int, "remaining": int}. + """ + # This method will also print a summary and per-user counts. + return queue.show_task_status() + + +def print_diff(prev: dict[str, dict[str, int]], curr: dict[str, dict[str, int]]) -> None: + """Print aggregated totals and per-user changes compared to previous snapshot.""" + ts = time.strftime("%Y-%m-%d %H:%M:%S") + tot_p_prev = sum(v.get("pending", 0) for v in prev.values()) if prev else 0 + tot_r_prev = sum(v.get("remaining", 0) for v in prev.values()) if prev else 0 + tot_p_curr = sum(v.get("pending", 0) for v in curr.values()) + tot_r_curr = sum(v.get("remaining", 0) for v in curr.values()) + + dp_tot = tot_p_curr - tot_p_prev + dr_tot = tot_r_curr - tot_r_prev + + print(f"[{ts}] Total pending={tot_p_curr} ({dp_tot:+d}), remaining={tot_r_curr} ({dr_tot:+d})") + + # Print per-user deltas (current counts are already printed by show_task_status) + all_uids = sorted(set(prev.keys()) | set(curr.keys())) + for uid in all_uids: + p_prev = prev.get(uid, {}).get("pending", 0) + r_prev = prev.get(uid, {}).get("remaining", 0) + p_curr = curr.get(uid, {}).get("pending", 0) + r_curr = curr.get(uid, {}).get("remaining", 0) + dp = p_curr - p_prev + dr = r_curr - r_prev + # Only print when there is any change to reduce noise + if dp != 0 or dr != 0: + print(f" Δ {uid}: pending={dp:+d}, remaining={dr:+d}") + + +# Note: queue.show_task_status() handles printing per-user counts internally. + + +def main(interval_sec: float = 5.0) -> None: + prev: dict[str, dict[str, int]] = {} + while True: + try: + curr = fetch_status(queue) + print_diff(prev, curr) + print(f"stream_cache ({len(queue._stream_keys_cache)}): {queue._stream_keys_cache}") + prev = curr + time.sleep(interval_sec) + except KeyboardInterrupt: + print("Stopped.") + break + except Exception as e: + print(f"Error while fetching status: {e}") + time.sleep(interval_sec) + + +if __name__ == "__main__": + main() diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 9a83ab16e..f0833d716 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -453,7 +453,7 @@ def get_memory( @staticmethod def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"delete": bool, "rewritten": str, "reason": str}, ... } + Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } Returns (success, parsed_dict) with int keys. """ try: @@ -476,27 +476,33 @@ def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dic continue if not isinstance(v, dict): continue - delete_flag = v.get("delete") + need_rewrite = v.get("need_rewrite") rewritten = v.get("rewritten", "") reason = v.get("reason", "") if ( - isinstance(delete_flag, bool) + isinstance(need_rewrite, bool) and isinstance(rewritten, str) and isinstance(reason, str) ): - result[idx] = {"delete": delete_flag, "rewritten": rewritten, "reason": reason} + result[idx] = { + "need_rewrite": need_rewrite, + "rewritten": rewritten, + "reason": reason, + } return (len(result) > 0), result def filter_hallucination_in_memories( - self, user_messages: list[str], memory_list: list[TextualMemoryItem] + self, messages: list[dict], memory_list: list[TextualMemoryItem] ) -> list[TextualMemoryItem]: - flat_memories = [one.memory for one in memory_list] + # Build input objects with memory text and metadata (timestamps, sources, etc.) template = PROMPT_MAPPING["hallucination_filter"] prompt_args = { - "user_messages_inline": "\n".join([f"- {memory}" for memory in user_messages]), + "messages_inline": "\n".join( + [f"- [{message['role']}]: {message['content']}" for message in messages] + ), "memories_inline": json.dumps( - {str(i): memory for i, memory in enumerate(flat_memories)}, + {idx: mem.memory for idx, mem in enumerate(memory_list)}, ensure_ascii=False, indent=2, ), @@ -511,40 +517,25 @@ def filter_hallucination_in_memories( f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success}" ) if success: + new_mem_list = [] logger.info(f"Hallucination filter result: {parsed}") - total = len(memory_list) - keep_flags = [True] * total + assert len(parsed) == len(memory_list) for mem_idx, content in parsed.items(): - # Validate index bounds - if not isinstance(mem_idx, int) or mem_idx < 0 or mem_idx >= total: - logger.warning( - f"[filter_hallucination_in_memories] Ignoring out-of-range index: {mem_idx}" - ) - continue - - delete_flag = content.get("delete", False) - rewritten = content.get("rewritten", None) + need_rewrite = content.get("need_rewrite", False) + rewritten = content.get("rewritten", "") reason = content.get("reason", "") - logger.info( - f"[filter_hallucination_in_memories] index={mem_idx}, delete={delete_flag}, rewritten='{(rewritten or '')[:100]}', reason='{reason[:120]}'" - ) - - if delete_flag is True and rewritten is not None: - # Mark for deletion - keep_flags[mem_idx] = False - else: - # Apply rewrite if provided (safe-by-default: keep item when not mentioned or delete=False) - try: - if isinstance(rewritten, str): - memory_list[mem_idx].memory = rewritten - except Exception as e: - logger.warning( - f"[filter_hallucination_in_memories] Failed to apply rewrite for index {mem_idx}: {e}" - ) - - # Build result, preserving original order; keep items not mentioned by LLM by default - new_mem_list = [memory_list[i] for i in range(total) if keep_flags[i]] + # Apply rewriting if requested + if ( + need_rewrite + and isinstance(rewritten, str) + and len(rewritten) > len(memory_list[mem_idx].memory) + ): + memory_list[mem_idx].memory = rewritten + logger.info( + f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten}', reason='{reason}', original memory='{memory_list[mem_idx].memory}'" + ) + new_mem_list.append(memory_list[mem_idx]) return new_mem_list else: logger.warning("Hallucination filter parsing failed or returned empty result.") @@ -602,11 +593,8 @@ def _read_memory( # Build inputs new_memory_list = [] for unit_messages, unit_memory_list in zip(messages, memory_list, strict=False): - unit_user_messages = [ - msg["content"] for msg in unit_messages if msg["role"] == "user" - ] unit_memory_list = self.filter_hallucination_in_memories( - user_messages=unit_user_messages, memory_list=unit_memory_list + messages=unit_messages, memory_list=unit_memory_list ) new_memory_list.append(unit_memory_list) memory_list = new_memory_list diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 71012d42f..4c7d51a7c 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -144,7 +144,7 @@ def long_memory_update_process( old_memory_texts = [mem.memory for mem in cur_working_memory] new_memory_texts = [mem.memory for mem in new_order_working_memory] - logger.debug( + logger.info( f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': " f"Scheduler replaced working memory based on query history {queries}. " f"Old working memory ({len(old_memory_texts)} items): {old_memory_texts}. " @@ -1413,17 +1413,16 @@ def process_session_turn( logger.info( f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}" ) - info = { - "user_id": user_id, - "session_id": "", - } + search_args = {} results: list[TextualMemoryItem] = self.retriever.search( query=item, + user_id=user_id, + mem_cube_id=mem_cube_id, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method, - info=info, + search_args=search_args, ) logger.info( f"[process_session_turn] Search results for missing evidence '{item}': {[one.memory for one in results]}" diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index fdd8a8cfe..f205766f0 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -22,7 +22,11 @@ from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer from memos.memories.textual.item import TextualMemoryMetadata from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.types.general_types import FINE_STRATEGY, FineStrategy +from memos.types.general_types import ( + FINE_STRATEGY, + FineStrategy, + SearchMode, +) # Extract JSON response from .memory_filter import MemoryFilter @@ -237,10 +241,12 @@ def recall_for_missing_memories( def search( self, query: str, + user_id: str, + mem_cube_id: str, mem_cube: GeneralMemCube, top_k: int, method: str = TreeTextMemory_SEARCH_METHOD, - info: dict | None = None, + search_args: dict | None = None, ) -> list[TextualMemoryItem]: """Search in text memory with the given query. @@ -253,22 +259,67 @@ def search( Search results or None if not implemented """ text_mem_base = mem_cube.text_mem + # Normalize default for mutable argument + search_args = search_args or {} try: if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: assert isinstance(text_mem_base, TreeTextMemory) - if info is None: - logger.warning( - "Please input 'info' when use tree.search so that " - "the database would store the consume history." - ) - info = {"user_id": "", "session_id": ""} + session_id = search_args.get("session_id", "default_session") + target_session_id = session_id + search_priority = ( + {"session_id": target_session_id} if "session_id" in search_args else None + ) + search_filter = search_args.get("filter") + search_source = search_args.get("source") + plugin = bool(search_source is not None and search_source == "plugin") + user_name = search_args.get("user_name", mem_cube_id) + internet_search = search_args.get("internet_search", False) + chat_history = search_args.get("chat_history") + search_tool_memory = search_args.get("search_tool_memory", False) + tool_mem_top_k = search_args.get("tool_mem_top_k", 6) + playground_search_goal_parser = search_args.get( + "playground_search_goal_parser", False + ) - mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine" - results_long_term = text_mem_base.search( - query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info + info = search_args.get( + "info", + { + "user_id": user_id, + "session_id": target_session_id, + "chat_history": chat_history, + }, ) - results_user = text_mem_base.search( - query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info + + results_long_term = mem_cube.text_mem.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + manual_close_internet=not internet_search, + memory_type="LongTermMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, + ) + + results_user = mem_cube.text_mem.search( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + manual_close_internet=not internet_search, + memory_type="UserMemory", + search_filter=search_filter, + search_priority=search_priority, + info=info, + plugin=plugin, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, + playground_search_goal_parser=playground_search_goal_parser, ) results = results_long_term + results_user else: diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index 5439cf225..af0f2f233 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -45,10 +45,6 @@ class TaskPriorityLevel(Enum): USER_INPUT_TYPE = "UserInput" NOT_APPLICABLE_TYPE = "NotApplicable" -# pending claim configuration -# Only claim pending messages whose idle time exceeds this threshold. -# Unit: milliseconds. Default: 10 minute. -DEFAULT_PENDING_CLAIM_MIN_IDLE_MS = 600_000 # scheduler daemon defaults # Interval in seconds for periodically releasing stale pending messages @@ -60,15 +56,22 @@ class TaskPriorityLevel(Enum): # Interval in seconds for batching and cleaning up deletions (xdel) DEFAULT_DELETE_CLEANUP_INTERVAL_SEC = 30.0 -# Inactivity threshold for stream deletion -# Delete streams whose last message ID timestamp is older than this threshold. -# Unit: seconds. Default: 1 day. -DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS = 86_400.0 +# pending claim configuration +# Only claim pending messages whose idle time exceeds this threshold. +# Unit: milliseconds. Default: 1 hour. +DEFAULT_PENDING_CLAIM_MIN_IDLE_MS = 3_600_000 + # Recency threshold for active streams # Consider a stream "active" if its last message is within this window. -# Unit: seconds. Default: 30 minutes. -DEFAULT_STREAM_RECENT_ACTIVE_SECONDS = 1_800.0 +# Unit: seconds. Default: 1 hours. +DEFAULT_STREAM_RECENT_ACTIVE_SECONDS = 3_600.0 + + +# Inactivity threshold for stream deletion +# Delete streams whose last message ID timestamp is older than this threshold. +# Unit: seconds. Default: 2 hour. +DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS = 7_200.0 # task queue diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 36fe3c553..d3268eda8 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,6 +5,7 @@ the local memos_message_queue functionality in BaseScheduler. """ +import contextlib import os import re import threading @@ -26,6 +27,7 @@ from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule +from memos.utils import timed_with_status logger = get_logger(__name__) @@ -249,6 +251,14 @@ def _stop_stream_keys_refresh_thread(self) -> None: except Exception as e: logger.debug(f"Stopping stream keys refresh thread encountered: {e}") + @timed_with_status( + log_prefix="task_broker", + log_extra_args={ + "stream_prefix": os.getenv( + "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", DEFAULT_STREAM_KEY_PREFIX + ) + }, + ) def task_broker( self, consume_batch_size: int, @@ -257,17 +267,44 @@ def task_broker( if not stream_keys: return [] + # Determine per-stream quotas for this cycle stream_quotas = self.orchestrator.get_stream_quotas( stream_keys=stream_keys, consume_batch_size=consume_batch_size ) - cache: list[ScheduleMessageItem] = [] + + # Step A: batch-read new messages across streams (non-blocking) + new_messages_map: dict[str, list[tuple[str, list[tuple[str, dict]]]]] = ( + self._read_new_messages_batch(stream_keys=stream_keys, stream_quotas=stream_quotas) + ) + + # Step B: compute pending needs per stream + claims_spec: list[tuple[str, int, str]] = [] for stream_key in stream_keys: - messages = self.get( - stream_key=stream_key, - block=False, + need_pending_count = self._compute_pending_need( + new_messages=new_messages_map.get(stream_key), batch_size=stream_quotas[stream_key], ) - cache.extend(messages) + if need_pending_count: + # Derive task label from stream key suffix + task_label = stream_key.rsplit(":", 1)[1] + claims_spec.append((stream_key, need_pending_count, task_label)) + + # Step C: batch claim pending messages across streams + claimed_messages: list[tuple[str, list[tuple[str, dict]]]] = [] + if claims_spec: + claimed_messages = self._batch_claim_pending_messages(claims_spec=claims_spec) + + # Step D: assemble and convert to ScheduleMessageItem + messages: list[tuple[str, list[tuple[str, dict]]]] = [] + for stream_key in stream_keys: + nm = new_messages_map.get(stream_key) + if nm: + messages.extend(nm) + + if claimed_messages: + messages.extend(claimed_messages) + + cache: list[ScheduleMessageItem] = self._convert_messages(messages) # pack messages packed: list[list[ScheduleMessageItem]] = [] @@ -360,12 +397,12 @@ def put( user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) - if stream_key not in self.seen_streams: - self.seen_streams.add(stream_key) - self._ensure_consumer_group(stream_key=stream_key) - # Update stream keys cache with newly observed stream key with self._stream_keys_lock: + if stream_key not in self.seen_streams: + self.seen_streams.add(stream_key) + self._ensure_consumer_group(stream_key=stream_key) + if stream_key not in self._stream_keys_cache: self._stream_keys_cache.append(stream_key) self._stream_keys_last_refresh = time.time() @@ -511,6 +548,77 @@ def _read_new_messages( logger.error(f"{read_err}", stack_info=True) raise + def _read_new_messages_batch( + self, stream_keys: list[str], stream_quotas: dict[str, int] + ) -> dict[str, list[tuple[str, list[tuple[str, dict]]]]]: + """Batch-read new messages (non-blocking) across multiple streams. + + Uses a Redis pipeline to reduce round trips while honoring per-stream quotas. + + Args: + stream_keys: List of stream keys to read from. + stream_quotas: Per-stream message upper bounds. + + Returns: + Mapping from stream key to xreadgroup-style result list. + """ + if not self._redis_conn or not stream_keys: + return {} + + # Pre-ensure consumer groups to avoid NOGROUP during batch reads + for stream_key in stream_keys: + with contextlib.suppress(Exception): + self._ensure_consumer_group(stream_key=stream_key) + + pipe = self._redis_conn.pipeline(transaction=False) + for stream_key in stream_keys: + pipe.xreadgroup( + self.consumer_group, + self.consumer_name, + {stream_key: ">"}, + count=stream_quotas.get(stream_key), + block=None, + ) + + try: + res_list = pipe.execute() + except Exception as e: + logger.error(f"Pipeline xreadgroup failed: {e}") + # Fallback to sequential non-blocking reads + res_list = [] + for stream_key in stream_keys: + try: + res = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {stream_key: ">"}, + count=stream_quotas.get(stream_key), + block=None, + ) + except Exception as read_err: + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + self._ensure_consumer_group(stream_key=stream_key) + try: + res = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {stream_key: ">"}, + count=stream_quotas.get(stream_key), + block=None, + ) + except Exception: + res = [] + else: + logger.error(f"{read_err}", stack_info=True) + res = [] + res_list.append(res) + + out: dict[str, list[tuple[str, list[tuple[str, dict]]]]] = {} + for stream_key, res in zip(stream_keys, res_list, strict=False): + out[stream_key] = res or [] + return out + def _compute_pending_need( self, new_messages: list[tuple[str, list[tuple[str, dict]]]] | None, batch_size: int | None ) -> int: @@ -573,6 +681,82 @@ def _claim_pending_messages( return [(stream_key, claimed)] if claimed else [] return [] + def _batch_claim_pending_messages( + self, claims_spec: list[tuple[str, int, str]] + ) -> list[tuple[str, list[tuple[str, dict]]]]: + """Batch-claim pending messages across multiple streams. + + Args: + claims_spec: List of tuples (stream_key, need_pending_count, task_label) + + Returns: + A list of (stream_key, claimed_entries) pairs for all successful claims. + """ + if not self._redis_conn or not claims_spec: + return [] + + # Ensure consumer groups exist to avoid NOGROUP errors during batch claim + for stream_key, _need_count, _label in claims_spec: + with contextlib.suppress(Exception): + self._ensure_consumer_group(stream_key=stream_key) + + pipe = self._redis_conn.pipeline(transaction=False) + for stream_key, need_count, label in claims_spec: + pipe.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + + results = [] + try: + results = pipe.execute() + except Exception as e: + logger.error(f"Pipeline xautoclaim failed: {e}") + # Fallback: attempt sequential xautoclaim for robustness + results = [] + for stream_key, need_count, label in claims_spec: + try: + res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + results.append(res) + except Exception as se: + logger.warning(f"Sequential xautoclaim failed for '{stream_key}': {se}") + results.append(None) + + claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] + for (stream_key, _need_count, _label), claimed_result in zip( + claims_spec, results, strict=False + ): + try: + if not claimed_result: + continue + if len(claimed_result) == 2: + _next_id, claimed = claimed_result + elif len(claimed_result) == 3: + _next_id, claimed, _deleted_ids = claimed_result + else: + raise ValueError( + f"Unexpected xautoclaim response length: {len(claimed_result)} for '{stream_key}'" + ) + if claimed: + claimed_pairs.append((stream_key, claimed)) + except Exception as parse_err: + logger.warning(f"Failed to parse xautoclaim result for '{stream_key}': {parse_err}") + + return claimed_pairs + def _convert_messages( self, messages: list[tuple[str, list[tuple[str, dict]]]] ) -> list[ScheduleMessageItem]: @@ -617,6 +801,62 @@ def qsize(self) -> dict: logger.error(f"Failed to get Redis queue size: {e}", stack_info=True) return {} + def show_task_status(self) -> dict[str, dict[str, int]]: + stream_keys = self.get_stream_keys(stream_key_prefix=self.stream_key_prefix) + if not stream_keys: + logger.info("No Redis streams found for the configured prefix") + return {} + + consumer_group = self.consumer_group or "scheduler_group" + + grouped: dict[str, dict[str, int]] = {} + + for sk in stream_keys: + uid = sk + if uid not in grouped: + grouped[uid] = {"pending": 0, "remaining": 0} + + # Pending count via XPENDING + pending_count = 0 + try: + pending_info = self._redis_conn.xpending(sk, consumer_group) + # redis-py may return a tuple-like [count, ...] + if pending_info: + try: + pending_count = int(pending_info[0]) + except Exception: + # Fallback if structure differs + pending_count = int(getattr(pending_info, "count", 0) or 0) + except Exception as e: + logger.debug(f"XPENDING failed for '{sk}': {e}") + + # Remaining count via XLEN + remaining_count = 0 + try: + remaining_count = int(self._redis_conn.xlen(sk)) + except Exception as e: + logger.debug(f"XLEN failed for '{sk}': {e}") + + grouped[uid]["pending"] += pending_count + grouped[uid]["remaining"] += remaining_count + + # Pretty-print summary + try: + total_pending = sum(v.get("pending", 0) for v in grouped.values()) + total_remaining = sum(v.get("remaining", 0) for v in grouped.values()) + header = f"Task Queue Status by user_id | pending={total_pending}, remaining={total_remaining}" + print(header) + for uid in sorted(grouped.keys()): + counts = grouped[uid] + print( + f"- {uid}: pending={counts.get('pending', 0)}, remaining={counts.get('remaining', 0)}" + ) + except Exception: + # Printing is best-effort; return grouped regardless + pass + + return grouped + def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: """ Return cached Redis stream keys maintained by background refresher. diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 4f4fbb4af..edfd74264 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -315,7 +315,7 @@ def rabbitmq_publish_message(self, message: dict): return False logger.info( - f"[DIAGNOSTIC] rabbitmq_service.rabbitmq_publish_message: Attempting to publish message. Exchange: {exchange_name}, Routing Key: {routing_key}, Message Content: {json.dumps(message, indent=2)}" + f"[DIAGNOSTIC] rabbitmq_service.rabbitmq_publish_message: Attempting to publish message. Exchange: {exchange_name}, Routing Key: {routing_key}, Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}" ) try: self.rabbitmq_channel.basic_publish( diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 8f9810cf1..dfeb5d180 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -423,36 +423,42 @@ You are a strict memory validator. # TASK -Validate each memory entry against the user's current messages (ground truth). -Memories that hallucinate unsupported facts or contradict the user must be corrected or marked for deletion. - -# RULES -- Use ONLY facts explicitly stated in the user messages. -- Do NOT invent, assume, or retain unsupported specifics. -- Preserve the original language of each memory when rewriting. -- Output ONLY a JSON object with no extra text. +Review each memory object against the messages (ground truth). +Do NOT alter the original memory content. Instead, append a concise reference-resolution explanation after the original content. +If any part of the memory originates from assistant inference (i.e., not explicitly stated by the user), explicitly note this after the explanation. + +# RULENOTES (strictly enforced) +- NEVER change, delete, or paraphrase the original memory text. +- ALWAYS preserve the original language, structure, and factual phrasing of the memory. +- After the original text, add exactly one sentence starting with "[Ref] " that resolves ambiguous references (e.g., pronouns like 'she', 'it', or vague terms like 'the dog') using only information explicitly present in the user messages or prior memories. +- If the memory contains content that was inferred by the assistant (not directly stated by the user), append an additional sentence starting with "[Source:] Inference by assistant." after the [Ref:] sentence. +- Do NOT add any other commentary, formatting, or metadata beyond this. +- Keep all original timestamps and identifiers intact in the memory object; this rule applies only to the 'text' field. # INPUTS -User messages (ground truth): -{user_messages_inline} +messages (ground truth): +{messages_inline} -Memory list (to validate, in indexed JSON format): +Extracted memory list to validate (indexed JSON objects with text and metadata): {memories_inline} # OUTPUT FORMAT Return a JSON object where: - Keys are the same stringified indices as in the input memory list (e.g., "0", "1"). -- Each value is: {{"delete": boolean, "rewritten": string, "reason": string}} -- If "delete" is true, "rewritten" must be an empty string. -- "reason" must briefly explain the decision (delete or rewrite) based on user messages. -- The number of output entries MUST exactly match the number of input memories. - -# DECISION GUIDE -- Contradicted? → rewrite to match user message, "delete"=false, "rewritten"=corrected memory content. -- Hallucinated (specific fact not in user messages)? → "delete"=true, "rewritten"=dehallucinated rewritten memory. -- Consistent or non-factual (opinion, emotion)? → keep as-is, "delete"=false. - -Additionally, include a concise "reason" for each item explaining your decision. +- Each value is: {{"need_rewrite": boolean, "rewritten": string, "reason": string}} +- Set "need_rewrite" to true ONLY if the memory contains ambiguous references or assistant inference requiring clarification. +- If "need_rewrite" is true, "rewritten" = + " [Ref] ." +- If "need_rewrite" is false (i.e., memory is fully explicit and user-stated), "rewritten" is an empty string. +- "reason" must be brief: e.g., "resolved ambiguous reference with inference", "explicit user statement, no rewrite needed". + +# EXAMPLE +Input memory text: "She loves painting." +User messages include: "Caroline loves painting." +→ Rewritten: "She loves painting. [Ref] 'She' refers to Caroline." + +Input memory text: "The user is a developer." +User never stated this, but assistant inferred from context. +→ Rewritten: "The user is a developer. [Ref] 'The user' refers to the person interacting with the assistant; this statement is assistant inference." Final Output: """ diff --git a/src/memos/utils.py b/src/memos/utils.py index a29eaf99d..e4945b7d3 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -6,6 +6,9 @@ logger = get_logger(__name__) +# Global threshold (seconds) for timing logs +DEFAULT_TIME_BAR = 10.0 + def timed_with_status( func=None, @@ -20,7 +23,9 @@ def timed_with_status( - log: enable timing logs (default True) - log_prefix: prefix; falls back to function name - log_args: names to include in logs (str or list/tuple of str). - - log_extra_args: extra arguments to include in logs (dict). + - log_extra_args: extra arguments to include in logs (dict). If it contains + key "time_threshold", use its value (in seconds) as the logging threshold; otherwise + fall back to DEFAULT_TIME_BAR. """ if isinstance(log_args, str): @@ -70,8 +75,15 @@ def wrapper(*args, **kwargs): f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} " f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}" ) + threshold_ms = DEFAULT_TIME_BAR * 1000.0 + if log_extra_args and "time_threshold" in log_extra_args: + try: + threshold_ms = float(log_extra_args["time_threshold"]) * 1000.0 + except Exception: + threshold_ms = DEFAULT_TIME_BAR * 1000.0 - logger.info(msg) + if elapsed_ms >= threshold_ms: + logger.info(msg) return wrapper @@ -90,7 +102,8 @@ def wrapper(*args, **kwargs): if log is not True: return result - logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms") + if elapsed_ms >= (DEFAULT_TIME_BAR * 1000.0): + logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms") return result