diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 000000000..9199ccdf3 Binary files /dev/null and b/dump.rdb differ diff --git a/evaluation/scripts/locomo/locomo_eval.py b/evaluation/scripts/locomo/locomo_eval.py index 24a216b92..6e7dd4083 100644 --- a/evaluation/scripts/locomo/locomo_eval.py +++ b/evaluation/scripts/locomo/locomo_eval.py @@ -311,7 +311,7 @@ async def main(frame, version="default", options=None, num_runs=1, max_workers=4 with open(response_path) as file: locomo_responses = json.load(file) - num_users = 2 + num_users = 10 all_grades = {} total_responses_count = sum( diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index e835dd5d7..157c3f8ea 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -189,7 +189,9 @@ def search(self, query, user_id, top_k): ) response = requests.request("POST", url, data=payload, headers=self.headers) assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Memory searched successfully", response.text + assert json.loads(response.text)["message"] == "Search completed successfully", ( + response.text + ) return json.loads(response.text)["data"] diff --git a/examples/data/config/mem_scheduler/general_scheduler_config.yaml b/examples/data/config/mem_scheduler/general_scheduler_config.yaml index 2360bb14b..cc3de38a8 100644 --- a/examples/data/config/mem_scheduler/general_scheduler_config.yaml +++ b/examples/data/config/mem_scheduler/general_scheduler_config.yaml @@ -4,7 +4,7 @@ config: act_mem_update_interval: 30 context_window_size: 10 thread_pool_max_workers: 5 - consume_interval_seconds: 1 + consume_interval_seconds: 0.01 working_mem_monitor_capacity: 20 activation_mem_monitor_capacity: 5 enable_parallel_dispatch: true diff --git a/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml index 2d3958e60..cfb2a050c 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml @@ -38,7 +38,7 @@ mem_scheduler: act_mem_update_interval: 30 context_window_size: 10 thread_pool_max_workers: 10 - consume_interval_seconds: 1 + consume_interval_seconds: 0.01 working_mem_monitor_capacity: 20 activation_mem_monitor_capacity: 5 enable_parallel_dispatch: true diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml index cdfa49a76..bd9910300 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml @@ -38,7 +38,7 @@ mem_scheduler: act_mem_update_interval: 30 context_window_size: 10 thread_pool_max_workers: 10 - consume_interval_seconds: 1 + consume_interval_seconds: 0.01 working_mem_monitor_capacity: 20 activation_mem_monitor_capacity: 5 enable_parallel_dispatch: true diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py index d3522f8e1..871dd0258 100644 --- a/examples/mem_scheduler/api_w_scheduler.py +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -17,7 +17,6 @@ print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") print("=====================================\n") -mem_scheduler.memos_message_queue.debug_mode_on() queue = mem_scheduler.memos_message_queue queue.clear() diff --git a/examples/mem_scheduler/task_fair_schedule.py b/examples/mem_scheduler/task_fair_schedule.py new file mode 100644 index 000000000..8b02b1931 --- /dev/null +++ b/examples/mem_scheduler/task_fair_schedule.py @@ -0,0 +1,87 @@ +import sys + +from collections import defaultdict +from pathlib import Path + +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) + + +def make_message(user_id: str, mem_cube_id: str, label: str, idx: int | str) -> ScheduleMessageItem: + return ScheduleMessageItem( + item_id=f"{user_id}:{mem_cube_id}:{label}:{idx}", + user_id=user_id, + mem_cube_id=mem_cube_id, + label=label, + content=f"msg-{idx} for {user_id}/{mem_cube_id}/{label}", + ) + + +def seed_messages_for_test_fairness(queue, combos, per_stream): + # send overwhelm message by one user + (u, c, label) = combos[0] + task_target = 100 + print(f"{u}:{c}:{label} submit {task_target} messages") + for i in range(task_target): + msg = make_message(u, c, label, f"overwhelm_{i}") + queue.submit_messages(msg) + + for u, c, label in combos: + print(f"{u}:{c}:{label} submit {per_stream} messages") + for i in range(per_stream): + msg = make_message(u, c, label, i) + queue.submit_messages(msg) + print("======= seed_messages Done ===========") + + +def count_by_stream(messages): + counts = defaultdict(int) + for m in messages: + key = f"{m.user_id}:{m.mem_cube_id}:{m.label}" + counts[key] += 1 + return counts + + +def run_fair_redis_schedule(batch_size: int = 3): + print("=== Redis Fairness Demo ===") + print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") + mem_scheduler.consume_batch = batch_size + queue = mem_scheduler.memos_message_queue + + # Isolate and clear queue + queue.clear() + + # Define multiple streams: (user_id, mem_cube_id, task_label) + combos = [ + ("u1", "u1", "labelX"), + ("u1", "u1", "labelY"), + ("u2", "u2", "labelX"), + ("u2", "u2", "labelY"), + ] + per_stream = 5 + + # Seed messages evenly across streams + seed_messages_for_test_fairness(queue, combos, per_stream) + + # Compute target batch size (fair split across streams) + print(f"Request batch_size={batch_size} for {len(combos)} streams") + + for _ in range(len(combos)): + # Fetch one brokered pack + msgs = queue.get_messages(batch_size=batch_size) + print(f"Fetched {len(msgs)} messages in first pack") + + # Check fairness: counts per stream + counts = count_by_stream(msgs) + for k in sorted(counts): + print(f"{k}: {counts[k]}") + + +if __name__ == "__main__": + # task 1 fair redis schedule + run_fair_redis_schedule() diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/task_stop_rerun.py new file mode 100644 index 000000000..4664e0eaa --- /dev/null +++ b/examples/mem_scheduler/task_stop_rerun.py @@ -0,0 +1,85 @@ +from pathlib import Path +from time import sleep + +# Note: we skip API handler status/wait utilities in this demo +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +# Debug: Print scheduler configuration +print("=== Scheduler Configuration Debug ===") +print(f"Scheduler type: {type(mem_scheduler).__name__}") +print(f"Config: {mem_scheduler.config}") +print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") +print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}") +print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") +print("=====================================\n") + +queue = mem_scheduler.memos_message_queue + + +# Define a handler function +def my_test_handler(messages: list[ScheduleMessageItem]): + print(f"My test handler received {len(messages)} messages: {[one.item_id for one in messages]}") + for msg in messages: + # Create a file named by task_id (use item_id as numeric id 0..99) + task_id = str(msg.item_id) + file_path = tmp_dir / f"{task_id}.txt" + try: + print(f"writing {file_path}...") + file_path.write_text(f"Task {task_id} processed.\n") + except Exception as e: + print(f"Failed to write {file_path}: {e}") + + +def submit_tasks(): + mem_scheduler.memos_message_queue.clear() + + # Create 100 messages (task_id 0..99) + users = ["user_A", "user_B"] + messages_to_send = [ + ScheduleMessageItem( + item_id=str(i), + user_id=users[i % 2], + mem_cube_id="test_mem_cube", + label=TEST_HANDLER_LABEL, + content=f"Create file for task {i}", + ) + for i in range(100) + ] + # Submit messages in batch and print completion + print(f"Submitting {len(messages_to_send)} messages to the scheduler...") + mem_scheduler.memos_message_queue.submit_messages(messages_to_send) + print(f"Task submission done! tasks in queue: {mem_scheduler.get_tasks_status()}") + + +# Register the handler +TEST_HANDLER_LABEL = "test_handler" +mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) + + +tmp_dir = Path("./tmp") +tmp_dir.mkdir(exist_ok=True) + +# Test stop-and-restart: if tmp already has >1 files, skip submission and print info +existing_count = len(list(Path("tmp").glob("*.txt"))) if Path("tmp").exists() else 0 +if existing_count > 1: + print(f"Skip submission: found {existing_count} files in tmp (>1), continue processing") +else: + submit_tasks() + +# 6. Wait until tmp has 100 files or timeout +poll_interval = 0.01 +expected = 100 +tmp_dir = Path("tmp") +while mem_scheduler.get_tasks_status()["remaining"] != 0: + count = len(list(tmp_dir.glob("*.txt"))) if tmp_dir.exists() else 0 + tasks_status = mem_scheduler.get_tasks_status() + mem_scheduler.print_tasks_status(tasks_status=tasks_status) + print(f"[Monitor] Files in tmp: {count}/{expected}") + sleep(poll_interval) +print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})") + +# 7. Stop the scheduler +print("Stopping the scheduler...") +mem_scheduler.stop() diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 46e7fd108..fd0dfc7f8 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -47,7 +47,9 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: Returns: MemoryResponse with added memory information """ - self.logger.info(f"[AddHandler] Add Req is: {add_req}") + self.logger.info( + f"[DIAGNOSTIC] server_router -> add_handler.handle_add_memories called (Modified at 2025-11-29 18:46). Full request: {add_req.model_dump_json(indent=2)}" + ) if add_req.info: exclude_fields = list_all_fields() diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index 3c0314235..e071eacb3 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -8,7 +8,7 @@ from typing import Any from memos.log import get_logger -from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.memories.textual.tree_text_memory.retrieve.advanced_searcher import AdvancedSearcher @@ -129,7 +129,7 @@ def mem_reader(self): return self.deps.mem_reader @property - def mem_scheduler(self) -> BaseScheduler: + def mem_scheduler(self) -> OptimizedScheduler: """Get scheduler instance.""" return self.deps.mem_scheduler diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 71e384014..609d61124 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -188,6 +188,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): @router.post("/add", summary="add a new memory", response_model=SimpleResponse) def create_memory(memory_req: MemoryCreateRequest): """Create a new memory for a specific user.""" + logger.info("DIAGNOSTIC: /product/add endpoint called. This confirms the new code is deployed.") # Initialize status_tracker outside try block to avoid NameError in except blocks status_tracker = None diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 9de06cd90..9d0280a83 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -198,7 +198,6 @@ def add_node( self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: logger.info(f"[add_node] metadata: {metadata},info: {metadata.get('info')}") - print(f"[add_node] metadata: {metadata},info: {metadata.get('info')}") user_name = user_name if user_name else self.config.user_name if not self.config.use_multi_db and (self.config.user_name or user_name): diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 74dd38fc1..d62dacbc8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -669,7 +669,6 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - print(f"🪶 Creating elabel: {label_name}") conn = self._get_connection() logger.info(f"Creating elabel: {label_name}") try: @@ -1596,7 +1595,6 @@ def search_by_embedding( logger.info( f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" ) - print(f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") where_clauses = [] if scope: where_clauses.append( @@ -1687,13 +1685,12 @@ def search_by_embedding( wrapped_lines = textwrap.wrap( line, width=200, break_long_words=False, break_on_hyphens=False ) - for wrapped_line in wrapped_lines: - print(wrapped_line) + for _wrapped_line in wrapped_lines: + pass else: - print(line) + pass logger.info(f"[search_by_embedding] query: {query}, params: {params}") - print(f"[search_by_embedding] query: {query}, params: {params}") conn = self._get_connection() try: @@ -1714,8 +1711,6 @@ def search_by_embedding( raise results = cursor.fetchall() output = [] - print("=== Raw Results ===:", results) - print(f"=== Results count: {len(results)} ===") for row in results: """ polarId = row[0] # id @@ -1763,7 +1758,6 @@ def get_by_metadata( list[str]: Node IDs whose metadata match the filter conditions. (AND logic). """ logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") - print(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") user_name = user_name if user_name else self._get_config_value("user_name") @@ -1851,7 +1845,6 @@ def get_by_metadata( ids = [] conn = self._get_connection() logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") - print(f"[get_by_metadata] cypher_query: {cypher_query}") try: with conn.cursor() as cursor: cursor.execute(cypher_query) @@ -2278,7 +2271,6 @@ def get_all_memory_items( logger.info( f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" ) - print(f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") user_name = user_name if user_name else self._get_config_value("user_name") if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: @@ -2386,7 +2378,6 @@ def get_all_memory_items( nodes = [] conn = self._get_connection() logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") - print(f"[get_all_memory_items] cypher_query: {cypher_query}") try: with conn.cursor() as cursor: cursor.execute(cypher_query) @@ -2762,7 +2753,6 @@ def add_node( ) -> None: """Add a memory node to the graph.""" logger.info(f"[add_node] id: {id}, memory: {memory}, metadata: {metadata}") - print(f"[add_node] metadata: {metadata}, info: {metadata.get('info')}") # user_name comes from metadata; fallback to config if missing metadata["user_name"] = user_name if user_name else self.config.user_name @@ -2847,9 +2837,6 @@ def add_node( logger.info( f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" ) - print( - f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) else: insert_query = f""" INSERT INTO {self.db_name}_graph."Memory"(id, properties) @@ -2862,10 +2849,6 @@ def add_node( logger.info( f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" ) - print( - f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) - finally: logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") self._return_connection(conn) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index edf50feb1..75d0976a1 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -788,6 +788,9 @@ def process_textual_memory(): timestamp=datetime.utcnow(), task_id=task_id, ) + logger.info( + f"[DIAGNOSTIC] core.add: Submitting message to scheduler: {message_item.model_dump_json(indent=2)}" + ) self.mem_scheduler.memos_message_queue.submit_messages( messages=[message_item] ) diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index 967654d84..bf9f847d0 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -110,7 +110,7 @@ def get_default_config( "act_mem_update_interval": kwargs.get("scheduler_act_mem_update_interval", 300), "context_window_size": kwargs.get("scheduler_context_window_size", 5), "thread_pool_max_workers": kwargs.get("scheduler_thread_pool_max_workers", 10), - "consume_interval_seconds": kwargs.get("scheduler_consume_interval_seconds", 3), + "consume_interval_seconds": kwargs.get("scheduler_consume_interval_seconds", 0.01), "enable_parallel_dispatch": kwargs.get("scheduler_enable_parallel_dispatch", True), "enable_activation_memory": True, }, diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index ed81eeffa..50f21a092 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -21,6 +21,7 @@ from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor +from memos.mem_scheduler.monitors.task_schedule_monitor import TaskScheduleMonitor from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_BATCH, @@ -41,8 +42,6 @@ ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher -from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue -from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils import metrics from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -137,13 +136,19 @@ def __init__(self, config: BaseSchedulerConfig): self.dispatcher = SchedulerDispatcher( config=self.config, memos_message_queue=self.memos_message_queue, - use_redis_queue=self.use_redis_queue, max_workers=self.thread_pool_max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch, status_tracker=self.status_tracker, metrics=self.metrics, submit_web_logs=self._submit_web_logs, ) + # Task schedule monitor: initialize with underlying queue implementation + self.get_status_parallel = self.config.get("get_status_parallel", True) + self.task_schedule_monitor = TaskScheduleMonitor( + memos_message_queue=self.memos_message_queue.memos_message_queue, + dispatcher=self.dispatcher, + get_status_parallel=self.get_status_parallel, + ) # other attributes self._context_lock = threading.Lock() @@ -232,11 +237,6 @@ def initialize_modules( self._cleanup_on_init_failure() raise - # start queue monitor if enabled and a bot is set later - - def debug_mode_on(self): - self.memos_message_queue.debug_mode_on() - def _cleanup_on_init_failure(self): """Clean up resources if initialization fails.""" try: @@ -596,6 +596,11 @@ def _submit_web_logs( Args: messages: Single log message or list of log messages """ + messages_list = [messages] if isinstance(messages, ScheduleLogForWebItem) else messages + for message in messages_list: + logger.info( + f"[DIAGNOSTIC] base_scheduler._submit_web_logs called. Message to publish: {message.model_dump_json(indent=2)}" + ) if self.rabbitmq_config is None: return @@ -720,7 +725,7 @@ def _message_consumer(self) -> None: except Exception as e: # Don't log error for "No messages available in Redis queue" as it's expected if "No messages available in Redis queue" not in str(e): - logger.error(f"Unexpected error in message consumer: {e!s}") + logger.error(f"Unexpected error in message consumer: {e!s}", exc_info=True) time.sleep(self._consume_interval) # Prevent tight error loops def _monitor_loop(self): @@ -940,47 +945,13 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di return result - @staticmethod - def init_task_status(): - return { - "running": 0, - "remaining": 0, - "completed": 0, - } - def get_tasks_status(self): - task_status = self.init_task_status() - memos_message_queue = self.memos_message_queue.memos_message_queue - if isinstance(memos_message_queue, SchedulerRedisQueue): - stream_keys = memos_message_queue.get_stream_keys( - stream_key_prefix=memos_message_queue.stream_key_prefix - ) - for stream_key in stream_keys: - if stream_key not in task_status: - task_status[stream_key] = self.init_task_status() - # For Redis queue, prefer XINFO GROUPS to compute pending - groups_info = memos_message_queue.redis.xinfo_groups(stream_key) - if groups_info: - for group in groups_info: - if group.get("name") == memos_message_queue.consumer_group: - task_status[stream_key]["running"] += int(group.get("pending", 0)) - task_status[stream_key]["remaining"] += memos_message_queue.qsize()[ - stream_key - ] - task_status["running"] += int(group.get("pending", 0)) - task_status["remaining"] += task_status[stream_key]["remaining"] - break - - elif isinstance(memos_message_queue, SchedulerLocalQueue): - running_task_count = self.dispatcher.get_running_task_count() - task_status["running"] = running_task_count - task_status["remaining"] = sum(memos_message_queue.qsize().values()) - else: - logger.error( - f"type of self.memos_message_queue is {memos_message_queue}, which is not supported" - ) - raise NotImplementedError() - return task_status + """Delegate status collection to TaskScheduleMonitor.""" + return self.task_schedule_monitor.get_tasks_status() + + def print_tasks_status(self, tasks_status: dict | None = None) -> None: + """Delegate pretty printing to TaskScheduleMonitor.""" + self.task_schedule_monitor.print_tasks_status(tasks_status=tasks_status) def _gather_queue_stats(self) -> dict: """Collect queue/dispatcher stats for reporting.""" diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 89cd9b7ba..9b1153c87 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -48,6 +48,10 @@ def create_autofilled_log_item( mem_cube_id: str, mem_cube: GeneralMemCube, ) -> ScheduleLogForWebItem: + if mem_cube is None: + logger.error( + "mem_cube is None — this should not happen in production!", stack_info=True + ) text_mem_base: TreeTextMemory = mem_cube.text_mem current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) current_memory_sizes = { @@ -113,9 +117,10 @@ def create_event_log( metadata: list[dict], memory_len: int, memcube_name: str | None = None, + log_content: str | None = None, ) -> ScheduleLogForWebItem: item = self.create_autofilled_log_item( - log_content="", + log_content=log_content or "", label=label, from_memory_type=from_memory_type, to_memory_type=to_memory_type, diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index df843e496..f7c8e9d32 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -369,16 +369,19 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if kb_log_content: event = self.create_event_log( label="knowledgeBaseUpdate", + # 1. 移除 log_content 参数 + # 2. 补充 memory_type from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, mem_cube=self.current_mem_cube, memcube_log_content=kb_log_content, - metadata=None, # Per design doc for KB logs + metadata=None, memory_len=len(kb_log_content), memcube_name=self._map_memcube_name(msg.mem_cube_id), ) + # 3. 后置赋值 log_content event.log_content = ( f"Knowledge Base Memory Update: {len(kb_log_content)} changes." ) @@ -534,6 +537,9 @@ def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> logger.error(f"Error processing feedbackMemory message: {e}", exc_info=True) def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info( + f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}" + ) logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") def process_message(message: ScheduleMessageItem): @@ -541,6 +547,12 @@ def process_message(message: ScheduleMessageItem): user_id = message.user_id mem_cube_id = message.mem_cube_id mem_cube = self.current_mem_cube + if mem_cube is None: + logger.warning( + f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing" + ) + return + content = message.content user_name = message.user_name info = message.info or {} @@ -598,6 +610,9 @@ def _process_memories_with_reader( task_id: str | None = None, info: dict | None = None, ) -> None: + logger.info( + f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader called. mem_ids: {mem_ids}, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}" + ) """ Process memories using mem_reader for enhanced memory processing. @@ -695,6 +710,9 @@ def _process_memories_with_reader( } ) if kb_log_content: + logger.info( + f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}. KB content: {json.dumps(kb_log_content, indent=2)}" + ) event = self.create_event_log( label="knowledgeBaseUpdate", from_memory_type=USER_INPUT_TYPE, @@ -833,6 +851,11 @@ def process_message(message: ScheduleMessageItem): user_id = message.user_id mem_cube_id = message.mem_cube_id mem_cube = self.current_mem_cube + if mem_cube is None: + logger.warning( + f"mem_cube is None for user_id={user_id}, mem_cube_id={mem_cube_id}, skipping processing" + ) + return content = message.content user_name = message.user_name @@ -1058,6 +1081,11 @@ def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> Non def process_message(message: ScheduleMessageItem): try: mem_cube = self.current_mem_cube + if mem_cube is None: + logger.warning( + f"mem_cube is None for user_id={message.user_id}, mem_cube_id={message.mem_cube_id}, skipping processing" + ) + return user_id = message.user_id session_id = message.session_id diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index 6cf3a9e58..2278abc2a 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -209,10 +209,9 @@ def _split_batches( def recall_for_missing_memories( self, query: str, - memories: list[TextualMemoryItem], + memories: list[str], ) -> tuple[str, bool]: - text_memories = [one.memory for one in memories] if memories else [] - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(text_memories)]) + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)]) prompt = self.build_prompt( template_name="enlarge_recall", diff --git a/src/memos/mem_scheduler/monitors/task_schedule_monitor.py b/src/memos/mem_scheduler/monitors/task_schedule_monitor.py new file mode 100644 index 000000000..82e43d858 --- /dev/null +++ b/src/memos/mem_scheduler/monitors/task_schedule_monitor.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from memos.log import get_logger +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue + + +logger = get_logger(__name__) + + +class TaskScheduleMonitor: + """ + Monitor for task scheduling queue status. + + Initialize with the underlying `memos_message_queue` implementation + (either SchedulerRedisQueue or SchedulerLocalQueue) and optionally a + dispatcher for local running task counts. + """ + + def __init__( + self, + memos_message_queue: SchedulerRedisQueue | SchedulerLocalQueue, + dispatcher: object | None = None, + get_status_parallel: bool = False, + ) -> None: + self.queue = memos_message_queue + self.dispatcher = dispatcher + self.get_status_parallel = get_status_parallel + + @staticmethod + def init_task_status() -> dict: + return {"running": 0, "remaining": 0} + + def get_tasks_status(self) -> dict: + if isinstance(self.queue, SchedulerRedisQueue): + return self._get_redis_tasks_status() + elif isinstance(self.queue, SchedulerLocalQueue): + return self._get_local_tasks_status() + else: + logger.error( + f"Unsupported queue type for TaskScheduleMonitor: {type(self.queue).__name__}" + ) + raise NotImplementedError() + + def print_tasks_status(self, tasks_status: dict | None = None) -> None: + """ + Nicely print task queue status grouped by "user_id:mem_cube_id". + + For Redis queues, stream keys follow the pattern + "{prefix}:{user_id}:{mem_cube_id}:{task_label}" — group by user/mem + and show per-task_label counts. For local queues, only totals are + available, so print aggregate metrics. + """ + try: + status = tasks_status if isinstance(tasks_status, dict) else self.get_tasks_status() + except Exception as e: + logger.warning(f"Failed to get tasks status: {e}") + return + + if not isinstance(status, dict) or not status: + print("[Tasks] No status available.") + return + + total_running = int(status.get("running", 0) or 0) + total_remaining = int(status.get("remaining", 0) or 0) + + header = f"Task Queue Status | running={total_running}, remaining={total_remaining}" + print(header) + + if isinstance(self.queue, SchedulerRedisQueue): + # Build grouping: {"user_id:mem_cube_id": {task_label: {counts}}} + try: + from collections import defaultdict + except Exception: + defaultdict = None + + group_stats = ( + defaultdict(lambda: defaultdict(lambda: {"running": 0, "remaining": 0})) + if defaultdict is not None + else {} + ) + + # Keys that look like stream entries (exclude the totals keys) + stream_keys = [ + k for k in status if isinstance(k, str) and k not in ("running", "remaining") + ] + + for stream_key in stream_keys: + stream_stat = status.get(stream_key, {}) + if not isinstance(stream_stat, dict): + continue + parts = stream_key.split(":") + # Safely parse from the right to avoid prefix colons + if len(parts) < 3: + # Not enough parts to form user:mem:label — skip + continue + task_label = parts[-1] + mem_cube_id = parts[-2] + user_id = parts[-3] + group_key = f"{user_id}:{mem_cube_id}" + + try: + group_stats[group_key][task_label]["running"] += int( + stream_stat.get("running", 0) or 0 + ) + group_stats[group_key][task_label]["remaining"] += int( + stream_stat.get("remaining", 0) or 0 + ) + except Exception: + # Keep printing robust in face of bad data + pass + + if not group_stats: + print("[Tasks] No per-stream details found.") + return + + # Pretty print per group + for group_key in sorted(group_stats.keys()): + print("") + print(f"[{group_key}]") + + labels = sorted(group_stats[group_key].keys()) + label_width = max(10, max((len(label) for label in labels), default=10)) + # Table header + header_line = f"{'Task Label'.ljust(label_width)} {'Running':>7} {'Remaining':>9}" + sep_line = f"{'-' * label_width} {'-' * 7} {'-' * 9}" + print(header_line) + print(sep_line) + + for label in labels: + counts = group_stats[group_key][label] + line = ( + f"{label.ljust(label_width)} " + f"{int(counts.get('running', 0)):>7} " + f"{int(counts.get('remaining', 0)):>9} " + ) + print(line) + + elif isinstance(self.queue, SchedulerLocalQueue): + # Local queue: only aggregate totals available; print them clearly + print("") + print("[Local Queue Totals]") + label_width = 12 + header_line = f"{'Metric'.ljust(label_width)} {'Value':>7}" + sep_line = f"{'-' * label_width} {'-' * 7}" + print(header_line) + print(sep_line) + print(f"{'Running'.ljust(label_width)} {total_running:>7}") + print(f"{'Remaining'.ljust(label_width)} {total_remaining:>7}") + + def _get_local_tasks_status(self) -> dict: + task_status = self.init_task_status() + + try: + # remaining is the sum of per-stream qsize + qsize_map = self.queue.qsize() + task_status["remaining"] = sum(v for k, v in qsize_map.items() if isinstance(v, int)) + # running from dispatcher if available + if self.dispatcher and hasattr(self.dispatcher, "get_running_task_count"): + task_status["running"] = int(self.dispatcher.get_running_task_count()) + except Exception as e: + logger.warning(f"Failed to collect local queue status: {e}") + return task_status + + def _get_redis_tasks_status(self) -> dict: + task_status = self.init_task_status() + + stream_keys = self.queue.get_stream_keys(stream_key_prefix=self.queue.stream_key_prefix) + + # Parallel path: use asyncio.to_thread for blocking redis calls + if self.get_status_parallel: + try: + import asyncio + + async def _collect_async() -> dict: + # Collect xlen and group info in parallel for each stream + xlen_tasks = [ + asyncio.to_thread(self.queue.redis.xlen, stream_key) + for stream_key in stream_keys + ] + groups_tasks = [ + asyncio.to_thread(self.queue.redis.xinfo_groups, stream_key) + for stream_key in stream_keys + ] + xlen_results = await asyncio.gather(*xlen_tasks, return_exceptions=True) + groups_results = await asyncio.gather(*groups_tasks, return_exceptions=True) + + local = self.init_task_status() + for idx, stream_key in enumerate(stream_keys): + local[stream_key] = self.init_task_status() + groups_info = groups_results[idx] if idx < len(groups_results) else None + xlen_val = xlen_results[idx] if idx < len(xlen_results) else 0 + if isinstance(xlen_val, Exception): + xlen_val = 0 + if isinstance(groups_info, Exception): + continue + pending = 0 + if groups_info: + for group in groups_info: + if group.get("name") == self.queue.consumer_group: + pending = int(group.get("pending", 0)) + break + # Remaining = total messages (xlen) - pending for our group + remaining = max(0, int(xlen_val or 0)) + local[stream_key]["running"] += pending + local[stream_key]["remaining"] += remaining + local["running"] += pending + local["remaining"] += remaining + return local + + try: + asyncio.get_running_loop() + loop_running = True + except RuntimeError: + loop_running = False + + if not loop_running: + return asyncio.run(_collect_async()) + except Exception as e: + logger.debug(f"Parallel status collection failed, fallback to sequential: {e}") + + # Sequential fallback + for stream_key in stream_keys: + task_status[stream_key] = self.init_task_status() + try: + groups_info = self.queue.redis.xinfo_groups(stream_key) + except Exception: + groups_info = None + try: + xlen_val = int(self.queue.redis.xlen(stream_key)) + except Exception: + xlen_val = 0 + if groups_info: + for group in groups_info: + if group.get("name") == self.queue.consumer_group: + pending = int(group.get("pending", 0)) + remaining = max(0, xlen_val) + task_status[stream_key]["running"] += pending + task_status[stream_key]["remaining"] += remaining + task_status["running"] += pending + task_status["remaining"] += remaining + break + + return task_status diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index e25c7cb1c..a85c533a0 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -149,12 +149,13 @@ def mix_search_memories( "chat_history": search_req.chat_history, } - fast_retrieved_memories = self.searcher.retrieve( + raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FAST, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, search_filter=search_filter, search_priority=search_priority, info=info, @@ -167,90 +168,24 @@ def mix_search_memories( turns=self.history_memory_turns, ) logger.info(f"Found {len(history_memories)} history memories.") - if not history_memories: - memories = self.searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - else: - # if history memories can directly answer - sorted_history_memories = self.reranker.rerank( - query=search_req.query, # Use search_req.query instead of undefined query - graph_results=history_memories, # Pass TextualMemoryItem objects directly - top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k - search_priority=search_priority, - ) - logger.info(f"Reranked {len(sorted_history_memories)} history memories.") - processed_hist_mem = self.searcher.post_retrieve( - retrieved_results=sorted_history_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - - can_answer = self.retriever.evaluate_memory_answer_ability( - query=search_req.query, memory_texts=[one.memory for one in processed_hist_mem] - ) - - if can_answer: - logger.info("History memories can answer the query.") - sorted_results = fast_retrieved_memories + sorted_history_memories - combined_results = self.searcher.post_retrieve( - retrieved_results=sorted_results, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - memories = combined_results[: search_req.top_k] - else: - logger.info("History memories cannot answer the query, enhancing memories.") - sorted_results = fast_retrieved_memories + sorted_history_memories - combined_results = self.searcher.post_retrieve( - retrieved_results=sorted_results, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - enhanced_memories, _ = self.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=combined_results, - ) - if len(enhanced_memories) < search_req.top_k: - logger.info( - f"Enhanced memories ({len(enhanced_memories)}) are less than top_k ({search_req.top_k}). Recalling for more." - ) - missing_info_hint, trigger = self.retriever.recall_for_missing_memories( - query=search_req.query, - memories=combined_results, - ) - retrieval_size = search_req.top_k - len(enhanced_memories) - if trigger: - logger.info(f"Triggering additional search with hint: {missing_info_hint}") - additional_memories = self.searcher.search( - query=missing_info_hint, - user_name=user_context.mem_cube_id, - top_k=retrieval_size, - mode=SearchMode.FAST, - memory_type="All", - search_filter=search_filter, - search_priority=search_priority, - info=info, - ) - else: - logger.info("Not triggering additional search, using combined results.") - additional_memories = combined_results[:retrieval_size] - logger.info( - f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" - ) - enhanced_memories += additional_memories - - memories = enhanced_memories[: search_req.top_k] + # if history memories can directly answer + sorted_history_memories = self.reranker.rerank( + query=search_req.query, # Use search_req.query instead of undefined query + graph_results=history_memories, # Pass TextualMemoryItem objects directly + top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k + search_filter=search_filter, + ) + logger.info(f"Reranked {len(sorted_history_memories)} history memories.") + merged_memories = self.searcher.post_retrieve( + retrieved_results=raw_retrieved_memories + sorted_history_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + memories = merged_memories[: search_req.top_k] formatted_memories = [format_textual_memory_item(item) for item in memories] - logger.info("Submitted memory history async task.") self.submit_memory_history_async_task( search_req=search_req, user_context=user_context, @@ -259,7 +194,6 @@ def mix_search_memories( "formatted_memories": formatted_memories, }, ) - return formatted_memories def update_search_memories_to_redis( diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index e76728286..71700bc63 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -25,7 +25,7 @@ DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD_POOL_MAX_WORKERS = 50 DEFAULT_CONSUME_INTERVAL_SECONDS = 0.01 -DEFAULT_CONSUME_BATCH = 1 +DEFAULT_CONSUME_BATCH = 3 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 @@ -62,3 +62,6 @@ DEFAULT_MAX_QUERY_KEY_WORDS = 1000 DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05] DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 + +# task queue +DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.3" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 87738671c..9f39d9888 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -34,6 +34,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) redis_message_id: str = Field(default="", description="the message get from redis stream") + stream_key: str = Field("", description="stream_key for identifying the queue in line") user_id: str = Field(..., description="user id") mem_cube_id: str = Field(..., description="memcube id") session_id: str = Field(default="", description="Session ID for soft-filtering memories") @@ -84,6 +85,7 @@ def to_dict(self) -> dict: "content": self.content, "timestamp": self.timestamp.isoformat(), "user_name": self.user_name, + "task_id": self.task_id if self.task_id is not None else "", } @classmethod @@ -97,6 +99,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), user_name=data.get("user_name"), + task_id=data.get("task_id"), ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index c361a77a2..4570461c5 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -16,6 +16,8 @@ ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue +from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -39,8 +41,7 @@ class SchedulerDispatcher(BaseSchedulerModule): def __init__( self, max_workers: int = 30, - memos_message_queue: Any | None = None, - use_redis_queue: bool | None = None, + memos_message_queue: ScheduleTaskQueue | None = None, enable_parallel_dispatch: bool = True, config=None, status_tracker: TaskStatusTracker | None = None, @@ -53,8 +54,12 @@ def __init__( # Main dispatcher thread pool self.max_workers = max_workers - self.memos_message_queue = memos_message_queue - self.use_redis_queue = use_redis_queue + # Accept either a ScheduleTaskQueue wrapper or a concrete queue instance + self.memos_message_queue = ( + memos_message_queue.memos_message_queue + if hasattr(memos_message_queue, "memos_message_queue") + else memos_message_queue + ) # Get multi-task timeout from config self.multi_task_running_timeout = ( @@ -87,8 +92,6 @@ def __init__( # Task tracking for monitoring self._running_tasks: dict[str, RunningTaskItem] = {} self._task_lock = threading.Lock() - self._completed_tasks = [] - self.completed_tasks_max_show_size = 10 # Configure shutdown wait behavior from config or default self.stop_wait = ( @@ -159,13 +162,17 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): self.metrics.task_completed(user_id=m.user_id, task_type=m.label) # acknowledge redis messages - if self.use_redis_queue and self.memos_message_queue is not None: + if ( + isinstance(self.memos_message_queue, SchedulerRedisQueue) + and self.memos_message_queue is not None + ): for msg in messages: redis_message_id = msg.redis_message_id # Acknowledge message processing self.memos_message_queue.ack_message( user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, + task_label=msg.label, redis_message_id=redis_message_id, ) @@ -174,9 +181,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if task_item.item_id in self._running_tasks: task_item.mark_completed(result) del self._running_tasks[task_item.item_id] - self._completed_tasks.append(task_item) - if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks.pop(0) logger.info(f"Task completed: {task_item.get_execution_info()}") return result @@ -192,8 +196,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if task_item.item_id in self._running_tasks: task_item.mark_failed(str(e)) del self._running_tasks[task_item.item_id] - if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks.pop(0) logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") raise @@ -329,7 +331,8 @@ def stats(self) -> dict[str, int]: except Exception: running = 0 try: - inflight = len(self._futures) + with self._task_lock: + inflight = len(self._futures) except Exception: inflight = 0 try: @@ -342,7 +345,8 @@ def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None: logger.debug(f"Using _default_message_handler to deal with messages: {messages}") def _handle_future_result(self, future): - self._futures.remove(future) + with self._task_lock: + self._futures.discard(future) try: future.result() # this will throw exception except Exception as e: @@ -383,7 +387,7 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): messages=msgs, ) - # Add to running tasks + # Uniformly register the task before execution with self._task_lock: self._running_tasks[task_item.item_id] = task_item @@ -393,12 +397,16 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): # dispatch to different handler logger.debug(f"Task started: {task_item.get_execution_info()}") if self.enable_parallel_dispatch and self.dispatcher_executor is not None: - # Capture variables in lambda to avoid loop variable issues - _ = self.dispatcher_executor.submit(wrapped_handler, msgs) + # Submit and track the future + future = self.dispatcher_executor.submit(wrapped_handler, msgs) + with self._task_lock: + self._futures.add(future) + future.add_done_callback(self._handle_future_result) logger.info( f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}." ) else: + # For synchronous execution, the wrapper will run and remove the task upon completion wrapped_handler(msgs) def join(self, timeout: float | None = None) -> bool: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index f7e3eac15..69cfc0af9 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -38,8 +38,8 @@ def __init__( f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" ) - def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: - stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}" return stream_key def put( @@ -61,7 +61,11 @@ def put( queue.Full: If the queue is full and block=False or timeout expires. Exception: Any underlying error during queue.put() operation. """ - stream_key = self.get_stream_key(user_id=message.user_id, mem_cube_id=message.mem_cube_id) + stream_key = self.get_stream_key( + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label + ) + + message.stream_key = stream_key # Create the queue if it doesn't exist yet if stream_key not in self.queue_streams: diff --git a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py new file mode 100644 index 000000000..d03648bba --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py @@ -0,0 +1,47 @@ +""" +Scheduler Orchestrator for Redis-backed task queues. + +This module provides an orchestrator class that works with `SchedulerRedisQueue` to: +- Broker tasks from Redis streams according to per-user priority weights. +- Maintain a cache of fetched messages and assemble balanced batches across + `(user_id, mem_cube_id, task_label)` groups. + +Stream format: +- Keys follow: `{prefix}:{user_id}:{mem_cube_id}:{task_label}` + +Default behavior: +- All users have priority 1, so fetch sizes are equal per user. +""" + +from __future__ import annotations + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class SchedulerOrchestrator: + def __init__(self, queue): + """ + Args: + queue: An instance of `SchedulerRedisQueue`. + """ + self.queue = queue + # Cache of fetched messages grouped by (user_id, mem_cube_id, task_label) + self._cache = None + + def get_stream_priorities(self) -> None | dict: + return None + + def get_stream_quotas(self, stream_keys, consume_batch_size) -> dict: + stream_priorities = self.get_stream_priorities() + stream_quotas = {} + for stream_key in stream_keys: + if stream_priorities is None: + # Distribute per-stream evenly + stream_quotas[stream_key] = consume_batch_size + else: + # TODO: not implemented yet + stream_quotas[stream_key] = consume_batch_size + return stream_quotas diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index dc2b9af26..22a044358 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -7,13 +7,18 @@ import os import re +import threading import time +from collections import deque from collections.abc import Callable from uuid import uuid4 +from memos.context.context import ContextThread from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STREAM_KEY_PREFIX from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -35,7 +40,8 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, stream_key_prefix: str = os.getenv( - "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", "scheduler:messages:stream" + "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", + DEFAULT_STREAM_KEY_PREFIX, ), consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", @@ -78,20 +84,97 @@ def __init__( # Task tracking for mem_scheduler_wait compatibility self._unfinished_tasks = 0 + # Broker flush threshold and async refill control + self.task_broker_flush_bar = 10 + self._refill_lock = threading.Lock() + self._refill_thread: ContextThread | None = None + + logger.info( + f"[REDIS_QUEUE] Initialized with stream_prefix='{self.stream_key_prefix}', " + f"consumer_group='{self.consumer_group}', consumer_name='{self.consumer_name}'" + ) + # Auto-initialize Redis connection if self.auto_initialize_redis(): self._is_connected = True self.seen_streams = set() - # Task Broker - # Task Orchestrator + self.message_pack_cache = deque() + self.orchestrator = SchedulerOrchestrator(queue=self) - def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: - stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}" return stream_key + def task_broker( + self, + consume_batch_size: int, + ) -> list[list[ScheduleMessageItem]]: + stream_keys = self.get_stream_keys(stream_key_prefix=self.stream_key_prefix) + if not stream_keys: + return [] + + stream_quotas = self.orchestrator.get_stream_quotas( + stream_keys=stream_keys, consume_batch_size=consume_batch_size + ) + cache: list[ScheduleMessageItem] = [] + for stream_key in stream_keys: + messages = self.get( + stream_key=stream_key, + block=False, + batch_size=stream_quotas[stream_key], + ) + cache.extend(messages) + + # pack messages + packed: list[list[ScheduleMessageItem]] = [] + for i in range(0, len(cache), consume_batch_size): + packed.append(cache[i : i + consume_batch_size]) + # return packed list without overwriting existing cache + return packed + + def _async_refill_cache(self, batch_size: int) -> None: + """Background thread to refill message cache without blocking get_messages.""" + try: + logger.debug(f"Starting async cache refill with batch_size={batch_size}") + new_packs = self.task_broker(consume_batch_size=batch_size) + logger.debug(f"task_broker returned {len(new_packs)} packs") + with self._refill_lock: + for pack in new_packs: + if pack: # Only add non-empty packs + self.message_pack_cache.append(pack) + logger.debug(f"Added pack with {len(pack)} messages to cache") + logger.debug(f"Cache refill complete, cache size now: {len(self.message_pack_cache)}") + except Exception as e: + logger.warning(f"Async cache refill failed: {e}", exc_info=True) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + if self.message_pack_cache: + # Trigger async refill if below threshold (non-blocking) + if len(self.message_pack_cache) < self.task_broker_flush_bar and ( + self._refill_thread is None or not self._refill_thread.is_alive() + ): + logger.debug( + f"Triggering async cache refill: cache size {len(self.message_pack_cache)} < {self.task_broker_flush_bar}" + ) + self._refill_thread = ContextThread( + target=self._async_refill_cache, args=(batch_size,), name="redis-cache-refill" + ) + self._refill_thread.start() + else: + logger.debug(f"The size of message_pack_cache is {len(self.message_pack_cache)}") + else: + new_packs = self.task_broker(consume_batch_size=batch_size) + for pack in new_packs: + if pack: # Only add non-empty packs + self.message_pack_cache.append(pack) + if len(self.message_pack_cache) == 0: + return [] + else: + return self.message_pack_cache.popleft() + def _ensure_consumer_group(self, stream_key) -> None: """Ensure the consumer group exists for the stream.""" if not self._redis_conn: @@ -135,13 +218,15 @@ def put( try: stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) if stream_key not in self.seen_streams: self.seen_streams.add(stream_key) self._ensure_consumer_group(stream_key=stream_key) + message.stream_key = stream_key + # Convert message to dictionary for Redis storage message_data = message.to_dict() @@ -158,10 +243,31 @@ def put( logger.error(f"Failed to add message to Redis queue: {e}") raise - def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None: - stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) + def ack_message( + self, user_id: str, mem_cube_id: str, task_label: str, redis_message_id + ) -> None: + stream_key = self.get_stream_key( + user_id=user_id, mem_cube_id=mem_cube_id, task_label=task_label + ) + # No-op if not connected or message doesn't come from Redis + if not self._redis_conn: + logger.debug( + f"Skip ack: Redis not connected for stream '{stream_key}', msg_id='{redis_message_id}'" + ) + return + if not redis_message_id: + logger.debug( + f"Skip ack: Empty redis_message_id for stream '{stream_key}', user_id='{user_id}', label='{task_label}'" + ) + return - self.redis.xack(stream_key, self.consumer_group, redis_message_id) + try: + self._redis_conn.xack(stream_key, self.consumer_group, redis_message_id) + except Exception as e: + logger.warning( + f"xack failed for stream '{stream_key}', msg_id='{redis_message_id}': {e}" + ) + return # Optionally delete the message from the stream to keep it clean if self.auto_delete_acked: @@ -190,12 +296,14 @@ def get( redis_timeout = None # Non-blocking # Read messages from the consumer group + # 1) Read remaining/new messages first (not yet delivered to any consumer) + new_messages: list[tuple[str, list[tuple[str, dict]]]] = [] try: - messages = self._redis_conn.xreadgroup( + new_messages = self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, {stream_key: ">"}, - count=batch_size if not batch_size else 1, + count=(batch_size if batch_size is not None else None), block=redis_timeout, ) except Exception as read_err: @@ -203,18 +311,69 @@ def get( err_msg = str(read_err).lower() if "nogroup" in err_msg or "no such key" in err_msg: logger.warning( - f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry." + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (new)." ) self._ensure_consumer_group(stream_key=stream_key) - messages = self._redis_conn.xreadgroup( + new_messages = self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, {stream_key: ">"}, - count=batch_size if not batch_size else 1, + count=(batch_size if batch_size is not None else None), block=redis_timeout, ) else: raise + + # 2) If needed, read pending messages for THIS consumer only + pending_messages: list[tuple[str, list[tuple[str, dict]]]] = [] + need_pending_count = None + if batch_size is None: + # No batch_size: prefer returning a single new message; if none, fetch one pending + if not new_messages: + need_pending_count = 1 + else: + # With batch_size: fill from pending if new insufficient + new_count = sum(len(sm) for _s, sm in new_messages) if new_messages else 0 + need_pending = max(0, batch_size - new_count) + need_pending_count = need_pending if need_pending > 0 else 0 + + if need_pending_count: + try: + pending_messages = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {stream_key: "0"}, # read only this consumer's pending + count=need_pending_count, + block=None, # do not block when checking pending + ) + except Exception as read_err: + # Handle missing group/stream by creating and retrying once + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (pending)." + ) + self._ensure_consumer_group(stream_key=stream_key) + try: + pending_messages = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {stream_key: "0"}, + count=need_pending_count, + block=None, + ) + except Exception: + pending_messages = [] + else: + pending_messages = [] + + # Combine: new first, then pending + messages = [] + if new_messages: + messages.extend(new_messages) + if pending_messages: + messages.extend(pending_messages) + result_messages = [] for _stream, stream_messages in messages: @@ -247,22 +406,6 @@ def get( logger.error(f"Failed to get message from Redis queue: {e}") raise - def get_nowait( - self, user_id: str, mem_cube_id: str, batch_size: int | None = None - ) -> list[ScheduleMessageItem]: - """ - Get messages from the Redis queue without blocking (Queue-compatible interface). - - Returns: - List of SchedulerMessageItem objects - - Raises: - Empty: If no message is available - """ - return self.get( - user_id=user_id, mem_cube_id=mem_cube_id, block=False, batch_size=batch_size - ) - def qsize(self) -> dict: """ Get the current size of the Redis queue (Queue-compatible interface). @@ -320,12 +463,13 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: def size(self) -> int: """ - Get the current size of the Redis queue (alias for qsize). + Get the current size of the Redis queue (total message count from qsize dict). Returns: - Number of messages in the queue + Total number of messages across all streams """ - return self.qsize() + qsize_result = self.qsize() + return qsize_result.get("total_size", 0) def empty(self) -> bool: """ @@ -334,7 +478,7 @@ def empty(self) -> bool: Returns: True if the queue is empty, False otherwise """ - return self.qsize() == 0 + return self.size() == 0 def full(self) -> bool: """ @@ -348,7 +492,7 @@ def full(self) -> bool: """ if self.maxsize <= 0: return False - return self.qsize() >= self.maxsize + return self.size() >= self.maxsize def join(self) -> None: """ @@ -358,18 +502,22 @@ def join(self) -> None: which is complex. For now, this is a no-op. """ - def clear(self) -> None: + def clear(self, stream_key=None) -> None: """Clear all messages from the queue.""" if not self._is_connected or not self._redis_conn: return try: - stream_keys = self.get_stream_keys() - - for stream_key in stream_keys: - # Delete the entire stream + if stream_key is not None: self._redis_conn.delete(stream_key) logger.info(f"Cleared Redis stream: {stream_key}") + else: + stream_keys = self.get_stream_keys() + + for stream_key in stream_keys: + # Delete the entire stream + self._redis_conn.delete(stream_key) + logger.info(f"Cleared Redis stream: {stream_key}") except Exception as e: logger.error(f"Failed to clear Redis queue: {e}") diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 6d824f4b1..a1285098e 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -35,8 +35,9 @@ def __init__( def ack_message( self, - user_id, - mem_cube_id, + user_id: str, + mem_cube_id: str, + task_label: str, redis_message_id, ) -> None: if not isinstance(self.memos_message_queue, SchedulerRedisQueue): @@ -46,14 +47,10 @@ def ack_message( self.memos_message_queue.ack_message( user_id=user_id, mem_cube_id=mem_cube_id, + task_label=task_label, redis_message_id=redis_message_id, ) - def debug_mode_on(self): - self.memos_message_queue.stream_key_prefix = ( - f"debug_mode:{self.memos_message_queue.stream_key_prefix}" - ) - def get_stream_keys(self) -> list[str]: if isinstance(self.memos_message_queue, SchedulerRedisQueue): stream_keys = self.memos_message_queue.get_stream_keys() @@ -66,6 +63,11 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt if isinstance(messages, ScheduleMessageItem): messages = [messages] + for msg in messages: + msg.stream_key = self.memos_message_queue.get_stream_key( + user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, task_label=msg.label + ) + if len(messages) < 1: logger.error("Submit empty") elif len(messages) == 1: @@ -97,6 +99,8 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + if isinstance(self.memos_message_queue, SchedulerRedisQueue): + return self.memos_message_queue.get_messages(batch_size=batch_size) stream_keys = self.get_stream_keys() if len(stream_keys) == 0: diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index 7b0bcea34..27ca708c6 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -215,7 +215,7 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - logger.error(f"Error in {func.__name__}: {e}", exc_info=True) + logger.error(f"Error in {func.__name__}: {e}", stack_info=True) return wrapper diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 2762ddaca..68d265f81 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -1,4 +1,5 @@ import json +import os import ssl import threading import time @@ -69,6 +70,16 @@ def initialize_rabbitmq( Establish connection to RabbitMQ using pika. """ try: + # Skip remote initialization in CI/pytest unless explicitly enabled + enable_env = os.getenv("MEMOS_ENABLE_RABBITMQ", "").lower() == "true" + in_ci = os.getenv("CI", "").lower() == "true" + in_pytest = os.getenv("PYTEST_CURRENT_TEST") is not None + if (in_ci or in_pytest) and not enable_env: + logger.info( + "Skipping RabbitMQ initialization in CI/test environment. Set MEMOS_ENABLE_RABBITMQ=true to enable." + ) + return + from pika.adapters.select_connection import SelectConnection if config is None: @@ -270,15 +281,36 @@ def rabbitmq_publish_message(self, message: dict): """ import pika + exchange_name = self.rabbitmq_exchange_name + routing_key = self.rabbit_queue_name + + if message.get("label") == "knowledgeBaseUpdate": + kb_specific_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") + + if kb_specific_exchange_name: + exchange_name = kb_specific_exchange_name + + routing_key = "" # User specified empty routing key for KB updates + + logger.info( + f"[DIAGNOSTIC] Publishing KB Update message. " + f"ENV_EXCHANGE_NAME_USED: {kb_specific_exchange_name is not None}. " + f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." + ) + logger.info(f" - Message Content: {json.dumps(message, indent=2)}") + with self._rabbitmq_lock: if not self.is_rabbitmq_connected(): logger.error("Cannot publish - no active connection") return False + logger.info( + f"[DIAGNOSTIC] rabbitmq_service.rabbitmq_publish_message: Attempting to publish message. Exchange: {exchange_name}, Routing Key: {routing_key}, Message Content: {json.dumps(message, indent=2)}" + ) try: self.rabbitmq_channel.basic_publish( - exchange=self.rabbitmq_exchange_name, - routing_key=self.rabbit_queue_name, + exchange=exchange_name, + routing_key=routing_key, body=json.dumps(message), properties=pika.BasicProperties( delivery_mode=2, # Persistent diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index e79553f33..d7ca6565f 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -111,6 +111,16 @@ def auto_initialize_redis(self) -> bool: Returns: bool: True if Redis connection is successfully established, False otherwise """ + # Skip remote initialization in CI/pytest unless explicitly enabled + enable_env = os.getenv("MEMOS_ENABLE_REDIS", "").lower() == "true" + in_ci = os.getenv("CI", "").lower() == "true" + in_pytest = os.getenv("PYTEST_CURRENT_TEST") is not None + if (in_ci or in_pytest) and not enable_env: + logger.info( + "Skipping Redis auto-initialization in CI/test environment. Set MEMOS_ENABLE_REDIS=true to enable." + ) + return False + import redis # Strategy 1: Try to initialize from config diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 9c892d8b8..6a10087f9 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -50,7 +50,7 @@ def __init__( self.stage_retrieve_top = 3 self.process_llm = process_llm - self.thinking_stages = 0 # TODO: to increase thinking depth when the algorithm is reliable + self.thinking_stages = 3 self.max_retry_times = 2 self.deep_search_top_k_bar = 2 @@ -72,8 +72,7 @@ def stage_retrieve( query: str, previous_retrieval_phrases: list[str], text_memories: str, - context: str | None = None, - ) -> tuple[bool, str, str, list[str]]: + ) -> tuple[bool, str, list[str]]: """Run a retrieval-expansion stage and parse structured LLM output. Returns a tuple of: @@ -94,8 +93,6 @@ def stage_retrieve( "previous_retrieval_phrases": prev_phrases_text, "memories": text_memories, } - if context is not None: - args["context"] = context prompt = self.build_prompt(**args) max_attempts = max(0, self.max_retry_times) + 1 @@ -112,8 +109,6 @@ def stage_retrieve( reason = result.get("reason", "") - context_out = str(result.get("context", "")) - phrases_val = result.get("retrieval_phrases", result.get("retrival_phrases", [])) if isinstance(phrases_val, list): retrieval_phrases = [str(p).strip() for p in phrases_val if str(p).strip()] @@ -122,7 +117,7 @@ def stage_retrieve( else: retrieval_phrases = [] - return can_answer, reason, context_out, retrieval_phrases + return can_answer, reason, retrieval_phrases except Exception as e: if attempt < max_attempts: @@ -135,39 +130,6 @@ def stage_retrieve( ) raise e - def summarize_memories(self, query: str, context: str, text_memories: str, top_k: int): - args = { - "template_name": "memory_summary", - "query": query, - "context": context, - "memories": text_memories, - "top_k": top_k, - } - - prompt = self.build_prompt(**args) - - max_attempts = max(0, self.max_retry_times) + 1 - for attempt in range(1, max_attempts + 1): - try: - llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) - result = parse_structured_output(content=llm_response) - context, mem_list = result["context"], result["memories"] - if not isinstance(mem_list, list): - logger.error(f"The result of summarize_memories is {result}") - return context, mem_list - except Exception as e: - if attempt < max_attempts: - logger.debug( - f"[summarize_memories]🔁 retry {attempt}/{max_attempts} failed: {e!s}" - ) - time.sleep(1) - else: - logger.error( - f"[summarize_memories]❌ all {max_attempts} attempts failed: {e!s}; \nprompt: {prompt}", - exc_info=True, - ) - raise e - def judge_memories(self, query: str, text_memories: str): args = { "template_name": "memory_judgement", @@ -226,22 +188,32 @@ def get_final_memories(self, user_id: str, top_k: int, mem_list: list[str]): result_memories = enhanced_memories[:top_k] return result_memories - def recreate_enhancement( + def memory_recreate_enhancement( self, query: str, + top_k: int, text_memories: list[str], retries: int, ) -> list: attempt = 0 text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(text_memories)]) prompt_name = "memory_recreate_enhancement" - prompt = self.build_prompt(template_name=prompt_name, query=query, memories=text_memories) + prompt = self.build_prompt( + template_name=prompt_name, query=query, top_k=top_k, memories=text_memories + ) llm_response = None while attempt <= max(0, retries) + 1: try: llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) processed_text_memories = parse_structured_output(content=llm_response) + logger.debug( + f"[memory_recreate_enhancement]\n " + f"- original memories: \n" + f"{text_memories}\n" + f"- final memories: \n" + f"{processed_text_memories['answer']}" + ) return processed_text_memories["answer"] except Exception as e: attempt += 1 @@ -281,16 +253,15 @@ def deep_search( user_name=user_name, info=info, ) - if top_k < self.deep_search_top_k_bar or len(memories) == 0: + if len(memories) == 0: logger.warning("Requirements not met; returning memories as-is.") return memories user_id = memories[0].metadata.user_id - context = None mem_list, _ = self.tree_memories_to_text_memories(memories=memories) retrieved_memories = copy.deepcopy(retrieved_memories) - retrieved_memories_from_deep_search = [] + rewritten_flag = False for current_stage_id in range(self.thinking_stages + 1): try: # at last @@ -306,179 +277,31 @@ def deep_search( f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " f"final can_answer: {can_answer}; reason: {reason}" ) - mem_list = self.recreate_enhancement( - query=query, text_memories=mem_list, retries=self.max_retry_times - ) - enhanced_memories = self.get_final_memories( - user_id=user_id, top_k=top_k, mem_list=mem_list - ) - return enhanced_memories - - can_answer, reason, context, retrieval_phrases = self.stage_retrieve( - stage_id=current_stage_id + 1, - query=query, - previous_retrieval_phrases=previous_retrieval_phrases, - context=context, - text_memories="- " + "\n- ".join(mem_list) + "\n", - ) - if can_answer: - logger.info( - f"Stage {current_stage_id}: determined answer can be provided, creating enhanced memories; reason: {reason}", - ) - - enhanced_memories = self.get_final_memories( - user_id=user_id, top_k=top_k, mem_list=mem_list - ) - return enhanced_memories - else: - previous_retrieval_phrases.extend(retrieval_phrases) - logger.info( - f"Start complementary retrieval for Stage {current_stage_id}; " - f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " - f"can_answer: {can_answer}; reason: {reason}" - ) - logger.info( - "Stage %d - Found %d new retrieval phrases", - current_stage_id, - len(retrieval_phrases), - ) - # Search for additional memories based on retrieval phrases - additional_retrieved_memories = [] - for phrase in retrieval_phrases: - _retrieved_memories = self.retrieve( - query=phrase, - user_name=user_name, - top_k=self.stage_retrieve_top, - mode=SearchMode.FAST, - memory_type=memory_type, - search_filter=search_filter, - info=info, - ) - logger.info( - "Found %d additional memories for phrase: '%s'", - len(_retrieved_memories), - phrase[:30] + "..." if len(phrase) > 30 else phrase, - ) - additional_retrieved_memories.extend(_retrieved_memories) - retrieved_memories_from_deep_search.extend(additional_retrieved_memories) - merged_memories = self.post_retrieve( - retrieved_results=retrieved_memories + additional_retrieved_memories, - top_k=top_k * 2, - user_name=user_name, - info=info, - ) - - _mem_list, _ = self.tree_memories_to_text_memories(memories=merged_memories) - mem_list = _mem_list - mem_list = list(set(mem_list)) - logger.info( - "After stage %d, total memories in list: %d", - current_stage_id, - len(mem_list), - ) - - # enhance memories - mem_list = self.recreate_enhancement( - query=query, text_memories=mem_list, retries=self.max_retry_times - ) - logger.info("After summarization, memory list contains %d items", len(mem_list)) - - except Exception as e: - logger.error("Error in stage %d: %s", current_stage_id, str(e), exc_info=True) - # Continue to next stage instead of failing completely - continue - logger.error("Deep search failed, returning original memories") - return memories - - def deep_search_backup( - self, - query: str, - top_k: int, - info=None, - memory_type="All", - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ): - previous_retrieval_phrases = [query] - retrieved_memories = self.retrieve( - query=query, - user_name=user_name, - top_k=top_k, - mode=SearchMode.FAST, - memory_type=memory_type, - search_filter=search_filter, - info=info, - ) - memories = self.post_retrieve( - retrieved_results=retrieved_memories, - top_k=top_k, - user_name=user_name, - info=info, - ) - if top_k < self.deep_search_top_k_bar or len(memories) == 0: - logger.warning("Requirements not met; returning memories as-is.") - return memories - - user_id = memories[0].metadata.user_id - context = None - - mem_list, _ = self.tree_memories_to_text_memories(memories=memories) - retrieved_memories = copy.deepcopy(retrieved_memories) - retrieved_memories_from_deep_search = [] - for current_stage_id in range(self.thinking_stages + 1): - try: - # at last - if current_stage_id == self.thinking_stages: - # eval to finish - reason, can_answer = self.judge_memories( - query=query, - text_memories="- " + "\n- ".join(mem_list) + "\n", - ) - - logger.info( - f"Final Stage: Stage {current_stage_id}; " - f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " - f"final can_answer: {can_answer}; reason: {reason}" - ) - if len(retrieved_memories_from_deep_search) == 0: - memories = self.post_retrieve( - retrieved_results=retrieved_memories, - top_k=top_k, - user_name=user_name, - info=info, - ) - return memories[:top_k] - else: + if rewritten_flag: enhanced_memories = self.get_final_memories( user_id=user_id, top_k=top_k, mem_list=mem_list ) - return enhanced_memories + else: + enhanced_memories = memories + return enhanced_memories[:top_k] - can_answer, reason, context, retrieval_phrases = self.stage_retrieve( + can_answer, reason, retrieval_phrases = self.stage_retrieve( stage_id=current_stage_id + 1, query=query, previous_retrieval_phrases=previous_retrieval_phrases, - context=context, text_memories="- " + "\n- ".join(mem_list) + "\n", ) if can_answer: logger.info( f"Stage {current_stage_id}: determined answer can be provided, creating enhanced memories; reason: {reason}", ) - if len(retrieved_memories_from_deep_search) == 0: - memories = self.post_retrieve( - retrieved_results=retrieved_memories, - top_k=top_k, - user_name=user_name, - info=info, - ) - return memories[:top_k] - else: + if rewritten_flag: enhanced_memories = self.get_final_memories( user_id=user_id, top_k=top_k, mem_list=mem_list ) - return enhanced_memories + else: + enhanced_memories = memories + return enhanced_memories[:top_k] else: previous_retrieval_phrases.extend(retrieval_phrases) logger.info( @@ -509,32 +332,28 @@ def deep_search_backup( phrase[:30] + "..." if len(phrase) > 30 else phrase, ) additional_retrieved_memories.extend(_retrieved_memories) - retrieved_memories_from_deep_search.extend(additional_retrieved_memories) merged_memories = self.post_retrieve( retrieved_results=retrieved_memories + additional_retrieved_memories, top_k=top_k * 2, user_name=user_name, info=info, ) - + rewritten_flag = True _mem_list, _ = self.tree_memories_to_text_memories(memories=merged_memories) mem_list = _mem_list mem_list = list(set(mem_list)) + mem_list = self.memory_recreate_enhancement( + query=query, + top_k=top_k, + text_memories=mem_list, + retries=self.max_retry_times, + ) logger.info( "After stage %d, total memories in list: %d", current_stage_id, len(mem_list), ) - # Summarize memories - context, mem_list = self.summarize_memories( - query=query, - context=context, - text_memories="- " + "\n- ".join(mem_list) + "\n", - top_k=top_k, - ) - logger.info("After summarization, memory list contains %d items", len(mem_list)) - except Exception as e: logger.error("Error in stage %d: %s", current_stage_id, str(e), exc_info=True) # Continue to next stage instead of failing completely diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index cc577f1bd..b5bd34417 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -57,6 +57,10 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: This is basically your current handle_add_memories logic, but scoped to a single cube_id. """ + sync_mode = add_req.async_mode or self._get_sync_mode() + self.logger.info( + f"[DIAGNOSTIC] single_cube.add_memories called for cube_id: {self.cube_id}. sync_mode: {sync_mode}. Request: {add_req.model_dump_json(indent=2)}" + ) user_context = UserContext( user_id=add_req.user_id, mem_cube_id=self.cube_id, @@ -134,6 +138,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: ) self.logger.info(f"Search memories result: {memories_result}") + self.logger.info(f"Search {len(memories_result)} memories.") return memories_result def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]: @@ -195,7 +200,7 @@ def _search_text( user_context: UserContext, search_mode: str, ) -> list[dict[str, Any]]: - """G + """ Search text memories based on mode. Args: @@ -322,7 +327,7 @@ def _fine_search( ) missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( query=search_req.query, - memories=raw_memories, + memories=[mem.memory for mem in enhanced_memories], ) retrieval_size = len(raw_memories) - len(enhanced_memories) logger.info(f"Retrieval size: {retrieval_size}") @@ -370,8 +375,8 @@ def _search_pref( """ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - print(f"search_req.filter for preference memory: {search_req.filter}") - print(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") + logger.info(f"search_req.filter for preference memory: {search_req.filter}") + logger.info(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") try: results = self.naive_mem_cube.pref_mem.search( query=search_req.query, @@ -582,7 +587,7 @@ def _process_pref_mem( return [ { - "memory": memory.memory, + "memory": memory.metadata.preference, "memory_id": memory_id, "memory_type": memory.metadata.preference_type, } diff --git a/src/memos/templates/advanced_search_prompts.py b/src/memos/templates/advanced_search_prompts.py index 13e80a79a..baf2f7536 100644 --- a/src/memos/templates/advanced_search_prompts.py +++ b/src/memos/templates/advanced_search_prompts.py @@ -1,54 +1,4 @@ -MEMORY_SUMMARY_PROMPT = """ -# Memory Summary and Context Assembly - -## Role -You are a precise context assembler. Given a user query and a set of retrieved memories (each indexed), your task is to synthesize a factual, concise, and coherent context using only the information explicitly present in the memories. - -## Instructions - -### Core Principles -- Use ONLY facts from the provided memories. Do not invent, infer, guess, or hallucinate. -- Resolve all pronouns (e.g., "he", "it", "they") and vague terms (e.g., "this", "that", "some people") to explicit entities using memory content. -- Merge overlapping or redundant facts. Preserve temporal, spatial, and relational details. -- Each fact must be atomic, unambiguous, and verifiable. -- Preserve all key details: who, what, when, where, why — if present in memory. -- Created a summarized facts for answering query at the first item, and separate logically coherent separate memories. -- Begin the with a single, aggregated summary that directly answers the query using the most relevant facts. -- The total number of facts in must not exceed {top_k}. -- If additional context is relevant, try to weave it together logically—or chronologically—based on how the pieces connect. -- **Must preserve the full timeline of all memories**: if multiple events or states are mentioned with temporal markers (e.g., dates, sequences, phases), their chronological order must be retained in both and . - -### Processing Logic -- Aggregate logically connected memories (e.g., events involving the same person, cause-effect chains, repeated entities). -- Exclude any memory that does not directly support answering the query. -- Prioritize specificity: e.g., "Travis Tang moved to Singapore in 2021" > "He relocated abroad." - -## Input -- Query: {query} -- Current context: -{context} -- Current Memories: -{memories} - -## Output Format (STRICT TAG-BASED) -Respond ONLY with the following XML-style tags. Do NOT include any other text, explanations, or formatting. - - -A single, compact, fluent paragraph synthesizing the above facts into a coherent narrative directly relevant to the query. Use resolved entities and logical flow. No bullet points. No markdown. No commentary. - - -- Aggregated summary -- Fact 1 -- Fact 2 - - -Answer: -""" - -# Stage 1: determine answerability; if not answerable, produce concrete retrieval phrases for missing info STAGE1_EXPAND_RETRIEVE_PROMPT = """ -# Stage 1 — Answerability and Missing Retrieval Phrases - ## Goal Determine whether the current memories can answer the query using concrete, specific facts. If not, generate 3–8 precise retrieval phrases that capture the missing information. @@ -76,9 +26,6 @@ true or false - -summary of current memories - Brief, one-sentence explanation for why the query is or isn't answerable with current memories. @@ -94,27 +41,24 @@ # Stage 2: if Stage 1 phrases still fail, rewrite the retrieval query and phrases to maximize recall STAGE2_EXPAND_RETRIEVE_PROMPT = """ -# Stage 2 — Rewrite Retrieval Query and Phrases to Improve Recall - ## Goal -If Stage 1's retrieval phrases failed to yield an answer, rewrite the original query and expand the phrase list to maximize recall of relevant memories. Use canonicalization, synonym expansion, and constraint enrichment. +Rewrite the original query and generate an improved list of retrieval phrases to maximize recall of relevant memories. Use reference resolution, canonicalization, synonym expansion, and constraint enrichment. ## Rewrite Strategy -- Canonicalize entities: use full names, official titles, or known aliases. -- Normalize time formats: e.g., "last year" → "2024", "in 2021" → "2021". -- Add discriminative tokens: entity + attribute + time + location where applicable. -- Split complex queries into focused sub-queries targeting distinct facets. -- Never include pronouns, vague terms, or subjective language. +- **Resolve ambiguous references**: Replace pronouns (e.g., “she”, “they”, “it”) and vague terms (e.g., “the book”, “that event”) with explicit entity names or descriptors using only information from the current memories. +- **Canonicalize entities**: Use full names (e.g., “Melanie Smith”), known roles (e.g., “Caroline’s mentor”), or unambiguous identifiers when available. +- **Normalize temporal expressions**: Convert relative time references (e.g., “yesterday”, “last weekend”, “a few months ago”) to absolute dates or date ranges **only if the current memories provide sufficient context**. +- **Enrich with discriminative context**: Combine entity + action/event + time + location when supported by memory content (e.g., “Melanie pottery class July 2023”). +- **Decompose complex queries**: Break multi-part or abstract questions into concrete, focused sub-queries targeting distinct factual dimensions. +- **Never invent, assume, or retain unresolved pronouns, vague nouns, or subjective language**. ## Input - Query: {query} - Previous retrieval phrases: {previous_retrieval_phrases} -- Context: {context} - Current Memories: {memories} - ## Output (STRICT TAG-BASED FORMAT) Respond ONLY with the following structure. Do not add any other text, explanation, or formatting. @@ -122,13 +66,10 @@ true or false -Brief explanation (1–2 sentences) of how this rewrite improves recall over Stage 1 phrases. +Brief explanation (1–2 sentences) of how this rewrite improves recall—e.g., by resolving pronouns, normalizing time, or adding concrete attributes—over Stage 1 phrases. - -summary of current memories - -- new phrase 1 (Rewritten version of the original query. More precise, canonical, and retrieval-optimized.) +- new phrase 1 (Rewritten, canonical, fully grounded in memory content) - new phrase 2 ... @@ -139,22 +80,19 @@ # Stage 3: generate grounded hypotheses to guide retrieval when still not answerable STAGE3_EXPAND_RETRIEVE_PROMPT = """ -# Stage 3 — Hypothesis Generation for Retrieval - ## Goal -When the query remains unanswerable, generate grounded, plausible hypotheses based ONLY on provided context and memories. Each hypothesis must imply a concrete retrieval target and validation criteria. +As the query remains unanswerable, generate grounded, plausible hypotheses based ONLY on the provided memories. Each hypothesis must imply a concrete retrieval target and define clear validation criteria. ## Rules -- Base hypotheses strictly on facts from the memories. No new entities or assumptions. -- Frame each hypothesis as a testable statement: "If [X] is true, then the query is answered." -- For each hypothesis, define 1–3 specific evidence requirements that would confirm it. -- Do NOT guess. Do NOT invent. Only extrapolate from existing facts. +- Base hypotheses strictly on facts from the memories. Do NOT introduce new entities, events, or assumptions. +- Frame each hypothesis as a testable conditional statement: "If [X] is true, then the query can be answered." +- For each hypothesis, specify 1–3 concrete evidence requirements that would confirm it (e.g., a specific date, name, or event description). +- Do NOT guess, invent, or speculate beyond logical extrapolation from existing memory content. ## Input - Query: {query} - Previous retrieval phrases: {previous_retrieval_phrases} -- Context: {context} - Memories: {memories} @@ -164,24 +102,20 @@ true or false - -summary of current memories - -- statement: - retrieval_query: +- statement: + retrieval_query: validation_criteria: - - - - -- statement: + - + - +- statement: retrieval_query: validation_criteria: - - + - - -- hypothesis retrieval query 1 (searchable query derived from the hypothesis) -- hypothesis retrieval query 2: +- +- ... @@ -229,33 +163,36 @@ """ MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ -You are a knowledgeable and precise AI assistant. +You are a precise and detail-oriented AI assistant specialized in temporal memory reconstruction, reference resolution, and relevance-aware memory fusion. # GOAL -Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. - -# RULES & THINKING STEPS -1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. -2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). -3. Resolve all ambiguities using only memory content: - - Pronouns → full name: “she” → “Melanie” - - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” - - “the user” → identity from context (e.g., “Melanie” if travel/running memories) -4. Never invent, assume, or extrapolate. -5. Each output line must be a standalone, clear, factual statement. -6. Output format: one line per fact, starting with "- ", no extra text. +Transform the original memories into a clean, unambiguous, and consolidated set of factual statements that: +1. **Resolve all vague or relative references** (e.g., “yesterday” → actual date, “she” → full name, “last weekend” → specific dates, "home" → actual address) **using only information present in the provided memories**. +2. **Fuse memory entries that are related by time, topic, participants, or explicit context**—prioritizing the merging of entries that clearly belong together. +3. **Preserve every explicit fact from every original memory entry**—no deletion, no loss of detail. Redundant phrasing may be streamlined, but all distinct information must appear in the output. +4. **Return at most {top_k} fused and disambiguated memory segments in , ordered by relevance to the user query** (most relevant first). + +# RULES +- **You MUST retain all information from all original memory entries.** Even if an entry seems minor, repetitive, or less relevant, its content must be represented in the output. +- **Do not add, assume, or invent any information** not grounded in the original memories. +- **Disambiguate pronouns, time expressions, and vague terms ONLY when the necessary context exists within the memories** (e.g., if “yesterday” appears in a message dated July 3, resolve it to July 2). +- **If you cannot resolve a vague reference (e.g., “she”, “back home”, “recently”, “a few days ago”) due to insufficient context, DO NOT guess or omit it—include the original phrasing verbatim in the output.** +- **Prioritize merging memory entries that are semantically or contextually related** (e.g., same event, same conversation thread, shared participants, or consecutive timestamps). Grouping should reflect natural coherence, not just proximity. +- **The total number of bullets in must not exceed {top_k}.** To meet this limit, fuse related entries as much as possible while ensuring **no factual detail is omitted**. +- **Never sacrifice factual completeness for brevity or conciseness.** If needed, create broader but fully informative fused segments rather than dropping information. +- **Each bullet in must be a self-contained, fluent sentence or clause** that includes all resolved details from the original entries it represents. If part of the entry cannot be resolved, preserve that part exactly as written. +- **Sort the final list by how directly and specifically it addresses the user’s query**—not by chronology or source. # OUTPUT FORMAT (STRICT) -Return ONLY the following block, with **one enhanced memory per line**. -Each line MUST start with "- " (dash + space). +Return ONLY the following structure: -Wrap the final output inside: -- enhanced memory 1 -- enhanced memory 2 -... +- [Fully resolved, fused memory segment most relevant to the query — containing all facts from the original entries it covers; unresolved parts kept verbatim] +- [Next most relevant resolved and fused segment — again, with no factual loss] +- [...] + ## User Query {query} @@ -265,9 +202,7 @@ Final Output: """ - PROMPT_MAPPING = { - "memory_summary": MEMORY_SUMMARY_PROMPT, "memory_judgement": MEMORY_JUDGMENT_PROMPT, "stage1_expand_retrieve": STAGE1_EXPAND_RETRIEVE_PROMPT, "stage2_expand_retrieve": STAGE2_EXPAND_RETRIEVE_PROMPT, diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index 7f7415e79..acbae2281 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -393,6 +393,79 @@ MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ You are a knowledgeable and precise AI assistant. +# GOAL +Transform raw memories into clean, complete, and fully disambiguated statements that preserve original meaning and explicit details. + +# RULES & THINKING STEPS +1. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”). +2. Resolve all ambiguities using only memory content. If disambiguation cannot be performed using only the provided memories, retain the original phrasing exactly as written. Never guess, infer, or fabricate missing information: + - Pronouns → full name (e.g., “she” → “Caroline”) + - Relative time expressions → concrete dates or full context (e.g., “last night” → “on the evening of November 25, 2025”) + - Vague references → specific, grounded details (e.g., “the event” → “the LGBTQ+ art workshop in Malmö”) + - Incomplete descriptions → full version from memory (e.g., “the activity” → “the abstract painting session at the community center”) +3. Merge memories that are largely repetitive in content but contain complementary or distinct details. Combine them into a single, cohesive statement that preserves all unique information from each original memory. Do not merge memories that describe different events, even if they share a theme. +4. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query_history} + +## Original Memories +{memories} + +Final Output: +""" + +MEMORY_RECREATE_ENHANCEMENT_PROMPT_BACKUP_1 = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform raw memories into clean, complete, and fully disambiguated statements that preserve original meaning and explicit details. + +# RULES & THINKING STEPS +1. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”). +2. Resolve all ambiguities using only memory content. If disambiguation cannot be performed using only the provided memories, retain the original phrasing exactly as written. Never guess, infer, or fabricate missing information: + - Pronouns → full name (e.g., “she” → “Caroline”) + - Relative time expressions → concrete dates or full context (e.g., “last night” → “on the evening of November 25, 2025”) + - Vague references → specific, grounded details (e.g., “the event” → “the LGBTQ+ art workshop in Malmö”) + - Incomplete descriptions → full version from memory (e.g., “the activity” → “the abstract painting session at the community center”) +3. Merge memories that are largely repetitive in content but contain complementary or distinct details. Combine them into a single, cohesive statement that preserves all unique information from each original memory. Do not merge memories that describe different events, even if they share a theme. +4. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query_history} + +## Original Memories +{memories} + +Final Output: +""" + + +MEMORY_RECREATE_ENHANCEMENT_PROMPT_BACKUP_2 = """ +You are a knowledgeable and precise AI assistant. + # GOAL Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. @@ -427,7 +500,6 @@ Final Output: """ -# Rewrite version: return enhanced memories with original IDs MEMORY_REWRITE_ENHANCEMENT_PROMPT = """ You are a knowledgeable and precise AI assistant. @@ -470,10 +542,43 @@ Final Output: """ + # One-sentence prompt for recalling missing information to answer the query (English) ENLARGE_RECALL_PROMPT_ONE_SENTENCE = """ You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query. +# GOAL +Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them. + +# RULES +- Analyze the user's query to understand what information is being asked. +- Review the available memories to see what information is already present. +- Identify the gap between the user's query and the available memories. +- Generate a single, concise hint that prompts the user to provide the missing information. +- The hint should be a direct question or a statement that clearly indicates what is needed. + +# OUTPUT FORMAT +A JSON object with: + +trigger_retrieval: true if information is missing, false if sufficient. +hint: A clear, specific prompt to retrieve the missing information (or an empty string if trigger_retrieval is false): +{{ + "trigger_recall": , + "hint": a paraphrase to retrieve support memories +}} + +## User Query +{query} + +## Available Memories +{memories_inline} + +Final Output: +""" + +ENLARGE_RECALL_PROMPT_ONE_SENTENCE_BACKUP = """ +You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query. + # GOAL Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them. @@ -505,7 +610,6 @@ Final Output: """ - PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index f796e682a..3706b49da 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -36,7 +36,6 @@ "MessagesType", "Permission", "PermissionDict", - "RawMessageList", "SearchMode", "UserContext", "UserID", @@ -50,7 +49,7 @@ # Message structure class MessageDict(TypedDict, total=False): - """Typed dictionary for chat message dictionaries, will (Deprecate), use ChatCompletionMessageParam instead.""" + """Typed dictionary for chat message dictionaries.""" role: MessageRole content: str @@ -102,11 +101,10 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" DEEP_SEARCH = "deep_search" - AGENTIC_SEARCH = "agentic_search" # algorithm strategies -DEFAULT_FINE_STRATEGY = FineStrategy.DEEP_SEARCH +DEFAULT_FINE_STRATEGY = FineStrategy.RECREATE FINE_STRATEGY = DEFAULT_FINE_STRATEGY # Read fine strategy from environment variable `FINE_STRATEGY`. diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 2181961d2..42aeec29b 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -229,7 +229,7 @@ def search( List of search results with distance scores and payloads. """ # Convert filter to Milvus expression - print(f"filter for milvus: {filter}") + logger.info(f"filter for milvus: {filter}") expr = self._dict_to_expr(filter) if filter else "" search_func_map = { diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index fe889559c..ccc4d77a1 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -157,7 +157,10 @@ def test_dispatch_serial(self): """Test dispatching messages in serial mode.""" # Create a new dispatcher with parallel dispatch disabled serial_dispatcher = SchedulerDispatcher( - max_workers=2, enable_parallel_dispatch=False, metrics=MagicMock() + max_workers=2, + memos_message_queue=self.dispatcher.memos_message_queue, + enable_parallel_dispatch=False, + metrics=MagicMock(), ) # Create fresh mock handlers for this test