diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index b7347ae15..c2137a011 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -41,14 +41,26 @@ def init_task(): "role": "user", "content": "Max needs joint supplements. Actually, we're moving to Chicago next month.", }, + { + "role": "assistant", + "content": "Got it — Max is on joint supplements, and you’re relocating to Chicago soon. That’s a big move! Have you looked into how the change in climate or vet access might affect his needs?", + }, { "role": "user", "content": "Correction: Bella is 6, not 5. And she's allergic to chicken.", }, + { + "role": "assistant", + "content": "Thanks for the update! So Bella is 6 years old and has a chicken allergy — good to know. You’ll want to double-check her food and treats, especially during the move. Has she had any reactions recently?", + }, { "role": "user", "content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.", }, + { + "role": "assistant", + "content": "Ah, the classic dog-and-cat dynamic! Since Bella chases Whiskers, it might help to give them gradual supervised interactions or create safe zones for the cat—especially important as you settle into a new home in Chicago. Keeping Bella’s routine stable during the move could also reduce her urge to chase. How do they usually get along when Whiskers visits?", + }, ] questions = [ @@ -145,18 +157,25 @@ def start_conversation(self, user_id="test_user", mem_cube_id="test_cube", sessi print(f" User ID: {self.current_user_id}") print(f" Mem Cube ID: {self.current_mem_cube_id}") - def add_msgs(self, messages: list[dict]): + def add_msgs( + self, + messages: list[dict], + extract_mode: str = "fine", + async_mode: str = "sync", + ): # Create add request add_req = self.create_test_add_request( user_id=self.current_user_id, mem_cube_id=self.current_mem_cube_id, messages=messages, session_id=self.current_session_id, + extract_mode=extract_mode, + async_mode=async_mode, ) # Add to memory result = self.add_memories(add_req) - print(f" ✅ Added to memory successfully: \n{messages}") + print(f" ✅ Added to memory successfully: \n{result}") return result diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index f43ad01ba..d89df0b38 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -1,6 +1,7 @@ import concurrent.futures import copy import json +import os import re import traceback @@ -25,6 +26,7 @@ from memos.templates.mem_reader_prompts import ( CUSTOM_TAGS_INSTRUCTION, CUSTOM_TAGS_INSTRUCTION_ZH, + PROMPT_MAPPING, SIMPLE_STRUCT_DOC_READER_PROMPT, SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, SIMPLE_STRUCT_MEM_READER_EXAMPLE, @@ -80,6 +82,7 @@ def from_config(_config): "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, } + try: import tiktoken @@ -448,6 +451,81 @@ def get_memory( standard_scene_data = coerce_scene_data(scene_data, type) return self._read_memory(standard_scene_data, type, info, mode) + @staticmethod + def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from hallucination filter response. + Expected shape: { "0": {"if_delete": bool, "rewritten memory content": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + data = json.loads(text) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + # allow integer keys as-is + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + delete_flag = v.get("delete_flag") + rewritten = v.get("rewritten memory content", "") + if isinstance(delete_flag, bool) and isinstance(rewritten, str): + result[idx] = {"delete_flag": delete_flag, "rewritten memory content": rewritten} + + return (len(result) > 0), result + + def filter_hallucination_in_memories( + self, user_messages: list[str], memory_list: list[list[TextualMemoryItem]] + ): + filtered_memory_list = [] + for group in memory_list: + try: + flat_memories = [one.memory for one in group] + template = PROMPT_MAPPING["hallucination_filter"] + prompt_args = { + "user_messages_inline": "\n".join(user_messages), + "memories_inline": json.dumps(flat_memories, ensure_ascii=False, indent=2), + } + prompt = template.format(**prompt_args) + + # Optionally run filter and parse the output + try: + raw = self.llm.generate(prompt) + success, parsed = self._parse_hallucination_filter_response(raw) + logger.info(f"Hallucination filter parsed successfully: {success}") + new_mem_list = [] + if success: + logger.info(f"Hallucination filter result: {parsed}") + for mem_idx, (delete_flag, rewritten_mem_content) in parsed.items(): + if not delete_flag: + group[mem_idx].memory = rewritten_mem_content + new_mem_list.append(group[mem_idx]) + filtered_memory_list.append(new_mem_list) + logger.info( + f"Successfully transform origianl memories from {group} to {new_mem_list}." + ) + else: + logger.warning( + "Hallucination filter parsing failed or returned empty result." + ) + except Exception as e: + logger.error(f"Hallucination filter execution error: {e}", stack_info=True) + filtered_memory_list.append(group) + except Exception: + logger.error("Fail to filter memories", stack_info=True) + filtered_memory_list.append(group) + return filtered_memory_list + def _read_memory( self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: @@ -492,6 +570,14 @@ def _read_memory( except Exception as e: logger.error(f"Task failed with exception: {e}") logger.error(traceback.format_exc()) + + if os.getenv("SIMPLE_STRUCT_ADD_FILTER", "false") == "true": + # Build inputs + user_messages = [msg.content for msg in messages if msg.role == "user"] + memory_list = self.filter_hallucination_in_memories( + user_messages=user_messages, memory_list=memory_list + ) + return memory_list def fine_transfer_simple_mem( diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index 090e13f54..40e34fd4f 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -599,6 +599,8 @@ def create_test_add_request( messages=None, memory_content=None, session_id=None, + extract_mode=None, + async_mode="sync", ): """ Create a test APIADDRequest object with the given parameters. @@ -637,6 +639,8 @@ def create_test_add_request( source="api_analyzer_test", chat_history=None, operation=None, + mode=extract_mode, + async_mode=async_mode, ) def run_all_tests(self, mode=SearchMode.MIXTURE): diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 58765f055..8f8ac8b3b 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -140,12 +140,7 @@ def __init__(self, config: BaseSchedulerConfig): "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) self.orchestrator = SchedulerOrchestrator() - self.memos_message_queue = ScheduleTaskQueue( - use_redis_queue=self.use_redis_queue, - maxsize=self.max_internal_message_queue_size, - disabled_handlers=self.disabled_handlers, - orchestrator=self.orchestrator, - ) + self.searcher: Searcher | None = None self.retriever: SchedulerRetriever | None = None self.db_engine: Engine | None = None @@ -155,6 +150,13 @@ def __init__(self, config: BaseSchedulerConfig): self.status_tracker: TaskStatusTracker | None = None self.metrics = metrics self._monitor_thread = None + self.memos_message_queue = ScheduleTaskQueue( + use_redis_queue=self.use_redis_queue, + maxsize=self.max_internal_message_queue_size, + disabled_handlers=self.disabled_handlers, + orchestrator=self.orchestrator, + status_tracker=self.status_tracker, + ) self.dispatcher = SchedulerDispatcher( config=self.config, memos_message_queue=self.memos_message_queue, @@ -228,6 +230,8 @@ def initialize_modules( self.status_tracker = TaskStatusTracker(redis_client) if self.dispatcher: self.dispatcher.status_tracker = self.status_tracker + if self.memos_message_queue: + self.memos_message_queue.status_tracker = self.status_tracker # initialize submodules self.chat_llm = chat_llm self.process_llm = process_llm diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index e4e7edb89..aff725833 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -233,7 +233,7 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non def get( self, block: bool = True, timeout: float | None = None, batch_size: int | None = None - ) -> list[T] | T: + ) -> list[T]: """Get items from the queue. Args: diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index a147ebee0..fb3a5931a 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -62,10 +62,9 @@ class TaskPriorityLevel(Enum): # task queue -DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.7" -exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME", None) -if exchange_name is not None: - DEFAULT_STREAM_KEY_PREFIX += f":{exchange_name}" +DEFAULT_STREAM_KEY_PREFIX = os.getenv( + "MEMSCHEDULER_STREAM_KEY_PREFIX", "scheduler:messages:stream:v2.0" +) # ============== Running Tasks ============== diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index ab67c683f..928b2f5bd 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -273,6 +273,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): mem_cube_id=msg.mem_cube_id, task_label=msg.label, redis_message_id=redis_message_id, + message=msg, ) except Exception as ack_err: logger.warning(f"Ack in finally failed: {ack_err}") diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 2a2f9b046..c6a8c3d47 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -328,7 +328,12 @@ def put( raise def ack_message( - self, user_id: str, mem_cube_id: str, task_label: str, redis_message_id + self, + user_id: str, + mem_cube_id: str, + task_label: str, + redis_message_id, + message: ScheduleMessageItem | None, ) -> None: stream_key = self.get_stream_key( user_id=user_id, mem_cube_id=mem_cube_id, task_label=task_label @@ -347,6 +352,12 @@ def ack_message( try: self._redis_conn.xack(stream_key, self.consumer_group, redis_message_id) + + if message: + self.status_tracker.task_completed(task_id=message.item_id, user_id=message.user_id) + logger.info( + f"Message {message.item_id} | {message.label} | {message.content} has been acknowledged." + ) except Exception as e: logger.warning( f"xack failed for stream '{stream_key}', msg_id='{redis_message_id}': {e}" @@ -364,7 +375,7 @@ def get( stream_key: str, block: bool = True, timeout: float | None = None, - batch_size: int | None = None, + batch_size: int | None = 1, ) -> list[ScheduleMessageItem]: if not self._redis_conn: raise ConnectionError("Not connected to Redis. Redis connection not available.") @@ -385,7 +396,7 @@ def get( self.consumer_group, self.consumer_name, {stream_key: ">"}, - count=(batch_size if batch_size is not None else None), + count=batch_size, block=redis_timeout, ) except Exception as read_err: @@ -400,7 +411,7 @@ def get( self.consumer_group, self.consumer_name, {stream_key: ">"}, - count=(batch_size if batch_size is not None else None), + count=batch_size, block=redis_timeout, ) else: @@ -492,7 +503,7 @@ def get( raise Empty("No messages available in Redis queue") - return result_messages if batch_size is not None else result_messages[0] + return result_messages except Exception as e: if "Empty" in str(type(e).__name__): @@ -630,7 +641,7 @@ def start_listening( try: while self._is_listening: - messages = self.get(timeout=poll_interval, count=batch_size) + messages = self.get_messages(batch_size=1) for message in messages: try: diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 7c9139200..7dc19d01d 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -14,6 +14,7 @@ from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker logger = get_logger(__name__) @@ -26,10 +27,12 @@ def __init__( maxsize: int, disabled_handlers: list | None = None, orchestrator: SchedulerOrchestrator | None = None, + status_tracker: TaskStatusTracker | None = None, ): self.use_redis_queue = use_redis_queue self.maxsize = maxsize self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator + self.status_tracker = status_tracker if self.use_redis_queue: if maxsize is None or not isinstance(maxsize, int) or maxsize <= 0: @@ -51,6 +54,7 @@ def ack_message( mem_cube_id: str, task_label: str, redis_message_id, + message: ScheduleMessageItem | None, ) -> None: if not isinstance(self.memos_message_queue, SchedulerRedisQueue): logger.warning("ack_message is only supported for Redis queues") @@ -61,6 +65,7 @@ def ack_message( mem_cube_id=mem_cube_id, task_label=task_label, redis_message_id=redis_message_id, + message=message, ) def get_stream_keys(self) -> list[str]: diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 5a9a87acb..4ae0c207e 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -93,7 +93,11 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: for item in pref_results: item["cube_id"] = self.cube_id - return text_results + pref_results + all_memories = text_results + pref_results + + # TODO: search existing memories and compare + + return all_memories def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: # Create UserContext object @@ -692,7 +696,7 @@ def _process_text_mem( sync_mode=sync_mode, ) - return [ + text_memories = [ { "memory": memory.memory, "memory_id": memory_id, @@ -700,3 +704,5 @@ def _process_text_mem( } for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) ] + + return text_memories diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 50afb86f2..ffe6db2d0 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -417,3 +417,43 @@ - `memory_type` 保持英文。 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" + + +SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +You are a precise memory consistency auditor. + +# GOAL +Given user messages and an extracted memory list, identify and fix inconsistencies for each memory. + +# RULES +- Use ONLY information present in the user messages; do not invent. +- Preserve explicit facts: names, timestamps, quantities, locations. +- For each memory, keep the language identical to that memory's original language. +- Output only JSON. No extra commentary. + +# INPUTS +User messages: +{user_messages_inline} + +Current memory list (JSON): +{memories_inline} + +# OUTPUT FORMAT +Return a JSON object where keys are the 0-based indices of the input memories (string keys allowed), and each value is an object: +{ + "0": {"delete_flag": false, "rewritten memory content": "..."}, + "1": {"delete_flag": true, "rewritten memory content": ""}, + "2": {"delete_flag": false, "rewritten memory content": "..."} +} + +Notes: +- If a memory is entirely hallucinated or contradicted by user messages, set `if_delete` to true and leave `rewritten memory content` empty. +- If a memory conflicts but can be corrected, set `if_delete` to false and provide the corrected content in `"rewritten memory content"` using the memory's original language. +- If a memory is valid, set `if_delete` to false and return the original content. +""" + + +# Prompt mapping for specialized tasks (e.g., hallucination filtering) +PROMPT_MAPPING = { + "hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT, +}