diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index c2137a011..a5c5bc737 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -204,19 +204,16 @@ def add_msgs( for item_idx, item in enumerate(tqdm(questions, desc="processing queries")): query = item["question"] - messages_to_send = [ - ScheduleMessageItem( - item_id=f"test_item_{item_idx}", - user_id=trying_modules.current_user_id, - mem_cube_id=trying_modules.current_mem_cube_id, - label=MEM_UPDATE_TASK_LABEL, - content=query, - ) - ] - + message = ScheduleMessageItem( + item_id=f"test_item_{item_idx}", + user_id=trying_modules.current_user_id, + mem_cube_id=trying_modules.current_mem_cube_id, + label=MEM_UPDATE_TASK_LABEL, + content=query, + ) # Run one session turn manually to get search candidates mem_scheduler._memory_update_consumer( - messages=messages_to_send, + messages=[message], ) # Show accumulated web logs diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 0c3645b49..ac79c246b 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -459,7 +459,7 @@ def get_memory( @staticmethod def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"need_rewrite": bool, "rewritten_suffix": str, "reason": str}, ... } + Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } Returns (success, parsed_dict) with int keys. """ try: @@ -483,16 +483,16 @@ def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dic if not isinstance(v, dict): continue need_rewrite = v.get("need_rewrite") - rewritten_suffix = v.get("rewritten_suffix", "") + rewritten = v.get("rewritten", "") reason = v.get("reason", "") if ( isinstance(need_rewrite, bool) - and isinstance(rewritten_suffix, str) + and isinstance(rewritten, str) and isinstance(reason, str) ): result[idx] = { "need_rewrite": need_rewrite, - "rewritten_suffix": rewritten_suffix, + "rewritten": rewritten, "reason": reason, } @@ -503,6 +503,8 @@ def filter_hallucination_in_memories( ) -> list[TextualMemoryItem]: # Build input objects with memory text and metadata (timestamps, sources, etc.) template = PROMPT_MAPPING["hallucination_filter"] + if len(messages) < 2: + return memory_list prompt_args = { "messages_inline": "\n".join( [f"- [{message['role']}]: {message['content']}" for message in messages] @@ -523,32 +525,27 @@ def filter_hallucination_in_memories( f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success}" ) if success: - new_mem_list = [] logger.info(f"Hallucination filter result: {parsed}") assert len(parsed) == len(memory_list) for mem_idx, content in parsed.items(): need_rewrite = content.get("need_rewrite", False) - rewritten_suffix = content.get("rewritten_suffix", "") + rewritten_text = content.get("rewritten", "") reason = content.get("reason", "") - # Append a new memory item instead of replacing the original + # Replace memory text with rewritten content when rewrite is needed if ( need_rewrite - and isinstance(rewritten_suffix, str) - and len(rewritten_suffix.strip()) > 0 + and isinstance(rewritten_text, str) + and len(rewritten_text.strip()) > 0 ): original_text = memory_list[mem_idx].memory logger.info( - f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten_suffix='{rewritten_suffix}', reason='{reason}', original memory='{original_text}', action='append_suffix'" + f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" ) - # Append only the suffix to the original memory text - memory_list[mem_idx].memory = original_text + rewritten_suffix - new_mem_list.append(memory_list[mem_idx]) - else: - new_mem_list.append(memory_list[mem_idx]) - return new_mem_list + memory_list[mem_idx].memory = rewritten_text + return memory_list else: logger.warning("Hallucination filter parsing failed or returned empty result.") except Exception as e: @@ -603,13 +600,46 @@ def _read_memory( if os.getenv("SIMPLE_STRUCT_ADD_FILTER", "false") == "true": # Build inputs - new_memory_list = [] - for unit_messages, unit_memory_list in zip(messages, memory_list, strict=False): - unit_memory_list = self.filter_hallucination_in_memories( - messages=unit_messages, memory_list=unit_memory_list - ) - new_memory_list.append(unit_memory_list) - memory_list = new_memory_list + combined_messages = [] + for group_messages in messages: + combined_messages.extend(group_messages) + + for group_id in range(len(memory_list)): + try: + revised_memory_list = self.filter_hallucination_in_memories( + messages=combined_messages, + memory_list=memory_list[group_id], + ) + if len(revised_memory_list) != len(memory_list[group_id]): + original_serialized = [ + one.memory if hasattr(one, "memory") else str(one) + for one in memory_list[group_id] + ] + filtered_serialized = [ + one.memory if hasattr(one, "memory") else str(one) + for one in revised_memory_list + ] + logger.error( + f"Length mismatch after hallucination filtering for group_id={group_id}: " + f"original={len(memory_list[group_id])}, filtered={len(revised_memory_list)}" + f"\noriginal_memory_list(serialized): {original_serialized}" + f"\nfiltered_memory_list(serialized): {filtered_serialized}" + f"\nmessages: {combined_messages}" + f"\nSkipping update and keeping original memory." + ) + continue + memory_list[group_id] = revised_memory_list + except Exception as e: + group_serialized = [ + one.memory if hasattr(one, "memory") else str(one) + for one in memory_list[group_id] + ] + logger.error( + f"There is an exception while filtering group_id={group_id}: {e}\n" + f"messages: {combined_messages}\n" + f"memory_list(serialized): {group_serialized}", + exc_info=True, + ) return memory_list def fine_transfer_simple_mem( diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 6256467ba..afe81d61e 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -156,8 +156,8 @@ def long_memory_update_process( logger.info( f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': " f"Scheduler replaced working memory based on query history {queries}. " - f"Old working memory ({len(old_memory_texts)} items): {old_memory_texts}. " - f"New working memory ({len(new_memory_texts)} items): {new_memory_texts}." + f"Old working memory ({len(cur_working_memory)} items): {old_memory_texts}. " + f"New working memory ({len(new_order_working_memory)} items): {new_memory_texts}." ) # update activation memories diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index ae1b44a80..6913429c3 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -723,8 +723,32 @@ def _batch_claim_pending_messages( ) results.append(res) except Exception as se: - logger.warning(f"Sequential xautoclaim failed for '{stream_key}': {se}") - results.append(None) + err_msg = str(se).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Sequential xautoclaim failed for '{stream_key}': {se}. Retrying with _ensure_consumer_group." + ) + with contextlib.suppress(Exception): + self._ensure_consumer_group(stream_key=stream_key) + try: + res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + results.append(res) + except Exception as retry_err: + logger.warning( + f"Retry sequential xautoclaim failed for '{stream_key}': {retry_err}" + ) + results.append(None) + else: + logger.warning(f"Sequential xautoclaim failed for '{stream_key}': {se}") + results.append(None) claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 4ac12eb70..12c445df7 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -623,15 +623,23 @@ SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ -You are a strict memory validator. +You are a strict, language-preserving memory validator and rewriter. -Task: -Check each memory against the user messages (ground truth). Do not modify the original text. Generate ONLY a suffix to append. +Your task is to compare each memory against the provided user messages (the ground truth) and produce a corrected version only when necessary. Always preserve the original language of the memory—do not translate. Rules: -- Append " [Source:] Inference by assistant." if the memory contains assistant inference (not directly stated by the user). -- Otherwise output an empty suffix. -- No other commentary or formatting. +1. **Language Consistency**: The rewritten memory must be in the exact same language as the original input memory. Never switch languages. +2. **Strict Grounding**: Only use information explicitly stated in the user messages. Do not introduce external facts, assumptions, or common sense. +3. **Ambiguity Resolution**: + - Replace vague pronouns (e.g., "he", "it", "they") or unclear references with specific, unambiguous entities based solely on the messages. + - Convert relative time expressions (e.g., "yesterday", "last week", "in two days") into absolute dates or times **only if the messages provide enough context** (e.g., current date is known or implied). +4. **Handling Assistant Inferences**: + - If a memory contains any content **not directly stated by the user**—such as interpretations, summaries, emotional attributions, predictions, causal claims, or generalizations—this is considered an assistant inference. + - In such cases, you **must** set `need_rewrite = true`. + - The `rewritten` text **must explicitly indicate that the statement is an inference**, using a clear and natural prefix in the memory’s language. For English memories, use: + > "The assistant inferred that [rest of the memory]." + - Do **not** present inferred content as factual user statements. +5. **No Rewrite Needed**: If the memory is factually accurate, fully grounded in the messages, unambiguous, and contains no unsupported content, set `need_rewrite = false` and copy the original memory exactly. Inputs: messages: @@ -640,12 +648,16 @@ memories: {memories_inline} -Output JSON: -- Keys: same indices as input ("0", "1", ...). -- Values: {{ "need_rewrite": boolean, "rewritten_suffix": string, "reason": string }} -- need_rewrite = true only when assistant inference is detected. -- rewritten_suffix = " [Source:] Inference by assistant." or "". -- reason: brief, e.g., "assistant inference detected" or "explicit user statement". +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) corresponding to the input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" should be concise and specific, e.g.: + - "contains assistant inference not stated by user" + - "pronoun 'it' has no clear referent in messages" + - "relative time 'yesterday' converted to 2025-12-16" + - "accurate and directly supported by user message" + +Important: Output **only** the JSON. No additional text, explanations, markdown, or fields. """