diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 706269b52..574f2ae17 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -41,6 +41,7 @@ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer if TYPE_CHECKING: @@ -196,6 +197,7 @@ def init_server() -> dict[str, Any]: logger.debug("Memory manager initialized") + tokenizer = FastTokenizer() # Initialize text memory text_mem = SimpleTreeTextMemory( llm=llm, @@ -206,6 +208,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, config=default_cube_config.text_mem.config, internet_retriever=internet_retriever, + tokenizer=tokenizer, ) logger.debug("Text memory initialized") diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 4f445e9ab..cc76e6751 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -388,6 +388,12 @@ class APISearchRequest(BaseRequest): description="(Internal) Operation definitions for multi-cube read permissions.", ) + # ==== Source for plugin ==== + source: str | None = Field( + None, + description="Source of the search query [plugin will router diff search]", + ) + @model_validator(mode="after") def _convert_deprecated_fields(self) -> "APISearchRequest": """ diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index bfde8c80c..90bea8ea3 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1451,6 +1451,130 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + @timed + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + """ + Full-text search functionality using PostgreSQL's full-text search capabilities. + + Args: + query_text: query text + top_k: maximum number of results to return + scope: memory type filter (memory_type) + status: status filter, defaults to "activated" + threshold: similarity threshold filter + search_filter: additional property filter conditions + user_name: username filter + knowledgebase_ids: knowledgebase ids filter + filter: filter conditions with 'and' or 'or' logic for search results. + tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 + tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + **kwargs: other parameters (e.g. cube_name) + + Returns: + list[dict]: result list containing id and score + """ + # Build WHERE clause dynamically, same as search_by_embedding + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + + # Add OR condition if we have any user_name conditions + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + # Add fulltext search condition + # Convert query_text to OR query format: "word1 | word2 | word3" + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text, + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY rank DESC + LIMIT {top_k}; + """ + + params = [tsquery_string, tsquery_string] + + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[2] # rank score + + id_val = str(oldid) + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + + return output[:top_k] + finally: + self._return_connection(conn) + @timed def search_by_embedding( self, diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 05e62e3ee..c67271f76 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -9,6 +9,7 @@ from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.reranker.base import BaseReranker @@ -35,6 +36,7 @@ def __init__( config: TreeTextMemoryConfig, internet_retriever: None = None, is_reorganize: bool = False, + tokenizer: FastTokenizer | None = None, ): """Initialize memory with the given configuration.""" self.config: TreeTextMemoryConfig = config @@ -51,6 +53,7 @@ def __init__( if self.search_strategy and self.search_strategy.get("bm25", False) else None ) + self.tokenizer = tokenizer self.reranker = reranker self.memory_manager: MemoryManager = memory_manager # Create internet retriever if configured diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 2a109bf71..ad2bcd9c4 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -91,6 +91,7 @@ def __init__(self, config: TreeTextMemoryConfig): ) else: logger.info("No internet retriever configured") + self.tokenizer = None def add( self, @@ -165,6 +166,7 @@ def search( search_priority: dict | None = None, search_filter: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. User query -> TaskGoalParser -> MemoryPathResolver -> @@ -197,6 +199,7 @@ def search( internet_retriever=None, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) else: searcher = Searcher( @@ -208,6 +211,7 @@ def search( internet_retriever=self.internet_retriever, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) return searcher.search( query, @@ -218,6 +222,7 @@ def search( search_filter, search_priority, user_name=user_name, + plugin=kwargs.get("plugin", False), ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 22cd44b8c..9c892d8b8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -10,6 +10,7 @@ from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, parse_structured_output, ) from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -33,6 +34,7 @@ def __init__( search_strategy: dict | None = None, manual_close_internet: bool = True, process_llm: Any | None = None, + tokenizer: FastTokenizer | None = None, ): super().__init__( dispatcher_llm=dispatcher_llm, @@ -43,6 +45,7 @@ def __init__( internet_retriever=internet_retriever, search_strategy=search_strategy, manual_close_internet=manual_close_internet, + tokenizer=tokenizer, ) self.stage_retrieve_top = 3 diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 7fa8a87be..7ac274a62 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -148,6 +148,25 @@ def retrieve_from_cube( return list(combined.values()) + def retrieve_from_mixed( + self, + top_k: int, + memory_scope: str | None = None, + query_embedding: list[list[float]] | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + ) -> list[TextualMemoryItem]: + """Retrieve from mixed and memory""" + vector_results = self._vector_recall( + query_embedding or [], + memory_scope, + top_k, + search_filter=search_filter, + user_name=user_name, + ) # Merge and deduplicate by ID + combined = {item.id: item for item in vector_results} + return list(combined.values()) + def _graph_recall( self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs ) -> list[TextualMemoryItem]: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 0720d1fca..9e1e6c240 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Any +import numpy as np + from memos.dependency import require_python_package from memos.log import get_logger @@ -463,3 +465,28 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]: memory["metadata"]["memory"] = memory["memory"] return memory + + +def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8): + assert len(sentences) == len(similarity_matrix) + + num_sentence = len(sentences) + selected_sentences = [] + selected_indices = [] + for i in range(num_sentence): + can_add = True + for j in selected_indices: + if similarity_matrix[i][j] > bar: + can_add = False + break + if can_add: + selected_sentences.append(i) + selected_indices.append(i) + return selected_sentences, selected_indices + + +def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]: + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + x_normalized = embeddings / norms + similarity_matrix = np.dot(x_normalized, x_normalized.T) + return similarity_matrix diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 976be6a54..f428bf5c0 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -8,7 +8,10 @@ from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + cosine_similarity_matrix, detect_lang, + find_best_unrelated_subgroup, parse_json_result, ) from memos.reranker.base import BaseReranker @@ -43,6 +46,7 @@ def __init__( internet_retriever: None = None, search_strategy: dict | None = None, manual_close_internet: bool = True, + tokenizer: FastTokenizer | None = None, ): self.graph_store = graph_store self.embedder = embedder @@ -58,6 +62,7 @@ def __init__( self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self.manual_close_internet = manual_close_internet + self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed @@ -104,9 +109,10 @@ def post_retrieve( top_k: int, user_name: str | None = None, info=None, + plugin=False, ): deduped = self._deduplicate_results(retrieved_results) - final_results = self._sort_and_trim(deduped, top_k) + final_results = self._sort_and_trim(deduped, top_k, plugin) self._update_usage_history(final_results, info, user_name) return final_results @@ -121,6 +127,7 @@ def search( search_filter: dict | None = None, search_priority: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -149,22 +156,29 @@ def search( else: logger.debug(f"[SEARCH] Received info dict: {info}") - retrieved_results = self.retrieve( - query=query, - top_k=top_k, - info=info, - mode=mode, - memory_type=memory_type, - search_filter=search_filter, - search_priority=search_priority, - user_name=user_name, - ) + if kwargs.get("plugin"): + logger.info(f"[SEARCH] Retrieve from plugin: {query}") + retrieved_results = self._retrieve_simple( + query=query, top_k=top_k, search_filter=search_filter, user_name=user_name + ) + else: + retrieved_results = self.retrieve( + query=query, + top_k=top_k, + info=info, + mode=mode, + memory_type=memory_type, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + ) final_results = self.post_retrieve( retrieved_results=retrieved_results, top_k=top_k, user_name=user_name, info=None, + plugin=kwargs.get("plugin", False), ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -484,6 +498,49 @@ def _retrieve_from_internet( parsed_goal=parsed_goal, ) + @timed + def _retrieve_simple( + self, + query: str, + top_k: int, + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + """Retrieve from by keywords and embedding""" + query_words = [] + if self.tokenizer: + query_words = self.tokenizer.tokenize_mixed(query) + else: + query_words = query.strip().split() + query_words = list(set(query_words))[: top_k * 3] + query_words = [query, *query_words] + logger.info(f"[SIMPLESEARCH] Query words: {query_words}") + query_embeddings = self.embedder.embed(query_words) + + items = self.graph_retriever.retrieve_from_mixed( + top_k=top_k * 2, + memory_scope=None, + query_embedding=query_embeddings, + search_filter=search_filter, + user_name=user_name, + ) + logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") + documents = [getattr(item, "memory", "") for item in items] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) + return self.reranker.rerank( + query=query, + query_embedding=query_embeddings[0], + graph_results=selected_items, + top_k=top_k, + ) + @timed def _deduplicate_results(self, results): """Deduplicate results by memory text""" @@ -494,12 +551,14 @@ def _deduplicate_results(self, results): return list(deduped.values()) @timed - def _sort_and_trim(self, results, top_k): + def _sort_and_trim(self, results, top_k, plugin=False): """Sort results by score and trim to top_k""" sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: + if plugin and round(score, 2) == 0.00: + continue meta_data = item.metadata.model_dump() meta_data["relativity"] = score final_items.append( diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index e346bdf1f..880646939 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -363,7 +363,8 @@ def _fast_search( target_session_id = search_req.session_id or "default_session" search_priority = {"session_id": search_req.session_id} if search_req.session_id else None search_filter = search_req.filter or None - print(f"type of text_mem: {type(self.naive_mem_cube.text_mem)}") + plugin = bool(search_req.source is not None and search_req.source == "plugin") + search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, @@ -377,6 +378,7 @@ def _fast_search( "session_id": target_session_id, "chat_history": search_req.chat_history, }, + plugin=plugin, ) formatted_memories = [format_memory_item(data) for data in search_results]