Skip to content
111 changes: 68 additions & 43 deletions src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from memos.api.product_models import (
APIADDRequest,
APIChatCompleteRequest,
APISearchPlaygroundRequest,
APISearchRequest,
ChatPlaygroundRequest,
ChatRequest,
)
from memos.context.context import ContextThread
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(
self.enable_mem_scheduler = (
hasattr(dependencies, "enable_mem_scheduler") and dependencies.enable_mem_scheduler
)
self.dependencies = dependencies

def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -356,7 +359,7 @@ def generate_chat_response() -> Generator[str, None, None]:
self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err

def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse:
def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> StreamingResponse:
"""
Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers.

Expand Down Expand Up @@ -413,8 +416,8 @@ def generate_chat_response() -> Generator[str, None, None]:
label=QUERY_TASK_LABEL,
)

# ====== first search without parse goal ======
search_req = APISearchRequest(
# ====== first search text mem with parse goal ======
search_req = APISearchPlaygroundRequest(
query=chat_req.query,
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
Expand All @@ -426,6 +429,7 @@ def generate_chat_response() -> Generator[str, None, None]:
include_preference=chat_req.include_preference,
pref_top_k=chat_req.pref_top_k,
filter=chat_req.filter,
playground_search_goal_parser=True,
)
search_response = self.search_handler.handle_search_memories(search_req)

Expand All @@ -439,10 +443,10 @@ def generate_chat_response() -> Generator[str, None, None]:
memories_list = text_mem_results[0]["memories"]

# Filter memories by threshold
first_filtered_memories = self._filter_memories_by_threshold(memories_list)
filtered_memories = self._filter_memories_by_threshold(memories_list)

# Prepare reference data (first search)
reference = prepare_reference_data(first_filtered_memories)
reference = prepare_reference_data(filtered_memories)
# get preference string
pref_string = search_response.data.get("pref_string", "")

Expand All @@ -455,48 +459,68 @@ def generate_chat_response() -> Generator[str, None, None]:
pref_md_string = self._build_pref_md_string_for_playground(pref_memories)
yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n"

# internet status
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"

# ====== second search with parse goal ======
search_req = APISearchRequest(
query=chat_req.query,
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
mode=chat_req.mode,
internet_search=chat_req.internet_search,
top_k=chat_req.top_k,
chat_history=chat_req.history,
session_id=chat_req.session_id,
include_preference=False,
filter=chat_req.filter,
playground_search_goal_parser=True,
# parse goal for internet search
searcher = self.dependencies.searcher
parsed_goal = searcher.task_goal_parser.parse(
task_description=chat_req.query,
context="\n".join(
[memory.get("memory", "") for memory in filtered_memories]
),
conversation=chat_req.history,
mode="fine",
)
search_response = self.search_handler.handle_search_memories(search_req)

# Extract memories from search results (second search)
memories_list = []
if search_response.data and search_response.data.get("text_mem"):
text_mem_results = search_response.data["text_mem"]
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]
if chat_req.beginner_guide_step == "first":
chat_req.internet_search = False
parsed_goal.internet_search = False
elif chat_req.beginner_guide_step == "second":
chat_req.internet_search = True
parsed_goal.internet_search = True

if chat_req.internet_search or parsed_goal.internet_search:
# internet status
yield f"data: {json.dumps({'type': 'status', 'data': 'start_internet_search'})}\n\n"

# ====== internet search with parse goal ======
search_req = APISearchPlaygroundRequest(
query=chat_req.query
+ (f"{parsed_goal.tags}" if parsed_goal.tags else ""),
user_id=chat_req.user_id,
readable_cube_ids=readable_cube_ids,
mode=chat_req.mode,
internet_search=True,
top_k=chat_req.top_k,
chat_history=chat_req.history,
session_id=chat_req.session_id,
include_preference=False,
filter=chat_req.filter,
search_memory_type="OuterMemory",
)
search_response = self.search_handler.handle_search_memories(search_req)

# Filter memories by threshold
second_filtered_memories = self._filter_memories_by_threshold(memories_list)
# Extract memories from search results (second search)
memories_list = []
if search_response.data and search_response.data.get("text_mem"):
text_mem_results = search_response.data["text_mem"]
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]

# dedup and supplement memories
filtered_memories = self._dedup_and_supplement_memories(
first_filtered_memories, second_filtered_memories
)
# Filter memories by threshold
second_filtered_memories = self._filter_memories_by_threshold(memories_list)

# Prepare remain reference data (second search)
reference = prepare_reference_data(filtered_memories)
# get internet reference
internet_reference = self._get_internet_reference(
search_response.data.get("text_mem")[0]["memories"]
)
# dedup and supplement memories
filtered_memories = self._dedup_and_supplement_memories(
filtered_memories, second_filtered_memories
)

yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
# Prepare remain reference data (second search)
reference = prepare_reference_data(filtered_memories)
# get internet reference
internet_reference = self._get_internet_reference(
search_response.data.get("text_mem")[0]["memories"]
)

yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"

# Step 2: Build system prompt with memories
system_prompt = self._build_enhance_system_prompt(
Expand Down Expand Up @@ -571,8 +595,9 @@ def generate_chat_response() -> Generator[str, None, None]:
chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
yield chunk_data

# Yield internet reference after text response
yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n"
if chat_req.internet_search or parsed_goal.internet_search:
# Yield internet reference after text response
yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n"

# Calculate timing
time_end = time.time()
Expand Down
23 changes: 20 additions & 3 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ def _convert_deprecated_fields(self):
return self


class ChatPlaygroundRequest(ChatRequest):
"""Request model for chat operations in playground."""

beginner_guide_step: str | None = Field(
None, description="Whether to use beginner guide, option: [first, second]"
)


class ChatCompleteRequest(BaseRequest):
"""Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest."""

Expand Down Expand Up @@ -373,9 +381,11 @@ class APISearchRequest(BaseRequest):
"If None, default thresholds will be applied."
),
)

# TODO: tmp field for playground search goal parser, will be removed later
playground_search_goal_parser: bool = Field(False, description="Playground search goal parser")
# Internal field for search memory type
search_memory_type: str = Field(
"All",
description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory",
)

# ==== Context ====
chat_history: MessageList | None = Field(
Expand Down Expand Up @@ -448,6 +458,13 @@ def _convert_deprecated_fields(self) -> "APISearchRequest":
return self


class APISearchPlaygroundRequest(APISearchRequest):
"""Request model for searching memories in playground."""

# TODO: tmp field for playground search goal parser, will be removed later
playground_search_goal_parser: bool = Field(False, description="Playground search goal parser")


class APIADDRequest(BaseRequest):
"""Request model for creating memories."""

Expand Down
3 changes: 2 additions & 1 deletion src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
APIChatCompleteRequest,
APIFeedbackRequest,
APISearchRequest,
ChatPlaygroundRequest,
ChatRequest,
DeleteMemoryRequest,
DeleteMemoryResponse,
Expand Down Expand Up @@ -200,7 +201,7 @@ def chat_stream(chat_req: ChatRequest):


@router.post("/chat/stream/playground", summary="Chat with MemOS playground")
def chat_stream_playground(chat_req: ChatRequest):
def chat_stream_playground(chat_req: ChatPlaygroundRequest):
"""
Chat with MemOS for a specific user. Returns SSE stream.

Expand Down
3 changes: 3 additions & 0 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,12 @@ def get_searcher(
self.graph_store,
self.embedder,
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=self.internet_retriever,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
process_llm=process_llm,
tokenizer=self.tokenizer,
)
return searcher

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ def _parse_task(
**kwargs,
)

# TODO: tmp field playground_search_goal_parser for playground, will be removed later
if kwargs.get("playground_search_goal_parser", False):
parsed_goal.internet_search = False

query = parsed_goal.rephrased_query or query
# if goal has extra memories, embed them too
if parsed_goal.memories:
Expand Down Expand Up @@ -527,7 +531,8 @@ def _retrieve_from_internet(
if self.manual_close_internet and not parsed_goal.internet_search:
logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)")
return []
if memory_type not in ["All"]:
if memory_type not in ["All", "OuterMemory"]:
logger.info(f"[PATH-C] '{query}' Skipped (memory_type does not match)")
return []
logger.info(f"[PATH-C] '{query}' Retrieving from internet...")
items = self.internet_retriever.retrieve_from_internet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse(
elif mode == "fine":
if not self.llm:
raise ValueError("LLM not provided for slow mode.")
return self._parse_fine(task_description, context, conversation)
return self._parse_fine(task_description, context, conversation, **kwargs)
else:
raise ValueError(f"Unknown mode: {mode}")

Expand Down Expand Up @@ -81,7 +81,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
)

def _parse_fine(
self, query: str, context: str = "", conversation: list[dict] | None = None
self, query: str, context: str = "", conversation: list[dict] | None = None, **kwargs
) -> ParsedTaskGoal:
"""
Slow mode: LLM structured parse.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
1. Keys: the high-level keywords directly relevant to the user’s task.
2. Tags: thematic tags to help categorize and retrieve related memories.
3. Goal Type: retrieval | qa | generation
4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string.
4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. Make full use of information related to the query. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string.
5. Need for internet search: If the user's task instruction only involves objective facts or can be completed without introducing external knowledge, set "internet_search" to False. Otherwise, set it to True.
6. Memories: Provide 2–5 short semantic expansions or rephrasings of the rephrased/original user task instruction. These are used for improved embedding search coverage. Each should be clear, concise, and meaningful for retrieval.

Expand Down
1 change: 1 addition & 0 deletions src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def _fast_search(
top_k=search_req.top_k,
mode=SearchMode.FAST,
manual_close_internet=not search_req.internet_search,
momory_type=search_req.search_memory_type,
search_filter=search_filter,
search_priority=search_priority,
info={
Expand Down
Loading