From 83d157689c5d1dc4a2a10d8176e81dff97e2b9c6 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 14:38:19 +0800 Subject: [PATCH 01/14] fix bugs: try to fix bugs in _submit_web_logs --- src/memos/mem_scheduler/base_scheduler.py | 37 ++++++++++------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 81defaa0f..9ab356f1d 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -846,28 +846,23 @@ def _submit_web_logs( f"[DIAGNOSTIC] base_scheduler._submit_web_logs called. Message to publish: {message.model_dump_json(indent=2)}" ) - if self.rabbitmq_config is None: - logger.info( - "[DIAGNOSTIC] base_scheduler._submit_web_logs: RabbitMQ config not loaded; skipping publish." - ) - return - - for message in messages: - message_info = message.debug_info() - logger.info(f"[DIAGNOSTIC] base_scheduler._submit_web_logs: submitted {message_info}") + try: + for message in messages: + # Always call publish; the publisher now caches when offline and flushes after reconnect + logger.info( + f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}" + ) + self.rabbitmq_publish_message(message=message.to_dict()) + logger.info( + "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched " + "item_id=%s task_id=%s label=%s", + message.item_id, + message.task_id, + message.label, + ) + except Exception as e: + logger.error(f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True) - # Always call publish; the publisher now caches when offline and flushes after reconnect - logger.info( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message_info}" - ) - self.rabbitmq_publish_message(message=message.to_dict()) - logger.info( - "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched " - "item_id=%s task_id=%s label=%s", - message.item_id, - message.task_id, - message.label, - ) logger.debug( f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}" ) From e50c56cf817cb6d63b8e8e882aeaa4de12c444b8 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 15:00:28 +0800 Subject: [PATCH 02/14] fix bugs: try to address bugs --- src/memos/mem_scheduler/base_scheduler.py | 13 +++++-------- .../webservice_modules/rabbitmq_service.py | 6 ++++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 9ab356f1d..1e0ecaadb 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -842,12 +842,7 @@ def _submit_web_logs( messages = [messages] # transform single message to list for message in messages: - logger.info( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs called. Message to publish: {message.model_dump_json(indent=2)}" - ) - - try: - for message in messages: + try: # Always call publish; the publisher now caches when offline and flushes after reconnect logger.info( f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}" @@ -860,8 +855,10 @@ def _submit_web_logs( message.task_id, message.label, ) - except Exception as e: - logger.error(f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True) + except Exception as e: + logger.error( + f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True + ) logger.debug( f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}" diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index a8a09760c..db8320879 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -7,6 +7,8 @@ from pathlib import Path from queue import Empty +from pyglet.libs.win32.constants import FALSE + from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig from memos.context.context import ContextThread from memos.dependency import require_python_package @@ -325,14 +327,14 @@ def rabbitmq_publish_message(self, message: dict): f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. " f"Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") elif label == "knowledgeBaseUpdate": # Original diagnostic logging for knowledgeBaseUpdate if NOT in cloud env logger.info( f"[DIAGNOSTIC] Publishing knowledgeBaseUpdate message (Local Env). " f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") with self._rabbitmq_lock: logger.info( From 58eb6b81af34437677e929e629f25dd3ddf0c1ff Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 15:13:21 +0800 Subject: [PATCH 03/14] fix bugs --- .../mem_scheduler/webservice_modules/rabbitmq_service.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index db8320879..43d24c5b9 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -7,8 +7,6 @@ from pathlib import Path from queue import Empty -from pyglet.libs.win32.constants import FALSE - from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig from memos.context.context import ContextThread from memos.dependency import require_python_package @@ -327,14 +325,14 @@ def rabbitmq_publish_message(self, message: dict): f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. " f"Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") elif label == "knowledgeBaseUpdate": # Original diagnostic logging for knowledgeBaseUpdate if NOT in cloud env logger.info( f"[DIAGNOSTIC] Publishing knowledgeBaseUpdate message (Local Env). " f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") with self._rabbitmq_lock: logger.info( From 0d72ce7669f3a9b30aa6849893a0e6ec6f991063 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 15:59:20 +0800 Subject: [PATCH 04/14] refactor: modify examples --- examples/mem_scheduler/memos_w_scheduler.py | 40 --------------------- 1 file changed, 40 deletions(-) diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index 09aec4cba..ef7d853df 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -4,7 +4,6 @@ from datetime import datetime from pathlib import Path -from queue import Queue from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig @@ -12,7 +11,6 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS -from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -160,42 +158,6 @@ def _first_content() -> str: return title, _truncate_with_rules(_first_content()) -def show_web_logs(mem_scheduler: GeneralScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - collected: list[ScheduleLogForWebItem] = [] - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - collected.append(log_item) - temp_queue.put(log_item) - - for idx, log_item in enumerate(sorted(collected, key=lambda x: x.timestamp, reverse=True), 1): - title, content = _format_entry(log_item) - print(f"\nLog Entry #{idx}:") - print(title) - print(content) - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {len(collected)} web log entries displayed.") - print("=" * 110 + "\n") - - def run_with_scheduler_init(): print("==== run_with_automatic_scheduler_init ====") conversations, questions = init_task() @@ -253,8 +215,6 @@ def run_with_scheduler_init(): response = mos.chat(query=query, user_id=user_id) print(f"Answer:\n {response}\n") - show_web_logs(mem_scheduler=mos.mem_scheduler) - mos.mem_scheduler.stop() From 2fe965be240ea0e68c511b5573d88e9599b7cbd2 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 20:06:40 +0800 Subject: [PATCH 05/14] revise add operation and fix an unbelievable bug --- .../mem_scheduler/try_schedule_modules.py | 47 ------------------- src/memos/mem_reader/simple_struct.py | 2 +- .../webservice_modules/rabbitmq_service.py | 3 +- src/memos/templates/mem_reader_prompts.py | 39 ++++++++------- 4 files changed, 21 insertions(+), 70 deletions(-) diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index a5c5bc737..d942aad4e 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -1,8 +1,6 @@ import sys from pathlib import Path -from queue import Queue -from typing import TYPE_CHECKING from tqdm import tqdm @@ -11,18 +9,11 @@ ) from memos.log import get_logger from memos.mem_scheduler.analyzer.api_analyzer import DirectSearchMemoriesAnalyzer -from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import MEM_UPDATE_TASK_LABEL -if TYPE_CHECKING: - from memos.mem_scheduler.schemas import ( - ScheduleLogForWebItem, - ) - - FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory @@ -105,41 +96,6 @@ def init_task(): return conversations, questions -def show_web_logs(mem_scheduler: BaseScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - log_count = 0 - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - temp_queue.put(log_item) - log_count += 1 - - # Print log entry details - print(f"\nLog Entry #{log_count}:") - print(f'- "{log_item.label}" log: {log_item}') - - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {log_count} web log entries displayed.") - print("=" * 110 + "\n") - - class ScheduleModulesRunner(DirectSearchMemoriesAnalyzer): def __init__(self): super().__init__() @@ -215,6 +171,3 @@ def add_msgs( mem_scheduler._memory_update_consumer( messages=[message], ) - - # Show accumulated web logs - show_web_logs(mem_scheduler) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index ac79c246b..b870bf70a 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -522,7 +522,7 @@ def filter_hallucination_in_memories( raw = self.llm.generate([{"role": "user", "content": prompt}]) success, parsed = self._parse_hallucination_filter_response(raw) logger.info( - f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success}" + f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" ) if success: logger.info(f"Hallucination filter result: {parsed}") diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 43d24c5b9..46b2ad3d1 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -108,8 +108,7 @@ def initialize_rabbitmq( elif Path(config_path).exists(): auth_config = AuthConfig.from_local_config(config_path=config_path) else: - logger.error("Fail to initialize auth_config") - return + auth_config = AuthConfig.from_local_env() self.rabbitmq_config = auth_config.rabbitmq elif isinstance(config, RabbitMQConfig): self.rabbitmq_config = config diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 12c445df7..fef3ee6c0 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -625,21 +625,20 @@ SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. -Your task is to compare each memory against the provided user messages (the ground truth) and produce a corrected version only when necessary. Always preserve the original language of the memory—do not translate. +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. Rules: -1. **Language Consistency**: The rewritten memory must be in the exact same language as the original input memory. Never switch languages. -2. **Strict Grounding**: Only use information explicitly stated in the user messages. Do not introduce external facts, assumptions, or common sense. -3. **Ambiguity Resolution**: - - Replace vague pronouns (e.g., "he", "it", "they") or unclear references with specific, unambiguous entities based solely on the messages. - - Convert relative time expressions (e.g., "yesterday", "last week", "in two days") into absolute dates or times **only if the messages provide enough context** (e.g., current date is known or implied). -4. **Handling Assistant Inferences**: - - If a memory contains any content **not directly stated by the user**—such as interpretations, summaries, emotional attributions, predictions, causal claims, or generalizations—this is considered an assistant inference. - - In such cases, you **must** set `need_rewrite = true`. - - The `rewritten` text **must explicitly indicate that the statement is an inference**, using a clear and natural prefix in the memory’s language. For English memories, use: - > "The assistant inferred that [rest of the memory]." - - Do **not** present inferred content as factual user statements. -5. **No Rewrite Needed**: If the memory is factually accurate, fully grounded in the messages, unambiguous, and contains no unsupported content, set `need_rewrite = false` and copy the original memory exactly. +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations. +3. **Ambiguity Elimination**: + - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them. + - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context. +4. **Hallucination Removal**: + - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten. + - Do **not** rephrase inferences as facts. Instead, either: + - Remove the unsupported part and retain only the grounded core, or + - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. Inputs: messages: @@ -649,15 +648,15 @@ {memories_inline} Output Format: -- Return a JSON object with string keys ("0", "1", "2", ...) corresponding to the input memory indices. +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. - Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} -- The "reason" should be concise and specific, e.g.: - - "contains assistant inference not stated by user" - - "pronoun 'it' has no clear referent in messages" - - "relative time 'yesterday' converted to 2025-12-16" - - "accurate and directly supported by user message" +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference" + - "vague pronoun with no referent in messages" + - "relative time resolved to 2025-12-16" + - "fully grounded and concise" -Important: Output **only** the JSON. No additional text, explanations, markdown, or fields. +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ From eecfa5136d9065fefd82867068a8deb12efae8a2 Mon Sep 17 00:00:00 2001 From: chentang Date: Mon, 22 Dec 2025 10:37:34 +0800 Subject: [PATCH 06/14] address the bug issues --- .../task_schedule_modules/redis_queue.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index ed8171ade..1c57f18f0 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -699,27 +699,23 @@ def _batch_claim_pending_messages( results = [] try: results = pipe.execute() - except Exception as e: - err_msg = str(e).lower() - if "nogroup" in err_msg or "no such key" in err_msg: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue - else: - logger.error(f"Pipeline xautoclaim failed: {e}") + except Exception: + # Fallback: attempt sequential xautoclaim for robustness + for stream_key, need_count, label in claims_spec: + try: + self._ensure_consumer_group(stream_key=stream_key) + res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + results.append(res) + except Exception: + continue claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( From f2da3a7bd718cf663b29fb285e602e847f4dc91a Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 19:35:40 +0800 Subject: [PATCH 07/14] the doc file has a format problem which has been fixed in this commit --- docs/README.md | 2 +- .../{task_stop_rerun.py => scheduler_for_async_tasks.py} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename examples/mem_scheduler/{task_stop_rerun.py => scheduler_for_async_tasks.py} (98%) diff --git a/docs/README.md b/docs/README.md index bf5fea70d..8be17ffb7 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,3 +1,3 @@ All documentation has been moved to a separate repository: https://github.com/MemTensor/MemOS-Docs. Please edit documentation there. -所有文档已迁移至独立仓库:https://github.com/MemTensor/MemOS-Docs。请在该仓库中编辑文档。 +所有文档已迁移至独立仓库 https://github.com/MemTensor/MemOS-Docs 。请在该仓库中编辑文档。 diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/scheduler_for_async_tasks.py similarity index 98% rename from examples/mem_scheduler/task_stop_rerun.py rename to examples/mem_scheduler/scheduler_for_async_tasks.py index b5e62ff8f..a767b57c4 100644 --- a/examples/mem_scheduler/task_stop_rerun.py +++ b/examples/mem_scheduler/scheduler_for_async_tasks.py @@ -25,7 +25,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): task_id = str(msg.item_id) file_path = tmp_dir / f"{task_id}.txt" try: - sleep(1) + sleep(5) file_path.write_text(f"Task {task_id} processed.\n") print(f"writing {file_path} done") except Exception as e: @@ -58,7 +58,7 @@ def submit_tasks(): mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) # 10s to restart -mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 10_000 +mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 5_000 tmp_dir = Path("./tmp") tmp_dir.mkdir(exist_ok=True) @@ -88,6 +88,6 @@ def submit_tasks(): print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})") # 7. Stop the scheduler +sleep(20) print("Stopping the scheduler...") -sleep(5) mem_scheduler.stop() From a6881b4b064145f032c9b9e58ed0f9772ef33612 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 19:45:48 +0800 Subject: [PATCH 08/14] add a range of new feats for the add operation --- src/memos/api/config.py | 21 +- src/memos/llms/openai.py | 6 +- src/memos/mem_reader/simple_struct.py | 354 ++++++++++-------- src/memos/mem_reader/utils.py | 210 +++++++++++ .../mem_scheduler/schemas/general_schemas.py | 4 +- .../task_schedule_modules/redis_queue.py | 69 ++-- .../textual/prefer_text_memory/extractor.py | 4 + src/memos/templates/mem_reader_prompts.py | 156 +++++++- 8 files changed, 624 insertions(+), 200 deletions(-) create mode 100644 src/memos/mem_reader/utils.py diff --git a/src/memos/api/config.py b/src/memos/api/config.py index b795c2be6..0cdcb9a92 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -7,16 +7,19 @@ import re import time -from typing import Any +from typing import TYPE_CHECKING, Any import requests from dotenv import load_dotenv -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig from memos.context.context import ContextThread -from memos.mem_cube.general import GeneralMemCube + + +if TYPE_CHECKING: + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube # Load environment variables @@ -805,8 +808,12 @@ def get_start_default_config() -> dict[str, Any]: return config @staticmethod - def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]: + def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "GeneralMemCube"]: """Create configuration for a specific user.""" + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube + openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() @@ -933,12 +940,14 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General return default_config, default_mem_cube @staticmethod - def get_default_cube_config() -> GeneralMemCubeConfig | None: + def get_default_cube_config() -> "GeneralMemCubeConfig | None": """Get default cube configuration for product initialization. Returns: GeneralMemCubeConfig | None: Default cube configuration if enabled, None otherwise. """ + from memos.configs.mem_cube import GeneralMemCubeConfig + if not APIConfig.is_default_cube_config_enabled(): return None diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1d180eebd..752386c91 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -57,8 +57,8 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) if reasoning_content: - return reasoning_content + response_content - return response_content + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM", @@ -146,7 +146,7 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) else: - return response_content + return response_content or "" def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from Azure OpenAI LLM with optional reasoning support.""" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b870bf70a..866b6d988 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -2,7 +2,6 @@ import copy import json import os -import re import traceback from abc import ABC @@ -18,6 +17,13 @@ from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang +from memos.mem_reader.utils import ( + count_tokens_text, + derive_key, + parse_json_result, + parse_keep_filter_response, + parse_rewritten_response, +) from memos.memories.textual.item import ( SourceMessage, TextualMemoryItem, @@ -89,27 +95,6 @@ def from_config(_config): } -try: - import tiktoken - - try: - _ENC = tiktoken.encoding_for_model("gpt-4o-mini") - except Exception: - _ENC = tiktoken.get_encoding("cl100k_base") - - def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "", disallowed_special=())) -except Exception: - # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars - def _count_tokens_text(s: str) -> int: - if not s: - return 0 - zh_chars = re.findall(r"[\u4e00-\u9fff]", s) - zh = len(zh_chars) - rest = len(s) - zh - return zh + max(1, rest // 4) - - def _build_node(idx, message, info, source_info, llm, parse_json_result, embedder): # generate try: @@ -172,14 +157,6 @@ def _build_node(idx, message, info, source_info, llm, parse_json_result, embedde return None -def _derive_key(text: str, max_len: int = 80) -> str: - """default key when without LLM: first max_len words""" - if not text: - return "" - sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] - return (sent[:max_len]).strip() - - class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" @@ -197,7 +174,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.memory_max_length = 8000 # Use token-based windowing; default to ~5000 tokens if not configured self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) - self._count_tokens = _count_tokens_text + self._count_tokens = count_tokens_text + self.searcher = None def _make_memory_item( self, @@ -224,7 +202,7 @@ def _make_memory_item( memory_type=memory_type, status="activated", tags=tags or [], - key=key if key is not None else _derive_key(value), + key=key if key is not None else derive_key(value), embedding=self.embedder.embed([value])[0], usage=[], sources=sources or [], @@ -254,7 +232,7 @@ def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { @@ -456,47 +434,73 @@ def get_memory( standard_scene_data = coerce_scene_data(scene_data, type) return self._read_memory(standard_scene_data, type, info, mode) - @staticmethod - def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: - """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } - Returns (success, parsed_dict) with int keys. - """ + def rewrite_memories( + self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + if user_only: + template = PROMPT_MAPPING["rewrite_user_only"] + filtered_messages = [m for m in messages if m.get("role") != "assistant"] + if len(filtered_messages) < 1: + return memory_list + else: + template = PROMPT_MAPPING["rewrite"] + filtered_messages = messages + if len(filtered_messages) < 2: + return memory_list + + prompt_args = { + "messages_inline": "\n".join( + [f"- [{message['role']}]: {message['content']}" for message in filtered_messages] + ), + "memories_inline": json.dumps( + {idx: mem.memory for idx, mem in enumerate(memory_list)}, + ensure_ascii=False, + indent=2, + ), + } + prompt = template.format(**prompt_args) + + # Optionally run filter and parse the output try: - data = json.loads(text) - except Exception: - return False, {} + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed = parse_rewritten_response(raw) + logger.info( + f"[rewrite_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" + ) + if success: + logger.info(f"Rewrite filter result: {parsed}") - if not isinstance(data, dict): - return False, {} + new_memory_list = [] + for mem_idx, content in parsed.items(): + if mem_idx < 0 or mem_idx >= len(memory_list): + logger.warning( + f"[rewrite_memories] Invalid memory index {mem_idx} for memory_list {len(memory_list)}, skipping." + ) + continue - result: dict[int, dict] = {} - for k, v in data.items(): - try: - idx = int(k) - except Exception: - # allow integer keys as-is - if isinstance(k, int): - idx = k - else: - continue - if not isinstance(v, dict): - continue - need_rewrite = v.get("need_rewrite") - rewritten = v.get("rewritten", "") - reason = v.get("reason", "") - if ( - isinstance(need_rewrite, bool) - and isinstance(rewritten, str) - and isinstance(reason, str) - ): - result[idx] = { - "need_rewrite": need_rewrite, - "rewritten": rewritten, - "reason": reason, - } + need_rewrite = content.get("need_rewrite", False) + rewritten_text = content.get("rewritten", "") + reason = content.get("reason", "") + original_text = memory_list[mem_idx].memory + + # Replace memory text with rewritten content when rewrite is needed + if need_rewrite and isinstance(rewritten_text, str): + logger.info( + f"[rewrite_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + ) + if len(rewritten_text.strip()) != 0: + memory_list[mem_idx].memory = rewritten_text + new_memory_list.append(memory_list[mem_idx]) + else: + new_memory_list.append(memory_list[mem_idx]) + return new_memory_list + else: + logger.warning("Rewrite filter parsing failed or returned empty result.") + except Exception as e: + logger.error(f"Rewrite filter execution error: {e}", stack_info=True) - return (len(result) > 0), result + return memory_list def filter_hallucination_in_memories( self, messages: list[dict], memory_list: list[TextualMemoryItem] @@ -520,32 +524,32 @@ def filter_hallucination_in_memories( # Optionally run filter and parse the output try: raw = self.llm.generate([{"role": "user", "content": prompt}]) - success, parsed = self._parse_hallucination_filter_response(raw) + success, parsed = parse_keep_filter_response(raw) logger.info( f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" ) if success: logger.info(f"Hallucination filter result: {parsed}") - assert len(parsed) == len(memory_list) - for mem_idx, content in parsed.items(): - need_rewrite = content.get("need_rewrite", False) - rewritten_text = content.get("rewritten", "") - reason = content.get("reason", "") - # Replace memory text with rewritten content when rewrite is needed - if ( - need_rewrite - and isinstance(rewritten_text, str) - and len(rewritten_text.strip()) > 0 - ): - original_text = memory_list[mem_idx].memory + filtered_list = [] + for mem_idx, mem in enumerate(memory_list): + content = parsed.get(mem_idx) + if not content: + logger.warning(f"No verdict for memory {mem_idx}, keeping it.") + filtered_list.append(mem) + continue + keep = content.get("keep", True) + reason = content.get("reason", "") + + if keep: + filtered_list.append(mem) + else: logger.info( - f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + f"[filter_hallucination_in_memories] Dropping memory index={mem_idx}, reason='{reason}', memory='{mem.memory}'" ) - memory_list[mem_idx].memory = rewritten_text - return memory_list + return filtered_list else: logger.warning("Hallucination filter parsing failed or returned empty result.") except Exception as e: @@ -553,6 +557,103 @@ def filter_hallucination_in_memories( return memory_list + def add_before_search( + self, + messages: list[dict], + memory_list: list[TextualMemoryItem], + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + template = PROMPT_MAPPING["add_before_search"] + + if not self.searcher: + try: + from memos.mem_reader.utils import init_searcher + + self.searcher = init_searcher(self.llm, self.embedder) + except Exception as e: + logger.error(f"[add_before_search] Failed to init searcher: {e}") + return memory_list + + # 1. Gather candidates and search for related memories + candidates_data = [] + for idx, mem in enumerate(memory_list): + try: + related_memories = self.searcher.search( + query=mem.memory, top_k=3, mode="fast", info={"user_id": "", "session_id": ""} + ) + related_text = "None" + if related_memories: + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) + + candidates_data.append( + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} + ) + except Exception as e: + logger.error(f"[add_before_search] Search error for memory '{mem.memory}': {e}") + # If search fails, we can either skip this check or treat related as empty + candidates_data.append( + { + "idx": idx, + "new_memory": mem.memory, + "related_memories": "None (Search Failed)", + } + ) + + if not candidates_data: + return memory_list + + # 2. Build Prompt + messages_inline = "\n".join( + [ + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" + for message in messages + ] + ) + + candidates_inline_dict = { + str(item["idx"]): { + "new_memory": item["new_memory"], + "related_memories": item["related_memories"], + } + for item in candidates_data + } + + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) + + prompt = template.format( + messages_inline=messages_inline, candidates_inline=candidates_inline + ) + + # 3. Call LLM + try: + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed_result = parse_keep_filter_response(raw) + + if not success: + logger.warning("[add_before_search] Failed to parse LLM response, keeping all.") + return memory_list + + # 4. Filter + filtered_list = [] + for idx, mem in enumerate(memory_list): + res = parsed_result.get(idx) + if not res: + filtered_list.append(mem) + continue + + if res.get("keep", True): + filtered_list.append(mem) + else: + logger.info( + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" + ) + + return filtered_list + + except Exception as e: + logger.error(f"[add_before_search] LLM execution error: {e}") + return memory_list + def _read_memory( self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: @@ -606,29 +707,27 @@ def _read_memory( for group_id in range(len(memory_list)): try: - revised_memory_list = self.filter_hallucination_in_memories( + original_memory_group = copy.deepcopy(memory_list[group_id]) + serialized_origin_memories = json.dumps( + [one.memory for one in original_memory_group], indent=2 + ) + revised_memory_list = self.rewrite_memories( messages=combined_messages, - memory_list=memory_list[group_id], + memory_list=original_memory_group, + user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() + == "true", + ) + serialized_revised_memories = json.dumps( + [one.memory for one in revised_memory_list], indent=2 ) - if len(revised_memory_list) != len(memory_list[group_id]): - original_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in memory_list[group_id] - ] - filtered_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in revised_memory_list - ] - logger.error( - f"Length mismatch after hallucination filtering for group_id={group_id}: " - f"original={len(memory_list[group_id])}, filtered={len(revised_memory_list)}" - f"\noriginal_memory_list(serialized): {original_serialized}" - f"\nfiltered_memory_list(serialized): {filtered_serialized}" - f"\nmessages: {combined_messages}" - f"\nSkipping update and keeping original memory." + if serialized_origin_memories != serialized_revised_memories: + memory_list[group_id] = revised_memory_list + logger.info( + f"[SIMPLE_STRUCT_ADD_FILTER] Modified the list for group_id={group_id}: " + f"\noriginal={serialized_origin_memories}," + f"\nrevised={serialized_revised_memories}" ) - continue - memory_list[group_id] = revised_memory_list + except Exception as e: group_serialized = [ one.memory if hasattr(one, "memory") else str(one) @@ -847,7 +946,7 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): info, source_info_list, self.llm, - self.parse_json_result, + parse_json_result, self.embedder, ): idx for idx, msg in enumerate(messages) @@ -870,44 +969,3 @@ def _process_transfer_doc_data( self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None ): raise NotImplementedError - - def parse_json_result(self, response_text: str) -> dict: - s = (response_text or "").strip() - - m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) - s = (m.group(1) if m else s.replace("```", "")).strip() - - i = s.find("{") - if i == -1: - return {} - s = s[i:].strip() - - try: - return json.loads(s) - except json.JSONDecodeError: - pass - - j = max(s.rfind("}"), s.rfind("]")) - if j != -1: - try: - return json.loads(s[: j + 1]) - except json.JSONDecodeError: - pass - - def _cheap_close(t: str) -> str: - t += "}" * max(0, t.count("{") - t.count("}")) - t += "]" * max(0, t.count("[") - t.count("]")) - return t - - t = _cheap_close(s) - try: - return json.loads(t) - except json.JSONDecodeError as e: - if "Invalid \\escape" in str(e): - s = s.replace("\\", "\\\\") - return json.loads(s) - logger.error( - f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ - json: {s}" - ) - return {} diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py new file mode 100644 index 000000000..843345ec4 --- /dev/null +++ b/src/memos/mem_reader/utils.py @@ -0,0 +1,210 @@ +import json +import os +import re + +from typing import Any + +from memos import log +from memos.api.config import APIConfig +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.factory import RerankerFactory + + +logger = log.get_logger(__name__) + +try: + import tiktoken + + try: + _ENC = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + _ENC = tiktoken.get_encoding("cl100k_base") + + def count_tokens_text(s: str) -> int: + return len(_ENC.encode(s or "", disallowed_special=())) +except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + def count_tokens_text(s: str) -> int: + if not s: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", s) + zh = len(zh_chars) + rest = len(s) - zh + return zh + max(1, rest // 4) + + +def derive_key(text: str, max_len: int = 80) -> str: + """default key when without LLM: first max_len words""" + if not text: + return "" + sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] + return (sent[:max_len]).strip() + + +def parse_json_result(response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.error( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) + return {} + + +def parse_rewritten_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from hallucination filter response. + Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + # allow integer keys as-is + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + need_rewrite = v.get("need_rewrite") + rewritten = v.get("rewritten", "") + reason = v.get("reason", "") + if ( + isinstance(need_rewrite, bool) + and isinstance(rewritten, str) + and isinstance(reason, str) + ): + result[idx] = { + "need_rewrite": need_rewrite, + "rewritten": rewritten, + "reason": reason, + } + + return (len(result) > 0), result + + +def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from keep filter response. + Expected shape: { "0": {"keep": bool, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + keep = v.get("keep") + reason = v.get("reason", "") + if isinstance(keep, bool): + result[idx] = { + "keep": keep, + "reason": reason, + } + return (len(result) > 0), result + + +def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + "polardb": APIConfig.get_polardb_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def build_reranker_config() -> dict[str, Any]: + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def init_searcher(llm, embedder) -> Searcher: + """Initialize a Searcher instance for SimpleStructMemReader.""" + + # Build configs + graph_db_config = build_graph_db_config() + reranker_config = build_reranker_config() + + # Create instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + reranker = RerankerFactory.from_config(reranker_config) + + # Create Searcher + searcher = Searcher( + dispatcher_llm=llm, + graph_store=graph_db, + embedder=embedder, + reranker=reranker, + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + ) + + return searcher diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f4ad9fe48..06910ba17 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,3 +1,5 @@ +import os + from pathlib import Path @@ -21,7 +23,7 @@ DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 DEFAULT_TOP_K = 5 DEFAULT_CONTEXT_WINDOW_SIZE = 5 -DEFAULT_USE_REDIS_QUEUE = True +DEFAULT_USE_REDIS_QUEUE = os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true" DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20 DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1c57f18f0..7923b3750 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -679,11 +679,6 @@ def _batch_claim_pending_messages( if not self._redis_conn or not claims_spec: return [] - # Ensure consumer groups exist to avoid NOGROUP errors during batch claim - for stream_key, _need_count, _label in claims_spec: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - pipe = self._redis_conn.pipeline(transaction=False) for stream_key, need_count, label in claims_spec: pipe.xautoclaim( @@ -696,26 +691,42 @@ def _batch_claim_pending_messages( justid=False, ) - results = [] try: - results = pipe.execute() - except Exception: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue + # Execute with raise_on_error=False so we get exceptions in the results list + # instead of aborting the whole batch. + results = pipe.execute(raise_on_error=False) + except Exception as e: + logger.error(f"Pipeline execution critical failure: {e}") + results = [e] * len(claims_spec) + + # Handle individual failures (e.g. NOGROUP) by retrying just that stream + final_results = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + err_msg = str(res).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + stream_key, need_count, label = claims_spec[i] + try: + self._ensure_consumer_group(stream_key=stream_key) + retry_res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + final_results.append(retry_res) + except Exception as retry_err: + logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}") + final_results.append(None) + else: + final_results.append(None) + else: + final_results.append(res) + + results = final_results claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( @@ -1189,9 +1200,7 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) - logger.info( - f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', " - f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, " - f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, " - f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}" - ) + logger.info( + f"Refreshed stream keys cache: {cache_count} active keys, " + f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." + ) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 144bfad7f..3404c6d4c 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -69,6 +69,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: @@ -92,6 +94,8 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index fef3ee6c0..40971c77e 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,23 +622,56 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. +Note: The provided messages contain only user messages. The assistant's responses are intentionally omitted, not because the assistant didn't answer, but to focus strictly on validating memories against user input. + Rules: 1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. -2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations. -3. **Ambiguity Elimination**: - - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them. - - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. 4. **Hallucination Removal**: - - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten. - - Do **not** rephrase inferences as facts. Instead, either: - - Remove the unsupported part and retain only the grounded core, or - - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit. +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. 5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. Inputs: messages: @@ -651,16 +684,115 @@ - Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. - Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} - The "reason" must be brief and precise, e.g.: - - "contains unsupported inference" - - "vague pronoun with no referent in messages" - - "relative time resolved to 2025-12-16" + - "contains unsupported inference ...." - "fully grounded and concise" Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +You are a strict memory validator. +Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages. + +Rules: +1. **User-Only Origin**: Verify facts against USER messages ONLY. If the Assistant repeats a User fact, it is VALID. If the Assistant introduces a new detail (e.g., 'philanthropy') that the User did not explicitly confirm, it is INVALID. +2. **No Inference Allowed**: Do NOT keep memories based on implication, emotion, preference, or generalization. Only verbatim or direct restatements of user-provided facts are valid. However, minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +3. **Hallucination = Deletion**: If a memory contains any detail not directly expressed by the user, mark it for deletion. +4. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Examples: +Messages: +- [user]: I love coding in Python. +- [assistant]: That's great! I assume you also contribute to open source projects? +Memory: User enjoys Python and contributes to open source. +Result: {{"keep": false, "reason": "User never stated they contribute to open source; this came from Assistant's assumption."}} + +Messages: +- [user]: I am tired. +- [assistant]: I hear you are tired. Rest is important. +Memory: User stated they are tired. +Result: {{"keep": true, "reason": "Direct restatement of user input, even if Assistant repeated it."}} + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true only if the memory is a direct reflection of the user's explicit words. +- "reason": brief, factual, and cites missing or unsupported content. + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + + +SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT = """ +You are a memory manager. +Your task is to decide if a new memory should be added to the long-term memory, given a list of existing related memories. + +Rules: +1. **Redundancy Check**: If the new memory is completely redundant, already known, or covered by the existing memories, discard it. +2. **New Information**: If the new memory provides new information, details, or updates compared to the existing memories, keep it. +3. **Contradiction**: If the new memory contradicts existing memories but seems valid/newer, keep it (updates). +4. **Context Check**: Use the provided conversation messages to verify if the new memory is grounded in the user's explicit statements. + +Inputs: +Messages: +{messages_inline} + +Candidate Memories (to be evaluated): +{candidates_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input candidate memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true if the memory should be added. +- "reason": brief explanation. + +Important: Output **only** the JSON. No extra text. +""" # Prompt mapping for specialized tasks (e.g., hallucination filtering) PROMPT_MAPPING = { "hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT, + "rewrite": SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT, + "rewrite_user_only": SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT, + "add_before_search": SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT, } From 7f39e7ecc052d2e85e7bbeb2ca73f586db143875 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:04:56 +0800 Subject: [PATCH 09/14] address the incompatible issue of local scheduler --- src/memos/mem_scheduler/base_scheduler.py | 20 +++-- .../task_schedule_modules/local_queue.py | 75 +++++++++++++++++-- .../task_schedule_modules/task_queue.py | 23 +----- .../mem_scheduler/utils/status_tracker.py | 26 ++++++- tests/test_local_queue_full.py | 54 +++++++++++++ 5 files changed, 164 insertions(+), 34 deletions(-) create mode 100644 tests/test_local_queue_full.py diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 1e0ecaadb..728203f5b 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1009,14 +1009,24 @@ def _monitor_loop(self): q_sizes = self.memos_message_queue.qsize() for stream_key, queue_length in q_sizes.items(): - # Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}" + # Skip aggregate keys like 'total_size' + if stream_key == "total_size": + continue + + # Key format: ...:{user_id}:{mem_cube_id}:{task_label} + # We want to extract user_id, which is the 3rd component from the end. parts = stream_key.split(":") if len(parts) >= 3: - user_id = parts[2] - self.metrics.update_queue_length(queue_length, user_id) - elif not self.use_redis_queue: # local queue - user_id = stream_key + user_id = parts[-3] self.metrics.update_queue_length(queue_length, user_id) + else: + # Fallback for unexpected key formats (e.g. legacy or testing) + # Try to use the key itself if it looks like a user_id (no colons) + # or just log a warning? + # For now, let's assume if it's not total_size and short, it might be a direct user_id key + # (though that shouldn't happen with current queue implementations) + if ":" not in stream_key: + self.metrics.update_queue_length(queue_length, stream_key) except Exception as e: logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 69cfc0af9..32d79cef3 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -62,7 +62,7 @@ def put( Exception: Any underlying error during queue.put() operation. """ stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) message.stream_key = stream_key @@ -108,35 +108,95 @@ def get( ) return res - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait( + self, stream_key: str, batch_size: int | None = None + ) -> list[ScheduleMessageItem]: """ - Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). Returns immediately with available messages or an empty list if queue is empty. Args: + stream_key (str): The stream/queue identifier. batch_size (int | None): Number of messages to retrieve in a batch. If None, retrieves one message. Returns: List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. """ - logger.debug(f"get_nowait() called with batch_size: {batch_size}") - return self.get(block=False, batch_size=batch_size) + logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}") + return self.get(stream_key=stream_key, block=False, batch_size=batch_size) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + """ + Get messages from all streams in round-robin or sequential fashion. + Equivalent to SchedulerRedisQueue.get_messages. + """ + messages = [] + # Snapshot keys to avoid runtime modification issues + stream_keys = list(self.queue_streams.keys()) + + # Simple strategy: try to get up to batch_size messages across all streams + # We can just iterate and collect. + + # Calculate how many to get per stream to be fair? + # Or just greedy? Redis implementation uses a complex logic. + # For local, let's keep it simple: just iterate and take what's available (non-blocking) + + for stream_key in stream_keys: + if len(messages) >= batch_size: + break + + needed = batch_size - len(messages) + # Use get_nowait to avoid blocking + fetched = self.get_nowait(stream_key=stream_key, batch_size=needed) + messages.extend(fetched) + + return messages def qsize(self) -> dict: """ Return the current size of all internal queues as a dictionary. Each key is the stream name, and each value is the number of messages in that queue. + Also includes 'total_size'. Returns: Dict[str, int]: Mapping from stream name to current queue size. """ sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + total_size = sum(sizes.values()) + sizes["total_size"] = total_size logger.debug(f"Current queue sizes: {sizes}") return sizes + def size(self) -> int: + """ + Get the current size of the queue (total message count). + Compatible with SchedulerRedisQueue. + """ + return self.unfinished_tasks + + def empty(self) -> bool: + """ + Check if the queue is empty. + Compatible with SchedulerRedisQueue. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if the queue is full. + Compatible with SchedulerRedisQueue. + + Returns True if all internal queues are full. + If there are no queues, returns False. + """ + if not self.queue_streams: + return False + + return all(queue.full() for queue in self.queue_streams.values()) + def clear(self) -> None: for queue in self.queue_streams.values(): queue.clear() @@ -151,6 +211,9 @@ def unfinished_tasks(self) -> int: Returns: int: Sum of all message counts in all internal queues. """ - total = sum(self.qsize().values()) + # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values + # But qsize() implementation above sums values from queue_streams, then adds total_size. + # So sum(self.queue_streams.values().qsize()) is safer. + total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index c20243242..b49db2b36 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -153,28 +153,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_messages(batch_size=batch_size) - stream_keys = self.get_stream_keys() - - if len(stream_keys) == 0: - return [] - - messages: list[ScheduleMessageItem] = [] - - for stream_key in stream_keys: - fetched = self.memos_message_queue.get( - stream_key=stream_key, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - if len(messages) > 0: - logger.debug( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) - return messages + return self.memos_message_queue.get_messages(batch_size=batch_size) def clear(self): self.memos_message_queue.clear() diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index d8c8d2cee..2a995b239 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -13,7 +13,7 @@ class TaskStatusTracker: @require_python_package(import_name="redis", install_command="pip install redis") - def __init__(self, redis_client: "redis.Redis"): + def __init__(self, redis_client: "redis.Redis | None"): self.redis = redis_client def _get_key(self, user_id: str) -> str: @@ -41,6 +41,9 @@ def task_submitted( mem_cube_id: Memory cube identifier business_task_id: Optional business-level task ID (one task_id can have multiple item_ids) """ + if not self.redis: + return + key = self._get_key(user_id) payload = { "status": "waiting", @@ -61,6 +64,9 @@ def task_submitted( self.redis.expire(key, timedelta(days=7)) def task_started(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -77,6 +83,9 @@ def task_started(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_completed(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -91,6 +100,9 @@ def task_completed(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_failed(self, task_id: str, user_id: str, error_message: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -108,11 +120,17 @@ def task_failed(self, task_id: str, user_id: str, error_message: str): self.redis.expire(key, timedelta(days=7)) def get_task_status(self, task_id: str, user_id: str) -> dict | None: + if not self.redis: + return None + key = self._get_key(user_id) data = self.redis.hget(key, task_id) return json.loads(data) if data else None def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]: + if not self.redis: + return {} + key = self._get_key(user_id) all_tasks = self.redis.hgetall(key) return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()} @@ -132,6 +150,9 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) -> - If any item is 'failed' → 'failed' Returns None if task_id not found. """ + if not self.redis: + return None + # Get all item_ids for this task_id task_items_key = self._get_task_items_key(user_id, business_task_id) item_ids = self.redis.smembers(task_items_key) @@ -180,6 +201,9 @@ def get_all_tasks_global(self) -> dict[str, dict[str, dict]]: Returns: dict: {user_id: {task_id: task_data, ...}, ...} """ + if not self.redis: + return {} + all_users_tasks = {} cursor: int | str = 0 while True: diff --git a/tests/test_local_queue_full.py b/tests/test_local_queue_full.py new file mode 100644 index 000000000..6c523046a --- /dev/null +++ b/tests/test_local_queue_full.py @@ -0,0 +1,54 @@ +import unittest + +from datetime import datetime, timezone + +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue + + +class TestLocalQueueFull(unittest.TestCase): + def test_full_behavior(self): + # Create a queue with very small maxsize for testing + lq = SchedulerLocalQueue(maxsize=1) + + # Initially empty + self.assertFalse(lq.full()) + + # Add message to stream 1 + msg1 = ScheduleMessageItem( + user_id="u1", + mem_cube_id="c1", + label="l1", + content="m1", + timestamp=datetime.now(timezone.utc), + ) + lq.put(msg1) + + # Now stream 1 is full (maxsize=1). + # Since it's the only stream, and it's full, lq.full() should be True. + self.assertTrue(lq.full()) + + # Add message to stream 2 + msg2 = ScheduleMessageItem( + user_id="u2", + mem_cube_id="c2", + label="l2", + content="m2", + timestamp=datetime.now(timezone.utc), + ) + lq.put(msg2) + + # Now both stream 1 and stream 2 are full. lq.full() should be True. + self.assertTrue(lq.full()) + + # Remove message from stream 1 + stream1_key = lq.get_stream_key("u1", "c1", "l1") + lq.get(stream1_key) + + # Now stream 1 is empty, stream 2 is full. + # "all streams are full" is False. + self.assertFalse(lq.full()) + + +if __name__ == "__main__": + unittest.main() From 3fe9cb09b4f1864db8225be7e64b0959e50c358f Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:25:07 +0800 Subject: [PATCH 10/14] feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. --- .../task_schedule_modules/redis_queue.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 7923b3750..2f4318003 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,7 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -import contextlib import os import re import threading @@ -201,6 +200,20 @@ def _refresh_stream_keys( recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, now_sec=now_sec, ) + + # Ensure consumer groups for newly discovered active streams + with self._stream_keys_lock: + # Identify keys we haven't seen yet + new_streams = [k for k in active_stream_keys if k not in self.seen_streams] + + # Create groups outside the lock to avoid blocking + for key in new_streams: + self._ensure_consumer_group(key) + + if new_streams: + with self._stream_keys_lock: + self.seen_streams.update(new_streams) + deleted_count = self._delete_streams(keys_to_delete) self._update_stream_cache_with_log( stream_key_prefix=stream_key_prefix, @@ -560,10 +573,7 @@ def _read_new_messages_batch( return {} # Pre-ensure consumer groups to avoid NOGROUP during batch reads - for stream_key in stream_keys: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - + # (Optimization: rely on put() and _refresh_stream_keys() to ensure groups) pipe = self._redis_conn.pipeline(transaction=False) for stream_key in stream_keys: pipe.xreadgroup( @@ -1170,10 +1180,14 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: del_pipe.delete(key) del_pipe.execute() deleted_count = len(keys_to_delete) - # Clean up empty-tracking state for deleted keys + # Clean up empty-tracking state and seen_streams for deleted keys with self._empty_stream_seen_lock: for key in keys_to_delete: self._empty_stream_seen_times.pop(key, None) + + with self._stream_keys_lock: + for key in keys_to_delete: + self.seen_streams.discard(key) except Exception: for key in keys_to_delete: try: @@ -1181,6 +1195,8 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: deleted_count += 1 with self._empty_stream_seen_lock: self._empty_stream_seen_times.pop(key, None) + with self._stream_keys_lock: + self.seen_streams.discard(key) except Exception: pass return deleted_count From b35096fa61d3d3aeea3297354b0d10a78916a0f8 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:34:05 +0800 Subject: [PATCH 11/14] fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. --- tests/mem_reader/test_simple_structure.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index f81356886..fd07fbf41 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -1,4 +1,3 @@ -import json import unittest from unittest.mock import MagicMock, patch @@ -8,6 +7,7 @@ from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem @@ -57,7 +57,6 @@ def test_process_chat_data(self): '"summary": "Tom is currently focused on managing a new project with a tight schedule."}' ) self.reader.llm.generate.return_value = mock_response - self.reader.parse_json_result = lambda x: json.loads(x) result = self.reader._process_chat_data(scene_data_info, info) @@ -105,7 +104,7 @@ def test_get_scene_data_info_with_chat(self): def test_parse_json_result_success(self): """Test successful JSON parsing.""" raw_response = '{"summary": "Test summary", "tags": ["test"]}' - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertIsInstance(result, dict) self.assertIn("summary", result) @@ -113,7 +112,7 @@ def test_parse_json_result_success(self): def test_parse_json_result_failure(self): """Test failure in JSON parsing.""" raw_response = "Invalid JSON string" - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertEqual(result, {}) From 8943ba8b437d9b0f2bfe3ec4e93901c36b976314 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:49:32 +0800 Subject: [PATCH 12/14] fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. --- src/memos/mem_reader/simple_struct.py | 3 +- tests/mem_reader/test_simple_structure.py | 92 +++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 866b6d988..18bad7ab7 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -561,6 +561,7 @@ def add_before_search( self, messages: list[dict], memory_list: list[TextualMemoryItem], + info: dict[str, Any], ) -> list[TextualMemoryItem]: # Build input objects with memory text and metadata (timestamps, sources, etc.) template = PROMPT_MAPPING["add_before_search"] @@ -579,7 +580,7 @@ def add_before_search( for idx, mem in enumerate(memory_list): try: related_memories = self.searcher.search( - query=mem.memory, top_k=3, mode="fast", info={"user_id": "", "session_id": ""} + query=mem.memory, top_k=3, mode="fast", info=info ) related_text = "None" if related_memories: diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index fd07fbf41..987ff25ae 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -116,6 +116,98 @@ def test_parse_json_result_failure(self): self.assertEqual(result, {}) + def test_add_before_search(self): + """Test add_before_search method.""" + import json + + from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata + + # Mock searcher + self.reader.searcher = MagicMock() + self.reader.searcher.search.return_value = [ + TextualMemoryItem( + memory="Related memory 1", + metadata=TreeNodeTextualMemoryMetadata( + user_id="user1", + session_id="session1", + memory_type="LongTermMemory", + status="activated", + tags=[], + key="key1", + embedding=[0.1], + usage=[], + sources=[], + background="", + confidence=0.99, + type="fact", + info={}, + ), + ) + ] + + # Mock LLM response for filter + # The method expects a JSON response with keep/drop decisions + mock_response = json.dumps( + { + "0": {"keep": True, "reason": "Relevant"}, + "1": {"keep": False, "reason": "Duplicate"}, + } + ) + self.reader.llm.generate.return_value = mock_response + + messages = [{"role": "user", "content": "test message"}] + memory_list = [ + TextualMemoryItem( + memory="Mem 1", + metadata=TreeNodeTextualMemoryMetadata( + user_id="user1", + session_id="session1", + memory_type="LongTermMemory", + status="activated", + tags=[], + key="key1", + embedding=[0.1], + usage=[], + sources=[], + background="", + confidence=0.99, + type="fact", + info={}, + ), + ), + TextualMemoryItem( + memory="Mem 2", + metadata=TreeNodeTextualMemoryMetadata( + user_id="user1", + session_id="session1", + memory_type="LongTermMemory", + status="activated", + tags=[], + key="key2", + embedding=[0.1], + usage=[], + sources=[], + background="", + confidence=0.99, + type="fact", + info={}, + ), + ), + ] + info = {"user_id": "user1", "session_id": "session1"} + + # Call the method + result = self.reader.add_before_search(messages, memory_list, info) + + # Assertions + # Check if searcher.search was called with correct info + self.reader.searcher.search.assert_called_with( + query="Mem 2", top_k=3, mode="fast", info=info + ) + # Check result + self.assertEqual(len(result), 1) + self.assertEqual(result[0].memory, "Mem 1") + if __name__ == "__main__": unittest.main() From 78a43275f5d9550cda6514ea51b05ee64417d979 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 21:25:35 +0800 Subject: [PATCH 13/14] refactor add_before_search from mem_reader to SingleCubeView --- src/memos/mem_reader/simple_struct.py | 3 +- src/memos/mem_reader/utils.py | 53 ------------- src/memos/multi_mem_cube/single_cube.py | 101 ++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 54 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 18bad7ab7..fdd109079 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -561,6 +561,7 @@ def add_before_search( self, messages: list[dict], memory_list: list[TextualMemoryItem], + user_name: str, info: dict[str, Any], ) -> list[TextualMemoryItem]: # Build input objects with memory text and metadata (timestamps, sources, etc.) @@ -580,7 +581,7 @@ def add_before_search( for idx, mem in enumerate(memory_list): try: related_memories = self.searcher.search( - query=mem.memory, top_k=3, mode="fast", info=info + query=mem.memory, top_k=3, mode="fast", user_nam=user_name, info=info ) related_text = "None" if related_memories: diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py index 843345ec4..4e5a78af2 100644 --- a/src/memos/mem_reader/utils.py +++ b/src/memos/mem_reader/utils.py @@ -1,16 +1,7 @@ import json -import os import re -from typing import Any - from memos import log -from memos.api.config import APIConfig -from memos.configs.graph_db import GraphDBConfigFactory -from memos.configs.reranker import RerankerConfigFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher -from memos.reranker.factory import RerankerFactory logger = log.get_logger(__name__) @@ -164,47 +155,3 @@ def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]: "reason": reason, } return (len(result) > 0), result - - -def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: - graph_db_backend_map = { - "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), - "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), - "polardb": APIConfig.get_polardb_config(user_id=user_id), - } - - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() - return GraphDBConfigFactory.model_validate( - { - "backend": graph_db_backend, - "config": graph_db_backend_map[graph_db_backend], - } - ) - - -def build_reranker_config() -> dict[str, Any]: - return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) - - -def init_searcher(llm, embedder) -> Searcher: - """Initialize a Searcher instance for SimpleStructMemReader.""" - - # Build configs - graph_db_config = build_graph_db_config() - reranker_config = build_reranker_config() - - # Create instances - graph_db = GraphStoreFactory.from_config(graph_db_config) - reranker = RerankerFactory.from_config(reranker_config) - - # Create Searcher - searcher = Searcher( - dispatcher_llm=llm, - graph_store=graph_db, - embedder=embedder, - reranker=reranker, - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - ) - - return searcher diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 57f2cdba1..ab3d0ce03 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -15,6 +15,7 @@ ) from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_reader.utils import parse_keep_filter_response from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -23,6 +24,7 @@ PREF_ADD_TASK_LABEL, ) from memos.multi_mem_cube.views import MemCubeView +from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, FineStrategy, @@ -41,6 +43,7 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler + from memos.memories.textual.item import TextualMemoryItem @dataclass @@ -631,6 +634,104 @@ def _process_pref_mem( for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] + def add_before_search( + self, + messages: list[dict], + memory_list: list[TextualMemoryItem], + user_name: str, + info: dict[str, Any], + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + template = PROMPT_MAPPING["add_before_search"] + + if not self.searcher: + self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.") + return memory_list + + # 1. Gather candidates and search for related memories + candidates_data = [] + for idx, mem in enumerate(memory_list): + try: + related_memories = self.searcher.search( + query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info + ) + related_text = "None" + if related_memories: + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) + + candidates_data.append( + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} + ) + except Exception as e: + self.logger.error( + f"[add_before_search] Search error for memory '{mem.memory}': {e}" + ) + # If search fails, we can either skip this check or treat related as empty + candidates_data.append( + { + "idx": idx, + "new_memory": mem.memory, + "related_memories": "None (Search Failed)", + } + ) + + if not candidates_data: + return memory_list + + # 2. Build Prompt + messages_inline = "\n".join( + [ + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" + for message in messages + ] + ) + + candidates_inline_dict = { + str(item["idx"]): { + "new_memory": item["new_memory"], + "related_memories": item["related_memories"], + } + for item in candidates_data + } + + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) + + prompt = template.format( + messages_inline=messages_inline, candidates_inline=candidates_inline + ) + + # 3. Call LLM + try: + raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) + success, parsed_result = parse_keep_filter_response(raw) + + if not success: + self.logger.warning( + "[add_before_search] Failed to parse LLM response, keeping all." + ) + return memory_list + + # 4. Filter + filtered_list = [] + for idx, mem in enumerate(memory_list): + res = parsed_result.get(idx) + if not res: + filtered_list.append(mem) + continue + + if res.get("keep", True): + filtered_list.append(mem) + else: + self.logger.info( + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" + ) + + return filtered_list + + except Exception as e: + self.logger.error(f"[add_before_search] LLM execution error: {e}") + return memory_list + def _process_text_mem( self, add_req: APIADDRequest, From a5fc4c09c94c4ded8d153cf34f5c09dc19cc979a Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 21:36:07 +0800 Subject: [PATCH 14/14] address bugs --- src/memos/mem_reader/simple_struct.py | 99 ----------------------- tests/mem_reader/test_simple_structure.py | 92 --------------------- 2 files changed, 191 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index fdd109079..70472958e 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -557,105 +557,6 @@ def filter_hallucination_in_memories( return memory_list - def add_before_search( - self, - messages: list[dict], - memory_list: list[TextualMemoryItem], - user_name: str, - info: dict[str, Any], - ) -> list[TextualMemoryItem]: - # Build input objects with memory text and metadata (timestamps, sources, etc.) - template = PROMPT_MAPPING["add_before_search"] - - if not self.searcher: - try: - from memos.mem_reader.utils import init_searcher - - self.searcher = init_searcher(self.llm, self.embedder) - except Exception as e: - logger.error(f"[add_before_search] Failed to init searcher: {e}") - return memory_list - - # 1. Gather candidates and search for related memories - candidates_data = [] - for idx, mem in enumerate(memory_list): - try: - related_memories = self.searcher.search( - query=mem.memory, top_k=3, mode="fast", user_nam=user_name, info=info - ) - related_text = "None" - if related_memories: - related_text = "\n".join([f"- {r.memory}" for r in related_memories]) - - candidates_data.append( - {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} - ) - except Exception as e: - logger.error(f"[add_before_search] Search error for memory '{mem.memory}': {e}") - # If search fails, we can either skip this check or treat related as empty - candidates_data.append( - { - "idx": idx, - "new_memory": mem.memory, - "related_memories": "None (Search Failed)", - } - ) - - if not candidates_data: - return memory_list - - # 2. Build Prompt - messages_inline = "\n".join( - [ - f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" - for message in messages - ] - ) - - candidates_inline_dict = { - str(item["idx"]): { - "new_memory": item["new_memory"], - "related_memories": item["related_memories"], - } - for item in candidates_data - } - - candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) - - prompt = template.format( - messages_inline=messages_inline, candidates_inline=candidates_inline - ) - - # 3. Call LLM - try: - raw = self.llm.generate([{"role": "user", "content": prompt}]) - success, parsed_result = parse_keep_filter_response(raw) - - if not success: - logger.warning("[add_before_search] Failed to parse LLM response, keeping all.") - return memory_list - - # 4. Filter - filtered_list = [] - for idx, mem in enumerate(memory_list): - res = parsed_result.get(idx) - if not res: - filtered_list.append(mem) - continue - - if res.get("keep", True): - filtered_list.append(mem) - else: - logger.info( - f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" - ) - - return filtered_list - - except Exception as e: - logger.error(f"[add_before_search] LLM execution error: {e}") - return memory_list - def _read_memory( self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index 987ff25ae..fd07fbf41 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -116,98 +116,6 @@ def test_parse_json_result_failure(self): self.assertEqual(result, {}) - def test_add_before_search(self): - """Test add_before_search method.""" - import json - - from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata - - # Mock searcher - self.reader.searcher = MagicMock() - self.reader.searcher.search.return_value = [ - TextualMemoryItem( - memory="Related memory 1", - metadata=TreeNodeTextualMemoryMetadata( - user_id="user1", - session_id="session1", - memory_type="LongTermMemory", - status="activated", - tags=[], - key="key1", - embedding=[0.1], - usage=[], - sources=[], - background="", - confidence=0.99, - type="fact", - info={}, - ), - ) - ] - - # Mock LLM response for filter - # The method expects a JSON response with keep/drop decisions - mock_response = json.dumps( - { - "0": {"keep": True, "reason": "Relevant"}, - "1": {"keep": False, "reason": "Duplicate"}, - } - ) - self.reader.llm.generate.return_value = mock_response - - messages = [{"role": "user", "content": "test message"}] - memory_list = [ - TextualMemoryItem( - memory="Mem 1", - metadata=TreeNodeTextualMemoryMetadata( - user_id="user1", - session_id="session1", - memory_type="LongTermMemory", - status="activated", - tags=[], - key="key1", - embedding=[0.1], - usage=[], - sources=[], - background="", - confidence=0.99, - type="fact", - info={}, - ), - ), - TextualMemoryItem( - memory="Mem 2", - metadata=TreeNodeTextualMemoryMetadata( - user_id="user1", - session_id="session1", - memory_type="LongTermMemory", - status="activated", - tags=[], - key="key2", - embedding=[0.1], - usage=[], - sources=[], - background="", - confidence=0.99, - type="fact", - info={}, - ), - ), - ] - info = {"user_id": "user1", "session_id": "session1"} - - # Call the method - result = self.reader.add_before_search(messages, memory_list, info) - - # Assertions - # Check if searcher.search was called with correct info - self.reader.searcher.search.assert_called_with( - query="Mem 2", top_k=3, mode="fast", info=info - ) - # Check result - self.assertEqual(len(result), 1) - self.assertEqual(result[0].memory, "Mem 1") - if __name__ == "__main__": unittest.main()