Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/memos/api/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def generate_chat_response() -> Generator[str, None, None]:
include_preference=chat_req.include_preference,
pref_top_k=chat_req.pref_top_k,
filter=chat_req.filter,
playground_search_goal_parser=True,
)

search_response = self.search_handler.handle_search_memories(search_req)
Expand Down
3 changes: 3 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ class APISearchRequest(BaseRequest):
),
)

# TODO: tmp field for playground search goal parser, will be removed later
playground_search_goal_parser: bool = Field(False, description="Playground search goal parser")

# ==== Context ====
chat_history: MessageList | None = Field(
None,
Expand Down
67 changes: 21 additions & 46 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,15 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int
def get_searcher(
self, manual_close_internet: bool = False, moscube: bool = False, process_llm=None
):
if (self.internet_retriever is not None) and manual_close_internet:
logger.warning(
"Internet retriever is init by config , but this search set manual_close_internet is True and will close it"
)
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
internet_retriever=None,
process_llm=process_llm,
)
else:
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
internet_retriever=self.internet_retriever,
process_llm=process_llm,
)
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
internet_retriever=self.internet_retriever,
manual_close_internet=manual_close_internet,
process_llm=process_llm,
)
return searcher

def search(
Expand Down Expand Up @@ -191,30 +179,17 @@ def search(
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
if (self.internet_retriever is not None) and manual_close_internet:
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=None,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
else:
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=self.internet_retriever,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
self.embedder,
self.reranker,
bm25_retriever=self.bm25_retriever,
internet_retriever=self.internet_retriever,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
return searcher.search(
query,
top_k,
Expand All @@ -224,9 +199,9 @@ def search(
search_filter,
search_priority,
user_name=user_name,
plugin=kwargs.get("plugin", False),
search_tool_memory=search_tool_memory,
tool_mem_top_k=tool_mem_top_k,
**kwargs,
)

def get_relevant_subgraph(
Expand Down
13 changes: 10 additions & 3 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def retrieve(
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
**kwargs,
)
results = self._retrieve_paths(
query,
Expand Down Expand Up @@ -166,7 +167,7 @@ def search(
else:
logger.debug(f"[SEARCH] Received info dict: {info}")

if kwargs.get("plugin"):
if kwargs.get("plugin", False):
logger.info(f"[SEARCH] Retrieve from plugin: {query}")
retrieved_results = self._retrieve_simple(
query=query, top_k=top_k, search_filter=search_filter, user_name=user_name
Expand All @@ -183,6 +184,7 @@ def search(
user_name=user_name,
search_tool_memory=search_tool_memory,
tool_mem_top_k=tool_mem_top_k,
**kwargs,
)

full_recall = kwargs.get("full_recall", False)
Expand Down Expand Up @@ -218,6 +220,7 @@ def _parse_task(
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
**kwargs,
):
"""Parse user query, do embedding search and create context"""
context = []
Expand Down Expand Up @@ -268,6 +271,7 @@ def _parse_task(
conversation=info.get("chat_history", []),
mode=mode,
use_fast_graph=self.use_fast_graph,
**kwargs,
)

query = parsed_goal.rephrased_query or query
Expand Down Expand Up @@ -351,7 +355,7 @@ def _retrieve_paths(
query,
parsed_goal,
query_embedding,
top_k,
tool_mem_top_k,
memory_type,
search_filter,
search_priority,
Expand Down Expand Up @@ -516,7 +520,10 @@ def _retrieve_from_internet(
user_id: str | None = None,
):
"""Retrieve and rerank from Internet source"""
if not self.internet_retriever or self.manual_close_internet:
if not self.internet_retriever:
logger.info(f"[PATH-C] '{query}' Skipped (no retriever)")
return []
if self.manual_close_internet and not parsed_goal.internet_search:
logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)")
return []
if memory_type not in ["All"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def parse(
- mode == 'fast': use jieba to split words only
- mode == 'fine': use LLM to parse structured topic/keys/tags
"""
# TODO: tmp mode for playground search goal parser, will be removed later
if kwargs.get("playground_search_goal_parser", False):
mode = "fine"

if mode == "fast":
return self._parse_fast(task_description, context=context, **kwargs)
elif mode == "fine":
Expand Down
2 changes: 2 additions & 0 deletions src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def _fast_search(
plugin=plugin,
search_tool_memory=search_req.search_tool_memory,
tool_mem_top_k=search_req.tool_mem_top_k,
# TODO: tmp field for playground search goal parser, will be removed later
playground_search_goal_parser=search_req.playground_search_goal_parser,
)

formatted_memories = [format_memory_item(data) for data in search_results]
Expand Down
Loading