Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 2 additions & 76 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from memos.mem_cube.general import GeneralMemCube
from memos.mem_feedback.simple_feedback import SimpleMemFeedback
from memos.mem_scheduler.general_modules.init_components_for_scheduler import init_components
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
Expand All @@ -37,7 +36,6 @@
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
DEFAULT_MAX_WEB_LOG_QUEUE_SIZE,
DEFAULT_STARTUP_MODE,
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
Expand All @@ -51,12 +49,6 @@
)
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
from memos.mem_scheduler.schemas.task_schemas import (
ADD_TASK_LABEL,
ANSWER_TASK_LABEL,
MEM_ARCHIVE_TASK_LABEL,
MEM_ORGANIZE_TASK_LABEL,
MEM_UPDATE_TASK_LABEL,
QUERY_TASK_LABEL,
TaskPriorityLevel,
)
from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
Expand Down Expand Up @@ -121,12 +113,6 @@ def __init__(self, config: BaseSchedulerConfig):
# optional configs
self.disabled_handlers: list | None = self.config.get("disabled_handlers", None)

self.max_web_log_queue_size = self.config.get(
"max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
)
self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue(
maxsize=self.max_web_log_queue_size
)
self._consumer_thread = None # Reference to our consumer thread/process
self._consumer_process = None # Reference to our consumer process
self._running = False
Expand Down Expand Up @@ -853,11 +839,6 @@ def _submit_web_logs(
return

for message in messages:
try:
self._web_log_message_queue.put(message)
except Exception as e:
logger.warning(f"Failed to put message to web log queue: {e}", stack_info=True)

message_info = message.debug_info()
logger.debug(f"Submitted Scheduling log for web: {message_info}")

Expand All @@ -866,68 +847,13 @@ def _submit_web_logs(
f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message_info}"
)
self.rabbitmq_publish_message(message=message.to_dict())
logger.debug(
f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}"
)
logger.debug(f"{len(messages)} submitted. additional_log_info: {additional_log_info}")

def get_web_log_messages(self) -> list[dict]:
"""
Retrieve structured log messages from the queue and return JSON-serializable dicts.
"""
raw_items: list[ScheduleLogForWebItem] = []
while True:
try:
raw_items.append(self._web_log_message_queue.get_nowait())
except Exception:
break

def _map_label(label: str) -> str:
mapping = {
QUERY_TASK_LABEL: "addMessage",
ANSWER_TASK_LABEL: "addMessage",
ADD_TASK_LABEL: "addMemory",
MEM_UPDATE_TASK_LABEL: "updateMemory",
MEM_ORGANIZE_TASK_LABEL: "mergeMemory",
MEM_ARCHIVE_TASK_LABEL: "archiveMemory",
}
return mapping.get(label, label)

def _normalize_item(item: ScheduleLogForWebItem) -> dict:
data = item.to_dict()
data["label"] = _map_label(data.get("label"))
memcube_content = getattr(item, "memcube_log_content", None) or []
metadata = getattr(item, "metadata", None) or []

memcube_name = getattr(item, "memcube_name", None)
if not memcube_name and hasattr(self, "_map_memcube_name"):
memcube_name = self._map_memcube_name(item.mem_cube_id)
data["memcube_name"] = memcube_name

memory_len = getattr(item, "memory_len", None)
if memory_len is None:
if data["label"] == "mergeMemory":
memory_len = len([c for c in memcube_content if c.get("type") != "postMerge"])
elif memcube_content:
memory_len = len(memcube_content)
else:
memory_len = 1 if item.log_content else 0

data["memcube_log_content"] = memcube_content
data["memory_len"] = memory_len

def _with_memory_time(meta: dict) -> dict:
enriched = dict(meta)
if "memory_time" not in enriched:
enriched["memory_time"] = enriched.get("updated_at") or enriched.get(
"update_at"
)
return enriched

data["metadata"] = [_with_memory_time(m) for m in metadata]
data["log_title"] = ""
return data

return [_normalize_item(it) for it in raw_items]
return []

