Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def _send_message_to_scheduler(
content=query,
timestamp=datetime.utcnow(),
)
self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
self.mem_scheduler.submit_messages(messages=[message_item])
self.logger.info(f"Sent message to scheduler with label: {label}")
except Exception as e:
self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True)
Expand Down
24 changes: 18 additions & 6 deletions src/memos/api/handlers/scheduler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def handle_scheduler_status(
Args:
user_id: User ID to query for.
status_tracker: The TaskStatusTracker instance.
task_id: Optional Task ID to query a specific task.
task_id: Optional Task ID to query. Can be either:
- business_task_id (will aggregate all related item statuses)
- item_id (will return single item status)

Returns:
StatusResponse with a list of task statuses.
Expand All @@ -46,12 +48,22 @@ def handle_scheduler_status(

try:
if task_id:
task_data = status_tracker.get_task_status(task_id, user_id)
if not task_data:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found for user {user_id}"
# First try as business_task_id (aggregated query)
business_task_data = status_tracker.get_task_status_by_business_id(task_id, user_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增了business_id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

business_id在status tracker内部为实际业务task_id

if business_task_data:
response_data.append(
StatusResponseItem(task_id=task_id, status=business_task_data["status"])
)
else:
# Fallback: try as item_id (single item query)
item_task_data = status_tracker.get_task_status(task_id, user_id)
if not item_task_data:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found for user {user_id}"
)
response_data.append(
StatusResponseItem(task_id=task_id, status=item_task_data["status"])
)
response_data.append(StatusResponseItem(task_id=task_id, status=task_data["status"]))
else:
all_tasks = status_tracker.get_all_tasks_for_user(user_id)
# The plan returns an empty list, which is good.
Expand Down
1 change: 1 addition & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class MemoryCreateRequest(BaseRequest):
source: str | None = Field(None, description="Source of the memory")
user_profile: bool = Field(False, description="User profile memory")
session_id: str | None = Field(None, description="Session id")
task_id: str | None = Field(None, description="Task ID for monitoring async tasks")


class SearchRequest(BaseRequest):
Expand Down
46 changes: 46 additions & 0 deletions src/memos/api/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,43 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
@router.post("/add", summary="add a new memory", response_model=SimpleResponse)
def create_memory(memory_req: MemoryCreateRequest):
"""Create a new memory for a specific user."""
# Initialize status_tracker outside try block to avoid NameError in except blocks
status_tracker = None

try:
time_start_add = time.time()
mos_product = get_mos_product_instance()

# Track task if task_id is provided
item_id: str | None = None
if (
memory_req.task_id
and hasattr(mos_product, "mem_scheduler")
and mos_product.mem_scheduler
):
from uuid import uuid4

from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker

item_id = str(uuid4()) # Generate a unique item_id for this submission

# Get Redis client from scheduler
if (
hasattr(mos_product.mem_scheduler, "redis_client")
and mos_product.mem_scheduler.redis_client
):
status_tracker = TaskStatusTracker(mos_product.mem_scheduler.redis_client)
# Submit task with "product_add" type
status_tracker.task_submitted(
task_id=item_id, # Use generated item_id for internal tracking
user_id=memory_req.user_id,
task_type="product_add",
mem_cube_id=memory_req.mem_cube_id or memory_req.user_id,
business_task_id=memory_req.task_id, # Use memory_req.task_id as business_task_id
)
status_tracker.task_started(item_id, memory_req.user_id) # Use item_id here

# Execute the add operation
mos_product.add(
user_id=memory_req.user_id,
memory_content=memory_req.memory_content,
Expand All @@ -200,15 +234,27 @@ def create_memory(memory_req: MemoryCreateRequest):
source=memory_req.source,
user_profile=memory_req.user_profile,
session_id=memory_req.session_id,
task_id=memory_req.task_id,
)

# Mark task as completed
if status_tracker and item_id:
status_tracker.task_completed(item_id, memory_req.user_id)

logger.info(
f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
)
return SimpleResponse(message="Memory created successfully")

except ValueError as err:
# Mark task as failed if tracking
if status_tracker and item_id:
status_tracker.task_failed(item_id, memory_req.user_id, str(err))
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
except Exception as err:
# Mark task as failed if tracking
if status_tracker and item_id:
status_tracker.task_failed(item_id, memory_req.user_id, str(err))
logger.error(f"Failed to create memory: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err

Expand Down
3 changes: 3 additions & 0 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def add(
mem_cube_id: str | None = None,
user_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None, # New: Add task_id parameter
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -773,6 +774,7 @@ def process_textual_memory():
label=MEM_READ_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
task_id=task_id,
)
self.mem_scheduler.memos_message_queue.submit_messages(
messages=[message_item]
Expand All @@ -784,6 +786,7 @@ def process_textual_memory():
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
task_id=task_id,
)
self.mem_scheduler.memos_message_queue.submit_messages(
messages=[message_item]
Expand Down
9 changes: 8 additions & 1 deletion src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,13 +1499,20 @@ def add(
source: str | None = None,
user_profile: bool = False,
session_id: str | None = None,
task_id: str | None = None, # Add task_id parameter
):
"""Add memory for a specific user."""

# Load user cubes if not already loaded
self._load_user_cubes(user_id, self.default_cube_config)
result = super().add(
messages, memory_content, doc_path, mem_cube_id, user_id, session_id=session_id
messages,
memory_content,
doc_path,
mem_cube_id,
user_id,
session_id=session_id,
task_id=task_id,
)
if user_profile:
try:
Expand Down
1 change: 1 addition & 0 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
user_id=message.user_id,
task_type=message.label,
mem_cube_id=message.mem_cube_id,
business_task_id=message.task_id, # Pass business task_id if provided
)
self.memos_message_queue.submit_messages(messages=messages)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_autofilled_log_item(
mem_cube: GeneralMemCube,
) -> ScheduleLogForWebItem:
text_mem_base: TreeTextMemory = mem_cube.text_mem
current_memory_sizes = text_mem_base.get_current_memory_size()
current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id)
current_memory_sizes = {
"long_term_memory_size": current_memory_sizes.get("LongTermMemory", 0),
"user_memory_size": current_memory_sizes.get("UserMemory", 0),
Expand Down
Loading
Loading