diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index 7d8cf2897..09aec4cba 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -13,15 +13,15 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - ANSWER_LABEL, - MEM_ARCHIVE_LABEL, - MEM_ORGANIZE_LABEL, - MEM_UPDATE_LABEL, - QUERY_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_ARCHIVE_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) from memos.mem_scheduler.utils.filter_utils import transform_name_to_key @@ -118,24 +118,24 @@ def _first_content() -> str: return memcube_content[0].get("content", "") or content return content - if label in ("addMessage", QUERY_LABEL, ANSWER_LABEL): + if label in ("addMessage", QUERY_TASK_LABEL, ANSWER_TASK_LABEL): target_cube = cube_display.replace("MemCube", "") title = _format_title(item.timestamp, f"addMessages to {target_cube} MemCube") return title, _truncate_with_rules(_first_content()) - if label in ("addMemory", ADD_LABEL): + if label in ("addMemory", ADD_TASK_LABEL): title = _format_title(item.timestamp, f"{cube_display} added {memory_len} memories") return title, _truncate_with_rules(_first_content()) - if label in ("updateMemory", MEM_UPDATE_LABEL): + if label in ("updateMemory", MEM_UPDATE_TASK_LABEL): title = _format_title(item.timestamp, f"{cube_display} updated {memory_len} memories") return title, _truncate_with_rules(_first_content()) - if label in ("archiveMemory", MEM_ARCHIVE_LABEL): + if label in ("archiveMemory", MEM_ARCHIVE_TASK_LABEL): title = _format_title(item.timestamp, f"{cube_display} archived {memory_len} memories") return title, _truncate_with_rules(_first_content()) - if label in ("mergeMemory", MEM_ORGANIZE_LABEL): + if label in ("mergeMemory", MEM_ORGANIZE_TASK_LABEL): title = _format_title(item.timestamp, f"{cube_display} merged {memory_len} memories") merged = [c for c in memcube_content if c.get("type") == "merged"] post = [c for c in memcube_content if c.get("type") == "postMerge"] diff --git a/examples/mem_scheduler/redis_example.py b/examples/mem_scheduler/redis_example.py index 2c3801539..be6f20bed 100644 --- a/examples/mem_scheduler/redis_example.py +++ b/examples/mem_scheduler/redis_example.py @@ -9,8 +9,8 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.schemas.general_schemas import QUERY_LABEL from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import QUERY_TASK_LABEL if TYPE_CHECKING: @@ -55,7 +55,7 @@ def service_run(): message_item = ScheduleMessageItem( user_id=user_id, mem_cube_id="mem_cube_2", - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, mem_cube=mem_cube, content=query, timestamp=datetime.now(), diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index 4aedac711..4ffa6557f 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -14,7 +14,7 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.schemas.general_schemas import ( +from memos.mem_scheduler.schemas.task_schemas import ( NOT_APPLICABLE_TYPE, ) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 3cfa49d3d..498768c1c 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -30,11 +30,11 @@ prepare_reference_data, process_streaming_references_complete, ) -from memos.mem_scheduler.schemas.general_schemas import ( - ANSWER_LABEL, - QUERY_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + QUERY_TASK_LABEL, +) from memos.templates.mos_prompts import ( FURTHER_SUGGESTION_PROMPT, get_memos_prompt, @@ -244,7 +244,7 @@ def generate_chat_response() -> Generator[str, None, None]: user_id=chat_req.user_id, mem_cube_id=scheduler_cube_id, query=chat_req.query, - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, ) # Extract memories from search results memories_list = [] @@ -423,7 +423,7 @@ def generate_chat_response() -> Generator[str, None, None]: user_id=chat_req.user_id, mem_cube_id=scheduler_cube_id, query=chat_req.query, - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, ) # Extract memories from search results memories_list = [] @@ -1034,7 +1034,7 @@ async def _post_chat_processing( # Send answer to scheduler self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL + user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL ) self.logger.info(f"Post-chat processing completed for user {user_id}") diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 9b9bee701..f5e1aaba0 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from memos.configs.base import BaseConfig from memos.configs.chunker import ChunkerConfigFactory @@ -44,6 +44,9 @@ def parse_datetime(cls, value): class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" + # Allow passing additional fields without raising validation errors + model_config = ConfigDict(extra="allow", strict=True) + class MultiModalStructMemReaderConfig(BaseMemReaderConfig): """MultiModalStruct MemReader configuration class.""" diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 75d0976a1..b411ecb77 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -15,14 +15,14 @@ from memos.mem_reader.factory import MemReaderFactory from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - ANSWER_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, - QUERY_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_READ_TASK_LABEL, + PREF_ADD_TASK_LABEL, + QUERY_TASK_LABEL, +) from memos.mem_user.user_manager import UserManager, UserRole from memos.memories.activation.item import ActivationMemoryItem from memos.memories.parametric.item import ParametricMemoryItem @@ -283,7 +283,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, content=query, timestamp=datetime.utcnow(), ) @@ -343,7 +343,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=ANSWER_LABEL, + label=ANSWER_TASK_LABEL, content=response, timestamp=datetime.utcnow(), ) @@ -771,7 +771,7 @@ def process_textual_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=MEM_READ_LABEL, + label=MEM_READ_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), task_id=task_id, @@ -783,7 +783,7 @@ def process_textual_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=ADD_LABEL, + label=ADD_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), task_id=task_id, @@ -824,7 +824,7 @@ def process_preference_memory(): user_id=target_user_id, session_id=target_session_id, mem_cube_id=mem_cube_id, - label=PREF_ADD_LABEL, + label=PREF_ADD_TASK_LABEL, content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) @@ -878,7 +878,7 @@ def process_preference_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=MEM_READ_LABEL, + label=MEM_READ_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) @@ -889,7 +889,7 @@ def process_preference_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=ADD_LABEL, + label=ADD_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) @@ -920,7 +920,7 @@ def process_preference_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=ADD_LABEL, + label=ADD_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 969d42c6e..2bec39741 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -29,11 +29,11 @@ prepare_reference_data, process_streaming_references_complete, ) -from memos.mem_scheduler.schemas.general_schemas import ( - ANSWER_LABEL, - QUERY_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + QUERY_TASK_LABEL, +) from memos.mem_user.persistent_factory import PersistentUserManagerFactory from memos.mem_user.user_manager import UserRole from memos.memories.textual.item import ( @@ -710,7 +710,7 @@ async def _post_chat_processing( logger.warning(f"Failed to send chat notification (async): {e}") self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL + user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL ) self.add( @@ -1151,7 +1151,7 @@ def chat_with_references( f"time chat: search text_mem time user_id: {user_id} time is: {search_time_end - time_start}" ) self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_LABEL + user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_TASK_LABEL ) if memories_result: memories_list = memories_result[0]["memories"] diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index df504ee75..dd858c86a 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -4,11 +4,13 @@ from memos.log import get_logger from memos.mem_os.main import MOS from memos.mem_scheduler.schemas.general_schemas import ( - ANSWER_LABEL, MONITOR_WORKING_MEMORY_TYPE, - QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + QUERY_TASK_LABEL, +) logger = get_logger(__name__) @@ -427,7 +429,7 @@ def chat(self, query: str, user_id: str | None = None) -> str: message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, content=query, timestamp=datetime.now(), ) @@ -517,7 +519,7 @@ def chat(self, query: str, user_id: str | None = None) -> str: message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - label=ANSWER_LABEL, + label=ANSWER_TASK_LABEL, content=response, timestamp=datetime.now(), ) diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py index 6638fa2f5..ae5ae5d47 100644 --- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py +++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py @@ -7,10 +7,10 @@ from memos.log import get_logger from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.schemas.general_schemas import ( +from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem +from memos.mem_scheduler.schemas.task_schemas import ( DEFAULT_MAX_QUERY_KEY_WORDS, ) -from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem if TYPE_CHECKING: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 62e1d0242..610999697 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -43,6 +43,7 @@ ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils import metrics from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -121,10 +122,12 @@ def __init__(self, config: BaseSchedulerConfig): self.max_internal_message_queue_size = self.config.get( "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) + self.orchestrator = SchedulerOrchestrator() self.memos_message_queue = ScheduleTaskQueue( use_redis_queue=self.use_redis_queue, maxsize=self.max_internal_message_queue_size, disabled_handlers=self.disabled_handlers, + orchestrator=self.orchestrator, ) self.searcher: Searcher | None = None self.retriever: SchedulerRetriever | None = None @@ -143,6 +146,7 @@ def __init__(self, config: BaseSchedulerConfig): status_tracker=self.status_tracker, metrics=self.metrics, submit_web_logs=self._submit_web_logs, + orchestrator=self.orchestrator, ) # Task schedule monitor: initialize with underlying queue implementation self.get_status_parallel = self.config.get("get_status_parallel", True) @@ -697,22 +701,22 @@ def get_web_log_messages(self) -> list[dict]: break def _map_label(label: str) -> str: - from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - ANSWER_LABEL, - MEM_ARCHIVE_LABEL, - MEM_ORGANIZE_LABEL, - MEM_UPDATE_LABEL, - QUERY_LABEL, + from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + MEM_ARCHIVE_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, ) mapping = { - QUERY_LABEL: "addMessage", - ANSWER_LABEL: "addMessage", - ADD_LABEL: "addMemory", - MEM_UPDATE_LABEL: "updateMemory", - MEM_ORGANIZE_LABEL: "mergeMemory", - MEM_ARCHIVE_LABEL: "archiveMemory", + QUERY_TASK_LABEL: "addMessage", + ANSWER_TASK_LABEL: "addMessage", + ADD_TASK_LABEL: "addMemory", + MEM_UPDATE_TASK_LABEL: "updateMemory", + MEM_ORGANIZE_TASK_LABEL: "mergeMemory", + MEM_ARCHIVE_TASK_LABEL: "archiveMemory", } return mapping.get(label, label) @@ -785,7 +789,7 @@ def _message_consumer(self) -> None: if enqueue_epoch is not None: queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 - # Avoid pydantic attribute enforcement + # Avoid pydantic field enforcement by using object.__setattr__ object.__setattr__(msg, "_dequeue_ts", now) emit_monitor_event( "dequeue", diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 9b1153c87..fa7bb1d15 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -7,19 +7,21 @@ from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( ACTIVATION_MEMORY_TYPE, - ADD_LABEL, - MEM_ARCHIVE_LABEL, - MEM_UPDATE_LABEL, NOT_INITIALIZED, PARAMETER_MEMORY_TYPE, TEXT_MEMORY_TYPE, - USER_INPUT_TYPE, WORKING_MEMORY_TYPE, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, ScheduleMessageItem, ) +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + MEM_ARCHIVE_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + USER_INPUT_TYPE, +) from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) @@ -271,7 +273,7 @@ def log_adding_memory( """Deprecated: legacy text log. Use create_event_log with structured fields instead.""" log_message = self.create_autofilled_log_item( log_content=memory, - label=ADD_LABEL, + label=ADD_TASK_LABEL, from_memory_type=USER_INPUT_TYPE, to_memory_type=memory_type, user_id=user_id, @@ -297,7 +299,7 @@ def log_updating_memory( """Deprecated: legacy text log. Use create_event_log with structured fields instead.""" log_message = self.create_autofilled_log_item( log_content=memory, - label=MEM_UPDATE_LABEL, + label=MEM_UPDATE_TASK_LABEL, from_memory_type=memory_type, to_memory_type=memory_type, user_id=user_id, @@ -319,7 +321,7 @@ def log_archiving_memory( """Deprecated: legacy text log. Use create_event_log with structured fields instead.""" log_message = self.create_autofilled_log_item( log_content=memory, - label=MEM_ARCHIVE_LABEL, + label=MEM_ARCHIVE_TASK_LABEL, from_memory_type=memory_type, to_memory_type=memory_type, user_id=user_id, diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 86718ec82..b3ad8f085 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -9,21 +9,22 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - ANSWER_LABEL, +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, LONG_TERM_MEMORY_TYPE, - MEM_FEEDBACK_LABEL, - MEM_ORGANIZE_LABEL, - MEM_READ_LABEL, + MEM_FEEDBACK_TASK_LABEL, + MEM_ORGANIZE_TASK_LABEL, + MEM_READ_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, NOT_APPLICABLE_TYPE, - PREF_ADD_LABEL, - QUERY_LABEL, + PREF_ADD_TASK_LABEL, + QUERY_TASK_LABEL, USER_INPUT_TYPE, ) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import ( is_all_chinese, is_all_english, @@ -51,13 +52,14 @@ def __init__(self, config: GeneralSchedulerConfig): # register handlers handlers = { - QUERY_LABEL: self._query_message_consumer, - ANSWER_LABEL: self._answer_message_consumer, - ADD_LABEL: self._add_message_consumer, - MEM_READ_LABEL: self._mem_read_message_consumer, - MEM_ORGANIZE_LABEL: self._mem_reorganize_message_consumer, - PREF_ADD_LABEL: self._pref_add_message_consumer, - MEM_FEEDBACK_LABEL: self._mem_feedback_message_consumer, + QUERY_TASK_LABEL: self._query_message_consumer, + ANSWER_TASK_LABEL: self._answer_message_consumer, + MEM_UPDATE_TASK_LABEL: self._memory_update_consumer, + ADD_TASK_LABEL: self._add_message_consumer, + MEM_READ_TASK_LABEL: self._mem_read_message_consumer, + MEM_ORGANIZE_TASK_LABEL: self._mem_reorganize_message_consumer, + PREF_ADD_TASK_LABEL: self._pref_add_message_consumer, + MEM_FEEDBACK_TASK_LABEL: self._mem_feedback_message_consumer, } self.dispatcher.register_handlers(handlers) @@ -124,7 +126,7 @@ def long_memory_update_process( top_k=self.top_k, ) logger.info( - f"Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} new candidate memories for user_id={user_id}" + f"[long_memory_update_process] Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} new candidate memories for user_id={user_id}" ) # rerank @@ -135,30 +137,40 @@ def long_memory_update_process( original_memory=cur_working_memory, new_memory=new_candidates, ) - logger.info( - f"Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}" + logger.debug( + f"[long_memory_update_process] Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}" + ) + + old_memory_texts = [mem.memory for mem in cur_working_memory] + new_memory_texts = [mem.memory for mem in new_order_working_memory] + + logger.debug( + f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': " + f"Scheduler replaced working memory based on query history {queries}. " + f"Old working memory ({len(old_memory_texts)} items): {old_memory_texts}. " + f"New working memory ({len(new_memory_texts)} items): {new_memory_texts}." ) # update activation memories - logger.info( + logger.debug( f"Activation memory update {'enabled' if self.enable_activation_memory else 'disabled'} " f"(interval: {self.monitor.act_mem_update_interval}s)" ) if self.enable_activation_memory: self.update_activation_memory_periodically( interval_seconds=self.monitor.act_mem_update_interval, - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=self.current_mem_cube, ) def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {ADD_TASK_LABEL} handler.") # Process the query in a session turn grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - self.validate_schedule_messages(messages=messages, label=ADD_LABEL) + self.validate_schedule_messages(messages=messages, label=ADD_TASK_LABEL) try: for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: @@ -192,6 +204,23 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: except Exception as e: logger.error(f"Error: {e}", exc_info=True) + def _memory_update_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {MEM_UPDATE_TASK_LABEL} handler.") + + grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) + + self.validate_schedule_messages(messages=messages, label=MEM_UPDATE_TASK_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + batch = grouped_messages[user_id][mem_cube_id] + if not batch: + continue + # Process the whole batch once; no need to iterate per message + self.long_memory_update_process( + user_id=user_id, mem_cube_id=mem_cube_id, messages=batch + ) + def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ Process and handle query trigger messages from the queue. @@ -199,19 +228,21 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: Args: messages: List of query messages to process """ - logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {QUERY_TASK_LABEL} handler.") grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) + self.validate_schedule_messages(messages=messages, label=QUERY_TASK_LABEL) + mem_update_messages = [] for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: batch = grouped_messages[user_id][mem_cube_id] if not batch: continue - try: - for msg in batch: + + for msg in batch: + try: event = self.create_event_log( label="addMessage", from_memory_type=USER_INPUT_TYPE, @@ -232,11 +263,22 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: ) event.task_id = msg.task_id self._submit_web_logs([event]) - except Exception: - logger.exception("Failed to record addMessage log for query") - self.long_memory_update_process( - user_id=user_id, mem_cube_id=mem_cube_id, messages=batch - ) + except Exception: + logger.exception("Failed to record addMessage log for query") + # Re-submit the message with label changed to mem_update + update_msg = ScheduleMessageItem( + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + label=MEM_UPDATE_TASK_LABEL, + content=msg.content, + session_id=msg.session_id, + user_name=msg.user_name, + info=msg.info, + task_id=msg.task_id, + ) + mem_update_messages.append(update_msg) + + self.submit_messages(messages=mem_update_messages) def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ @@ -245,10 +287,10 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: Args: messages: List of answer messages to process """ - logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {ANSWER_TASK_LABEL} handler.") grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL) + self.validate_schedule_messages(messages=messages, label=ANSWER_TASK_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: @@ -341,9 +383,8 @@ def log_add_messages(self, msg: ScheduleMessageItem): except Exception: missing_ids.append(memory_id) - logger.warning( - f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation.", - stack_info=True, + logger.debug( + f"This MemoryItem {memory_id} has already been deleted or an error occurred during preparation." ) if missing_ids: @@ -521,42 +562,6 @@ def send_add_log_messages_to_cloud_env( event.task_id = msg.task_id self._submit_web_logs([event]) - def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") - # Process the query in a session turn - grouped_messages = group_messages_by_user_and_mem_cube(messages=messages) - - self.validate_schedule_messages(messages=messages, label=ADD_LABEL) - try: - for user_id in grouped_messages: - for mem_cube_id in grouped_messages[user_id]: - batch = grouped_messages[user_id][mem_cube_id] - if not batch: - continue - - # Process each message in the batch - for msg in batch: - prepared_add_items, prepared_update_items_with_original = ( - self.log_add_messages(msg=msg) - ) - # Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default - is_cloud_env = ( - os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") - == "memos-memory-change" - ) - - if is_cloud_env: - self.send_add_log_messages_to_cloud_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - else: - self.send_add_log_messages_to_local_env( - msg, prepared_add_items, prepared_update_items_with_original - ) - - except Exception as e: - logger.error(f"Error: {e}", exc_info=True) - def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: try: if not messages: @@ -723,7 +728,7 @@ def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> Non logger.info( f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}" ) - logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {MEM_READ_TASK_LABEL} handler.") def process_message(message: ScheduleMessageItem): try: @@ -1028,7 +1033,7 @@ def _process_memories_with_reader( self._submit_web_logs([event]) def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_TASK_LABEL} handler.") def process_message(message: ScheduleMessageItem): try: @@ -1260,7 +1265,7 @@ def _process_memories_with_reorganize( ) def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {PREF_ADD_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {PREF_ADD_TASK_LABEL} handler.") def process_message(message: ScheduleMessageItem): try: @@ -1313,54 +1318,6 @@ def process_message(message: ScheduleMessageItem): f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}" ) - # Create and submit log for web display - # Only send logs if RabbitMQ is configured with direct exchange (cloud service scenario) - is_cloud_env = ( - os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") == "memos-memory-change" - ) - if pref_ids and is_cloud_env: - pref_content = [] - pref_meta = [] - for i, pref_mem_item in enumerate(pref_memories): - if i < len(pref_ids): - pref_content.append( - { - "content": pref_mem_item.memory, - "ref_id": pref_ids[i], - } - ) - pref_meta.append( - { - "ref_id": pref_ids[i], - "id": pref_ids[i], - "memory": pref_mem_item.memory, - "memory_type": getattr( - pref_mem_item.metadata, "memory_type", "preference" - ), - } - ) - - event = self.create_event_log( - label="addMemory", - from_memory_type=USER_INPUT_TYPE, - to_memory_type=LONG_TERM_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - memcube_log_content=pref_content, - metadata=pref_meta, - memory_len=len(pref_content), - memcube_name=self._map_memcube_name(mem_cube_id), - ) - event.task_id = message.task_id - self._submit_web_logs([event]) - else: - logger.info( - "Skipping web log for pref_add. pref_ids_count=%s is_cloud_env=%s", - len(pref_ids) if pref_ids else 0, - is_cloud_env, - ) - except Exception as e: logger.error(f"Error processing pref_add message: {e}", exc_info=True) @@ -1397,7 +1354,7 @@ def process_session_turn( return logger.info( - f"Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" + f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" ) cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() @@ -1415,18 +1372,18 @@ def process_session_turn( if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): logger.info( - f"Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}" + f"[process_session_turn] Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}" ) return elif (not intent_result["trigger_retrieval"]) and time_trigger_flag: logger.info( - f"Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}" + f"[process_session_turn] Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}" ) intent_result["trigger_retrieval"] = True intent_result["missing_evidences"] = queries else: logger.info( - f"Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. " + f"[process_session_turn] Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. " f"Missing evidences: {intent_result['missing_evidences']}" ) @@ -1436,7 +1393,7 @@ def process_session_turn( new_candidates = [] for item in missing_evidences: logger.info( - f"Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}" + f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}" ) info = { "user_id": user_id, @@ -1451,7 +1408,7 @@ def process_session_turn( info=info, ) logger.info( - f"Search results for missing evidence '{item}': {[one.memory for one in results]}" + f"[process_session_turn] Search results for missing evidence '{item}': {[one.memory for one in results]}" ) new_candidates.extend(results) return cur_working_memory, new_candidates diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index f99360a86..19816c310 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -11,10 +11,10 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_scheduler.general_modules.api_misc import SchedulerAPIModule from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.schemas.general_schemas import ( - API_MIX_SEARCH_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + API_MIX_SEARCH_TASK_LABEL, +) from memos.mem_scheduler.utils.api_utils import format_textual_memory_item from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube @@ -49,7 +49,7 @@ def __init__(self, config: GeneralSchedulerConfig): ) self.register_handlers( { - API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, + API_MIX_SEARCH_TASK_LABEL: self._api_mix_search_message_consumer, } ) self.searcher = None @@ -83,7 +83,7 @@ def submit_memory_history_async_task( item_id=async_task_id, user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id, - label=API_MIX_SEARCH_LABEL, + label=API_MIX_SEARCH_TASK_LABEL, content=json.dumps(message_content), timestamp=get_utc_now(), ) @@ -259,12 +259,12 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) Args: messages: List of query messages to process """ - logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_TASK_LABEL} handler.") # Process the query in a session turn grouped_messages = group_messages_by_user_and_mem_cube(messages) - self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_LABEL) + self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_TASK_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: @@ -303,7 +303,7 @@ def replace_working_memory( # Apply combined filtering (unrelated + redundant) logger.info( - f"Applying combined unrelated and redundant memory filtering to {len(memories_with_new_order)} memories" + f"[optimized replace_working_memory] Applying combined unrelated and redundant memory filtering to {len(memories_with_new_order)} memories" ) filtered_memories, filtering_success_flag = ( self.retriever.filter_unrelated_and_redundant_memories( @@ -314,20 +314,20 @@ def replace_working_memory( if filtering_success_flag: logger.info( - f"Combined filtering completed successfully. " + f"[optimized replace_working_memory] Combined filtering completed successfully. " f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories" ) memories_with_new_order = filtered_memories else: logger.warning( - "Combined filtering failed - keeping memories as fallback. " + "[optimized replace_working_memory] Combined filtering failed - keeping memories as fallback. " f"Count: {len(memories_with_new_order)}" ) # Update working memory monitors query_keywords = query_db_manager.obj.get_keywords_collections() logger.info( - f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" + f"[optimized replace_working_memory] Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" ) new_working_memory_monitors = self.transform_working_memories_to_monitors( query_keywords=query_keywords, @@ -338,7 +338,9 @@ def replace_working_memory( for one in new_working_memory_monitors: one.sorting_score = 0 - logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors") + logger.info( + f"[optimized replace_working_memory] update {len(new_working_memory_monitors)} working_memory_monitors" + ) self.monitor.update_working_memory_monitors( new_working_memory_monitors=new_working_memory_monitors, user_id=user_id, @@ -356,7 +358,7 @@ def replace_working_memory( new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] logger.info( - f"The working memory has been replaced with {len(memories_with_new_order)} new memories." + f"[optimized replace_working_memory] The working memory has been replaced with {len(memories_with_new_order)} new memories." ) self.log_working_memory_replacement( original_memory=original_memory, diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 30cba81b3..8493c596d 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -6,17 +6,6 @@ FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent -QUERY_LABEL = "query" -ANSWER_LABEL = "answer" -ADD_LABEL = "add" -MEM_READ_LABEL = "mem_read" -MEM_ORGANIZE_LABEL = "mem_organize" -MEM_UPDATE_LABEL = "mem_update" -MEM_ARCHIVE_LABEL = "mem_archive" -API_MIX_SEARCH_LABEL = "api_mix_search" -PREF_ADD_LABEL = "pref_add" -MEM_FEEDBACK_LABEL = "mem_feedback" - TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" TextMemory_SEARCH_METHOD = "text_memory_search" @@ -32,7 +21,7 @@ DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 -DEFAULT_TOP_K = 10 +DEFAULT_TOP_K = 5 DEFAULT_CONTEXT_WINDOW_SIZE = 5 DEFAULT_USE_REDIS_QUEUE = True DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 @@ -66,7 +55,7 @@ DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 # task queue -DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.5" +DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.6" exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME", None) if exchange_name is not None: DEFAULT_STREAM_KEY_PREFIX += f":{exchange_name}" diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py index f148f30d5..fd4204969 100644 --- a/src/memos/mem_scheduler/schemas/monitor_schemas.py +++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py @@ -12,10 +12,12 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue, DictConversionMixin from memos.mem_scheduler.schemas.general_schemas import ( - DEFAULT_MAX_QUERY_KEY_WORDS, DEFAULT_WEIGHT_VECTOR_FOR_RANKING, NOT_INITIALIZED, ) +from memos.mem_scheduler.schemas.task_schemas import ( + DEFAULT_MAX_QUERY_KEY_WORDS, +) from memos.mem_scheduler.utils.filter_utils import transform_name_to_key from memos.memories.textual.tree import TextualMemoryItem diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index 168a25b5d..f82b12d32 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from pathlib import Path from typing import Any from uuid import uuid4 @@ -16,6 +17,33 @@ BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent +# ============== Schedule Task Definitaion ============== +class TaskPriorityLevel(Enum): + # priority top + LEVEL_1 = 1 + LEVEL_2 = 2 + LEVEL_3 = 3 + # priority bottom + + +QUERY_TASK_LABEL = "query" +ANSWER_TASK_LABEL = "answer" +ADD_TASK_LABEL = "add" +MEM_READ_TASK_LABEL = "mem_read" +MEM_ORGANIZE_TASK_LABEL = "mem_organize" +MEM_UPDATE_TASK_LABEL = "mem_update" +MEM_ARCHIVE_TASK_LABEL = "mem_archive" +API_MIX_SEARCH_TASK_LABEL = "api_mix_search" +PREF_ADD_TASK_LABEL = "pref_add" +MEM_FEEDBACK_TASK_LABEL = "mem_feedback" + +# Additional constants moved from general_schemas +DEFAULT_MAX_QUERY_KEY_WORDS = 1000 +LONG_TERM_MEMORY_TYPE = "LongTermMemory" +USER_INPUT_TYPE = "UserInput" +NOT_APPLICABLE_TYPE = "NotApplicable" + + # ============== Running Tasks ============== class RunningTaskItem(BaseModel, DictConversionMixin): """Data class for tracking running tasks in SchedulerDispatcher.""" diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index ade2bbfbf..59afd7b61 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -20,7 +20,8 @@ DEFAULT_STOP_WAIT, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem, TaskPriorityLevel +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube @@ -53,6 +54,7 @@ def __init__( status_tracker: TaskStatusTracker | None = None, metrics: Any | None = None, submit_web_logs: Callable | None = None, # ADDED + orchestrator: SchedulerOrchestrator | None = None, ): super().__init__() self.config = config @@ -66,7 +68,7 @@ def __init__( if hasattr(memos_message_queue, "memos_message_queue") else memos_message_queue ) - + self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator # Get multi-task timeout from config self.multi_task_running_timeout = ( self.config.get("multi_task_running_timeout") if self.config else None @@ -79,6 +81,7 @@ def __init__( self.dispatcher_executor = ContextThreadPoolExecutor( max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix ) + logger.info(f"Max works of dispatcher is set to {self.max_workers}") else: self.dispatcher_executor = None logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}") @@ -463,9 +466,19 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): # Create wrapped handler for task tracking wrapped_handler = self._create_task_wrapper(handler, task_item) + task_priority = self.orchestrator.get_task_priority(task_label=label) + # dispatch to different handler logger.debug(f"Task started: {task_item.get_execution_info()}") - if self.enable_parallel_dispatch and self.dispatcher_executor is not None: + + # If priority is LEVEL_1, force synchronous execution regardless of thread pool availability + use_thread_pool = ( + self.enable_parallel_dispatch + and self.dispatcher_executor is not None + and task_priority != TaskPriorityLevel.LEVEL_1 + ) + + if use_thread_pool: # Submit and track the future future = self.dispatcher_executor.submit(wrapped_handler, msgs) with self._task_lock: @@ -476,6 +489,9 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): ) else: # For synchronous execution, the wrapper will run and remove the task upon completion + logger.info( + f"Execute {len(msgs)} message(s) synchronously for {label} (priority: {task_priority}) for user {user_id} and mem_cube {mem_cube_id}." + ) wrapped_handler(msgs) def join(self, timeout: float | None = None) -> bool: diff --git a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py index d03648bba..19da9c7de 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py +++ b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py @@ -16,24 +16,42 @@ from __future__ import annotations from memos.log import get_logger +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + ANSWER_TASK_LABEL, + QUERY_TASK_LABEL, + TaskPriorityLevel, +) +from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule logger = get_logger(__name__) -class SchedulerOrchestrator: - def __init__(self, queue): +class SchedulerOrchestrator(RedisSchedulerModule): + def __init__(self): """ Args: queue: An instance of `SchedulerRedisQueue`. """ - self.queue = queue # Cache of fetched messages grouped by (user_id, mem_cube_id, task_label) self._cache = None + self.tasks_priorities = { + ADD_TASK_LABEL: TaskPriorityLevel.LEVEL_1, + QUERY_TASK_LABEL: TaskPriorityLevel.LEVEL_1, + ANSWER_TASK_LABEL: TaskPriorityLevel.LEVEL_1, + } def get_stream_priorities(self) -> None | dict: return None + def get_task_priority(self, task_label: str): + task_priority = TaskPriorityLevel.LEVEL_3 + if task_label in self.tasks_priorities: + task_priority = self.tasks_priorities[task_label] + logger.info(f"get_task_priority: {task_priority}") + return task_priority + def get_stream_quotas(self, stream_keys, consume_batch_size) -> dict: stream_priorities = self.get_stream_priorities() stream_quotas = {} diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 703dd1eb8..fb38a0f44 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -46,10 +46,10 @@ def __init__( "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", DEFAULT_STREAM_KEY_PREFIX, ), + orchestrator: SchedulerOrchestrator | None = None, consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", - max_len: int = 10000, - maxsize: int = 0, # For Queue compatibility + max_len: int | None = None, auto_delete_acked: bool = True, # Whether to automatically delete acknowledged messages ): """ @@ -64,17 +64,11 @@ def __init__( auto_delete_acked: Whether to automatically delete acknowledged messages from stream """ super().__init__() - - # If maxsize <= 0, set to None (unlimited queue size) - if maxsize <= 0: - maxsize = 0 - # Stream configuration self.stream_key_prefix = stream_key_prefix self.consumer_group = consumer_group self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" self.max_len = max_len - self.maxsize = maxsize # For Queue compatibility self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages # Consumer state @@ -105,7 +99,8 @@ def __init__( # Task Orchestrator self.message_pack_cache = deque() - self.orchestrator = SchedulerOrchestrator(queue=self) + + self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}" @@ -191,11 +186,7 @@ def _ensure_consumer_group(self, stream_key) -> None: except Exception as e: # Check if it's a "consumer group already exists" error error_msg = str(e).lower() - if "busygroup" in error_msg or "already exists" in error_msg: - logger.info( - f"Consumer group '{self.consumer_group}' already exists for stream '{stream_key}'" - ) - else: + if not ("busygroup" in error_msg or "already exists" in error_msg): logger.error(f"Error creating consumer group: {e}", exc_info=True) # Pending lock methods removed as they are unnecessary with idle-threshold claiming @@ -498,18 +489,9 @@ def empty(self) -> bool: return self.size() == 0 def full(self) -> bool: - """ - Check if the Redis queue is full (Queue-compatible interface). - - For Redis streams, we consider the queue full if it exceeds maxsize. - If maxsize is 0 or None, the queue is never considered full. - - Returns: - True if the queue is full, False otherwise - """ - if self.maxsize <= 0: + if self.max_len is None: return False - return self.size() >= self.maxsize + return self.size() >= self.max_len def join(self) -> None: """ diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 2fd8716a3..7c9139200 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -9,6 +9,7 @@ from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube @@ -24,12 +25,21 @@ def __init__( use_redis_queue: bool, maxsize: int, disabled_handlers: list | None = None, + orchestrator: SchedulerOrchestrator | None = None, ): self.use_redis_queue = use_redis_queue self.maxsize = maxsize + self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator if self.use_redis_queue: - self.memos_message_queue = SchedulerRedisQueue(maxsize=self.maxsize) + if maxsize is None or not isinstance(maxsize, int) or maxsize <= 0: + maxsize = None + self.memos_message_queue = SchedulerRedisQueue( + max_len=maxsize, + consumer_group="scheduler_group", + consumer_name="scheduler_consumer", + orchestrator=self.orchestrator, + ) else: self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 1892849a4..88c0f87c7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -15,13 +15,13 @@ ) from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_FEEDBACK_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ADD_TASK_LABEL, + MEM_FEEDBACK_TASK_LABEL, + MEM_READ_TASK_LABEL, + PREF_ADD_TASK_LABEL, +) from memos.multi_mem_cube.views import MemCubeView from memos.types.general_types import ( FINE_STRATEGY, @@ -153,7 +153,7 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]: session_id=target_session_id, mem_cube_id=self.cube_id, mem_cube=self.naive_mem_cube, - label=MEM_FEEDBACK_LABEL, + label=MEM_FEEDBACK_TASK_LABEL, content=feedback_req_str, timestamp=datetime.utcnow(), ) @@ -503,7 +503,7 @@ def _schedule_memory_tasks( session_id=target_session_id, mem_cube_id=self.cube_id, mem_cube=self.naive_mem_cube, - label=MEM_READ_LABEL, + label=MEM_READ_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), user_name=self.cube_id, @@ -525,7 +525,7 @@ def _schedule_memory_tasks( session_id=target_session_id, mem_cube_id=self.cube_id, mem_cube=self.naive_mem_cube, - label=ADD_LABEL, + label=ADD_TASK_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), user_name=self.cube_id, @@ -571,7 +571,7 @@ def _process_pref_mem( session_id=target_session_id, mem_cube_id=user_context.mem_cube_id, mem_cube=self.naive_mem_cube, - label=PREF_ADD_LABEL, + label=PREF_ADD_TASK_LABEL, content=json.dumps(messages_list), timestamp=datetime.utcnow(), info=add_req.info, diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index fed1e8500..5b68a8bad 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -17,13 +17,13 @@ from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.schemas.general_schemas import ( - ANSWER_LABEL, - QUERY_LABEL, -) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, ) +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + QUERY_TASK_LABEL, +) from memos.memories.textual.tree import TreeTextMemory @@ -106,8 +106,8 @@ def tearDown(self): def test_initialization(self): """Test that scheduler initializes with correct default values and handlers.""" # Verify handler registration - self.assertTrue(QUERY_LABEL in self.scheduler.dispatcher.handlers) - self.assertTrue(ANSWER_LABEL in self.scheduler.dispatcher.handlers) + self.assertTrue(QUERY_TASK_LABEL in self.scheduler.dispatcher.handlers) + self.assertTrue(ANSWER_TASK_LABEL in self.scheduler.dispatcher.handlers) def test_initialize_modules(self): """Test module initialization with proper component assignments.""" @@ -121,7 +121,7 @@ def test_submit_web_logs(self): log_message = ScheduleLogForWebItem( user_id="test_user", mem_cube_id="test_cube", - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, from_memory_type="WorkingMemory", # New field to_memory_type="LongTermMemory", # New field log_content="Test Content", @@ -155,7 +155,7 @@ def test_submit_web_logs(self): # Verify core fields self.assertEqual(actual_message.user_id, "test_user") self.assertEqual(actual_message.mem_cube_id, "test_cube") - self.assertEqual(actual_message.label, QUERY_LABEL) + self.assertEqual(actual_message.label, QUERY_TASK_LABEL) self.assertEqual(actual_message.from_memory_type, "WorkingMemory") self.assertEqual(actual_message.to_memory_type, "LongTermMemory") self.assertEqual(actual_message.log_content, "Test Content") @@ -225,7 +225,7 @@ def test_activation_memory_update(self): try: self.scheduler.update_activation_memory( new_memories=test_memories, - label=QUERY_LABEL, + label=QUERY_TASK_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=self.mem_cube,