Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import concurrent.futures
import contextlib
import json
import os
import traceback

from memos.configs.mem_scheduler import GeneralSchedulerConfig
Expand Down Expand Up @@ -30,7 +29,10 @@
is_all_english,
transform_name_to_key,
)
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.mem_scheduler.utils.misc_utils import (
group_messages_by_user_and_mem_cube,
is_cloud_env,
)
from memos.memories.textual.item import TextualMemoryItem
from memos.memories.textual.preference import PreferenceTextMemory
from memos.memories.textual.tree import TreeTextMemory
Expand Down Expand Up @@ -194,9 +196,9 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
f"prepared_add_items: {prepared_add_items};\n prepared_update_items_with_original: {prepared_update_items_with_original}"
)
# Conditional Logging: Knowledge Base (Cloud Service) vs. Playground/Default
is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
cloud_env = is_cloud_env()

if is_cloud_env:
if cloud_env:
self.send_add_log_messages_to_cloud_env(
msg, prepared_add_items, prepared_update_items_with_original
)
Expand Down Expand Up @@ -615,8 +617,8 @@ def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) ->
f"Successfully processed feedback for user_id={user_id}, mem_cube_id={mem_cube_id}"
)

is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
if is_cloud_env:
cloud_env = is_cloud_env()
if cloud_env:
record = feedback_result.get("record") if isinstance(feedback_result, dict) else {}
add_records = record.get("add") if isinstance(record, dict) else []
update_records = record.get("update") if isinstance(record, dict) else []
Expand Down Expand Up @@ -733,7 +735,7 @@ def _extract_fields(mem_item):
else:
logger.info(
"Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)",
is_cloud_env,
cloud_env,
)

except Exception as e:
Expand Down Expand Up @@ -893,8 +895,8 @@ def _process_memories_with_reader(

# LOGGING BLOCK START
# This block is replicated from _add_message_consumer to ensure consistent logging
is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
if is_cloud_env:
cloud_env = is_cloud_env()
if cloud_env:
# New: Knowledge Base Logging (Cloud Service)
kb_log_content = []
for item in flattened_memories:
Expand Down Expand Up @@ -1013,8 +1015,8 @@ def _process_memories_with_reader(
f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True
)
with contextlib.suppress(Exception):
is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
if is_cloud_env:
cloud_env = is_cloud_env()
if cloud_env:
if not kb_log_content:
trigger_source = (
info.get("trigger_source", "Messages") if info else "Messages"
Expand Down
7 changes: 3 additions & 4 deletions src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import concurrent
import os
import threading
import time

Expand All @@ -25,7 +24,7 @@
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env
from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker

Expand Down Expand Up @@ -351,8 +350,8 @@ def _maybe_emit_task_completion(
mem_cube_id = first.mem_cube_id

try:
is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
if not is_cloud_env:
cloud_env = is_cloud_env()
if not cloud_env:
return

for task_id in task_ids:
Expand Down
35 changes: 35 additions & 0 deletions src/memos/mem_scheduler/utils/misc_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import re
import traceback

Expand All @@ -17,6 +18,40 @@
logger = get_logger(__name__)


def _normalize_env_value(value: str | None) -> str:
"""Normalize environment variable values for comparison."""
return value.strip().lower() if isinstance(value, str) else ""


def is_playground_env() -> bool:
"""Return True when ENV_NAME indicates a Playground environment."""
env_name = _normalize_env_value(os.getenv("ENV_NAME"))
return env_name.startswith("playground")


def is_cloud_env() -> bool:
"""
Determine whether the scheduler should treat the runtime as a cloud environment.

Rules:
- Any Playground ENV_NAME is explicitly NOT cloud.
- MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME must be set to enable cloud behavior.
- The default memos-fanout/fanout combination is treated as non-cloud.
"""
if is_playground_env():
return False

exchange_name = _normalize_env_value(os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME"))
exchange_type = _normalize_env_value(os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_TYPE"))

if not exchange_name:
return False

return not (
exchange_name == "memos-fanout" and (not exchange_type or exchange_type == "fanout")
)


def extract_json_obj(text: str):
"""
Safely extracts JSON from LLM response text with robust error handling.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue
from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE
from memos.mem_scheduler.utils.misc_utils import is_cloud_env


logger = get_logger(__name__)
Expand Down Expand Up @@ -291,7 +292,7 @@ def rabbitmq_publish_message(self, message: dict):

# Cloud environment override: applies to specific message types if MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set
env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
if env_exchange_name and label in ["taskStatus", "knowledgeBaseUpdate"]:
if is_cloud_env() and env_exchange_name and label in ["taskStatus", "knowledgeBaseUpdate"]:
exchange_name = env_exchange_name
routing_key = "" # Routing key is always empty in cloud environment for these types

Expand Down