diff --git a/docker/requirements.txt b/docker/requirements.txt index 4846f1832..873cb4d22 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -158,3 +158,4 @@ watchfiles==1.1.0 websockets==15.0.1 xlrd==2.0.2 xlsxwriter==3.2.5 +prometheus-client==0.23.1 diff --git a/evaluation/scripts/locomo/locomo_eval.py b/evaluation/scripts/locomo/locomo_eval.py index b431e7768..24a216b92 100644 --- a/evaluation/scripts/locomo/locomo_eval.py +++ b/evaluation/scripts/locomo/locomo_eval.py @@ -3,6 +3,7 @@ import json import logging import os +import re import time import nltk @@ -47,6 +48,29 @@ class LLMGrade(BaseModel): llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.") +def extract_label_json(text: str) -> str | None: + """ + Extracts a JSON object of the form {"label": "VALUE"} from a given text string. + This function is designed to handle cases where the LLM response contains + natural language alongside a final JSON snippet, ensuring robust parsing. + + Supports both single and double quotes around the label value. + Ignores surrounding whitespace and formatting. + + Returns: + The full matching JSON string (e.g., '{"label": "CORRECT"}') if found. + None if no valid label JSON is found. + """ + # Regex pattern to match: { "label": "value" } with optional whitespace + # Matches both single and double quotes, allows spaces around keys and values + pattern = r'\{\s*"label"\s*:\s*["\']([^"\']*)["\']\s*\}' + match = re.search(pattern, text) + if match: + # Return the complete matched JSON string for safe json.loads() + return match.group(0) + return None + + async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool: system_prompt = """ You are an expert grader that determines if answers to questions match a gold standard answer @@ -77,20 +101,23 @@ async def locomo_grader(llm_client, question: str, gold_answer: str, response: s Just return the label CORRECT or WRONG in a json format with the key as "label". """ - - response = await llm_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": accuracy_prompt}, - ], - temperature=0, - ) - message_content = response.choices[0].message.content - label = json.loads(message_content)["label"] - parsed = LLMGrade(llm_judgment=label, llm_reasoning="") - - return parsed.llm_judgment.strip().lower() == "correct" + try: + response = await llm_client.chat.completions.create( + model=os.getenv("EVAL_MODEL", "gpt-4o-mini"), + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": accuracy_prompt}, + ], + temperature=0, + ) + message_content = response.choices[0].message.content + message_content = extract_label_json(text=message_content) + label = json.loads(message_content)["label"] + parsed = LLMGrade(llm_judgment=label, llm_reasoning="") + return parsed.llm_judgment.strip().lower() == "correct" + except Exception as e: + print(f"======== {e}, {response} ===========") + exit() def calculate_rouge_scores(gold_answer, response): @@ -284,7 +311,7 @@ async def main(frame, version="default", options=None, num_runs=1, max_workers=4 with open(response_path) as file: locomo_responses = json.load(file) - num_users = 10 + num_users = 2 all_grades = {} total_responses_count = sum( diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 157c3f8ea..e835dd5d7 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -189,9 +189,7 @@ def search(self, query, user_id, top_k): ) response = requests.request("POST", url, data=payload, headers=self.headers) assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Search completed successfully", ( - response.text - ) + assert json.loads(response.text)["message"] == "Memory searched successfully", response.text return json.loads(response.text)["data"] diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py index 1b59543f3..d3522f8e1 100644 --- a/examples/mem_scheduler/api_w_scheduler.py +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -1,8 +1,10 @@ +from time import sleep + from memos.api.handlers.scheduler_handler import ( handle_scheduler_status, handle_scheduler_wait, ) -from memos.api.routers.server_router import mem_scheduler +from memos.api.routers.server_router import mem_scheduler, status_tracker from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -26,26 +28,25 @@ def my_test_handler(messages: list[ScheduleMessageItem]): for msg in messages: print(f" my_test_handler - {msg.item_id}: {msg.content}") user_status_running = handle_scheduler_status( - user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" + user_id=msg.user_id, status_tracker=status_tracker ) - print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) + print("[Monitor] Status after submit:", user_status_running) # 2. Register the handler TEST_HANDLER_LABEL = "test_handler" +TEST_USER_ID = "test_user" mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) # 2.1 Monitor global scheduler status before submitting tasks -global_status_before = handle_scheduler_status( - user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" -) +global_status_before = handle_scheduler_status(user_id=TEST_USER_ID, status_tracker=status_tracker) print("[Monitor] Global status before submit:", global_status_before) # 3. Create messages messages_to_send = [ ScheduleMessageItem( item_id=f"test_item_{i}", - user_id="test_user", + user_id=TEST_USER_ID, mem_cube_id="test_mem_cube", label=TEST_HANDLER_LABEL, content=f"This is test message {i}", @@ -56,28 +57,28 @@ def my_test_handler(messages: list[ScheduleMessageItem]): # 5. Submit messages for mes in messages_to_send: print(f"Submitting message {mes.item_id} to the scheduler...") - mem_scheduler.memos_message_queue.submit_messages([mes]) + mem_scheduler.submit_messages([mes]) + sleep(1) # 5.1 Monitor status for specific mem_cube while running USER_MEM_CUBE = "test_mem_cube" # 6. Wait for messages to be processed (limited to 100 checks) -print("Waiting for messages to be consumed (max 100 checks)...") -mem_scheduler.mem_scheduler_wait() + +user_status_running = handle_scheduler_status(user_id=TEST_USER_ID, status_tracker=status_tracker) +print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 6.1 Wait until idle for specific mem_cube via handler wait_result = handle_scheduler_wait( - user_name=USER_MEM_CUBE, + user_name=TEST_USER_ID, + status_tracker=status_tracker, timeout_seconds=120.0, - poll_interval=0.2, - mem_scheduler=mem_scheduler, + poll_interval=0.5, ) print(f"[Monitor] Wait result for {USER_MEM_CUBE}:", wait_result) # 6.2 Monitor global scheduler status after processing -global_status_after = handle_scheduler_status( - user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" -) +global_status_after = handle_scheduler_status(user_id=TEST_USER_ID, status_tracker=status_tracker) print("[Monitor] Global status after processing:", global_status_after) # 7. Stop the scheduler diff --git a/poetry.lock b/poetry.lock index 926d580fb..e5e3bc1bd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -3591,6 +3591,21 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "prometheus-client" +version = "0.23.1" +description = "Python client for the Prometheus monitoring system." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "prometheus_client-0.23.1-py3-none-any.whl", hash = "sha256:dd1913e6e76b59cfe44e7a4b83e01afc9873c1bdfd2ed8739f1e76aeca115f99"}, + {file = "prometheus_client-0.23.1.tar.gz", hash = "sha256:6ae8f9081eaaaf153a2e959d2e6c4f4fb57b12ef76c8c7980202f1e57b48b2ce"}, +] + +[package.extras] +twisted = ["twisted"] + [[package]] name = "protobuf" version = "6.31.1" @@ -6406,4 +6421,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "ec17679a44205ada4494fbc485ac592883281fde273d5e73d6b8cbc6f7f9ed10" +content-hash = "a98b5ddffb4c031342ef1314a93666460ce0903e207bc79d23478b80a99b7f40" diff --git a/pyproject.toml b/pyproject.toml index 29a29cca8..7efd77d80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "scikit-learn (>=1.7.0,<2.0.0)", # Machine learning "fastmcp (>=2.10.5,<3.0.0)", "python-dateutil (>=2.9.0.post0,<3.0.0)", + "prometheus-client (>=0.23.1,<0.24.0)", ] [project.urls] diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index 7a47f05e3..9df3310ec 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -9,7 +9,7 @@ from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.memories.textual.tree_text_memory.retrieve.advanced_searcher import AdvancedSearcher logger = get_logger(__name__) @@ -132,7 +132,7 @@ def mem_scheduler(self) -> BaseScheduler: return self.deps.mem_scheduler @property - def searcher(self) -> Searcher: + def searcher(self) -> AdvancedSearcher: """Get scheduler instance.""" return self.deps.searcher diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 7b34fcfae..706269b52 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -129,6 +129,21 @@ def init_server() -> dict[str, Any]: """ logger.info("Initializing MemOS server components...") + # Initialize Redis client first as it is a core dependency for features like scheduler status tracking + try: + from memos.mem_scheduler.orm_modules.api_redis_model import APIRedisDBManager + + redis_client = APIRedisDBManager.load_redis_engine_from_env() + if redis_client: + logger.info("Redis client initialized successfully.") + else: + logger.error( + "Failed to initialize Redis client. Check REDIS_HOST etc. in environment variables." + ) + except Exception as e: + logger.error(f"Failed to initialize Redis client: {e}", exc_info=True) + redis_client = None # Ensure redis_client exists even on failure + # Get default cube configuration default_cube_config = APIConfig.get_default_cube_config() @@ -272,6 +287,8 @@ def init_server() -> dict[str, Any]: tree_mem: TreeTextMemory = naive_mem_cube.text_mem searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + process_llm=mem_reader.llm, ) logger.debug("Searcher created") @@ -286,6 +303,7 @@ def init_server() -> dict[str, Any]: process_llm=mem_reader.llm, db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, + redis_client=redis_client, ) mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher) logger.debug("Scheduler initialized") @@ -335,5 +353,6 @@ def init_server() -> dict[str, Any]: "text_mem": text_mem, "pref_mem": pref_mem, "online_bot": online_bot, + "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, } diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py index 32b312f8a..4596889ac 100644 --- a/src/memos/api/handlers/scheduler_handler.py +++ b/src/memos/api/handlers/scheduler_handler.py @@ -14,196 +14,203 @@ from fastapi import HTTPException from fastapi.responses import StreamingResponse -from memos.api.handlers.formatters_handler import to_iter +# Imports for new implementation +from memos.api.product_models import StatusResponse, StatusResponseItem from memos.log import get_logger +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker logger = get_logger(__name__) def handle_scheduler_status( - mem_cube_id: str | None = None, - mem_scheduler: Any | None = None, - instance_id: str = "", -) -> dict[str, Any]: + user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None +) -> StatusResponse: """ - Get scheduler running status. + Get scheduler running status for one or all tasks of a user. - Retrieves the number of running tasks for a specific user or globally. + Retrieves task statuses from the persistent TaskStatusTracker. Args: - user_name: Optional specific user name to filter tasks - mem_scheduler: Scheduler instance - instance_id: Instance ID for response + user_id: User ID to query for. + status_tracker: The TaskStatusTracker instance. + task_id: Optional Task ID to query a specific task. Returns: - Dictionary with status information + StatusResponse with a list of task statuses. Raises: - HTTPException: If status retrieval fails + HTTPException: If a specific task is not found. """ + response_data: list[StatusResponseItem] = [] + try: - if mem_cube_id: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: getattr(task, "mem_cube_id", None) == mem_cube_id - ) - tasks_iter = to_iter(running) - running_count = len(tasks_iter) - return { - "message": "ok", - "data": { - "scope": "user", - "mem_cube_id": mem_cube_id, - "running_tasks": running_count, - "timestamp": time.time(), - "instance_id": instance_id, - }, - } + if task_id: + task_data = status_tracker.get_task_status(task_id, user_id) + if not task_data: + raise HTTPException( + status_code=404, detail=f"Task {task_id} not found for user {user_id}" + ) + response_data.append(StatusResponseItem(task_id=task_id, status=task_data["status"])) else: - running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True) - tasks_iter = to_iter(running_all) - running_count = len(tasks_iter) - - task_count_per_user: dict[str, int] = {} - for task in tasks_iter: - cube = getattr(task, "mem_cube_id", "unknown") - task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 - - try: - metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() - except Exception: - metrics_snapshot = {} - - return { - "message": "ok", - "data": { - "scope": "global", - "running_tasks": running_count, - "task_count_per_user": task_count_per_user, - "timestamp": time.time(), - "instance_id": instance_id, - "metrics": metrics_snapshot, - }, - } + all_tasks = status_tracker.get_all_tasks_for_user(user_id) + # The plan returns an empty list, which is good. + # No need to check "if not all_tasks" explicitly before the list comprehension + response_data = [ + StatusResponseItem(task_id=tid, status=t_data["status"]) + for tid, t_data in all_tasks.items() + ] + + return StatusResponse(data=response_data) + except HTTPException: + # Re-raise HTTPException directly to preserve its status code (e.g., 404) + raise except Exception as err: - logger.error("Failed to get scheduler status: %s", traceback.format_exc()) + logger.error(f"Failed to get scheduler status for user {user_id}: {traceback.format_exc()}") raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err def handle_scheduler_wait( user_name: str, + status_tracker: TaskStatusTracker, timeout_seconds: float = 120.0, - poll_interval: float = 0.2, - mem_scheduler: Any | None = None, + poll_interval: float = 0.5, ) -> dict[str, Any]: """ - Wait until scheduler is idle for a specific user. + Wait until the scheduler is idle for a specific user. - Blocks until scheduler has no running tasks for the given user, or timeout. + Blocks and polls the new /scheduler/status endpoint until no tasks are in + 'waiting' or 'in_progress' state, or until a timeout is reached. Args: - user_name: User name to wait for - timeout_seconds: Maximum wait time in seconds - poll_interval: Polling interval in seconds - mem_scheduler: Scheduler instance + user_name: User name to wait for. + status_tracker: The TaskStatusTracker instance. + timeout_seconds: Maximum wait time in seconds. + poll_interval: Polling interval in seconds. Returns: - Dictionary with wait result and statistics + Dictionary with wait result and statistics. Raises: - HTTPException: If wait operation fails + HTTPException: If wait operation fails. """ - start = time.time() + start_time = time.time() try: - while True: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: task.mem_cube_id == user_name + while time.time() - start_time < timeout_seconds: + # Directly call the new, reliable status logic + status_response = handle_scheduler_status( + user_id=user_name, status_tracker=status_tracker + ) + + # System is idle if the data list is empty or no tasks are active + is_idle = not status_response.data or all( + task.status in ["completed", "failed", "cancelled"] for task in status_response.data ) - running_count = len(running) - elapsed = time.time() - start - # success -> scheduler is idle - if running_count == 0: + if is_idle: return { "message": "idle", "data": { - "running_tasks": 0, - "waited_seconds": round(elapsed, 3), + "running_tasks": 0, # Kept for compatibility + "waited_seconds": round(time.time() - start_time, 3), "timed_out": False, "user_name": user_name, }, } - # timeout check - if elapsed > timeout_seconds: - return { - "message": "timeout", - "data": { - "running_tasks": running_count, - "waited_seconds": round(elapsed, 3), - "timed_out": True, - "user_name": user_name, - }, - } - time.sleep(poll_interval) + # Timeout occurred + final_status = handle_scheduler_status(user_id=user_name, status_tracker=status_tracker) + active_tasks = [t for t in final_status.data if t.status in ["waiting", "in_progress"]] + + return { + "message": "timeout", + "data": { + "running_tasks": len(active_tasks), # A more accurate count of active tasks + "waited_seconds": round(time.time() - start_time, 3), + "timed_out": True, + "user_name": user_name, + }, + } + except HTTPException: + # Re-raise HTTPException directly to preserve its status code + raise except Exception as err: - logger.error("Failed while waiting for scheduler: %s", traceback.format_exc()) + logger.error( + f"Failed while waiting for scheduler for user {user_name}: {traceback.format_exc()}" + ) raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err def handle_scheduler_wait_stream( user_name: str, + status_tracker: TaskStatusTracker, timeout_seconds: float = 120.0, - poll_interval: float = 0.2, - mem_scheduler: Any | None = None, + poll_interval: float = 0.5, instance_id: str = "", ) -> StreamingResponse: """ - Stream scheduler progress via Server-Sent Events (SSE). + Stream scheduler progress via Server-Sent Events (SSE) using the new status endpoint. - Emits periodic heartbeat frames while tasks are running, then final + Emits periodic heartbeat frames while tasks are active, then a final status frame indicating idle or timeout. Args: - user_name: User name to monitor - timeout_seconds: Maximum stream duration in seconds - poll_interval: Polling interval between updates - mem_scheduler: Scheduler instance - instance_id: Instance ID for response + user_name: User name to monitor. + status_tracker: The TaskStatusTracker instance. + timeout_seconds: Maximum stream duration in seconds. + poll_interval: Polling interval between updates. + instance_id: Instance ID for response. Returns: - StreamingResponse with SSE formatted progress updates - - Example: - curl -N "http://localhost:8000/product/scheduler/wait/stream?timeout_seconds=10" + StreamingResponse with SSE formatted progress updates. """ def event_generator(): - start = time.time() + start_time = time.time() try: while True: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: task.mem_cube_id == user_name + elapsed = time.time() - start_time + if elapsed > timeout_seconds: + # Send timeout message and break + final_status = handle_scheduler_status( + user_id=user_name, status_tracker=status_tracker + ) + active_tasks = [ + t for t in final_status.data if t.status in ["waiting", "in_progress"] + ] + payload = { + "user_name": user_name, + "active_tasks": len(active_tasks), + "elapsed_seconds": round(elapsed, 3), + "status": "timeout", + "timed_out": True, + "instance_id": instance_id, + } + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" + break + + # Get status + status_response = handle_scheduler_status( + user_id=user_name, status_tracker=status_tracker ) - running_count = len(running) - elapsed = time.time() - start + active_tasks = [ + t for t in status_response.data if t.status in ["waiting", "in_progress"] + ] + num_active = len(active_tasks) payload = { "user_name": user_name, - "running_tasks": running_count, + "active_tasks": num_active, "elapsed_seconds": round(elapsed, 3), - "status": "running" if running_count > 0 else "idle", + "status": "running" if num_active > 0 else "idle", "instance_id": instance_id, } yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - if running_count == 0 or elapsed > timeout_seconds: - payload["status"] = "idle" if running_count == 0 else "timeout" - payload["timed_out"] = running_count > 0 - yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - break + if num_active == 0: + break # Exit loop if idle time.sleep(poll_interval) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 7d547d4ba..ea5f8d136 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -6,8 +6,7 @@ # Import message types from core types module from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import SearchMode -from memos.types import MessageDict, MessagesType, PermissionDict +from memos.types import MessageDict, MessagesType, PermissionDict, SearchMode logger = get_logger(__name__) @@ -678,3 +677,28 @@ class MemOSAddResponse(BaseModel): def success(self) -> bool: """Convenient access to success status.""" return self.data.success + + +# ─── Scheduler Status Models ─────────────────────────────────────────────────── + + +class StatusRequest(BaseRequest): + """Request model for querying scheduler task status.""" + + user_id: str = Field(..., description="User ID") + task_id: str | None = Field(None, description="Optional Task ID to query a specific task") + + +class StatusResponseItem(BaseModel): + """Individual task status item.""" + + task_id: str = Field(..., description="The ID of the task") + status: Literal["in_progress", "completed", "waiting", "failed", "cancelled"] = Field( + ..., description="The current status of the task" + ) + + +class StatusResponse(BaseResponse[list[StatusResponseItem]]): + """Response model for scheduler status operations.""" + + message: str = "Memory get status successfully" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 0067d6e2f..b40547fa4 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -15,7 +15,7 @@ import random as _random import socket -from fastapi import APIRouter +from fastapi import APIRouter, Query from memos.api import handlers from memos.api.handlers.add_handler import AddHandler @@ -23,8 +23,6 @@ from memos.api.handlers.chat_handler import ChatHandler from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( - AddStatusRequest, - AddStatusResponse, APIADDRequest, APIChatCompleteRequest, APISearchRequest, @@ -36,11 +34,13 @@ GetMemoryResponse, MemoryResponse, SearchResponse, + StatusResponse, SuggestionRequest, SuggestionResponse, ) from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker logger = get_logger(__name__) @@ -72,6 +72,8 @@ mem_scheduler: BaseScheduler = components["mem_scheduler"] llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] +redis_client = components["redis_client"] +status_tracker = TaskStatusTracker(redis_client=redis_client) # ============================================================================= @@ -86,7 +88,8 @@ def search_memories(search_req: APISearchRequest): This endpoint uses the class-based SearchHandler for better code organization. """ - return search_handler.handle_search_memories(search_req) + search_results = search_handler.handle_search_memories(search_req) + return search_results # ============================================================================= @@ -109,17 +112,18 @@ def add_memories(add_req: APIADDRequest): # ============================================================================= -@router.get( - "/scheduler/status", summary="Get scheduler running status", response_model=AddStatusResponse +@router.get( # Changed from post to get + "/scheduler/status", summary="Get scheduler running status", response_model=StatusResponse ) -def scheduler_status(add_status_req: AddStatusRequest): +def scheduler_status( + user_id: str = Query(..., description="User ID"), + task_id: str | None = Query(None, description="Optional Task ID to query a specific task"), +): """Get scheduler running status.""" return handlers.scheduler_handler.handle_scheduler_status( - mem_cube_id=add_status_req.mem_cube_id, - user_id=add_status_req.user_id, - session_id=add_status_req.session_id, - mem_scheduler=mem_scheduler, - instance_id=INSTANCE_ID, + user_id=user_id, + task_id=task_id, + status_tracker=status_tracker, ) @@ -127,14 +131,14 @@ def scheduler_status(add_status_req: AddStatusRequest): def scheduler_wait( user_name: str, timeout_seconds: float = 120.0, - poll_interval: float = 0.2, + poll_interval: float = 0.5, ): """Wait until scheduler is idle for a specific user.""" return handlers.scheduler_handler.handle_scheduler_wait( user_name=user_name, + status_tracker=status_tracker, timeout_seconds=timeout_seconds, poll_interval=poll_interval, - mem_scheduler=mem_scheduler, ) @@ -142,14 +146,14 @@ def scheduler_wait( def scheduler_wait_stream( user_name: str, timeout_seconds: float = 120.0, - poll_interval: float = 0.2, + poll_interval: float = 0.5, ): """Stream scheduler progress via Server-Sent Events (SSE).""" return handlers.scheduler_handler.handle_scheduler_wait_stream( user_name=user_name, + status_tracker=status_tracker, timeout_seconds=timeout_seconds, poll_interval=poll_interval, - mem_scheduler=mem_scheduler, instance_id=INSTANCE_ID, ) diff --git a/src/memos/log.py b/src/memos/log.py index c98f95f2e..874f2c6a7 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -188,7 +188,7 @@ def close(self): }, "handlers": { "console": { - "level": "DEBUG", + "level": selected_log_level, "class": "logging.StreamHandler", "stream": stdout, "formatter": "no_datetime", diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index 085025b7f..090e13f54 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -13,8 +13,10 @@ import requests +from memos.api.product_models import APIADDRequest, APISearchRequest +from memos.api.routers.server_router import add_memories, search_memories from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import SearchMode +from memos.types import MessageDict, SearchMode, UserContext logger = get_logger(__name__) @@ -84,7 +86,7 @@ def _close_connection(self): self._connection = None def search( - self, user_id: str, mem_cube_id: str, query: str, top: int = 50, use_requests: bool = True + self, user_id: str, mem_cube_id: str, query: str, top_k: int = 50, use_requests: bool = True ) -> dict[str, Any]: """ Search for memories using the product/search API endpoint. @@ -93,13 +95,13 @@ def search( user_id: User identifier mem_cube_id: Memory cube identifier query: Search query string - top: Number of top results to return + top_k: Number of top_k results to return use_requests: Whether to use requests library (True) or http.client (False) Returns: Dictionary containing the API response """ - payload = {"user_id": user_id, "mem_cube_id": mem_cube_id, "query": query, "top": top} + payload = {"user_id": user_id, "mem_cube_id": mem_cube_id, "query": query, "top_k": top_k} try: if use_requests: @@ -326,7 +328,7 @@ def analyze_service(self): user_id="test_user_id", mem_cube_id="test_mem_cube_id", query="What are some good places to celebrate New Year's Eve in Shanghai?", - top=50, + top_k=50, ) print("Search result:", search_result) @@ -337,7 +339,7 @@ def analyze_features(self): user_id="test_user_id", mem_cube_id="test_mem_cube_id", query="What are some good places to celebrate New Year's Eve in Shanghai?", - top=50, + top_k=50, ) print("Search result:", search_result) except Exception as e: @@ -353,28 +355,20 @@ class DirectSearchMemoriesAnalyzer: def __init__(self): """Initialize the analyzer""" # Import necessary modules - try: - from memos.api.product_models import APIADDRequest, APISearchRequest - from memos.api.routers.server_router import add_memories, search_memories - from memos.types import MessageDict, UserContext - - self.APISearchRequest = APISearchRequest - self.APIADDRequest = APIADDRequest - self.search_memories = search_memories - self.add_memories = add_memories - self.UserContext = UserContext - self.MessageDict = MessageDict - - # Initialize conversation history for continuous conversation support - self.conversation_history = [] - self.current_session_id = None - self.current_user_id = None - self.current_mem_cube_id = None - - logger.info("DirectSearchMemoriesAnalyzer initialized successfully") - except ImportError as e: - logger.error(f"Failed to import modules: {e}") - raise + self.APISearchRequest = APISearchRequest + self.APIADDRequest = APIADDRequest + self.search_memories = search_memories + self.add_memories = add_memories + self.UserContext = UserContext + self.MessageDict = MessageDict + + # Initialize conversation history for continuous conversation support + self.conversation_history = [] + self.current_session_id = None + self.current_user_id = None + self.current_mem_cube_id = None + + logger.info("DirectSearchMemoriesAnalyzer initialized successfully") def start_conversation(self, user_id="test_user", mem_cube_id="test_cube", session_id=None): """ @@ -681,7 +675,7 @@ def run_all_tests(self, mode=SearchMode.MIXTURE): print("Using direct test mode") try: direct_analyzer = DirectSearchMemoriesAnalyzer() - direct_analyzer.run_all_tests(mode=SearchMode.MIXTURE) + direct_analyzer.run_all_tests(mode=SearchMode.FINE) except Exception as e: print(f"Direct test mode failed: {e}") import traceback @@ -711,6 +705,6 @@ def run_all_tests(self, mode=SearchMode.MIXTURE): user_id="test_user_id", mem_cube_id="test_mem_cube_id", query="What are some good places to celebrate New Year's Eve in Shanghai?", - top=50, + top_k=10, ) print("Search result:", search_result) diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py index cf0b8f1dd..49a382ce6 100644 --- a/src/memos/mem_scheduler/analyzer/eval_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py @@ -14,10 +14,7 @@ from openai import OpenAI -from memos.api.routers.server_router import mem_scheduler from memos.log import get_logger -from memos.memories.textual.item import TextualMemoryMetadata -from memos.memories.textual.tree import TextualMemoryItem FILE_PATH = Path(__file__).absolute() @@ -143,1106 +140,6 @@ def extract_bad_cases(self, judged_file: str, search_results_file: str) -> list[ logger.info(f"Extracted {len(bad_cases)} bad cases") return bad_cases - def analyze_memory_sufficiency( - self, query: str, golden_answer: str, memories: str - ) -> dict[str, Any]: - """ - Use LLM to analyze whether memories contain sufficient information to answer the golden answer. - - Args: - query: The original query - golden_answer: The correct answer - memories: The memory context - - Returns: - Analysis result containing sufficiency judgment and relevant memory indices - """ - prompt = f""" -You are an expert analyst tasked with determining whether the provided memories contain sufficient information to answer a specific question correctly. - -**Question:** {query} - -**Golden Answer (Correct Answer):** {golden_answer} - -**Available Memories:** -{memories} - -**Task:** -1. Analyze whether the memories contain enough information to derive the golden answer -2. Identify which specific memory entries (if any) contain relevant information -3. Provide a clear judgment: True if sufficient, False if insufficient - -**Response Format (JSON):** -{{ - "sufficient": true/false, - "confidence": 0.0-1.0, - "relevant_memories": ["memory_1", "memory_2", ...], - "reasoning": "Detailed explanation of your analysis", - "missing_information": "What key information is missing (if insufficient)" -}} - -**Guidelines:** -- Be strict in your evaluation - only mark as sufficient if the memories clearly contain the information needed -- Consider both direct and indirect information that could lead to the golden answer -- Pay attention to dates, names, events, and specific details -- If information is ambiguous or requires significant inference, lean towards insufficient -""" - - try: - response = self.openai_client.chat.completions.create( - model=self.openai_model, - messages=[ - { - "role": "system", - "content": "You are a precise analyst who evaluates information sufficiency.", - }, - {"role": "user", "content": prompt}, - ], - temperature=0.1, - max_tokens=1000, - ) - - content = response.choices[0].message.content.strip() - - # Try to parse JSON response - try: - # Remove markdown code blocks if present - if content.startswith("```json"): - content = content[7:] - if content.endswith("```"): - content = content[:-3] - content = content.strip() - - analysis = json.loads(content) - return analysis - - except json.JSONDecodeError: - logger.warning(f"Failed to parse LLM response as JSON: {content}") - return { - "sufficient": False, - "confidence": 0.0, - "relevant_memories": [], - "reasoning": f"Failed to parse LLM response: {content}", - "missing_information": "Analysis failed", - } - - except Exception as e: - logger.error(f"Error in LLM analysis: {e}") - return { - "sufficient": False, - "confidence": 0.0, - "relevant_memories": [], - "reasoning": f"Error occurred: {e!s}", - "missing_information": "Analysis failed due to error", - } - - def process_memories_with_llm( - self, memories: str, query: str, processing_type: str = "summarize" - ) -> dict[str, Any]: - """ - Use LLM to process memories for better question answering. - - Args: - memories: The raw memory content - query: The query that will be answered using these memories - processing_type: Type of processing ("summarize", "restructure", "enhance") - - Returns: - Dictionary containing processed memories and processing metadata - """ - if processing_type == "summarize": - prompt = f""" -You are an expert at summarizing and organizing information to help answer specific questions. - -**Target Question:** {query} - -**Raw Memories:** -{memories} - -**Task:** -Summarize and organize the above memories in a way that would be most helpful for answering the target question. Focus on: -1. Key facts and information relevant to the question -2. Important relationships and connections -3. Chronological or logical organization where applicable -4. Remove redundant or irrelevant information - -**Processed Memories:** -""" - elif processing_type == "restructure": - prompt = f""" -You are an expert at restructuring information to optimize question answering. - -**Target Question:** {query} - -**Raw Memories:** -{memories} - -**Task:** -Restructure the above memories into a clear, logical format that directly supports answering the target question. Organize by: -1. Most relevant information first -2. Supporting details and context -3. Clear categorization of different types of information -4. Logical flow that leads to the answer - -**Restructured Memories:** -""" - elif processing_type == "enhance": - prompt = f""" -You are an expert at enhancing information by adding context and making connections. - -**Target Question:** {query} - -**Raw Memories:** -{memories} - -**Task:** -Enhance the above memories by: -1. Making implicit connections explicit -2. Adding relevant context that helps answer the question -3. Highlighting key relationships between different pieces of information -4. Organizing information in a question-focused manner - -**Enhanced Memories:** -""" - else: - raise ValueError(f"Unknown processing_type: {processing_type}") - - try: - response = self.openai_client.chat.completions.create( - model=self.openai_model, - messages=[ - { - "role": "system", - "content": "You are an expert information processor who optimizes content for question answering.", - }, - {"role": "user", "content": prompt}, - ], - temperature=0.3, - max_tokens=2000, - ) - - processed_memories = response.choices[0].message.content.strip() - - return { - "processed_memories": processed_memories, - "processing_type": processing_type, - "original_length": len(memories), - "processed_length": len(processed_memories), - "compression_ratio": len(processed_memories) / len(memories) - if len(memories) > 0 - else 0, - } - - except Exception as e: - logger.error(f"Error in memory processing: {e}") - return { - "processed_memories": memories, # Fallback to original - "processing_type": processing_type, - "original_length": len(memories), - "processed_length": len(memories), - "compression_ratio": 1.0, - "error": str(e), - } - - def generate_answer_with_memories( - self, query: str, memories: str, memory_type: str = "original" - ) -> dict[str, Any]: - """ - Generate an answer to the query using the provided memories. - - Args: - query: The question to answer - memories: The memory content to use - memory_type: Type of memories ("original", "processed") - - Returns: - Dictionary containing the generated answer and metadata - """ - prompt = f""" - You are a knowledgeable and helpful AI assistant. - - # CONTEXT: - You have access to memories from two speakers in a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer. - 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth. - 3. If the question asks about a specific event or fact, look for direct evidence in the memories. - 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description). - 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years in your final answer. - 7. Do not confuse character names mentioned in memories with the actual users who created them. - 8. The answer must be brief (under 5-6 words) and direct, with no extra description. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question. - 2. Synthesize findings from multiple memories if a single entry is insufficient. - 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events. - 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation. - 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). - 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. - 7. Ensure your final answer is specific and avoids vague time references. - - {memories} - - Question: {query} - - Answer: -""" - - try: - response = self.openai_client.chat.completions.create( - model=self.openai_model, - messages=[ - { - "role": "system", - "content": "You are a precise assistant who answers questions based only on provided information.", - }, - {"role": "user", "content": prompt}, - ], - temperature=0.1, - max_tokens=1000, - ) - - answer = response.choices[0].message.content.strip() - - return { - "answer": answer, - "memory_type": memory_type, - "query": query, - "memory_length": len(memories), - "answer_length": len(answer), - } - - except Exception as e: - logger.error(f"Error in answer generation: {e}") - return { - "answer": f"Error generating answer: {e!s}", - "memory_type": memory_type, - "query": query, - "memory_length": len(memories), - "answer_length": 0, - "error": str(e), - } - - def compare_answer_quality( - self, query: str, golden_answer: str, original_answer: str, processed_answer: str - ) -> dict[str, Any]: - """ - Compare the quality of answers generated from original vs processed memories. - - Args: - query: The original query - golden_answer: The correct/expected answer - original_answer: Answer generated from original memories - processed_answer: Answer generated from processed memories - - Returns: - Dictionary containing comparison results - """ - prompt = f""" -You are an expert evaluator comparing the quality of two answers against a golden standard. - -**Question:** {query} - -**Golden Answer (Correct):** {golden_answer} - -**Answer A (Original Memories):** {original_answer} - -**Answer B (Processed Memories):** {processed_answer} - -**Task:** -Compare both answers against the golden answer and evaluate: -1. Accuracy: How correct is each answer? -2. Completeness: How complete is each answer? -3. Relevance: How relevant is each answer to the question? -4. Clarity: How clear and well-structured is each answer? - -**Response Format (JSON):** -{{ - "original_scores": {{ - "accuracy": 0.0-1.0, - "completeness": 0.0-1.0, - "relevance": 0.0-1.0, - "clarity": 0.0-1.0, - "overall": 0.0-1.0 - }}, - "processed_scores": {{ - "accuracy": 0.0-1.0, - "completeness": 0.0-1.0, - "relevance": 0.0-1.0, - "clarity": 0.0-1.0, - "overall": 0.0-1.0 - }}, - "winner": "original|processed|tie", - "improvement": 0.0-1.0, - "reasoning": "Detailed explanation of the comparison" -}} -""" - - try: - response = self.openai_client.chat.completions.create( - model=self.openai_model, - messages=[ - { - "role": "system", - "content": "You are an expert evaluator who compares answer quality objectively.", - }, - {"role": "user", "content": prompt}, - ], - temperature=0.1, - max_tokens=1500, - ) - - content = response.choices[0].message.content.strip() - - # Try to parse JSON response - try: - if content.startswith("```json"): - content = content[7:] - if content.endswith("```"): - content = content[:-3] - content = content.strip() - - comparison = json.loads(content) - return comparison - - except json.JSONDecodeError: - logger.warning(f"Failed to parse comparison response as JSON: {content}") - return { - "original_scores": { - "accuracy": 0.5, - "completeness": 0.5, - "relevance": 0.5, - "clarity": 0.5, - "overall": 0.5, - }, - "processed_scores": { - "accuracy": 0.5, - "completeness": 0.5, - "relevance": 0.5, - "clarity": 0.5, - "overall": 0.5, - }, - "winner": "tie", - "improvement": 0.0, - "reasoning": f"Failed to parse comparison: {content}", - } - - except Exception as e: - logger.error(f"Error in answer comparison: {e}") - return { - "original_scores": { - "accuracy": 0.0, - "completeness": 0.0, - "relevance": 0.0, - "clarity": 0.0, - "overall": 0.0, - }, - "processed_scores": { - "accuracy": 0.0, - "completeness": 0.0, - "relevance": 0.0, - "clarity": 0.0, - "overall": 0.0, - }, - "winner": "tie", - "improvement": 0.0, - "reasoning": f"Error occurred: {e!s}", - } - - def analyze_memory_processing_effectiveness( - self, - bad_cases: list[dict[str, Any]], - processing_types: list[str] | None = None, - ) -> dict[str, Any]: - """ - Analyze the effectiveness of different memory processing techniques. - - Args: - bad_cases: List of bad cases to analyze - processing_types: List of processing types to test - - Returns: - Dictionary containing comprehensive analysis results - """ - if processing_types is None: - processing_types = ["summarize", "restructure", "enhance"] - results = {"processing_results": [], "statistics": {}, "processing_types": processing_types} - - for i, case in enumerate(bad_cases): - logger.info(f"Processing case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") - - case_result = { - "case_id": i, - "query": case["query"], - "golden_answer": case["golden_answer"], - "original_memories": case["memories"], - "processing_results": {}, - } - - # Generate answer with original memories - original_answer_result = self.generate_answer_with_memories( - case["query"], case["memories"], "original" - ) - case_result["original_answer"] = original_answer_result - - # Test each processing type - for processing_type in processing_types: - logger.info(f" Testing {processing_type} processing...") - - # Process memories - processing_result = self.process_memories_with_llm( - case["memories"], case["query"], processing_type - ) - - # Generate answer with processed memories - processed_answer_result = self.generate_answer_with_memories( - case["query"], - processing_result["processed_memories"], - f"processed_{processing_type}", - ) - - # Compare answer quality - comparison_result = self.compare_answer_quality( - case["query"], - case["golden_answer"], - original_answer_result["answer"], - processed_answer_result["answer"], - ) - - case_result["processing_results"][processing_type] = { - "processing": processing_result, - "answer": processed_answer_result, - "comparison": comparison_result, - } - - results["processing_results"].append(case_result) - - # Calculate statistics - self._calculate_processing_statistics(results) - - return results - - def _calculate_processing_statistics(self, results: dict[str, Any]) -> None: - """Calculate statistics for processing effectiveness analysis.""" - processing_types = results["processing_types"] - processing_results = results["processing_results"] - - if not processing_results: - results["statistics"] = {} - return - - stats = {"total_cases": len(processing_results), "processing_type_stats": {}} - - for processing_type in processing_types: - type_stats = { - "wins": 0, - "ties": 0, - "losses": 0, - "avg_improvement": 0.0, - "avg_compression_ratio": 0.0, - "avg_scores": { - "accuracy": 0.0, - "completeness": 0.0, - "relevance": 0.0, - "clarity": 0.0, - "overall": 0.0, - }, - } - - valid_cases = [] - for case in processing_results: - if processing_type in case["processing_results"]: - result = case["processing_results"][processing_type] - comparison = result["comparison"] - - # Count wins/ties/losses - if comparison["winner"] == "processed": - type_stats["wins"] += 1 - elif comparison["winner"] == "tie": - type_stats["ties"] += 1 - else: - type_stats["losses"] += 1 - - valid_cases.append(result) - - if valid_cases: - # Calculate averages - type_stats["avg_improvement"] = sum( - case["comparison"]["improvement"] for case in valid_cases - ) / len(valid_cases) - - type_stats["avg_compression_ratio"] = sum( - case["processing"]["compression_ratio"] for case in valid_cases - ) / len(valid_cases) - - # Calculate average scores - for score_type in type_stats["avg_scores"]: - type_stats["avg_scores"][score_type] = sum( - case["comparison"]["processed_scores"][score_type] for case in valid_cases - ) / len(valid_cases) - - # Calculate win rate - total_decisions = type_stats["wins"] + type_stats["ties"] + type_stats["losses"] - type_stats["win_rate"] = ( - type_stats["wins"] / total_decisions if total_decisions > 0 else 0.0 - ) - type_stats["success_rate"] = ( - (type_stats["wins"] + type_stats["ties"]) / total_decisions - if total_decisions > 0 - else 0.0 - ) - - stats["processing_type_stats"][processing_type] = type_stats - - results["statistics"] = stats - - def analyze_bad_cases(self, bad_cases: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Analyze all bad cases to determine memory sufficiency. - - Args: - bad_cases: List of bad cases to analyze - - Returns: - List of analyzed bad cases with sufficiency information - """ - analyzed_cases = [] - - for i, case in enumerate(bad_cases): - logger.info(f"Analyzing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") - - analysis = self.analyze_memory_sufficiency( - case["query"], case["golden_answer"], case["memories"] - ) - - # Add analysis results to the case - analyzed_case = case.copy() - analyzed_case.update( - { - "memory_analysis": analysis, - "has_sufficient_memories": analysis["sufficient"], - "analysis_confidence": analysis["confidence"], - "relevant_memory_count": len(analysis["relevant_memories"]), - } - ) - - analyzed_cases.append(analyzed_case) - - return analyzed_cases - - def collect_bad_cases(self, eval_result_dir: str | None = None) -> dict[str, Any]: - """ - Main method to collect and analyze bad cases from evaluation results. - - Args: - eval_result_dir: Directory containing evaluation results - - Returns: - Dictionary containing analysis results and statistics - """ - if eval_result_dir is None: - eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-072005-fast" - - judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") - search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") - - # Extract bad cases - bad_cases = self.extract_bad_cases(judged_file, search_results_file) - - if not bad_cases: - logger.warning("No bad cases found") - return {"bad_cases": [], "statistics": {}} - - # Analyze bad cases - analyzed_cases = self.analyze_bad_cases(bad_cases) - - # Calculate statistics - total_cases = len(analyzed_cases) - sufficient_cases = sum( - 1 for case in analyzed_cases if case.get("has_sufficient_memories", False) - ) - insufficient_cases = total_cases - sufficient_cases - - avg_confidence = ( - sum(case["analysis_confidence"] for case in analyzed_cases) / total_cases - if total_cases > 0 - else 0 - ) - avg_relevant_memories = ( - sum(case["relevant_memory_count"] for case in analyzed_cases) / total_cases - if total_cases > 0 - else 0 - ) - - statistics = { - "total_bad_cases": total_cases, - "sufficient_memory_cases": sufficient_cases, - "insufficient_memory_cases": insufficient_cases, - "sufficiency_rate": sufficient_cases / total_cases if total_cases > 0 else 0, - "average_confidence": avg_confidence, - "average_relevant_memories": avg_relevant_memories, - } - - # Save results - results = { - "bad_cases": analyzed_cases, - "statistics": statistics, - "metadata": { - "eval_result_dir": eval_result_dir, - "judged_file": judged_file, - "search_results_file": search_results_file, - "analysis_model": self.openai_model, - }, - } - - output_file = self.output_dir / "bad_cases_analysis.json" - with open(output_file, "w", encoding="utf-8") as f: - json.dump(results, f, indent=2, ensure_ascii=False) - - logger.info(f"Analysis complete. Results saved to: {output_file}") - logger.info(f"Statistics: {statistics}") - - return results - - def _parse_json_response(self, response_text: str) -> dict: - """ - Parse JSON response from LLM, handling various formats and potential errors. - - Args: - response_text: Raw response text from LLM - - Returns: - Parsed JSON dictionary - - Raises: - ValueError: If JSON cannot be parsed - """ - import re - - # Try to extract JSON from response text - # Look for JSON blocks between ```json and ``` or just {} blocks - json_patterns = [r"```json\s*(\{.*?\})\s*```", r"```\s*(\{.*?\})\s*```", r"(\{.*\})"] - - for pattern in json_patterns: - matches = re.findall(pattern, response_text, re.DOTALL) - if matches: - json_str = matches[0].strip() - try: - return json.loads(json_str) - except json.JSONDecodeError: - continue - - # If no JSON pattern found, try parsing the entire response - try: - return json.loads(response_text.strip()) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON response: {response_text[:200]}...") - raise ValueError(f"Invalid JSON response: {e!s}") from e - - def filter_memories_with_llm(self, memories: list[str], query: str) -> tuple[list[str], bool]: - """ - Use LLM to filter memories based on relevance to the query. - - Args: - memories: List of memory strings - query: Query to filter memories against - - Returns: - Tuple of (filtered_memories, success_flag) - """ - if not memories: - return [], True - - # Build prompt for memory filtering - memories_text = "\n".join([f"{i + 1}. {memory}" for i, memory in enumerate(memories)]) - - prompt = f"""You are a memory filtering system. Given a query and a list of memories, identify which memories are relevant and non-redundant for answering the query. - -Query: {query} - -Memories: -{memories_text} - -Please analyze each memory and return a JSON response with the following format: -{{ - "relevant_memory_indices": [list of indices (1-based) of memories that are relevant to the query], - "reasoning": "Brief explanation of your filtering decisions" -}} - -Only include memories that are directly relevant to answering the query. Remove redundant or unrelated memories.""" - - try: - response = self.openai_client.chat.completions.create( - model=self.openai_model, - messages=[{"role": "user", "content": prompt}], - temperature=0.1, - ) - - response_text = response.choices[0].message.content - - # Extract JSON from response - result = self._parse_json_response(response_text) - - if "relevant_memory_indices" in result: - relevant_indices = result["relevant_memory_indices"] - filtered_memories = [] - - for idx in relevant_indices: - if 1 <= idx <= len(memories): - filtered_memories.append(memories[idx - 1]) - - logger.info(f"Filtered memories: {len(memories)} -> {len(filtered_memories)}") - return filtered_memories, True - else: - logger.warning("Invalid response format from memory filtering LLM") - return memories, False - - except Exception as e: - logger.error(f"Error in memory filtering: {e}") - return memories, False - - def evaluate_answer_ability_with_llm(self, query: str, memories: list[str]) -> bool: - """ - Use LLM to evaluate whether the given memories can answer the query. - - Args: - query: Query to evaluate - memories: List of memory strings - - Returns: - Boolean indicating whether memories can answer the query - """ - if not memories: - return False - - memories_text = "\n".join([f"- {memory}" for memory in memories]) - - prompt = f"""You are an answer ability evaluator. Given a query and a list of memories, determine whether the memories contain sufficient information to answer the query. - -Query: {query} - -Available Memories: -{memories_text} - -Please analyze the memories and return a JSON response with the following format: -{{ - "can_answer": true/false, - "confidence": 0.0-1.0, - "reasoning": "Brief explanation of your decision" -}} - -Consider whether the memories contain the specific information needed to provide a complete and accurate answer to the query.""" - - try: - response = self.openai_client.chat.completions.create( - model=self.openai_model, - messages=[{"role": "user", "content": prompt}], - temperature=0.1, - ) - - response_text = response.choices[0].message.content - result = self._parse_json_response(response_text) - - if "can_answer" in result: - can_answer = result["can_answer"] - confidence = result.get("confidence", 0.5) - reasoning = result.get("reasoning", "No reasoning provided") - - logger.info( - f"Answer ability evaluation: {can_answer} (confidence: {confidence:.2f}) - {reasoning}" - ) - return can_answer - else: - logger.warning("Invalid response format from answer ability evaluation") - return False - - except Exception as e: - logger.error(f"Error in answer ability evaluation: {e}") - return False - - def memory_llm_processing_analysis( - self, bad_cases: list[dict[str, Any]], use_llm_filtering: bool = True - ) -> list[dict[str, Any]]: - """ - Analyze bad cases by processing memories with LLM filtering and testing answer ability. - - This method: - 1. Parses memory strings from bad cases - 2. Uses LLM to filter unrelated and redundant memories - 3. Tests whether processed memories can help answer questions correctly - 4. Compares results before and after LLM processing - - Args: - bad_cases: List of bad cases to analyze - use_llm_filtering: Whether to use LLM filtering - - Returns: - List of analyzed bad cases with LLM processing results - """ - analyzed_cases = [] - - for i, case in enumerate(bad_cases): - logger.info(f"Processing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") - - try: - # Parse memory string - memories_text = case.get("memories", "") - if not memories_text: - logger.warning(f"No memories found for case {i + 1}") - analyzed_case = case.copy() - analyzed_case.update( - { - "llm_processing_analysis": { - "error": "No memories available", - "original_memories_count": 0, - "processed_memories_count": 0, - "can_answer_with_original": False, - "can_answer_with_processed": False, - "processing_improved_answer": False, - } - } - ) - analyzed_cases.append(analyzed_case) - continue - - # Split memories by lines - memory_lines = [line.strip() for line in memories_text.split("\n") if line.strip()] - original_memories = [line for line in memory_lines if line] - - logger.info(f"Parsed {len(original_memories)} memories from text") - - # Test answer ability with original memories - can_answer_original = self.evaluate_answer_ability_with_llm( - query=case["query"], memories=original_memories - ) - - # Process memories with LLM filtering if enabled - processed_memories = original_memories - processing_success = False - - if use_llm_filtering and len(original_memories) > 0: - processed_memories, processing_success = self.filter_memories_with_llm( - memories=original_memories, query=case["query"] - ) - logger.info( - f"LLM filtering: {len(original_memories)} -> {len(processed_memories)} memories, success: {processing_success}" - ) - - # Test answer ability with processed memories - can_answer_processed = self.evaluate_answer_ability_with_llm( - query=case["query"], memories=processed_memories - ) - - # Determine if processing improved answer ability - processing_improved = can_answer_processed and not can_answer_original - - # Create analysis result - llm_analysis = { - "processing_success": processing_success, - "original_memories_count": len(original_memories), - "processed_memories_count": len(processed_memories), - "memories_removed_count": len(original_memories) - len(processed_memories), - "can_answer_with_original": can_answer_original, - "can_answer_with_processed": can_answer_processed, - "processing_improved_answer": processing_improved, - "original_memories": original_memories, - "processed_memories": processed_memories, - } - - # Add analysis to case - analyzed_case = case.copy() - analyzed_case["llm_processing_analysis"] = llm_analysis - - logger.info( - f"Case {i + 1} analysis complete: " - f"Original: {can_answer_original}, " - f"Processed: {can_answer_processed}, " - f"Improved: {processing_improved}" - ) - - except Exception as e: - logger.error(f"Error processing case {i + 1}: {e}") - analyzed_case = case.copy() - analyzed_case["llm_processing_analysis"] = { - "error": str(e), - "processing_success": False, - "original_memories_count": 0, - "processed_memories_count": 0, - "can_answer_with_original": False, - "can_answer_with_processed": False, - "processing_improved_answer": False, - } - - analyzed_cases.append(analyzed_case) - - return analyzed_cases - - def scheduler_mem_process(self, query, memories): - from memos.mem_scheduler.utils.misc_utils import extract_list_items_in_answer - - _memories = [] - for mem in memories: - mem_item = TextualMemoryItem(memory=mem, metadata=TextualMemoryMetadata()) - _memories.append(mem_item) - prompt = mem_scheduler.retriever._build_enhancement_prompt( - query_history=[query], batch_texts=memories - ) - logger.debug( - f"[Enhance][batch={0}] Prompt (first 200 chars, len={len(prompt)}): {prompt[:200]}..." - ) - - response = mem_scheduler.retriever.process_llm.generate( - [{"role": "user", "content": prompt}] - ) - logger.debug(f"[Enhance][batch={0}] Response (first 200 chars): {response[:200]}...") - - processed_results = extract_list_items_in_answer(response) - - return { - "processed_memories": processed_results, - "processing_type": "enhance", - "original_length": len("\n".join(memories)), - "processed_length": len("\n".join(processed_results)), - "compression_ratio": len("\n".join(processed_results)) / len("\n".join(memories)) - if len(memories) > 0 - else 0, - } - - def analyze_bad_cases_with_llm_processing( - self, - bad_cases: list[dict[str, Any]], - save_results: bool = True, - output_file: str | None = None, - ) -> dict[str, Any]: - """ - Comprehensive analysis of bad cases with LLM memory processing. - - This method performs a complete analysis including: - 1. Basic bad case analysis - 2. LLM memory processing analysis - 3. Statistical summary of improvements - 4. Detailed reporting - - Args: - bad_cases: List of bad cases to analyze - save_results: Whether to save results to file - output_file: Optional output file path - - Returns: - Dictionary containing comprehensive analysis results - """ - from datetime import datetime - - logger.info( - f"Starting comprehensive analysis of {len(bad_cases)} bad cases with LLM processing" - ) - - # Perform LLM memory processing analysis - analyzed_cases = self.memory_llm_processing_analysis( - bad_cases=bad_cases, use_llm_filtering=True - ) - - # Calculate statistics - total_cases = len(analyzed_cases) - successful_processing = 0 - improved_cases = 0 - original_answerable = 0 - processed_answerable = 0 - total_memories_before = 0 - total_memories_after = 0 - - for case in analyzed_cases: - llm_analysis = case.get("llm_processing_analysis", {}) - - if llm_analysis.get("processing_success", False): - successful_processing += 1 - - if llm_analysis.get("processing_improved_answer", False): - improved_cases += 1 - - if llm_analysis.get("can_answer_with_original", False): - original_answerable += 1 - - if llm_analysis.get("can_answer_with_processed", False): - processed_answerable += 1 - - total_memories_before += llm_analysis.get("original_memories_count", 0) - total_memories_after += llm_analysis.get("processed_memories_count", 0) - - # Calculate improvement metrics - processing_success_rate = successful_processing / total_cases if total_cases > 0 else 0 - improvement_rate = improved_cases / total_cases if total_cases > 0 else 0 - original_answer_rate = original_answerable / total_cases if total_cases > 0 else 0 - processed_answer_rate = processed_answerable / total_cases if total_cases > 0 else 0 - memory_reduction_rate = ( - (total_memories_before - total_memories_after) / total_memories_before - if total_memories_before > 0 - else 0 - ) - - # Create comprehensive results - results = { - "analysis_metadata": { - "total_cases_analyzed": total_cases, - "analysis_timestamp": datetime.now().isoformat(), - "llm_model_used": self.openai_model, - }, - "processing_statistics": { - "successful_processing_count": successful_processing, - "processing_success_rate": processing_success_rate, - "cases_with_improvement": improved_cases, - "improvement_rate": improvement_rate, - "original_answerable_cases": original_answerable, - "original_answer_rate": original_answer_rate, - "processed_answerable_cases": processed_answerable, - "processed_answer_rate": processed_answer_rate, - "answer_rate_improvement": processed_answer_rate - original_answer_rate, - }, - "memory_statistics": { - "total_memories_before_processing": total_memories_before, - "total_memories_after_processing": total_memories_after, - "memories_removed": total_memories_before - total_memories_after, - "memory_reduction_rate": memory_reduction_rate, - "average_memories_per_case_before": total_memories_before / total_cases - if total_cases > 0 - else 0, - "average_memories_per_case_after": total_memories_after / total_cases - if total_cases > 0 - else 0, - }, - "analyzed_cases": analyzed_cases, - } - - # Log summary - logger.info("LLM Processing Analysis Summary:") - logger.info(f" - Total cases: {total_cases}") - logger.info(f" - Processing success rate: {processing_success_rate:.2%}") - logger.info(f" - Cases with improvement: {improved_cases} ({improvement_rate:.2%})") - logger.info(f" - Original answer rate: {original_answer_rate:.2%}") - logger.info(f" - Processed answer rate: {processed_answer_rate:.2%}") - logger.info( - f" - Answer rate improvement: {processed_answer_rate - original_answer_rate:.2%}" - ) - logger.info(f" - Memory reduction: {memory_reduction_rate:.2%}") - - # Save results if requested - if save_results: - if output_file is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_file = f"llm_processing_analysis_{timestamp}.json" - - try: - with open(output_file, "w", encoding="utf-8") as f: - json.dump(results, f, indent=2, ensure_ascii=False) - logger.info(f"Analysis results saved to: {output_file}") - except Exception as e: - logger.error(f"Failed to save results to {output_file}: {e}") - - return results - def main(version_name="ct-1111"): """Main test function.""" @@ -1254,7 +151,7 @@ def main(version_name="ct-1111"): print("Analyzer initialized") # Test file paths - eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-{version_name}-locomo" + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-{version_name}" judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") @@ -1319,4 +216,4 @@ def main(version_name="ct-1111"): if __name__ == "__main__": - main() + main(version_name="ct-1118") diff --git a/src/memos/mem_scheduler/analyzer/memory_processing.py b/src/memos/mem_scheduler/analyzer/memory_processing.py deleted file mode 100644 index b692341c2..000000000 --- a/src/memos/mem_scheduler/analyzer/memory_processing.py +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for memory processing functionality in eval_analyzer.py - -This script demonstrates how to use the new LLM memory processing features -to analyze and improve memory-based question answering. -""" - -import json -import os -import sys - -from pathlib import Path -from typing import Any - -from memos.log import get_logger -from memos.mem_scheduler.analyzer.eval_analyzer import EvalAnalyzer - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent # Go up to project root -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - - -logger = get_logger(__name__) - - -def create_sample_bad_cases() -> list[dict[str, Any]]: - """Create sample bad cases for testing memory processing.""" - return [ - { - "query": "What is the capital of France?", - "golden_answer": "Paris", - "memories": """ - Memory 1: France is a country in Western Europe. - Memory 2: The Eiffel Tower is located in Paris. - Memory 3: Paris is known for its art museums and fashion. - Memory 4: French cuisine is famous worldwide. - Memory 5: The Seine River flows through Paris. - """, - }, - { - "query": "When was the iPhone first released?", - "golden_answer": "June 29, 2007", - "memories": """ - Memory 1: Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne. - Memory 2: The iPhone was announced by Steve Jobs at the Macworld Conference & Expo on January 9, 2007. - Memory 3: The iPhone went on sale on June 29, 2007. - Memory 4: The original iPhone had a 3.5-inch screen. - Memory 5: Apple's stock price increased significantly after the iPhone launch. - """, - }, - { - "query": "What is photosynthesis?", - "golden_answer": "Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.", - "memories": """ - Memory 1: Plants are living organisms that need sunlight to grow. - Memory 2: Chlorophyll is the green pigment in plants. - Memory 3: Plants take in carbon dioxide from the air. - Memory 4: Water is absorbed by plant roots from the soil. - Memory 5: Oxygen is released by plants during the day. - Memory 6: Glucose is a type of sugar that plants produce. - """, - }, - ] - - -def memory_processing(bad_cases): - """ - Test the memory processing functionality with cover rate and acc rate analysis. - - This function analyzes: - 1. Cover rate: Whether memories contain all information needed to answer the query - 2. Acc rate: Whether processed memories can correctly answer the query - """ - print("🧪 Testing Memory Processing Functionality with Cover Rate & Acc Rate Analysis") - print("=" * 80) - - # Initialize analyzer - analyzer = EvalAnalyzer() - - print(f"📊 Testing with {len(bad_cases)} sample cases") - print() - - # Initialize counters for real-time statistics - total_cases = 0 - cover_count = 0 # Cases where memories cover all needed information - acc_count = 0 # Cases where processed memories can correctly answer - - # Process each case - for i, case in enumerate(bad_cases): - total_cases += 1 - - # Safely handle query display - query_display = str(case.get("query", "Unknown query")) - print(f"🔍 Case {i + 1}/{len(bad_cases)}: {query_display}...") - - # Safely handle golden_answer display (convert to string if needed) - golden_answer = case.get("golden_answer", "Unknown answer") - golden_answer_str = str(golden_answer) if golden_answer is not None else "Unknown answer" - print(f"📝 Golden Answer: {golden_answer_str}") - print() - - # Step 1: Analyze if memories contain sufficient information (Cover Rate) - print(" 📋 Step 1: Analyzing memory coverage...") - coverage_analysis = analyzer.analyze_memory_sufficiency( - case["query"], - golden_answer_str, # Use the string version - case["memories"], - ) - - has_coverage = coverage_analysis.get("sufficient", False) - if has_coverage: - cover_count += 1 - - print(f" ✅ Memory Coverage: {'SUFFICIENT' if has_coverage else 'INSUFFICIENT'}") - print(f" 🎯 Confidence: {coverage_analysis.get('confidence', 0):.2f}") - print(f" 💭 Reasoning: {coverage_analysis.get('reasoning', 'N/A')}...") - if not has_coverage: - print( - f" ❌ Missing Info: {coverage_analysis.get('missing_information', 'N/A')[:100]}..." - ) - continue - print() - - # Step 2: Process memories and test answer ability (Acc Rate) - print(" 🔄 Step 2: Processing memories and testing answer ability...") - - processing_result = analyzer.scheduler_mem_process( - query=case["query"], - memories=case["memories"], - ) - print(f"Original Memories: {case['memories']}") - print(f"Processed Memories: {processing_result['processed_memories']}") - print(f" 📏 Compression ratio: {processing_result['compression_ratio']:.2f}") - print(f" 📄 Processed memories length: {processing_result['processed_length']} chars") - - # Generate answer with processed memories - answer_result = analyzer.generate_answer_with_memories( - case["query"], processing_result["processed_memories"], "processed_enhanced" - ) - - # Evaluate if the generated answer is correct - print(" 🎯 Step 3: Evaluating answer correctness...") - answer_evaluation = analyzer.compare_answer_quality( - case["query"], - golden_answer_str, # Use the string version - "No original answer available", # We don't have original answer - answer_result["answer"], - ) - - # Determine if processed memories can correctly answer (simplified logic) - processed_accuracy = answer_evaluation.get("processed_scores", {}).get("accuracy", 0) - can_answer_correctly = processed_accuracy >= 0.7 # Threshold for "correct" answer - - if can_answer_correctly: - acc_count += 1 - - print(f" 💬 Generated Answer: {answer_result['answer']}...") - print( - f" ✅ Answer Accuracy: {'CORRECT' if can_answer_correctly else 'INCORRECT'} (score: {processed_accuracy:.2f})" - ) - print() - - # Calculate and print real-time rates - current_cover_rate = cover_count / total_cases - current_acc_rate = acc_count / total_cases - - print(" 📊 REAL-TIME STATISTICS:") - print(f" 🎯 Cover Rate: {current_cover_rate:.2%} ({cover_count}/{total_cases})") - print(f" ✅ Acc Rate: {current_acc_rate:.2%} ({acc_count}/{total_cases})") - print() - - print("-" * 80) - print() - - # Final summary - print("🏁 FINAL ANALYSIS SUMMARY") - print("=" * 80) - print(f"📊 Total Cases Processed: {total_cases}") - print(f"🎯 Final Cover Rate: {cover_count / total_cases:.2%} ({cover_count}/{total_cases})") - print(f" - Cases with sufficient memory coverage: {cover_count}") - print(f" - Cases with insufficient memory coverage: {total_cases - cover_count}") - print() - print(f"✅ Final Acc Rate: {acc_count / total_cases:.2%} ({acc_count}/{total_cases})") - print(f" - Cases where processed memories can answer correctly: {acc_count}") - print(f" - Cases where processed memories cannot answer correctly: {total_cases - acc_count}") - print() - - # Additional insights - if cover_count > 0: - effective_processing_rate = acc_count / cover_count if cover_count > 0 else 0 - print(f"🔄 Processing Effectiveness: {effective_processing_rate:.2%}") - print( - f" - Among cases with sufficient coverage, {effective_processing_rate:.1%} can be answered correctly after processing" - ) - - print("=" * 80) - - -def load_real_bad_cases(file_path: str) -> list[dict[str, Any]]: - """Load real bad cases from JSON file.""" - print(f"📂 Loading bad cases from: {file_path}") - - with open(file_path, encoding="utf-8") as f: - data = json.load(f) - - bad_cases = data.get("bad_cases", []) - print(f"✅ Loaded {len(bad_cases)} bad cases") - - return bad_cases - - -def main(): - """Main test function.""" - print("🚀 Memory Processing Test Suite") - print("=" * 60) - print() - - # Check if OpenAI API key is set - if not os.getenv("OPENAI_API_KEY"): - print("⚠️ Warning: OPENAI_API_KEY not found in environment variables") - print(" Please set your OpenAI API key to run the tests") - return - - try: - bad_cases_file = f"{BASE_DIR}/tmp/eval_analyzer/bad_cases_extraction_only.json" - bad_cases = load_real_bad_cases(bad_cases_file) - - print(f"✅ Created {len(bad_cases)} sample bad cases") - print() - - # Run memory processing tests - memory_processing(bad_cases) - - print("✅ All tests completed successfully!") - - except Exception as e: - print(f"❌ Test failed with error: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py index 3d0235871..6638fa2f5 100644 --- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py +++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py @@ -9,13 +9,13 @@ from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_MAX_QUERY_KEY_WORDS, - UserID, ) from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem if TYPE_CHECKING: from memos.memories.textual.tree import TextualMemoryItem + from memos.types import UserID logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index a53e19191..f641fc442 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -6,7 +6,7 @@ from collections.abc import Callable from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from sqlalchemy.engine import Engine @@ -33,9 +33,7 @@ DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, STARTUP_BY_PROCESS, - MemCubeID, TreeTextMemory_SEARCH_METHOD, - UserID, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, @@ -43,12 +41,15 @@ ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue +from memos.mem_scheduler.utils import metrics from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule from memos.memories.activation.kv import KVCacheMemory @@ -56,9 +57,15 @@ from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE +from memos.types.general_types import ( + MemCubeID, + UserID, +) if TYPE_CHECKING: + import redis + from memos.reranker.http_bge import HTTPBGEReranker @@ -124,12 +131,18 @@ def __init__(self, config: BaseSchedulerConfig): self.monitor: SchedulerGeneralMonitor | None = None self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None self.mem_reader = None # Will be set by MOSCore + self.status_tracker: TaskStatusTracker | None = None + self.metrics = metrics + self._monitor_thread = None self.dispatcher = SchedulerDispatcher( config=self.config, memos_message_queue=self.memos_message_queue, use_redis_queue=self.use_redis_queue, max_workers=self.thread_pool_max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch, + status_tracker=self.status_tracker, + metrics=self.metrics, + submit_web_logs=self._submit_web_logs, ) # other attributes @@ -152,6 +165,8 @@ def init_mem_cube( if searcher is None: self.searcher: Searcher = self.text_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + process_llm=self.process_llm, ) else: self.searcher = searcher @@ -162,11 +177,16 @@ def initialize_modules( process_llm: BaseLLM | None = None, db_engine: Engine | None = None, mem_reader=None, + redis_client: Union["redis.Redis", None] = None, ): if process_llm is None: process_llm = chat_llm try: + if redis_client: + self.status_tracker = TaskStatusTracker(redis_client) + if self.dispatcher: + self.dispatcher.status_tracker = self.status_tracker # initialize submodules self.chat_llm = chat_llm self.process_llm = process_llm @@ -300,6 +320,26 @@ def replace_working_memory( query_db_manager.sync_with_orm() query_history = query_db_manager.obj.get_queries_with_timesort() + + original_count = len(original_memory) + # Filter out memories tagged with "mode:fast" + filtered_original_memory = [] + for origin_mem in original_memory: + if "mode:fast" not in origin_mem.metadata.tags: + filtered_original_memory.append(origin_mem) + else: + logger.debug( + f"Filtered out memory - ID: {getattr(origin_mem, 'id', 'unknown')}, Tags: {origin_mem.metadata.tags}" + ) + # Calculate statistics + filtered_count = original_count - len(filtered_original_memory) + remaining_count = len(filtered_original_memory) + + logger.info( + f"Filtering complete. Removed {filtered_count} memories with tag 'mode:fast'. Remaining memories: {remaining_count}" + ) + original_memory = filtered_original_memory + memories_with_new_order, rerank_success_flag = ( self.retriever.process_and_rerank_memories( queries=query_history, @@ -532,6 +572,17 @@ def update_activation_memory_periodically( logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + if isinstance(messages, ScheduleMessageItem): + messages = [messages] + for message in messages: + self.metrics.task_enqueued(user_id=message.user_id, task_type=message.label) + if self.status_tracker: + self.status_tracker.task_submitted( + task_id=message.item_id, + user_id=message.user_id, + task_type=message.label, + mem_cube_id=message.mem_cube_id, + ) self.memos_message_queue.submit_messages(messages=messages) def _submit_web_logs( @@ -647,6 +698,8 @@ def _message_consumer(self) -> None: messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) if messages: + for msg in messages: + self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) try: import contextlib @@ -667,6 +720,26 @@ def _message_consumer(self) -> None: logger.error(f"Unexpected error in message consumer: {e!s}") time.sleep(self._consume_interval) # Prevent tight error loops + def _monitor_loop(self): + while self._running: + try: + q_sizes = self.memos_message_queue.qsize() + + for stream_key, queue_length in q_sizes.items(): + # Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}" + parts = stream_key.split(":") + if len(parts) >= 3: + user_id = parts[2] + self.metrics.update_queue_length(queue_length, user_id) + elif not self.use_redis_queue: # local queue + user_id = stream_key + self.metrics.update_queue_length(queue_length, user_id) + + except Exception as e: + logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) + + time.sleep(15) # 每 15 秒采样一次 + def start(self) -> None: """ Start the message consumer thread/process and initialize dispatcher resources. @@ -682,6 +755,16 @@ def start(self) -> None: ) self.start_consumer() + self.start_background_monitor() + + def start_background_monitor(self): + if self._monitor_thread and self._monitor_thread.is_alive(): + return + self._monitor_thread = ContextThread( + target=self._monitor_loop, daemon=True, name="SchedulerMetricsMonitor" + ) + self._monitor_thread.start() + logger.info("Scheduler metrics monitor thread started.") def start_consumer(self) -> None: """ @@ -769,6 +852,9 @@ def stop(self) -> None: # Stop consumer first self.stop_consumer() + if self._monitor_thread: + self._monitor_thread.join(timeout=2.0) + # Shutdown dispatcher if self.dispatcher: logger.info("Shutting down dispatcher...") @@ -851,169 +937,63 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di return result - def mem_scheduler_wait( - self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01 - ) -> bool: - """ - Uses EWMA throughput, detects leaked `unfinished_tasks`, and waits for dispatcher. - """ - deadline = time.monotonic() + timeout - - # --- helpers (local, no external deps) --- - def _unfinished() -> int: - """Prefer `unfinished_tasks`; fallback to `qsize()`.""" - try: - u = getattr(self.memos_message_queue, "unfinished_tasks", None) - if u is not None: - return int(u) - except Exception: - pass - try: - return int(self.memos_message_queue.qsize()) - except Exception: - return 0 - - def _fmt_eta(seconds: float | None) -> str: - """Format seconds to human-readable string.""" - if seconds is None or seconds != seconds or seconds == float("inf"): - return "unknown" - s = max(0, int(seconds)) - h, s = divmod(s, 3600) - m, s = divmod(s, 60) - if h > 0: - return f"{h:d}h{m:02d}m{s:02d}s" - if m > 0: - return f"{m:d}m{s:02d}s" - return f"{s:d}s" - - # --- EWMA throughput state (tasks/s) --- - alpha = 0.3 - rate = 0.0 - last_t = None # type: float | None - last_done = 0 - - # --- dynamic totals & stuck detection --- - init_unfinished = _unfinished() - done_total = 0 - last_unfinished = None - stuck_ticks = 0 - next_log = 0.0 - - while True: - # 1) read counters - curr_unfinished = _unfinished() - try: - qsz = int(self.memos_message_queue.qsize()) - except Exception: - qsz = -1 - - pend = run = 0 - stats_fn = getattr(self.dispatcher, "stats", None) - if self.enable_parallel_dispatch and self.dispatcher is not None and callable(stats_fn): - try: - st = ( - stats_fn() - ) # expected: {'pending':int,'running':int,'done':int?,'rate':float?} - run = int(st.get("running", 0)) - - except Exception: - pass - - if isinstance(self.memos_message_queue, SchedulerRedisQueue): + @staticmethod + def init_task_status(): + return { + "running": 0, + "remaining": 0, + "completed": 0, + } + + def get_tasks_status(self): + task_status = self.init_task_status() + memos_message_queue = self.memos_message_queue.memos_message_queue + if isinstance(memos_message_queue, SchedulerRedisQueue): + stream_keys = memos_message_queue.get_stream_keys( + stream_key_prefix=memos_message_queue.stream_key_prefix + ) + for stream_key in stream_keys: + if stream_key not in task_status: + task_status[stream_key] = self.init_task_status() # For Redis queue, prefer XINFO GROUPS to compute pending - groups_info = self.memos_message_queue.redis.xinfo_groups( - self.memos_message_queue.stream_key_prefix - ) + groups_info = memos_message_queue.redis.xinfo_groups(stream_key) if groups_info: for group in groups_info: - if group.get("name") == self.memos_message_queue.consumer_group: - pend = int(group.get("pending", pend)) + if group.get("name") == memos_message_queue.consumer_group: + task_status[stream_key]["running"] += int(group.get("pending", 0)) + task_status[stream_key]["remaining"] += memos_message_queue.qsize()[ + stream_key + ] + task_status["running"] += int(group.get("pending", 0)) + task_status["remaining"] += task_status[stream_key]["remaining"] break - else: - pend = run - # 2) dynamic total (allows new tasks queued while waiting) - total_now = max(init_unfinished, done_total + curr_unfinished) - done_total = max(0, total_now - curr_unfinished) - - # 3) update EWMA throughput - now = time.monotonic() - if last_t is None: - last_t = now - else: - dt = max(1e-6, now - last_t) - dc = max(0, done_total - last_done) - inst = dc / dt - rate = inst if rate == 0.0 else alpha * inst + (1 - alpha) * rate - last_t = now - last_done = done_total - - eta = None if rate <= 1e-9 else (curr_unfinished / rate) - - # 4) progress log (throttled) - if now >= next_log: - print( - f"[mem_scheduler_wait] remaining≈{curr_unfinished} | throughput≈{rate:.2f} msg/s | ETA≈{_fmt_eta(eta)} " - f"| qsize={qsz} pending={pend} running={run}" - ) - next_log = now + max(0.2, log_every) - - # 5) exit / stuck detection - idle_dispatcher = ( - (pend == 0 and run == 0) - if (self.enable_parallel_dispatch and self.dispatcher is not None) - else True + elif isinstance(memos_message_queue, SchedulerLocalQueue): + running_task_count = self.dispatcher.get_running_task_count() + task_status["running"] = running_task_count + task_status["remaining"] = sum(memos_message_queue.qsize().values()) + else: + logger.error( + f"type of self.memos_message_queue is {memos_message_queue}, which is not supported" ) - if curr_unfinished == 0: - break - if curr_unfinished > 0 and qsz == 0 and idle_dispatcher: - if last_unfinished == curr_unfinished: - stuck_ticks += 1 - else: - stuck_ticks = 0 - else: - stuck_ticks = 0 - last_unfinished = curr_unfinished - - if stuck_ticks >= 3: - logger.warning( - "mem_scheduler_wait: detected leaked 'unfinished_tasks' -> treating queue as drained" - ) - break - - if now >= deadline: - logger.warning("mem_scheduler_wait: queue did not drain before timeout") - return False - - time.sleep(poll) - - # 6) wait dispatcher (second stage) - remaining = max(0.0, deadline - time.monotonic()) - if self.enable_parallel_dispatch and self.dispatcher is not None: - try: - ok = self.dispatcher.join(timeout=remaining if remaining > 0 else 0) - except TypeError: - ok = self.dispatcher.join() - if not ok: - logger.warning("mem_scheduler_wait: dispatcher did not complete before timeout") - return False - - return True + raise NotImplementedError() + return task_status def _gather_queue_stats(self) -> dict: """Collect queue/dispatcher stats for reporting.""" + memos_message_queue = self.memos_message_queue.memos_message_queue stats: dict[str, int | float | str] = {} stats["use_redis_queue"] = bool(self.use_redis_queue) # local queue metrics if not self.use_redis_queue: try: - stats["qsize"] = int(self.memos_message_queue.qsize()) + stats["qsize"] = int(memos_message_queue.qsize()) except Exception: stats["qsize"] = -1 # unfinished_tasks if available try: stats["unfinished_tasks"] = int( - getattr(self.memos_message_queue, "unfinished_tasks", 0) or 0 + getattr(memos_message_queue, "unfinished_tasks", 0) or 0 ) except Exception: stats["unfinished_tasks"] = -1 diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 2c20520ea..f18bfd715 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -20,8 +20,6 @@ QUERY_LABEL, USER_INPUT_TYPE, WORKING_MEMORY_TYPE, - MemCubeID, - UserID, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem @@ -34,6 +32,10 @@ from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory +from memos.types import ( + MemCubeID, + UserID, +) logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index 01b57563d..6cf3a9e58 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -11,8 +11,6 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, DEFAULT_SCHEDULER_RETRIEVER_RETRIES, - FINE_STRATEGY, - FineStrategy, TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, ) @@ -24,6 +22,7 @@ from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer from memos.memories.textual.item import TextualMemoryMetadata from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.types.general_types import FINE_STRATEGY, FineStrategy # Extract JSON response from .memory_filter import MemoryFilter diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 03221aa7b..f30efa52f 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -237,10 +237,6 @@ def _check_pool_health( # If we got here, pool appears healthy pool_info["last_active"] = get_utc_now() - # Log health status with comprehensive information - if self.dispatcher: - max_workers = pool_info.get("max_workers", 0) - return True, "" def _restart_pool(self, name: str, pool_info: dict) -> None: diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index a5f1c0097..b097b1e2d 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -20,8 +20,6 @@ DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, MONITOR_ACTIVATION_MEMORY_TYPE, MONITOR_WORKING_MEMORY_TYPE, - MemCubeID, - UserID, ) from memos.mem_scheduler.schemas.monitor_schemas import ( MemoryMonitorItem, @@ -31,6 +29,7 @@ from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import extract_json_obj from memos.memories.textual.tree import TreeTextMemory +from memos.types import MemCubeID, UserID logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index f6e9b86fe..0e64ea9a0 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -13,16 +13,18 @@ from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( API_MIX_SEARCH_LABEL, - MemCubeID, - SearchMode, - UserID, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.utils.api_utils import format_textual_memory_item from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -from memos.types import UserContext +from memos.types import ( + MemCubeID, + SearchMode, + UserContext, + UserID, +) if TYPE_CHECKING: diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 089a7cc6c..91d442720 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,24 +1,4 @@ -import os - -from enum import Enum from pathlib import Path -from typing import NewType - - -class SearchMode(str, Enum): - """Enumeration for search modes.""" - - FAST = "fast" - FINE = "fine" - MIXTURE = "mixture" - - -class FineStrategy(str, Enum): - """Enumeration for fine strategies.""" - - REWRITE = "rewrite" - RECREATE = "recreate" - DEEP_SEARCH = "deep_search" FILE_PATH = Path(__file__).absolute() @@ -81,22 +61,3 @@ class FineStrategy(str, Enum): DEFAULT_MAX_QUERY_KEY_WORDS = 1000 DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05] DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 - - -# new types -UserID = NewType("UserID", str) -MemCubeID = NewType("CubeID", str) - -# algorithm strategies -DEFAULT_FINE_STRATEGY = FineStrategy.REWRITE - -# Read fine strategy from environment variable `FINE_STRATEGY`. -# If provided and valid, use it; otherwise fall back to default. -_env_fine_strategy = os.getenv("FINE_STRATEGY") -if _env_fine_strategy: - try: - FINE_STRATEGY = FineStrategy(_env_fine_strategy) - except ValueError: - FINE_STRATEGY = DEFAULT_FINE_STRATEGY -else: - FINE_STRATEGY = DEFAULT_FINE_STRATEGY diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index d7e94e0e1..9c79fc42a 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -141,6 +141,9 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin): ) memcube_name: str | None = Field(default=None, description="Display name for memcube") memory_len: int | None = Field(default=None, description="Count of items involved in the event") + status: str | None = Field( + default=None, description="Completion status of the task (e.g., 'completed', 'failed')" + ) def debug_info(self) -> dict[str, Any]: """Return structured debug information for logging purposes.""" diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index b1a304754..df3e2055e 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -1,4 +1,5 @@ import concurrent +import os import threading import time @@ -11,11 +12,13 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.general_modules.task_threads import ThreadManager -from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_STOP_WAIT, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem, ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem -from memos.mem_scheduler.utils.metrics import MetricsRegistry from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker logger = get_logger(__name__) @@ -41,6 +44,9 @@ def __init__( use_redis_queue: bool | None = None, enable_parallel_dispatch: bool = True, config=None, + status_tracker: TaskStatusTracker | None = None, + metrics: Any | None = None, + submit_web_logs: Callable | None = None, # ADDED ): super().__init__() self.config = config @@ -90,18 +96,14 @@ def __init__( self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT ) - self.metrics = MetricsRegistry( - topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50) - ) + self.metrics = metrics + self.status_tracker = status_tracker + self.submit_web_logs = submit_web_logs # ADDED def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None: if not msgs: return - now = time.time() - for m in msgs: - self.metrics.on_enqueue( - label=m.label, mem_cube_id=m.mem_cube_id, inst_rate=1.0, now=now - ) + # This is handled in BaseScheduler now def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ @@ -116,38 +118,60 @@ def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ def wrapped_handler(messages: list[ScheduleMessageItem]): + start_time = time.time() + if self.status_tracker: + self.status_tracker.task_started( + task_id=task_item.item_id, user_id=task_item.user_id + ) try: # --- mark start: record queuing time(now - enqueue_ts)--- now = time.time() - for m in messages: - enq_ts = getattr(m, "timestamp", None) - - # Path 1: epoch seconds (preferred) - if isinstance(enq_ts, int | float): - enq_epoch = float(enq_ts) - - # Path 2: datetime -> normalize to UTC epoch - elif hasattr(enq_ts, "timestamp"): - dt = enq_ts - if dt.tzinfo is None: - # treat naive as UTC to neutralize +8h skew - dt = dt.replace(tzinfo=timezone.utc) - enq_epoch = dt.timestamp() - else: - # fallback: treat as "just now" - enq_epoch = now - - wait_sec = max(0.0, now - enq_epoch) - self.metrics.on_start( - label=m.label, mem_cube_id=m.mem_cube_id, wait_sec=wait_sec, now=now - ) + m = messages[0] # All messages in this batch have same user and type + enq_ts = getattr(m, "timestamp", None) + + # Path 1: epoch seconds (preferred) + if isinstance(enq_ts, int | float): + enq_epoch = float(enq_ts) + + # Path 2: datetime -> normalize to UTC epoch + elif hasattr(enq_ts, "timestamp"): + dt = enq_ts + if dt.tzinfo is None: + # treat naive as UTC to neutralize +8h skew + dt = dt.replace(tzinfo=timezone.utc) + enq_epoch = dt.timestamp() + else: + # fallback: treat as "just now" + enq_epoch = now + + wait_sec = max(0.0, now - enq_epoch) + self.metrics.observe_task_wait_duration(wait_sec, m.user_id, m.label) # Execute the original handler result = handler(messages) # --- mark done --- - for m in messages: - self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) + duration = time.time() - start_time + self.metrics.observe_task_duration(duration, m.user_id, m.label) + if self.status_tracker: + self.status_tracker.task_completed( + task_id=task_item.item_id, user_id=task_item.user_id + ) + self.metrics.task_completed(user_id=m.user_id, task_type=m.label) + + is_cloud_env = ( + os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") == "memos-memory-change" + ) + if self.submit_web_logs and is_cloud_env: + status_log = ScheduleLogForWebItem( + user_id=task_item.user_id, + mem_cube_id=task_item.mem_cube_id, + item_id=task_item.item_id, + label=m.label, + log_content=f"Task {task_item.item_id} completed successfully for user {task_item.user_id}.", + status="completed", + ) + self.submit_web_logs([status_log]) # acknowledge redis messages if self.use_redis_queue and self.memos_message_queue is not None: @@ -172,9 +196,12 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): return result except Exception as e: - # Mark task as failed and remove from tracking - for m in messages: - self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) + m = messages[0] + self.metrics.task_failed(m.user_id, m.label, type(e).__name__) + if self.status_tracker: + self.status_tracker.task_failed( + task_id=task_item.item_id, user_id=task_item.user_id, error_message=str(e) + ) # Mark task as failed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: @@ -183,6 +210,21 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if len(self._completed_tasks) > self.completed_tasks_max_show_size: self._completed_tasks.pop(0) logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") + + is_cloud_env = ( + os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") == "memos-memory-change" + ) + if self.submit_web_logs and is_cloud_env: + status_log = ScheduleLogForWebItem( + user_id=task_item.user_id, + mem_cube_id=task_item.mem_cube_id, + item_id=task_item.item_id, + label=m.label, + log_content=f"Task {task_item.item_id} failed for user {task_item.user_id} with error: {e!s}.", + status="failed", + exception=str(e), + ) + self.submit_web_logs([status_log]) raise return wrapped_handler diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 5e850c8ce..dc2b9af26 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,6 +5,7 @@ the local memos_message_queue functionality in BaseScheduler. """ +import os import re import time @@ -33,7 +34,9 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, - stream_key_prefix: str = "scheduler:messages:stream", + stream_key_prefix: str = os.getenv( + "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", "scheduler:messages:stream" + ), consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", max_len: int = 10000, @@ -81,6 +84,10 @@ def __init__( self.seen_streams = set() + # Task Broker + + # Task Orchestrator + def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" return stream_key @@ -256,7 +263,7 @@ def get_nowait( user_id=user_id, mem_cube_id=mem_cube_id, block=False, batch_size=batch_size ) - def qsize(self) -> int: + def qsize(self) -> dict: """ Get the current size of the Redis queue (Queue-compatible interface). @@ -271,19 +278,22 @@ def qsize(self) -> int: total_size = 0 try: + qsize_stats = {} # Scan for all stream keys matching the prefix - for stream_key in self._redis_conn.scan_iter(f"{self.stream_key_prefix}:*"): - try: - # Get the length of each stream and add to total - total_size += self._redis_conn.xlen(stream_key) - except Exception as e: - logger.debug(f"Failed to get length for stream {stream_key}: {e}") - return total_size + redis_pattern = f"{self.stream_key_prefix}:*" + for stream_key in self._redis_conn.scan_iter(redis_pattern): + # Get the length of each stream and add to total + stream_qsize = self._redis_conn.xlen(stream_key) + qsize_stats[stream_key] = stream_qsize + total_size += stream_qsize + qsize_stats["total_size"] = total_size + return qsize_stats + except Exception as e: logger.error(f"Failed to get Redis queue size: {e}") - return 0 + return {} - def get_stream_keys(self) -> list[str]: + def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: """ List all Redis stream keys that match this queue's prefix. @@ -293,16 +303,15 @@ def get_stream_keys(self) -> list[str]: if not self._redis_conn: return [] + if stream_key_prefix is None: + stream_key_prefix = self.stream_key_prefix # First, get all keys that might match (using Redis pattern matching) - redis_pattern = f"{self.stream_key_prefix}:*" - raw_keys = [ - key.decode("utf-8") if isinstance(key, bytes) else key - for key in self._redis_conn.scan_iter(match=redis_pattern) - ] + redis_pattern = f"{stream_key_prefix}:*" + raw_keys = self._redis_conn.scan_iter(match=redis_pattern) # Second, filter using Python regex to ensure exact prefix match # Escape special regex characters in the prefix, then add :.* - escaped_prefix = re.escape(self.stream_key_prefix) + escaped_prefix = re.escape(stream_key_prefix) regex_pattern = f"^{escaped_prefix}:" stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)] diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py index 0d781c996..d587dfb1e 100644 --- a/src/memos/mem_scheduler/utils/metrics.py +++ b/src/memos/mem_scheduler/utils/metrics.py @@ -1,244 +1,125 @@ -# metrics.py -from __future__ import annotations - -import threading +# src/memos/mem_scheduler/utils/metrics.py import time -from dataclasses import dataclass, field +from contextlib import ContextDecorator -from memos.log import get_logger +from prometheus_client import Counter, Gauge, Histogram, Summary -# ==== global window config ==== -WINDOW_SEC = 120 # 2 minutes sliding window +# --- Metric Definitions --- -logger = get_logger(__name__) +TASKS_ENQUEUED_TOTAL = Counter( + "memos_scheduler_tasks_enqueued_total", + "Total number of tasks enqueued", + ["user_id", "task_type"], +) +TASKS_DEQUEUED_TOTAL = Counter( + "memos_scheduler_tasks_dequeued_total", + "Total number of tasks dequeued", + ["user_id", "task_type"], +) + +TASK_DURATION_SECONDS = Summary( + "memos_scheduler_task_duration_seconds", + "Task processing duration in seconds", + ["user_id", "task_type"], +) + +TASK_WAIT_DURATION_SECONDS = Summary( + "memos_scheduler_task_wait_duration_seconds", + "Task waiting duration in seconds", + ["user_id", "task_type"], +) + +TASKS_FAILED_TOTAL = Counter( + "memos_scheduler_tasks_failed_total", + "Total number of failed tasks", + ["user_id", "task_type", "error_type"], +) + +TASKS_COMPLETED_TOTAL = Counter( + "memos_scheduler_tasks_completed_total", + "Total number of successfully completed tasks", + ["user_id", "task_type"], +) + +QUEUE_LENGTH = Gauge( + "memos_scheduler_queue_length", "Current length of the task queue", ["user_id"] +) + +INTERNAL_SPAN_DURATION = Histogram( + "memos_scheduler_internal_span_duration_seconds", + "Duration of internal operations", + ["span_name", "user_id", "task_id"], +) + + +# --- Instrumentation Functions --- + + +def task_enqueued(user_id: str, task_type: str, count: int = 1): + TASKS_ENQUEUED_TOTAL.labels(user_id=user_id, task_type=task_type).inc(count) + + +def task_dequeued(user_id: str, task_type: str, count: int = 1): + TASKS_DEQUEUED_TOTAL.labels(user_id=user_id, task_type=task_type).inc(count) + + +def observe_task_duration(duration: float, user_id: str, task_type: str): + TASK_DURATION_SECONDS.labels(user_id=user_id, task_type=task_type).observe(duration) -# ---------- O(1) EWMA ---------- -class Ewma: - """ - Time-decayed EWMA: - """ - __slots__ = ("alpha", "last_ts", "tau", "value") - - def __init__(self, alpha: float = 0.3, tau: float = WINDOW_SEC): - self.alpha = alpha - self.value = 0.0 - self.last_ts: float = time.time() - self.tau = max(1e-6, float(tau)) - - def _decay_to(self, now: float | None = None): - now = time.time() if now is None else now - dt = max(0.0, now - self.last_ts) - if dt <= 0: - return - from math import exp - - self.value *= exp(-dt / self.tau) - self.last_ts = now - - def update(self, instant: float, now: float | None = None): - self._decay_to(now) - self.value = self.alpha * instant + (1 - self.alpha) * self.value - - def value_at(self, now: float | None = None) -> float: - now = time.time() if now is None else now - dt = max(0.0, now - self.last_ts) - if dt <= 0: - return self.value - from math import exp - - return self.value * exp(-dt / self.tau) - - -# ---------- approximate P95(Reservoir sample) ---------- -class ReservoirP95: - __slots__ = ("_i", "buf", "k", "n", "window") - - def __init__(self, k: int = 512, window: float = WINDOW_SEC): - self.k = k - self.buf: list[tuple[float, float]] = [] # (value, ts) - self.n = 0 - self._i = 0 - self.window = float(window) - - def _gc(self, now: float): - win_start = now - self.window - self.buf = [p for p in self.buf if p[1] >= win_start] - if self.buf: - self._i %= len(self.buf) - else: - self._i = 0 - - def add(self, x: float, now: float | None = None): - now = time.time() if now is None else now - self._gc(now) - self.n += 1 - if len(self.buf) < self.k: - self.buf.append((x, now)) - return - self.buf[self._i] = (x, now) - self._i = (self._i + 1) % self.k - - def p95(self, now: float | None = None) -> float: - now = time.time() if now is None else now - self._gc(now) - if not self.buf: - return 0.0 - arr = sorted(v for v, _ in self.buf) - idx = int(0.95 * (len(arr) - 1)) - return arr[idx] - - -# ---------- Space-Saving Top-K ---------- -class SpaceSaving: - """only topK:add(key) O(1),query topk O(K log K)""" - - def __init__(self, k: int = 100): - self.k = k - self.cnt: dict[str, int] = {} - - def add(self, key: str): - if key in self.cnt: - self.cnt[key] += 1 - return - if len(self.cnt) < self.k: - self.cnt[key] = 1 - return - victim = min(self.cnt, key=self.cnt.get) - self.cnt[key] = self.cnt.pop(victim) + 1 - - def topk(self) -> list[tuple[str, int]]: - return sorted(self.cnt.items(), key=lambda kv: kv[1], reverse=True) - - -@dataclass -class KeyStats: - backlog: int = 0 - lambda_ewma: Ewma = field(default_factory=lambda: Ewma(0.3, WINDOW_SEC)) - mu_ewma: Ewma = field(default_factory=lambda: Ewma(0.3, WINDOW_SEC)) - wait_p95: ReservoirP95 = field(default_factory=lambda: ReservoirP95(512, WINDOW_SEC)) - last_ts: float = field(default_factory=time.time) - # last event timestamps for rate estimation - last_enqueue_ts: float | None = None - last_done_ts: float | None = None - - def snapshot(self, now: float | None = None) -> dict: - now = time.time() if now is None else now - lam = self.lambda_ewma.value_at(now) - mu = self.mu_ewma.value_at(now) - delta = mu - lam - eta = float("inf") if delta <= 1e-9 else self.backlog / delta - return { - "backlog": self.backlog, - "lambda": round(lam, 3), - "mu": round(mu, 3), - "delta": round(delta, 3), - "eta_sec": None if eta == float("inf") else round(eta, 1), - "wait_p95_sec": round(self.wait_p95.p95(now), 3), - } - - -class MetricsRegistry: +def observe_task_wait_duration(duration: float, user_id: str, task_type: str): + TASK_WAIT_DURATION_SECONDS.labels(user_id=user_id, task_type=task_type).observe(duration) + + +def task_failed(user_id: str, task_type: str, error_type: str): + TASKS_FAILED_TOTAL.labels(user_id=user_id, task_type=task_type, error_type=error_type).inc() + + +def task_completed(user_id: str, task_type: str, count: int = 1): + TASKS_COMPLETED_TOTAL.labels(user_id=user_id, task_type=task_type).inc(count) + + +def update_queue_length(length: int, user_id: str): + QUEUE_LENGTH.labels(user_id=user_id).set(length) + + +def observe_internal_span(duration: float, span_name: str, user_id: str, task_id: str): + INTERNAL_SPAN_DURATION.labels(span_name=span_name, user_id=user_id, task_id=task_id).observe( + duration + ) + + +# --- TimingSpan Context Manager --- + + +class TimingSpan(ContextDecorator): """ - metrics: - - 1st phase:label(must) - - 2nd phase:labelXmem_cube_id(only Top-K) - - on_enqueue(label, mem_cube_id) - - on_start(label, mem_cube_id, wait_sec) - - on_done(label, mem_cube_id) + A context manager/decorator to measure the duration of a code block and record it + as a Prometheus histogram observation. + + Usage as a decorator: + @TimingSpan("expensive_operation", user_id="user123") + def my_function(): + time.sleep(2) + + Usage as a context manager: + with TimingSpan("another_op", user_id="user456", task_id="t1"): + ... """ - def __init__(self, topk_per_label: int = 50): - self._lock = threading.RLock() - self._label_stats: dict[str, KeyStats] = {} - self._label_topk: dict[str, SpaceSaving] = {} - self._detail_stats: dict[tuple[str, str], KeyStats] = {} - self._topk_per_label = topk_per_label - - # ---------- helpers ---------- - def _get_label(self, label: str) -> KeyStats: - if label not in self._label_stats: - self._label_stats[label] = KeyStats() - self._label_topk[label] = SpaceSaving(self._topk_per_label) - return self._label_stats[label] - - def _get_detail(self, label: str, mem_cube_id: str) -> KeyStats | None: - # 只有 Top-K 的 mem_cube_id 才建细粒度 key - ss = self._label_topk[label] - if mem_cube_id in ss.cnt or len(ss.cnt) < ss.k: - key = (label, mem_cube_id) - if key not in self._detail_stats: - self._detail_stats[key] = KeyStats() - return self._detail_stats[key] - return None - - # ---------- events ---------- - def on_enqueue( - self, label: str, mem_cube_id: str, inst_rate: float = 1.0, now: float | None = None - ): - with self._lock: - now = time.time() if now is None else now - ls = self._get_label(label) - # derive instantaneous arrival rate from inter-arrival time (events/sec) - prev_ts = ls.last_enqueue_ts - dt = (now - prev_ts) if prev_ts is not None else None - inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 # first sample: no spike - ls.last_enqueue_ts = now - ls.backlog += 1 - ls.lambda_ewma.update(inst_rate, now) - self._label_topk[label].add(mem_cube_id) - ds = self._get_detail(label, mem_cube_id) - if ds: - prev_ts_d = ds.last_enqueue_ts - dt_d = (now - prev_ts_d) if prev_ts_d is not None else None - inst_rate_d = (1.0 / max(1e-3, dt_d)) if dt_d is not None else 0.0 - ds.last_enqueue_ts = now - ds.backlog += 1 - ds.lambda_ewma.update(inst_rate_d, now) - - def on_start(self, label: str, mem_cube_id: str, wait_sec: float, now: float | None = None): - with self._lock: - now = time.time() if now is None else now - ls = self._get_label(label) - ls.wait_p95.add(wait_sec, now) - ds = self._detail_stats.get((label, mem_cube_id)) - if ds: - ds.wait_p95.add(wait_sec, now) - - def on_done( - self, label: str, mem_cube_id: str, inst_rate: float = 1.0, now: float | None = None - ): - with self._lock: - now = time.time() if now is None else now - ls = self._get_label(label) - # derive instantaneous service rate from inter-completion time (events/sec) - prev_ts = ls.last_done_ts - dt = (now - prev_ts) if prev_ts is not None else None - inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 - ls.last_done_ts = now - if ls.backlog > 0: - ls.backlog -= 1 - ls.mu_ewma.update(inst_rate, now) - ds = self._detail_stats.get((label, mem_cube_id)) - if ds: - prev_ts_d = ds.last_done_ts - dt_d = (now - prev_ts_d) if prev_ts_d is not None else None - inst_rate_d = (1.0 / max(1e-3, dt_d)) if dt_d is not None else 0.0 - ds.last_done_ts = now - if ds.backlog > 0: - ds.backlog -= 1 - ds.mu_ewma.update(inst_rate_d, now) - - # ---------- snapshots ---------- - def snapshot(self) -> dict: - with self._lock: - now = time.time() - by_label = {lbl: ks.snapshot(now) for lbl, ks in self._label_stats.items()} - heavy = {lbl: self._label_topk[lbl].topk() for lbl in self._label_topk} - details = {} - for (lbl, cube), ks in self._detail_stats.items(): - details.setdefault(lbl, {})[cube] = ks.snapshot(now) - return {"by_label": by_label, "heavy": heavy, "details": details} + def __init__(self, span_name: str, user_id: str = "unknown", task_id: str = "unknown"): + self.span_name = span_name + self.user_id = user_id + self.task_id = task_id + self.start_time = 0 + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + duration = time.perf_counter() - self.start_time + observe_internal_span(duration, self.span_name, self.user_id, self.task_id) diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py new file mode 100644 index 000000000..98d4c6a3f --- /dev/null +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -0,0 +1,88 @@ +# src/memos/mem_scheduler/utils/status_tracker.py +import json + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from memos.dependency import require_python_package + + +if TYPE_CHECKING: + import redis + + +class TaskStatusTracker: + @require_python_package(import_name="redis", install_command="pip install redis") + def __init__(self, redis_client: "redis.Redis"): + self.redis = redis_client + + def _get_key(self, user_id: str) -> str: + return f"memos:task_meta:{user_id}" + + def task_submitted(self, task_id: str, user_id: str, task_type: str, mem_cube_id: str): + key = self._get_key(user_id) + payload = { + "status": "waiting", + "task_type": task_type, + "mem_cube_id": mem_cube_id, + "submitted_at": datetime.now(timezone.utc).isoformat(), + } + self.redis.hset(key, task_id, json.dumps(payload)) + self.redis.expire(key, timedelta(days=7)) + + def task_started(self, task_id: str, user_id: str): + key = self._get_key(user_id) + existing_data_json = self.redis.hget(key, task_id) + if not existing_data_json: + # 容错处理: 如果任务不存在, 也创建一个 + payload = { + "status": "in_progress", + "started_at": datetime.now(timezone.utc).isoformat(), + } + else: + payload = json.loads(existing_data_json) + payload["status"] = "in_progress" + payload["started_at"] = datetime.now(timezone.utc).isoformat() + self.redis.hset(key, task_id, json.dumps(payload)) + self.redis.expire(key, timedelta(days=7)) + + def task_completed(self, task_id: str, user_id: str): + key = self._get_key(user_id) + existing_data_json = self.redis.hget(key, task_id) + if not existing_data_json: + return + payload = json.loads(existing_data_json) + payload["status"] = "completed" + payload["completed_at"] = datetime.now(timezone.utc).isoformat() + # 设置该任务条目的过期时间, 例如 24 小时 + # 注意: Redis Hash 不能为单个 field 设置 TTL, 这里我们可以 通过后台任务清理或在获取时判断时间戳 + # 简单起见, 我们暂时依赖一个后台清理任务 + self.redis.hset(key, task_id, json.dumps(payload)) + self.redis.expire(key, timedelta(days=7)) + + def task_failed(self, task_id: str, user_id: str, error_message: str): + key = self._get_key(user_id) + existing_data_json = self.redis.hget(key, task_id) + if not existing_data_json: + payload = { + "status": "failed", + "error": error_message, + "failed_at": datetime.now(timezone.utc).isoformat(), + } + else: + payload = json.loads(existing_data_json) + payload["status"] = "failed" + payload["error"] = error_message + payload["failed_at"] = datetime.now(timezone.utc).isoformat() + self.redis.hset(key, task_id, json.dumps(payload)) + self.redis.expire(key, timedelta(days=7)) + + def get_task_status(self, task_id: str, user_id: str) -> dict | None: + key = self._get_key(user_id) + data = self.redis.hget(key, task_id) + return json.loads(data) if data else None + + def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]: + key = self._get_key(user_id) + all_tasks = self.redis.hgetall(key) + return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()} diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 27c33029c..df5e05a1f 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -16,11 +16,13 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.advanced_searcher import ( + AdvancedSearcher as Searcher, +) from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.factory import RerankerFactory from memos.types import MessageList @@ -127,8 +129,7 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int return self.memory_manager.get_current_memory_size(user_name=user_name) def get_searcher( - self, - manual_close_internet: bool = False, + self, manual_close_internet: bool = False, moscube: bool = False, process_llm=None ): if (self.internet_retriever is not None) and manual_close_internet: logger.warning( @@ -140,6 +141,7 @@ def get_searcher( self.embedder, self.reranker, internet_retriever=None, + process_llm=process_llm, ) else: searcher = Searcher( @@ -148,6 +150,7 @@ def get_searcher( self.embedder, self.reranker, internet_retriever=self.internet_retriever, + process_llm=process_llm, ) return searcher diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py new file mode 100644 index 000000000..22cd44b8c --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -0,0 +1,540 @@ +import copy +import time + +from typing import Any + +from memos.embedders.factory import OllamaEmbedder +from memos.graph_dbs.factory import Neo4jGraphDB +from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata +from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + parse_structured_output, +) +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.base import BaseReranker +from memos.templates.advanced_search_prompts import PROMPT_MAPPING +from memos.types.general_types import SearchMode + + +logger = get_logger(__name__) + + +class AdvancedSearcher(Searcher): + def __init__( + self, + dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM, + graph_store: Neo4jGraphDB, + embedder: OllamaEmbedder, + reranker: BaseReranker, + bm25_retriever: EnhancedBM25 | None = None, + internet_retriever: None = None, + search_strategy: dict | None = None, + manual_close_internet: bool = True, + process_llm: Any | None = None, + ): + super().__init__( + dispatcher_llm=dispatcher_llm, + graph_store=graph_store, + embedder=embedder, + reranker=reranker, + bm25_retriever=bm25_retriever, + internet_retriever=internet_retriever, + search_strategy=search_strategy, + manual_close_internet=manual_close_internet, + ) + + self.stage_retrieve_top = 3 + self.process_llm = process_llm + self.thinking_stages = 0 # TODO: to increase thinking depth when the algorithm is reliable + self.max_retry_times = 2 + self.deep_search_top_k_bar = 2 + + def load_template(self, template_name: str) -> str: + if template_name not in PROMPT_MAPPING: + logger.error("Prompt template is not found!") + prompt = PROMPT_MAPPING[template_name] + return prompt + + def build_prompt(self, template_name: str, **kwargs) -> str: + template = self.load_template(template_name) + if not template: + raise FileNotFoundError(f"Prompt template `{template_name}` not found.") + return template.format(**kwargs) + + def stage_retrieve( + self, + stage_id: int, + query: str, + previous_retrieval_phrases: list[str], + text_memories: str, + context: str | None = None, + ) -> tuple[bool, str, str, list[str]]: + """Run a retrieval-expansion stage and parse structured LLM output. + + Returns a tuple of: + - can_answer: whether current memories suffice to answer + - reason: brief reasoning or hypotheses + - context: synthesized context summary + - retrieval_phrases: list of phrases to retrieve next + """ + + # Format previous phrases as bullet list to align with prompt expectations + prev_phrases_text = ( + "- " + "\n- ".join(previous_retrieval_phrases) if previous_retrieval_phrases else "" + ) + + args = { + "template_name": f"stage{stage_id}_expand_retrieve", + "query": query, + "previous_retrieval_phrases": prev_phrases_text, + "memories": text_memories, + } + if context is not None: + args["context"] = context + prompt = self.build_prompt(**args) + + max_attempts = max(0, self.max_retry_times) + 1 + for attempt in range(1, max_attempts + 1): + try: + llm_response = self.process_llm.generate( + [{"role": "user", "content": prompt}] + ).strip() + result = parse_structured_output(content=llm_response) + + # Parse booleans and fallbacks robustly + can_answer_str = str(result.get("can_answer", "")).strip().lower() + can_answer = can_answer_str in {"true", "yes", "y", "1"} + + reason = result.get("reason", "") + + context_out = str(result.get("context", "")) + + phrases_val = result.get("retrieval_phrases", result.get("retrival_phrases", [])) + if isinstance(phrases_val, list): + retrieval_phrases = [str(p).strip() for p in phrases_val if str(p).strip()] + elif isinstance(phrases_val, str) and phrases_val.strip(): + retrieval_phrases = [p.strip() for p in phrases_val.splitlines() if p.strip()] + else: + retrieval_phrases = [] + + return can_answer, reason, context_out, retrieval_phrases + + except Exception as e: + if attempt < max_attempts: + logger.debug(f"[stage_retrieve]🔁 retry {attempt}/{max_attempts} failed: {e!s}") + time.sleep(1) + else: + logger.error( + f"[stage_retrieve]❌ all {max_attempts} attempts failed: {e!s}; \nprompt: {prompt}", + exc_info=True, + ) + raise e + + def summarize_memories(self, query: str, context: str, text_memories: str, top_k: int): + args = { + "template_name": "memory_summary", + "query": query, + "context": context, + "memories": text_memories, + "top_k": top_k, + } + + prompt = self.build_prompt(**args) + + max_attempts = max(0, self.max_retry_times) + 1 + for attempt in range(1, max_attempts + 1): + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + result = parse_structured_output(content=llm_response) + context, mem_list = result["context"], result["memories"] + if not isinstance(mem_list, list): + logger.error(f"The result of summarize_memories is {result}") + return context, mem_list + except Exception as e: + if attempt < max_attempts: + logger.debug( + f"[summarize_memories]🔁 retry {attempt}/{max_attempts} failed: {e!s}" + ) + time.sleep(1) + else: + logger.error( + f"[summarize_memories]❌ all {max_attempts} attempts failed: {e!s}; \nprompt: {prompt}", + exc_info=True, + ) + raise e + + def judge_memories(self, query: str, text_memories: str): + args = { + "template_name": "memory_judgement", + "query": query, + "memories": text_memories, + } + + prompt = self.build_prompt(**args) + + max_attempts = max(0, self.max_retry_times) + 1 + for attempt in range(1, max_attempts + 1): + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + result = parse_structured_output(content=llm_response) + reason, can_answer = ( + result["reason"], + result["can_answer"], + ) + + return reason, can_answer + except Exception as e: + if attempt < max_attempts: + logger.debug( + f"[summarize_and_eval]🔁 retry {attempt}/{max_attempts} failed: {e!s}" + ) + time.sleep(1) + else: + logger.error( + f"[summarize_and_eval]❌ all {max_attempts} attempts failed: {e!s}; \nprompt: {prompt}", + exc_info=True, + ) + raise e + + def tree_memories_to_text_memories(self, memories: list[TextualMemoryItem]): + mem_list = [] + source_documents = [] + for mem in memories: + source_documents.extend( + [f"({one.chat_time}) {one.content}" for one in mem.metadata.sources] + ) + mem_list.append(mem.memory) + mem_list = list(set(mem_list)) + source_documents = list(set(source_documents)) + return mem_list, source_documents + + def get_final_memories(self, user_id: str, top_k: int, mem_list: list[str]): + enhanced_memories = [] + for new_mem in mem_list: + enhanced_memories.append( + TextualMemoryItem(memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id)) + ) + if len(enhanced_memories) > top_k: + logger.info( + f"Result count {len(enhanced_memories)} exceeds requested top_k {top_k}, truncating to top {top_k} memories" + ) + result_memories = enhanced_memories[:top_k] + return result_memories + + def recreate_enhancement( + self, + query: str, + text_memories: list[str], + retries: int, + ) -> list: + attempt = 0 + text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(text_memories)]) + prompt_name = "memory_recreate_enhancement" + prompt = self.build_prompt(template_name=prompt_name, query=query, memories=text_memories) + + llm_response = None + while attempt <= max(0, retries) + 1: + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + processed_text_memories = parse_structured_output(content=llm_response) + return processed_text_memories["answer"] + except Exception as e: + attempt += 1 + time.sleep(1) + logger.debug( + f"[memory_recreate_enhancement] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" + ) + logger.error( + f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}", + exc_info=True, + ) + raise ValueError("Fail to run memory enhancement") + + def deep_search( + self, + query: str, + top_k: int, + info=None, + memory_type="All", + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + previous_retrieval_phrases = [query] + retrieved_memories = self.retrieve( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + memory_type=memory_type, + search_filter=search_filter, + info=info, + ) + memories = self.post_retrieve( + retrieved_results=retrieved_memories, + top_k=top_k, + user_name=user_name, + info=info, + ) + if top_k < self.deep_search_top_k_bar or len(memories) == 0: + logger.warning("Requirements not met; returning memories as-is.") + return memories + + user_id = memories[0].metadata.user_id + context = None + + mem_list, _ = self.tree_memories_to_text_memories(memories=memories) + retrieved_memories = copy.deepcopy(retrieved_memories) + retrieved_memories_from_deep_search = [] + for current_stage_id in range(self.thinking_stages + 1): + try: + # at last + if current_stage_id == self.thinking_stages: + # eval to finish + reason, can_answer = self.judge_memories( + query=query, + text_memories="- " + "\n- ".join(mem_list) + "\n", + ) + + logger.info( + f"Final Stage: Stage {current_stage_id}; " + f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " + f"final can_answer: {can_answer}; reason: {reason}" + ) + mem_list = self.recreate_enhancement( + query=query, text_memories=mem_list, retries=self.max_retry_times + ) + enhanced_memories = self.get_final_memories( + user_id=user_id, top_k=top_k, mem_list=mem_list + ) + return enhanced_memories + + can_answer, reason, context, retrieval_phrases = self.stage_retrieve( + stage_id=current_stage_id + 1, + query=query, + previous_retrieval_phrases=previous_retrieval_phrases, + context=context, + text_memories="- " + "\n- ".join(mem_list) + "\n", + ) + if can_answer: + logger.info( + f"Stage {current_stage_id}: determined answer can be provided, creating enhanced memories; reason: {reason}", + ) + + enhanced_memories = self.get_final_memories( + user_id=user_id, top_k=top_k, mem_list=mem_list + ) + return enhanced_memories + else: + previous_retrieval_phrases.extend(retrieval_phrases) + logger.info( + f"Start complementary retrieval for Stage {current_stage_id}; " + f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " + f"can_answer: {can_answer}; reason: {reason}" + ) + logger.info( + "Stage %d - Found %d new retrieval phrases", + current_stage_id, + len(retrieval_phrases), + ) + # Search for additional memories based on retrieval phrases + additional_retrieved_memories = [] + for phrase in retrieval_phrases: + _retrieved_memories = self.retrieve( + query=phrase, + user_name=user_name, + top_k=self.stage_retrieve_top, + mode=SearchMode.FAST, + memory_type=memory_type, + search_filter=search_filter, + info=info, + ) + logger.info( + "Found %d additional memories for phrase: '%s'", + len(_retrieved_memories), + phrase[:30] + "..." if len(phrase) > 30 else phrase, + ) + additional_retrieved_memories.extend(_retrieved_memories) + retrieved_memories_from_deep_search.extend(additional_retrieved_memories) + merged_memories = self.post_retrieve( + retrieved_results=retrieved_memories + additional_retrieved_memories, + top_k=top_k * 2, + user_name=user_name, + info=info, + ) + + _mem_list, _ = self.tree_memories_to_text_memories(memories=merged_memories) + mem_list = _mem_list + mem_list = list(set(mem_list)) + logger.info( + "After stage %d, total memories in list: %d", + current_stage_id, + len(mem_list), + ) + + # enhance memories + mem_list = self.recreate_enhancement( + query=query, text_memories=mem_list, retries=self.max_retry_times + ) + logger.info("After summarization, memory list contains %d items", len(mem_list)) + + except Exception as e: + logger.error("Error in stage %d: %s", current_stage_id, str(e), exc_info=True) + # Continue to next stage instead of failing completely + continue + logger.error("Deep search failed, returning original memories") + return memories + + def deep_search_backup( + self, + query: str, + top_k: int, + info=None, + memory_type="All", + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + previous_retrieval_phrases = [query] + retrieved_memories = self.retrieve( + query=query, + user_name=user_name, + top_k=top_k, + mode=SearchMode.FAST, + memory_type=memory_type, + search_filter=search_filter, + info=info, + ) + memories = self.post_retrieve( + retrieved_results=retrieved_memories, + top_k=top_k, + user_name=user_name, + info=info, + ) + if top_k < self.deep_search_top_k_bar or len(memories) == 0: + logger.warning("Requirements not met; returning memories as-is.") + return memories + + user_id = memories[0].metadata.user_id + context = None + + mem_list, _ = self.tree_memories_to_text_memories(memories=memories) + retrieved_memories = copy.deepcopy(retrieved_memories) + retrieved_memories_from_deep_search = [] + for current_stage_id in range(self.thinking_stages + 1): + try: + # at last + if current_stage_id == self.thinking_stages: + # eval to finish + reason, can_answer = self.judge_memories( + query=query, + text_memories="- " + "\n- ".join(mem_list) + "\n", + ) + + logger.info( + f"Final Stage: Stage {current_stage_id}; " + f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " + f"final can_answer: {can_answer}; reason: {reason}" + ) + if len(retrieved_memories_from_deep_search) == 0: + memories = self.post_retrieve( + retrieved_results=retrieved_memories, + top_k=top_k, + user_name=user_name, + info=info, + ) + return memories[:top_k] + else: + enhanced_memories = self.get_final_memories( + user_id=user_id, top_k=top_k, mem_list=mem_list + ) + return enhanced_memories + + can_answer, reason, context, retrieval_phrases = self.stage_retrieve( + stage_id=current_stage_id + 1, + query=query, + previous_retrieval_phrases=previous_retrieval_phrases, + context=context, + text_memories="- " + "\n- ".join(mem_list) + "\n", + ) + if can_answer: + logger.info( + f"Stage {current_stage_id}: determined answer can be provided, creating enhanced memories; reason: {reason}", + ) + if len(retrieved_memories_from_deep_search) == 0: + memories = self.post_retrieve( + retrieved_results=retrieved_memories, + top_k=top_k, + user_name=user_name, + info=info, + ) + return memories[:top_k] + else: + enhanced_memories = self.get_final_memories( + user_id=user_id, top_k=top_k, mem_list=mem_list + ) + return enhanced_memories + else: + previous_retrieval_phrases.extend(retrieval_phrases) + logger.info( + f"Start complementary retrieval for Stage {current_stage_id}; " + f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " + f"can_answer: {can_answer}; reason: {reason}" + ) + logger.info( + "Stage %d - Found %d new retrieval phrases", + current_stage_id, + len(retrieval_phrases), + ) + # Search for additional memories based on retrieval phrases + additional_retrieved_memories = [] + for phrase in retrieval_phrases: + _retrieved_memories = self.retrieve( + query=phrase, + user_name=user_name, + top_k=self.stage_retrieve_top, + mode=SearchMode.FAST, + memory_type=memory_type, + search_filter=search_filter, + info=info, + ) + logger.info( + "Found %d additional memories for phrase: '%s'", + len(_retrieved_memories), + phrase[:30] + "..." if len(phrase) > 30 else phrase, + ) + additional_retrieved_memories.extend(_retrieved_memories) + retrieved_memories_from_deep_search.extend(additional_retrieved_memories) + merged_memories = self.post_retrieve( + retrieved_results=retrieved_memories + additional_retrieved_memories, + top_k=top_k * 2, + user_name=user_name, + info=info, + ) + + _mem_list, _ = self.tree_memories_to_text_memories(memories=merged_memories) + mem_list = _mem_list + mem_list = list(set(mem_list)) + logger.info( + "After stage %d, total memories in list: %d", + current_stage_id, + len(mem_list), + ) + + # Summarize memories + context, mem_list = self.summarize_memories( + query=query, + context=context, + text_memories="- " + "\n- ".join(mem_list) + "\n", + top_k=top_k, + ) + logger.info("After summarization, memory list contains %d items", len(mem_list)) + + except Exception as e: + logger.error("Error in stage %d: %s", current_stage_id, str(e), exc_info=True) + # Continue to next stage instead of failing completely + continue + logger.error("Deep search failed, returning original memories") + return memories diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 3f2b41a47..0720d1fca 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -2,6 +2,7 @@ import re from pathlib import Path +from typing import Any from memos.dependency import require_python_package from memos.log import get_logger @@ -10,6 +11,76 @@ logger = get_logger(__name__) +def parse_structured_output(content: str) -> dict[str, str | list[str]]: + """ + Parse structured text containing arbitrary XML-like tags in the format content. + + This function extracts all tagged content and automatically determines whether each tag's content + should be returned as a string or a list of strings based on its format: + + - If the content consists of multiple non-empty lines, and each line starts with "- ", + it is interpreted as a list (e.g., a bullet-point list of phrases). + - Otherwise, the entire content is returned as a single string. + + The function is generic and supports any tag name (e.g., , , ). + + Args: + content (str): Raw text containing one or more ... blocks. + + Returns: + Dict[str, Union[str, List[str]]]: A dictionary where keys are tag names and values are either: + - a string (for single-line or non-list content) + - a list of strings (for content formatted as bullet points with "- " prefix) + + Example: + Input: + + true + + + - phrase 1 + - phrase 2 + + + Output: + { + 'can_answer': 'true', + 'missing_phrases': ['phrase 1', 'phrase 2'] + } + """ + result = {} + + # Regex pattern to match any tag with name and content (supports multi-line content via DOTALL) + # Pattern explanation: + # <([a-zA-Z_][a-zA-Z0-9_]*)> : Captures valid tag name (letter/underscore + alphanumeric) + # (.*?) : Non-greedy capture of content (including newlines) + # : Closing tag matching the captured name + tag_pattern = r"<([a-zA-Z_][a-zA-Z0-9_]*)>(.*?)" + matches = re.findall(tag_pattern, content, re.DOTALL) + + for tag_name, raw_content in matches: + content = raw_content.strip() # Remove leading/trailing whitespace + + # If content is empty, store as empty string + if not content: + result[tag_name] = "" + continue + + # Split content into lines and filter out empty ones + lines = [line.strip() for line in content.splitlines() if line.strip()] + + # Check if content is formatted as a bullet list: all non-empty lines start with "- " + if lines and all(line.startswith("-") for line in lines): + # Extract the text after the "- " prefix from each line + items = [line[1:].strip() for line in lines] + result[tag_name] = items + else: + # Treat as plain string (preserve original formatting if multi-line) + result[tag_name] = content + + return result + + def find_project_root(marker=".git"): """Find the project root directory by marking the file""" current = Path(__file__).resolve() @@ -376,3 +447,19 @@ def detect_lang(text): return "en" except Exception: return "en" + + +def format_memory_item(memory_data: Any) -> dict[str, Any]: + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["usage"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d2fde36a3..ac57f8202 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -16,15 +16,18 @@ from memos.log import get_logger from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, - FINE_STRATEGY, MEM_READ_LABEL, PREF_ADD_LABEL, - FineStrategy, - SearchMode, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.multi_mem_cube.views import MemCubeView -from memos.types import MOSSearchResult, UserContext +from memos.types.general_types import ( + FINE_STRATEGY, + FineStrategy, + MOSSearchResult, + SearchMode, + UserContext, +) logger = get_logger(__name__) @@ -126,7 +129,6 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: ) self.logger.info(f"Search memories result: {memories_result}") - return memories_result def _get_search_mode(self, mode: str) -> str: @@ -147,7 +149,7 @@ def _search_text( user_context: UserContext, search_mode: str, ) -> list[dict[str, Any]]: - """ + """G Search text memories based on mode. Args: @@ -168,81 +170,37 @@ def _search_text( else: self.logger.error(f"Unsupported search mode: {search_mode}") return [] - return text_memories except Exception as e: self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) return [] - def _search_pref( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list[dict[str, Any]]: - """ - Search preference memories. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of formatted preference memory items - TODO: ADD CUBE ID IN PREFERENCE MEMORY - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - try: - results = self.naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - ) - return [format_memory_item(data) for data in results] - except Exception as e: - self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _fast_search( + def _deep_search( self, search_req: APISearchRequest, user_context: UserContext, ) -> list: - """ - Fast search using vector database. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of search results - """ target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - search_results = self.naive_mem_cube.text_mem.search( + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + enhanced_memories = self.searcher.deep_search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FAST, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, + info=info, ) - - formatted_memories = [format_memory_item(data) for data in search_results] - + formatted_memories = [format_memory_item(data) for data in enhanced_memories] return formatted_memories def _deep_search( @@ -270,9 +228,7 @@ def _fine_search( List of enhanced search results """ if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: - return self._deep_search( - search_req=search_req, user_context=user_context, max_thinking_depth=3 - ) + return self._deep_search(search_req=search_req, user_context=user_context) target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None @@ -283,20 +239,21 @@ def _fine_search( "chat_history": search_req.chat_history, } - # Fast retrieve - fast_retrieved_memories = self.searcher.retrieve( + # Fine retrieve + raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, search_filter=search_filter, info=info, ) # Post retrieve raw_memories = self.searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, + retrieved_results=raw_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, @@ -343,6 +300,76 @@ def _fine_search( return formatted_memories + def _search_pref( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list[dict[str, Any]]: + """ + Search preference memories. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted preference memory items + TODO: ADD CUBE ID IN PREFERENCE MEMORY + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + try: + results = self.naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [format_memory_item(data) for data in results] + except Exception as e: + self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _fast_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fast search using vector database. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + search_results = self.naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + + formatted_memories = [format_memory_item(data) for data in search_results] + + return formatted_memories + def _mix_search( self, search_req: APISearchRequest, diff --git a/src/memos/templates/advanced_search_prompts.py b/src/memos/templates/advanced_search_prompts.py new file mode 100644 index 000000000..13e80a79a --- /dev/null +++ b/src/memos/templates/advanced_search_prompts.py @@ -0,0 +1,276 @@ +MEMORY_SUMMARY_PROMPT = """ +# Memory Summary and Context Assembly + +## Role +You are a precise context assembler. Given a user query and a set of retrieved memories (each indexed), your task is to synthesize a factual, concise, and coherent context using only the information explicitly present in the memories. + +## Instructions + +### Core Principles +- Use ONLY facts from the provided memories. Do not invent, infer, guess, or hallucinate. +- Resolve all pronouns (e.g., "he", "it", "they") and vague terms (e.g., "this", "that", "some people") to explicit entities using memory content. +- Merge overlapping or redundant facts. Preserve temporal, spatial, and relational details. +- Each fact must be atomic, unambiguous, and verifiable. +- Preserve all key details: who, what, when, where, why — if present in memory. +- Created a summarized facts for answering query at the first item, and separate logically coherent separate memories. +- Begin the with a single, aggregated summary that directly answers the query using the most relevant facts. +- The total number of facts in must not exceed {top_k}. +- If additional context is relevant, try to weave it together logically—or chronologically—based on how the pieces connect. +- **Must preserve the full timeline of all memories**: if multiple events or states are mentioned with temporal markers (e.g., dates, sequences, phases), their chronological order must be retained in both and . + +### Processing Logic +- Aggregate logically connected memories (e.g., events involving the same person, cause-effect chains, repeated entities). +- Exclude any memory that does not directly support answering the query. +- Prioritize specificity: e.g., "Travis Tang moved to Singapore in 2021" > "He relocated abroad." + +## Input +- Query: {query} +- Current context: +{context} +- Current Memories: +{memories} + +## Output Format (STRICT TAG-BASED) +Respond ONLY with the following XML-style tags. Do NOT include any other text, explanations, or formatting. + + +A single, compact, fluent paragraph synthesizing the above facts into a coherent narrative directly relevant to the query. Use resolved entities and logical flow. No bullet points. No markdown. No commentary. + + +- Aggregated summary +- Fact 1 +- Fact 2 + + +Answer: +""" + +# Stage 1: determine answerability; if not answerable, produce concrete retrieval phrases for missing info +STAGE1_EXPAND_RETRIEVE_PROMPT = """ +# Stage 1 — Answerability and Missing Retrieval Phrases + +## Goal +Determine whether the current memories can answer the query using concrete, specific facts. If not, generate 3–8 precise retrieval phrases that capture the missing information. + +## Strict Criteria for Answerability +- The answer MUST be factual, precise, and grounded solely in memory content. +- Do NOT use vague adjectives (e.g., "usually", "often"), unresolved pronouns ("he", "it"), or generic statements. +- Do NOT answer with placeholders, speculation, or inferred information. + +## Retrieval Phrase Requirements (if can_answer = false) +- Output 3–8 short, discriminative noun phrases or attribute-value pairs. +- Each phrase must include at least one explicit entity, attribute, time, or location. +- Avoid fuzzy words, subjective terms, or pronouns. +- Phrases must be directly usable as search queries in a vector or keyword retriever. + +## Input +- Query: {query} +- Previous retrieval phrases: +{previous_retrieval_phrases} +- Current Memories: +{memories} + +## Output (STRICT TAG-BASED FORMAT) +Respond ONLY with the following structure. Do not add any other text, explanation, or formatting. + + +true or false + + +summary of current memories + + +Brief, one-sentence explanation for why the query is or isn't answerable with current memories. + + +- missing phrase 1 +- missing phrase 2 +... + + +Answer: +""" + + +# Stage 2: if Stage 1 phrases still fail, rewrite the retrieval query and phrases to maximize recall +STAGE2_EXPAND_RETRIEVE_PROMPT = """ +# Stage 2 — Rewrite Retrieval Query and Phrases to Improve Recall + +## Goal +If Stage 1's retrieval phrases failed to yield an answer, rewrite the original query and expand the phrase list to maximize recall of relevant memories. Use canonicalization, synonym expansion, and constraint enrichment. + +## Rewrite Strategy +- Canonicalize entities: use full names, official titles, or known aliases. +- Normalize time formats: e.g., "last year" → "2024", "in 2021" → "2021". +- Add discriminative tokens: entity + attribute + time + location where applicable. +- Split complex queries into focused sub-queries targeting distinct facets. +- Never include pronouns, vague terms, or subjective language. + +## Input +- Query: {query} +- Previous retrieval phrases: +{previous_retrieval_phrases} +- Context: {context} +- Current Memories: +{memories} + + +## Output (STRICT TAG-BASED FORMAT) +Respond ONLY with the following structure. Do not add any other text, explanation, or formatting. + + +true or false + + +Brief explanation (1–2 sentences) of how this rewrite improves recall over Stage 1 phrases. + + +summary of current memories + + +- new phrase 1 (Rewritten version of the original query. More precise, canonical, and retrieval-optimized.) +- new phrase 2 +... + + +Answer: +""" + + +# Stage 3: generate grounded hypotheses to guide retrieval when still not answerable +STAGE3_EXPAND_RETRIEVE_PROMPT = """ +# Stage 3 — Hypothesis Generation for Retrieval + +## Goal +When the query remains unanswerable, generate grounded, plausible hypotheses based ONLY on provided context and memories. Each hypothesis must imply a concrete retrieval target and validation criteria. + +## Rules +- Base hypotheses strictly on facts from the memories. No new entities or assumptions. +- Frame each hypothesis as a testable statement: "If [X] is true, then the query is answered." +- For each hypothesis, define 1–3 specific evidence requirements that would confirm it. +- Do NOT guess. Do NOT invent. Only extrapolate from existing facts. + +## Input +- Query: {query} +- Previous retrieval phrases: +{previous_retrieval_phrases} +- Context: {context} +- Memories: +{memories} + +## Output (STRICT TAG-BASED FORMAT) +Respond ONLY with the following structure. Do not add any other text, explanation, or formatting. + + +true or false + + +summary of current memories + + +- statement: + retrieval_query: + validation_criteria: + - + - +- statement: + retrieval_query: + validation_criteria: + - + + + +- hypothesis retrieval query 1 (searchable query derived from the hypothesis) +- hypothesis retrieval query 2: +... + + +Answer: +""" + +MEMORY_JUDGMENT_PROMPT = """ +# Memory Relevance Judgment + +## Role +You are a precise memory evaluator. Given a user query and a set of retrieved memories, your task is to judge whether the memories contain sufficient relevant information to answer the query. + +## Instructions + +### Core Principles +- Use ONLY facts from the provided memories. Do not invent, infer, guess, or hallucinate. +- Resolve all pronouns (e.g., "he", "it", "they") and vague terms (e.g., "this", "that", "some people") to explicit entities using memory content. +- Each fact must be atomic, unambiguous, and verifiable. +- Preserve all key details: who, what, when, where, why — if present in memory. +- Judge whether the memories directly support answering the query. +- Focus on relevance: does this memory content actually help answer what was asked? + +### Processing Logic +- Assess each memory's direct relevance to the query. +- Judge whether the combination of memories provides sufficient information for a complete answer. +- Exclude any memory that does not directly support answering the query. +- Prioritize specificity: e.g., "Travis Tang moved to Singapore in 2021" > "He relocated abroad." + +## Input +- Query: {query} +- Current Memories: +{memories} + +## Output Format (STRICT TAG-BASED) +Respond ONLY with the following XML-style tags. Do NOT include any other text, explanations, or formatting. + + +Brief explanation of why the memories are or are not sufficient for answering the query + + +YES or NO - indicating whether the memories are sufficient to answer the query + + +Answer: +""" + +MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. + +# RULES & THINKING STEPS +1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. +2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). +3. Resolve all ambiguities using only memory content: + - Pronouns → full name: “she” → “Melanie” + - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” + - “the user” → identity from context (e.g., “Melanie” if travel/running memories) +4. Never invent, assume, or extrapolate. +5. Each output line must be a standalone, clear, factual statement. +6. Output format: one line per fact, starting with "- ", no extra text. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query} + +## Original Memories +{memories} + +Final Output: +""" + + +PROMPT_MAPPING = { + "memory_summary": MEMORY_SUMMARY_PROMPT, + "memory_judgement": MEMORY_JUDGMENT_PROMPT, + "stage1_expand_retrieve": STAGE1_EXPAND_RETRIEVE_PROMPT, + "stage2_expand_retrieve": STAGE2_EXPAND_RETRIEVE_PROMPT, + "stage3_expand_retrieve": STAGE3_EXPAND_RETRIEVE_PROMPT, + "memory_recreate_enhancement": MEMORY_RECREATE_ENHANCEMENT_PROMPT, +} diff --git a/src/memos/types/__init__.py b/src/memos/types/__init__.py index dd1b98305..9e08f8f13 100644 --- a/src/memos/types/__init__.py +++ b/src/memos/types/__init__.py @@ -1,3 +1,34 @@ -# ruff: noqa: F403, F401 +from .general_types import ( + FINE_STRATEGY, + ChatHistory, + FineStrategy, + MemCubeID, + MessageDict, + MessageList, + MessageRole, + MessagesType, + MOSSearchResult, + Permission, + PermissionDict, + SearchMode, + UserContext, + UserID, +) -from .types import * + +__all__ = [ + "FINE_STRATEGY", + "ChatHistory", + "FineStrategy", + "MOSSearchResult", + "MemCubeID", + "MessageDict", + "MessageList", + "MessageRole", + "MessagesType", + "Permission", + "PermissionDict", + "SearchMode", + "UserContext", + "UserID", +] diff --git a/src/memos/types/types.py b/src/memos/types/general_types.py similarity index 72% rename from src/memos/types/types.py rename to src/memos/types/general_types.py index 481b4c692..9babdc096 100644 --- a/src/memos/types/types.py +++ b/src/memos/types/general_types.py @@ -4,8 +4,11 @@ used throughout the MemOS project to improve type safety and code clarity. """ +import os + from datetime import datetime -from typing import Literal, TypeAlias +from enum import Enum +from typing import Literal, NewType, TypeAlias from pydantic import BaseModel from typing_extensions import TypedDict @@ -22,15 +25,20 @@ __all__ = [ + "FINE_STRATEGY", "ChatHistory", + "FineStrategy", "MOSSearchResult", + "MemCubeID", "MessageDict", "MessageList", "MessageRole", "MessagesType", "Permission", "PermissionDict", + "SearchMode", "UserContext", + "UserID", ] # ─── Message Types ────────────────────────────────────────────────────────────── @@ -73,6 +81,42 @@ class ChatHistory(BaseModel): chat_history: MessageList +# ─── Search ──────────────────────────────────────────────────────────────────── +# new types +UserID = NewType("UserID", str) +MemCubeID = NewType("CubeID", str) + + +class SearchMode(str, Enum): + """Enumeration for search modes.""" + + FAST = "fast" + FINE = "fine" + MIXTURE = "mixture" + + +class FineStrategy(str, Enum): + """Enumeration for fine strategies.""" + + REWRITE = "rewrite" + RECREATE = "recreate" + DEEP_SEARCH = "deep_search" + + +# algorithm strategies +DEFAULT_FINE_STRATEGY = FineStrategy.DEEP_SEARCH +FINE_STRATEGY = DEFAULT_FINE_STRATEGY + +# Read fine strategy from environment variable `FINE_STRATEGY`. +# If provided and valid, use it; otherwise fall back to default. +_env_fine_strategy = os.getenv("FINE_STRATEGY") +if _env_fine_strategy: + try: + FINE_STRATEGY = FineStrategy(_env_fine_strategy) + except ValueError: + FINE_STRATEGY = DEFAULT_FINE_STRATEGY + + # ─── MemOS ──────────────────────────────────────────────────────────────────── diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 2aa96257b..7c4b4be9d 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -48,6 +48,7 @@ def mock_init_server(): "pref_mem": None, "online_bot": None, "chat_llms": Mock(), + "redis_client": Mock(), "deepsearch_agent": Mock(), } diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index e687d2986..fe889559c 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -156,7 +156,9 @@ def test_register_handlers(self): def test_dispatch_serial(self): """Test dispatching messages in serial mode.""" # Create a new dispatcher with parallel dispatch disabled - serial_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=False) + serial_dispatcher = SchedulerDispatcher( + max_workers=2, enable_parallel_dispatch=False, metrics=MagicMock() + ) # Create fresh mock handlers for this test mock_handler1 = MagicMock()