def _message_consumer(self) -> None:
"""
Expand Down
26 changes: 16 additions & 10 deletions src/memos/mem_scheduler/general_modules/scheduler_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def create_autofilled_log_item(
user_id: str,
mem_cube_id: str,
mem_cube: GeneralMemCube,
item_id: str | None = None,
) -> ScheduleLogForWebItem:
if mem_cube is None:
logger.error(
Expand Down Expand Up @@ -94,16 +95,19 @@ def create_autofilled_log_item(
)
memory_capacities["parameter_memory_capacity"] = 1

log_message = ScheduleLogForWebItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
label=label,
from_memory_type=from_memory_type,
to_memory_type=to_memory_type,
log_content=log_content,
current_memory_sizes=current_memory_sizes,
memory_capacities=memory_capacities,
)
log_kwargs = {
"user_id": user_id,
"mem_cube_id": mem_cube_id,
"label": label,
"from_memory_type": from_memory_type,
"to_memory_type": to_memory_type,
"log_content": log_content,
"current_memory_sizes": current_memory_sizes,
"memory_capacities": memory_capacities,
}
if item_id:
log_kwargs["item_id"] = item_id
log_message = ScheduleLogForWebItem(**log_kwargs)
return log_message

@log_exceptions(logger=logger)
Expand All @@ -120,6 +124,7 @@ def create_event_log(
memory_len: int,
memcube_name: str | None = None,
log_content: str | None = None,
item_id: str | None = None,
) -> ScheduleLogForWebItem:
item = self.create_autofilled_log_item(
log_content=log_content or "",
Expand All @@ -129,6 +134,7 @@ def create_event_log(
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
item_id=item_id,
)
item.memcube_log_content = memcube_log_content
item.metadata = metadata
Expand Down
12 changes: 12 additions & 0 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
metadata=[],
memory_len=1,
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
self._submit_web_logs([event])
Expand Down Expand Up @@ -322,6 +323,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
metadata=[],
memory_len=1,
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
self._submit_web_logs([event])
Expand Down Expand Up @@ -492,6 +494,7 @@ def send_add_log_messages_to_local_env(
metadata=add_meta_legacy,
memory_len=len(add_content_legacy),
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
events.append(event)
Expand All @@ -507,6 +510,7 @@ def send_add_log_messages_to_local_env(
metadata=update_meta_legacy,
memory_len=len(update_content_legacy),
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.task_id = msg.task_id
events.append(event)
Expand Down Expand Up @@ -573,6 +577,7 @@ def send_add_log_messages_to_cloud_env(
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(msg.mem_cube_id),
item_id=msg.item_id,
)
event.log_content = f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
event.task_id = msg.task_id
Expand Down Expand Up @@ -719,6 +724,7 @@ def _extract_fields(mem_item):
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=message.item_id,
)
event.log_content = (
f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
Expand Down Expand Up @@ -788,6 +794,7 @@ def process_message(message: ScheduleMessageItem):
user_name=user_name,
custom_tags=info.get("custom_tags", None),
task_id=message.task_id,
item_id=message.item_id,
info=info,
)

Expand Down Expand Up @@ -815,6 +822,7 @@ def _process_memories_with_reader(
user_name: str,
custom_tags: list[str] | None = None,
task_id: str | None = None,
item_id: str | None = None,
info: dict | None = None,
) -> None:
logger.info(
Expand Down Expand Up @@ -934,6 +942,7 @@ def _process_memories_with_reader(
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=item_id,
)
event.log_content = (
f"Knowledge Base Memory Update: {len(kb_log_content)} changes."
Expand Down Expand Up @@ -979,6 +988,7 @@ def _process_memories_with_reader(
metadata=add_meta_legacy,
memory_len=len(add_content_legacy),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=item_id,
)
event.task_id = task_id
self._submit_web_logs([event])
Expand Down Expand Up @@ -1045,6 +1055,7 @@ def _process_memories_with_reader(
metadata=None,
memory_len=len(kb_log_content),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=item_id,
)
event.log_content = f"Knowledge Base Memory Update failed: {exc!s}"
event.task_id = task_id
Expand Down Expand Up @@ -1212,6 +1223,7 @@ def process_message(message: ScheduleMessageItem):
metadata=meta,
memory_len=len(keys),
memcube_name=self._map_memcube_name(mem_cube_id),
item_id=message.item_id,
)
self._submit_web_logs([event])

Expand Down
6 changes: 6 additions & 0 deletions src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _maybe_emit_task_completion(
# messages in one batch can belong to different business task_ids; check each
task_ids = set()
task_id_to_doc_id = {}
task_id_to_item_id = {}

for msg in messages:
tid = getattr(msg, "task_id", None)
Expand All @@ -340,6 +341,8 @@ def _maybe_emit_task_completion(
sid = info.get("source_doc_id")
if sid:
task_id_to_doc_id[tid] = sid
if tid not in task_id_to_item_id:
task_id_to_item_id[tid] = msg.item_id

if not task_ids:
return
Expand All @@ -356,6 +359,7 @@ def _maybe_emit_task_completion(

for task_id in task_ids:
source_doc_id = task_id_to_doc_id.get(task_id)
event_item_id = task_id_to_item_id.get(task_id)
status_data = self.status_tracker.get_task_status_by_business_id(
business_task_id=task_id, user_id=user_id
)
Expand All @@ -369,6 +373,7 @@ def _maybe_emit_task_completion(
# (Although if status is 'completed', local error shouldn't happen theoretically,
# unless status update lags or is inconsistent. We trust status_tracker here.)
event = ScheduleLogForWebItem(
item_id=event_item_id,
task_id=task_id,
user_id=user_id,
mem_cube_id=mem_cube_id,
Expand All @@ -393,6 +398,7 @@ def _maybe_emit_task_completion(
error_msg = "Unknown error (check system logs)"

event = ScheduleLogForWebItem(
item_id=event_item_id,
task_id=task_id,
user_id=user_id,
mem_cube_id=mem_cube_id,
Expand Down
8 changes: 1 addition & 7 deletions tests/mem_scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,9 @@ def test_submit_web_logs(self):
},
)

# Empty the queue by consuming all elements
while not self.scheduler._web_log_message_queue.empty():
self.scheduler._web_log_message_queue.get()

# Submit the log message
self.scheduler._submit_web_logs(messages=log_message)

# Verify the message was added to the queue
self.assertEqual(self.scheduler._web_log_message_queue.qsize(), 1)
# No local web log queue; ensure submission completes without error.

# Get the actual message from the queue
actual_message = self.scheduler._web_log_message_queue.get()
Expand Down
Loading