diff --git a/docs/README.md b/docs/README.md index bf5fea70d..8be17ffb7 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,3 +1,3 @@ All documentation has been moved to a separate repository: https://github.com/MemTensor/MemOS-Docs. Please edit documentation there. -所有文档已迁移至独立仓库:https://github.com/MemTensor/MemOS-Docs。请在该仓库中编辑文档。 +所有文档已迁移至独立仓库 https://github.com/MemTensor/MemOS-Docs 。请在该仓库中编辑文档。 diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/scheduler_for_async_tasks.py similarity index 98% rename from examples/mem_scheduler/task_stop_rerun.py rename to examples/mem_scheduler/scheduler_for_async_tasks.py index b5e62ff8f..a767b57c4 100644 --- a/examples/mem_scheduler/task_stop_rerun.py +++ b/examples/mem_scheduler/scheduler_for_async_tasks.py @@ -25,7 +25,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): task_id = str(msg.item_id) file_path = tmp_dir / f"{task_id}.txt" try: - sleep(1) + sleep(5) file_path.write_text(f"Task {task_id} processed.\n") print(f"writing {file_path} done") except Exception as e: @@ -58,7 +58,7 @@ def submit_tasks(): mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) # 10s to restart -mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 10_000 +mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 5_000 tmp_dir = Path("./tmp") tmp_dir.mkdir(exist_ok=True) @@ -88,6 +88,6 @@ def submit_tasks(): print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})") # 7. Stop the scheduler +sleep(20) print("Stopping the scheduler...") -sleep(5) mem_scheduler.stop() diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 48a16a6e2..7298658ff 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -7,16 +7,19 @@ import re import time -from typing import Any +from typing import TYPE_CHECKING, Any import requests from dotenv import load_dotenv -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig from memos.context.context import ContextThread -from memos.mem_cube.general import GeneralMemCube + + +if TYPE_CHECKING: + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube # Load environment variables @@ -805,8 +808,12 @@ def get_start_default_config() -> dict[str, Any]: return config @staticmethod - def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]: + def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "GeneralMemCube"]: """Create configuration for a specific user.""" + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube + openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() @@ -933,12 +940,14 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General return default_config, default_mem_cube @staticmethod - def get_default_cube_config() -> GeneralMemCubeConfig | None: + def get_default_cube_config() -> "GeneralMemCubeConfig | None": """Get default cube configuration for product initialization. Returns: GeneralMemCubeConfig | None: Default cube configuration if enabled, None otherwise. """ + from memos.configs.mem_cube import GeneralMemCubeConfig + if not APIConfig.is_default_cube_config_enabled(): return None diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 563b8723e..ea488329d 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -59,8 +59,8 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) if reasoning_content: - return reasoning_content + response_content - return response_content + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM Stream", @@ -151,7 +151,7 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) else: - return response_content + return response_content or "" def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from Azure OpenAI LLM with optional reasoning support.""" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b870bf70a..70472958e 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -2,7 +2,6 @@ import copy import json import os -import re import traceback from abc import ABC @@ -18,6 +17,13 @@ from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang +from memos.mem_reader.utils import ( + count_tokens_text, + derive_key, + parse_json_result, + parse_keep_filter_response, + parse_rewritten_response, +) from memos.memories.textual.item import ( SourceMessage, TextualMemoryItem, @@ -89,27 +95,6 @@ def from_config(_config): } -try: - import tiktoken - - try: - _ENC = tiktoken.encoding_for_model("gpt-4o-mini") - except Exception: - _ENC = tiktoken.get_encoding("cl100k_base") - - def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "", disallowed_special=())) -except Exception: - # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars - def _count_tokens_text(s: str) -> int: - if not s: - return 0 - zh_chars = re.findall(r"[\u4e00-\u9fff]", s) - zh = len(zh_chars) - rest = len(s) - zh - return zh + max(1, rest // 4) - - def _build_node(idx, message, info, source_info, llm, parse_json_result, embedder): # generate try: @@ -172,14 +157,6 @@ def _build_node(idx, message, info, source_info, llm, parse_json_result, embedde return None -def _derive_key(text: str, max_len: int = 80) -> str: - """default key when without LLM: first max_len words""" - if not text: - return "" - sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] - return (sent[:max_len]).strip() - - class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" @@ -197,7 +174,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.memory_max_length = 8000 # Use token-based windowing; default to ~5000 tokens if not configured self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) - self._count_tokens = _count_tokens_text + self._count_tokens = count_tokens_text + self.searcher = None def _make_memory_item( self, @@ -224,7 +202,7 @@ def _make_memory_item( memory_type=memory_type, status="activated", tags=tags or [], - key=key if key is not None else _derive_key(value), + key=key if key is not None else derive_key(value), embedding=self.embedder.embed([value])[0], usage=[], sources=sources or [], @@ -254,7 +232,7 @@ def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { @@ -456,47 +434,73 @@ def get_memory( standard_scene_data = coerce_scene_data(scene_data, type) return self._read_memory(standard_scene_data, type, info, mode) - @staticmethod - def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: - """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } - Returns (success, parsed_dict) with int keys. - """ + def rewrite_memories( + self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + if user_only: + template = PROMPT_MAPPING["rewrite_user_only"] + filtered_messages = [m for m in messages if m.get("role") != "assistant"] + if len(filtered_messages) < 1: + return memory_list + else: + template = PROMPT_MAPPING["rewrite"] + filtered_messages = messages + if len(filtered_messages) < 2: + return memory_list + + prompt_args = { + "messages_inline": "\n".join( + [f"- [{message['role']}]: {message['content']}" for message in filtered_messages] + ), + "memories_inline": json.dumps( + {idx: mem.memory for idx, mem in enumerate(memory_list)}, + ensure_ascii=False, + indent=2, + ), + } + prompt = template.format(**prompt_args) + + # Optionally run filter and parse the output try: - data = json.loads(text) - except Exception: - return False, {} + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed = parse_rewritten_response(raw) + logger.info( + f"[rewrite_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" + ) + if success: + logger.info(f"Rewrite filter result: {parsed}") - if not isinstance(data, dict): - return False, {} + new_memory_list = [] + for mem_idx, content in parsed.items(): + if mem_idx < 0 or mem_idx >= len(memory_list): + logger.warning( + f"[rewrite_memories] Invalid memory index {mem_idx} for memory_list {len(memory_list)}, skipping." + ) + continue - result: dict[int, dict] = {} - for k, v in data.items(): - try: - idx = int(k) - except Exception: - # allow integer keys as-is - if isinstance(k, int): - idx = k - else: - continue - if not isinstance(v, dict): - continue - need_rewrite = v.get("need_rewrite") - rewritten = v.get("rewritten", "") - reason = v.get("reason", "") - if ( - isinstance(need_rewrite, bool) - and isinstance(rewritten, str) - and isinstance(reason, str) - ): - result[idx] = { - "need_rewrite": need_rewrite, - "rewritten": rewritten, - "reason": reason, - } + need_rewrite = content.get("need_rewrite", False) + rewritten_text = content.get("rewritten", "") + reason = content.get("reason", "") + original_text = memory_list[mem_idx].memory - return (len(result) > 0), result + # Replace memory text with rewritten content when rewrite is needed + if need_rewrite and isinstance(rewritten_text, str): + logger.info( + f"[rewrite_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + ) + if len(rewritten_text.strip()) != 0: + memory_list[mem_idx].memory = rewritten_text + new_memory_list.append(memory_list[mem_idx]) + else: + new_memory_list.append(memory_list[mem_idx]) + return new_memory_list + else: + logger.warning("Rewrite filter parsing failed or returned empty result.") + except Exception as e: + logger.error(f"Rewrite filter execution error: {e}", stack_info=True) + + return memory_list def filter_hallucination_in_memories( self, messages: list[dict], memory_list: list[TextualMemoryItem] @@ -520,32 +524,32 @@ def filter_hallucination_in_memories( # Optionally run filter and parse the output try: raw = self.llm.generate([{"role": "user", "content": prompt}]) - success, parsed = self._parse_hallucination_filter_response(raw) + success, parsed = parse_keep_filter_response(raw) logger.info( f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" ) if success: logger.info(f"Hallucination filter result: {parsed}") - assert len(parsed) == len(memory_list) - for mem_idx, content in parsed.items(): - need_rewrite = content.get("need_rewrite", False) - rewritten_text = content.get("rewritten", "") - reason = content.get("reason", "") - # Replace memory text with rewritten content when rewrite is needed - if ( - need_rewrite - and isinstance(rewritten_text, str) - and len(rewritten_text.strip()) > 0 - ): - original_text = memory_list[mem_idx].memory + filtered_list = [] + for mem_idx, mem in enumerate(memory_list): + content = parsed.get(mem_idx) + if not content: + logger.warning(f"No verdict for memory {mem_idx}, keeping it.") + filtered_list.append(mem) + continue + + keep = content.get("keep", True) + reason = content.get("reason", "") + if keep: + filtered_list.append(mem) + else: logger.info( - f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + f"[filter_hallucination_in_memories] Dropping memory index={mem_idx}, reason='{reason}', memory='{mem.memory}'" ) - memory_list[mem_idx].memory = rewritten_text - return memory_list + return filtered_list else: logger.warning("Hallucination filter parsing failed or returned empty result.") except Exception as e: @@ -606,29 +610,27 @@ def _read_memory( for group_id in range(len(memory_list)): try: - revised_memory_list = self.filter_hallucination_in_memories( + original_memory_group = copy.deepcopy(memory_list[group_id]) + serialized_origin_memories = json.dumps( + [one.memory for one in original_memory_group], indent=2 + ) + revised_memory_list = self.rewrite_memories( messages=combined_messages, - memory_list=memory_list[group_id], + memory_list=original_memory_group, + user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() + == "true", ) - if len(revised_memory_list) != len(memory_list[group_id]): - original_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in memory_list[group_id] - ] - filtered_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in revised_memory_list - ] - logger.error( - f"Length mismatch after hallucination filtering for group_id={group_id}: " - f"original={len(memory_list[group_id])}, filtered={len(revised_memory_list)}" - f"\noriginal_memory_list(serialized): {original_serialized}" - f"\nfiltered_memory_list(serialized): {filtered_serialized}" - f"\nmessages: {combined_messages}" - f"\nSkipping update and keeping original memory." + serialized_revised_memories = json.dumps( + [one.memory for one in revised_memory_list], indent=2 + ) + if serialized_origin_memories != serialized_revised_memories: + memory_list[group_id] = revised_memory_list + logger.info( + f"[SIMPLE_STRUCT_ADD_FILTER] Modified the list for group_id={group_id}: " + f"\noriginal={serialized_origin_memories}," + f"\nrevised={serialized_revised_memories}" ) - continue - memory_list[group_id] = revised_memory_list + except Exception as e: group_serialized = [ one.memory if hasattr(one, "memory") else str(one) @@ -847,7 +849,7 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): info, source_info_list, self.llm, - self.parse_json_result, + parse_json_result, self.embedder, ): idx for idx, msg in enumerate(messages) @@ -870,44 +872,3 @@ def _process_transfer_doc_data( self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None ): raise NotImplementedError - - def parse_json_result(self, response_text: str) -> dict: - s = (response_text or "").strip() - - m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) - s = (m.group(1) if m else s.replace("```", "")).strip() - - i = s.find("{") - if i == -1: - return {} - s = s[i:].strip() - - try: - return json.loads(s) - except json.JSONDecodeError: - pass - - j = max(s.rfind("}"), s.rfind("]")) - if j != -1: - try: - return json.loads(s[: j + 1]) - except json.JSONDecodeError: - pass - - def _cheap_close(t: str) -> str: - t += "}" * max(0, t.count("{") - t.count("}")) - t += "]" * max(0, t.count("[") - t.count("]")) - return t - - t = _cheap_close(s) - try: - return json.loads(t) - except json.JSONDecodeError as e: - if "Invalid \\escape" in str(e): - s = s.replace("\\", "\\\\") - return json.loads(s) - logger.error( - f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ - json: {s}" - ) - return {} diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py new file mode 100644 index 000000000..4e5a78af2 --- /dev/null +++ b/src/memos/mem_reader/utils.py @@ -0,0 +1,157 @@ +import json +import re + +from memos import log + + +logger = log.get_logger(__name__) + +try: + import tiktoken + + try: + _ENC = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + _ENC = tiktoken.get_encoding("cl100k_base") + + def count_tokens_text(s: str) -> int: + return len(_ENC.encode(s or "", disallowed_special=())) +except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + def count_tokens_text(s: str) -> int: + if not s: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", s) + zh = len(zh_chars) + rest = len(s) - zh + return zh + max(1, rest // 4) + + +def derive_key(text: str, max_len: int = 80) -> str: + """default key when without LLM: first max_len words""" + if not text: + return "" + sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] + return (sent[:max_len]).strip() + + +def parse_json_result(response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.error( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) + return {} + + +def parse_rewritten_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from hallucination filter response. + Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + # allow integer keys as-is + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + need_rewrite = v.get("need_rewrite") + rewritten = v.get("rewritten", "") + reason = v.get("reason", "") + if ( + isinstance(need_rewrite, bool) + and isinstance(rewritten, str) + and isinstance(reason, str) + ): + result[idx] = { + "need_rewrite": need_rewrite, + "rewritten": rewritten, + "reason": reason, + } + + return (len(result) > 0), result + + +def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from keep filter response. + Expected shape: { "0": {"keep": bool, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + keep = v.get("keep") + reason = v.get("reason", "") + if isinstance(keep, bool): + result[idx] = { + "keep": keep, + "reason": reason, + } + return (len(result) > 0), result diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 1e0ecaadb..728203f5b 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1009,14 +1009,24 @@ def _monitor_loop(self): q_sizes = self.memos_message_queue.qsize() for stream_key, queue_length in q_sizes.items(): - # Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}" + # Skip aggregate keys like 'total_size' + if stream_key == "total_size": + continue + + # Key format: ...:{user_id}:{mem_cube_id}:{task_label} + # We want to extract user_id, which is the 3rd component from the end. parts = stream_key.split(":") if len(parts) >= 3: - user_id = parts[2] - self.metrics.update_queue_length(queue_length, user_id) - elif not self.use_redis_queue: # local queue - user_id = stream_key + user_id = parts[-3] self.metrics.update_queue_length(queue_length, user_id) + else: + # Fallback for unexpected key formats (e.g. legacy or testing) + # Try to use the key itself if it looks like a user_id (no colons) + # or just log a warning? + # For now, let's assume if it's not total_size and short, it might be a direct user_id key + # (though that shouldn't happen with current queue implementations) + if ":" not in stream_key: + self.metrics.update_queue_length(queue_length, stream_key) except Exception as e: logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f4ad9fe48..06910ba17 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,3 +1,5 @@ +import os + from pathlib import Path @@ -21,7 +23,7 @@ DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 DEFAULT_TOP_K = 5 DEFAULT_CONTEXT_WINDOW_SIZE = 5 -DEFAULT_USE_REDIS_QUEUE = True +DEFAULT_USE_REDIS_QUEUE = os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true" DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20 DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 69cfc0af9..eae70f8ef 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -62,7 +62,7 @@ def put( Exception: Any underlying error during queue.put() operation. """ stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) message.stream_key = stream_key @@ -108,35 +108,94 @@ def get( ) return res - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait( + self, stream_key: str, batch_size: int | None = None + ) -> list[ScheduleMessageItem]: """ - Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). Returns immediately with available messages or an empty list if queue is empty. Args: + stream_key (str): The stream/queue identifier. batch_size (int | None): Number of messages to retrieve in a batch. If None, retrieves one message. Returns: List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. """ - logger.debug(f"get_nowait() called with batch_size: {batch_size}") - return self.get(block=False, batch_size=batch_size) + logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}") + return self.get(stream_key=stream_key, block=False, batch_size=batch_size) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + """ + Get messages from all streams in round-robin or sequential fashion. + Equivalent to SchedulerRedisQueue.get_messages. + """ + messages = [] + # Snapshot keys to avoid runtime modification issues + stream_keys = list(self.queue_streams.keys()) + + # Simple strategy: try to get up to batch_size messages across all streams + # We can just iterate and collect. + + # Calculate how many to get per stream to be fair? + # Or just greedy? Redis implementation uses a complex logic. + # For local, let's keep it simple: just iterate and take what's available (non-blocking) + + for stream_key in stream_keys: + if len(messages) >= batch_size: + break + + needed = batch_size - len(messages) + # Use get_nowait to avoid blocking + fetched = self.get_nowait(stream_key=stream_key, batch_size=needed) + messages.extend(fetched) + + return messages def qsize(self) -> dict: """ Return the current size of all internal queues as a dictionary. Each key is the stream name, and each value is the number of messages in that queue. + Also includes 'total_size'. Returns: Dict[str, int]: Mapping from stream name to current queue size. """ sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + total_size = sum(sizes.values()) + sizes["total_size"] = total_size logger.debug(f"Current queue sizes: {sizes}") return sizes + def size(self) -> int: + """ + Get the current size of the queue (total message count). + Compatible with SchedulerRedisQueue. + """ + return self.unfinished_tasks + + def empty(self) -> bool: + """ + Check if the queue is empty. + Compatible with SchedulerRedisQueue. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if the queue is full. + Compatible with SchedulerRedisQueue. + """ + # Local queue limits are per-stream (max_internal_message_queue_size). + # It is considered full only if all streams are full. + if not self.queue_streams: + return False + + return all(queue.full() for queue in self.queue_streams.values()) + def clear(self) -> None: for queue in self.queue_streams.values(): queue.clear() @@ -151,6 +210,9 @@ def unfinished_tasks(self) -> int: Returns: int: Sum of all message counts in all internal queues. """ - total = sum(self.qsize().values()) + # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values + # But qsize() implementation above sums values from queue_streams, then adds total_size. + # So sum(self.queue_streams.values().qsize()) is safer. + total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1c57f18f0..2f4318003 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,7 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -import contextlib import os import re import threading @@ -201,6 +200,20 @@ def _refresh_stream_keys( recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, now_sec=now_sec, ) + + # Ensure consumer groups for newly discovered active streams + with self._stream_keys_lock: + # Identify keys we haven't seen yet + new_streams = [k for k in active_stream_keys if k not in self.seen_streams] + + # Create groups outside the lock to avoid blocking + for key in new_streams: + self._ensure_consumer_group(key) + + if new_streams: + with self._stream_keys_lock: + self.seen_streams.update(new_streams) + deleted_count = self._delete_streams(keys_to_delete) self._update_stream_cache_with_log( stream_key_prefix=stream_key_prefix, @@ -560,10 +573,7 @@ def _read_new_messages_batch( return {} # Pre-ensure consumer groups to avoid NOGROUP during batch reads - for stream_key in stream_keys: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - + # (Optimization: rely on put() and _refresh_stream_keys() to ensure groups) pipe = self._redis_conn.pipeline(transaction=False) for stream_key in stream_keys: pipe.xreadgroup( @@ -679,11 +689,6 @@ def _batch_claim_pending_messages( if not self._redis_conn or not claims_spec: return [] - # Ensure consumer groups exist to avoid NOGROUP errors during batch claim - for stream_key, _need_count, _label in claims_spec: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - pipe = self._redis_conn.pipeline(transaction=False) for stream_key, need_count, label in claims_spec: pipe.xautoclaim( @@ -696,26 +701,42 @@ def _batch_claim_pending_messages( justid=False, ) - results = [] try: - results = pipe.execute() - except Exception: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue + # Execute with raise_on_error=False so we get exceptions in the results list + # instead of aborting the whole batch. + results = pipe.execute(raise_on_error=False) + except Exception as e: + logger.error(f"Pipeline execution critical failure: {e}") + results = [e] * len(claims_spec) + + # Handle individual failures (e.g. NOGROUP) by retrying just that stream + final_results = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + err_msg = str(res).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + stream_key, need_count, label = claims_spec[i] + try: + self._ensure_consumer_group(stream_key=stream_key) + retry_res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + final_results.append(retry_res) + except Exception as retry_err: + logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}") + final_results.append(None) + else: + final_results.append(None) + else: + final_results.append(res) + + results = final_results claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( @@ -1159,10 +1180,14 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: del_pipe.delete(key) del_pipe.execute() deleted_count = len(keys_to_delete) - # Clean up empty-tracking state for deleted keys + # Clean up empty-tracking state and seen_streams for deleted keys with self._empty_stream_seen_lock: for key in keys_to_delete: self._empty_stream_seen_times.pop(key, None) + + with self._stream_keys_lock: + for key in keys_to_delete: + self.seen_streams.discard(key) except Exception: for key in keys_to_delete: try: @@ -1170,6 +1195,8 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: deleted_count += 1 with self._empty_stream_seen_lock: self._empty_stream_seen_times.pop(key, None) + with self._stream_keys_lock: + self.seen_streams.discard(key) except Exception: pass return deleted_count @@ -1189,9 +1216,7 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) - logger.info( - f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', " - f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, " - f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, " - f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}" - ) + logger.info( + f"Refreshed stream keys cache: {cache_count} active keys, " + f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." + ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index c20243242..b49db2b36 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -153,28 +153,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_messages(batch_size=batch_size) - stream_keys = self.get_stream_keys() - - if len(stream_keys) == 0: - return [] - - messages: list[ScheduleMessageItem] = [] - - for stream_key in stream_keys: - fetched = self.memos_message_queue.get( - stream_key=stream_key, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - if len(messages) > 0: - logger.debug( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) - return messages + return self.memos_message_queue.get_messages(batch_size=batch_size) def clear(self): self.memos_message_queue.clear() diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index d8c8d2cee..2a995b239 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -13,7 +13,7 @@ class TaskStatusTracker: @require_python_package(import_name="redis", install_command="pip install redis") - def __init__(self, redis_client: "redis.Redis"): + def __init__(self, redis_client: "redis.Redis | None"): self.redis = redis_client def _get_key(self, user_id: str) -> str: @@ -41,6 +41,9 @@ def task_submitted( mem_cube_id: Memory cube identifier business_task_id: Optional business-level task ID (one task_id can have multiple item_ids) """ + if not self.redis: + return + key = self._get_key(user_id) payload = { "status": "waiting", @@ -61,6 +64,9 @@ def task_submitted( self.redis.expire(key, timedelta(days=7)) def task_started(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -77,6 +83,9 @@ def task_started(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_completed(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -91,6 +100,9 @@ def task_completed(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_failed(self, task_id: str, user_id: str, error_message: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -108,11 +120,17 @@ def task_failed(self, task_id: str, user_id: str, error_message: str): self.redis.expire(key, timedelta(days=7)) def get_task_status(self, task_id: str, user_id: str) -> dict | None: + if not self.redis: + return None + key = self._get_key(user_id) data = self.redis.hget(key, task_id) return json.loads(data) if data else None def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]: + if not self.redis: + return {} + key = self._get_key(user_id) all_tasks = self.redis.hgetall(key) return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()} @@ -132,6 +150,9 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) -> - If any item is 'failed' → 'failed' Returns None if task_id not found. """ + if not self.redis: + return None + # Get all item_ids for this task_id task_items_key = self._get_task_items_key(user_id, business_task_id) item_ids = self.redis.smembers(task_items_key) @@ -180,6 +201,9 @@ def get_all_tasks_global(self) -> dict[str, dict[str, dict]]: Returns: dict: {user_id: {task_id: task_data, ...}, ...} """ + if not self.redis: + return {} + all_users_tasks = {} cursor: int | str = 0 while True: diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 144bfad7f..3404c6d4c 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -69,6 +69,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: @@ -92,6 +94,8 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 57f2cdba1..ab3d0ce03 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -15,6 +15,7 @@ ) from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_reader.utils import parse_keep_filter_response from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -23,6 +24,7 @@ PREF_ADD_TASK_LABEL, ) from memos.multi_mem_cube.views import MemCubeView +from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, FineStrategy, @@ -41,6 +43,7 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler + from memos.memories.textual.item import TextualMemoryItem @dataclass @@ -631,6 +634,104 @@ def _process_pref_mem( for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] + def add_before_search( + self, + messages: list[dict], + memory_list: list[TextualMemoryItem], + user_name: str, + info: dict[str, Any], + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + template = PROMPT_MAPPING["add_before_search"] + + if not self.searcher: + self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.") + return memory_list + + # 1. Gather candidates and search for related memories + candidates_data = [] + for idx, mem in enumerate(memory_list): + try: + related_memories = self.searcher.search( + query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info + ) + related_text = "None" + if related_memories: + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) + + candidates_data.append( + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} + ) + except Exception as e: + self.logger.error( + f"[add_before_search] Search error for memory '{mem.memory}': {e}" + ) + # If search fails, we can either skip this check or treat related as empty + candidates_data.append( + { + "idx": idx, + "new_memory": mem.memory, + "related_memories": "None (Search Failed)", + } + ) + + if not candidates_data: + return memory_list + + # 2. Build Prompt + messages_inline = "\n".join( + [ + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" + for message in messages + ] + ) + + candidates_inline_dict = { + str(item["idx"]): { + "new_memory": item["new_memory"], + "related_memories": item["related_memories"], + } + for item in candidates_data + } + + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) + + prompt = template.format( + messages_inline=messages_inline, candidates_inline=candidates_inline + ) + + # 3. Call LLM + try: + raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) + success, parsed_result = parse_keep_filter_response(raw) + + if not success: + self.logger.warning( + "[add_before_search] Failed to parse LLM response, keeping all." + ) + return memory_list + + # 4. Filter + filtered_list = [] + for idx, mem in enumerate(memory_list): + res = parsed_result.get(idx) + if not res: + filtered_list.append(mem) + continue + + if res.get("keep", True): + filtered_list.append(mem) + else: + self.logger.info( + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" + ) + + return filtered_list + + except Exception as e: + self.logger.error(f"[add_before_search] LLM execution error: {e}") + return memory_list + def _process_text_mem( self, add_req: APIADDRequest, diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index fef3ee6c0..40971c77e 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,23 +622,56 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. +Note: The provided messages contain only user messages. The assistant's responses are intentionally omitted, not because the assistant didn't answer, but to focus strictly on validating memories against user input. + Rules: 1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. -2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations. -3. **Ambiguity Elimination**: - - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them. - - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. 4. **Hallucination Removal**: - - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten. - - Do **not** rephrase inferences as facts. Instead, either: - - Remove the unsupported part and retain only the grounded core, or - - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit. +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. 5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. Inputs: messages: @@ -651,16 +684,115 @@ - Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. - Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} - The "reason" must be brief and precise, e.g.: - - "contains unsupported inference" - - "vague pronoun with no referent in messages" - - "relative time resolved to 2025-12-16" + - "contains unsupported inference ...." - "fully grounded and concise" Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +You are a strict memory validator. +Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages. + +Rules: +1. **User-Only Origin**: Verify facts against USER messages ONLY. If the Assistant repeats a User fact, it is VALID. If the Assistant introduces a new detail (e.g., 'philanthropy') that the User did not explicitly confirm, it is INVALID. +2. **No Inference Allowed**: Do NOT keep memories based on implication, emotion, preference, or generalization. Only verbatim or direct restatements of user-provided facts are valid. However, minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +3. **Hallucination = Deletion**: If a memory contains any detail not directly expressed by the user, mark it for deletion. +4. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Examples: +Messages: +- [user]: I love coding in Python. +- [assistant]: That's great! I assume you also contribute to open source projects? +Memory: User enjoys Python and contributes to open source. +Result: {{"keep": false, "reason": "User never stated they contribute to open source; this came from Assistant's assumption."}} + +Messages: +- [user]: I am tired. +- [assistant]: I hear you are tired. Rest is important. +Memory: User stated they are tired. +Result: {{"keep": true, "reason": "Direct restatement of user input, even if Assistant repeated it."}} + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true only if the memory is a direct reflection of the user's explicit words. +- "reason": brief, factual, and cites missing or unsupported content. + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + + +SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT = """ +You are a memory manager. +Your task is to decide if a new memory should be added to the long-term memory, given a list of existing related memories. + +Rules: +1. **Redundancy Check**: If the new memory is completely redundant, already known, or covered by the existing memories, discard it. +2. **New Information**: If the new memory provides new information, details, or updates compared to the existing memories, keep it. +3. **Contradiction**: If the new memory contradicts existing memories but seems valid/newer, keep it (updates). +4. **Context Check**: Use the provided conversation messages to verify if the new memory is grounded in the user's explicit statements. + +Inputs: +Messages: +{messages_inline} + +Candidate Memories (to be evaluated): +{candidates_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input candidate memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true if the memory should be added. +- "reason": brief explanation. + +Important: Output **only** the JSON. No extra text. +""" # Prompt mapping for specialized tasks (e.g., hallucination filtering) PROMPT_MAPPING = { "hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT, + "rewrite": SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT, + "rewrite_user_only": SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT, + "add_before_search": SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT, } diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index f81356886..fd07fbf41 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -1,4 +1,3 @@ -import json import unittest from unittest.mock import MagicMock, patch @@ -8,6 +7,7 @@ from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem @@ -57,7 +57,6 @@ def test_process_chat_data(self): '"summary": "Tom is currently focused on managing a new project with a tight schedule."}' ) self.reader.llm.generate.return_value = mock_response - self.reader.parse_json_result = lambda x: json.loads(x) result = self.reader._process_chat_data(scene_data_info, info) @@ -105,7 +104,7 @@ def test_get_scene_data_info_with_chat(self): def test_parse_json_result_success(self): """Test successful JSON parsing.""" raw_response = '{"summary": "Test summary", "tags": ["test"]}' - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertIsInstance(result, dict) self.assertIn("summary", result) @@ -113,7 +112,7 @@ def test_parse_json_result_success(self): def test_parse_json_result_failure(self): """Test failure in JSON parsing.""" raw_response = "Invalid JSON string" - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertEqual(result, {})