diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 657caf054..517005c9d 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1528,7 +1528,97 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: raise NotImplementedError @timed - def seach_by_keywords( + def seach_by_keywords_like( + self, + query_word: str, + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + + # Add OR condition if we have any user_name conditions + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + + # Build key + where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + + params = (query_word,) + logger.info( + f"[seach_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" + ) + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + output.append({"id": id_val}) + logger.info( + f"[seach_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output + finally: + self._return_connection(conn) + + @timed + def seach_by_keywords_tfidf( self, query_words: list[str], scope: str | None = None, @@ -1603,7 +1693,9 @@ def seach_by_keywords( """ params = (tsquery_string,) - logger.info(f"[search_by_fulltext] query: {query}, params: {params}") + logger.info( + f"[seach_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" + ) conn = self._get_connection() try: with conn.cursor() as cursor: @@ -1615,6 +1707,9 @@ def seach_by_keywords( id_val = str(oldid) output.append({"id": id_val}) + logger.info( + f"[seach_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) return output finally: self._return_connection(conn) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 49fd382a0..b986f7f13 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -78,6 +78,7 @@ def __init__(self, config: MemFeedbackConfig): is_reorganize=self.is_reorganize, ) self.searcher: Searcher = self.memory_manager.searcher + self.DB_IDX_READY = False def _batch_embed(self, texts: list[str], embed_bs: int = 5): embed_bs = 5 @@ -569,15 +570,24 @@ def process_keyword_replace(self, user_id: str, user_name: str, kwp_judge: dict original_word = kwp_judge.get("original") target_word = kwp_judge.get("target") - # retrieve - lang = detect_lang(original_word) - queries = self._tokenize_chinese(original_word) if lang == "zh" else original_word.split() + if self.DB_IDX_READY: + # retrieve + lang = detect_lang(original_word) + queries = ( + self._tokenize_chinese(original_word) if lang == "zh" else original_word.split() + ) - must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0] - retrieved_ids = self.graph_store.seach_by_keywords([must_part], user_name=user_name) - if len(retrieved_ids) < 1: - retrieved_ids = self.graph_store.search_by_fulltext( - queries, top_k=100, user_name=user_name + must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0] + retrieved_ids = self.graph_store.seach_by_keywords_tfidf( + [must_part], user_name=user_name + ) + if len(retrieved_ids) < 1: + retrieved_ids = self.graph_store.search_by_fulltext( + queries, top_k=100, user_name=user_name + ) + else: + retrieved_ids = self.graph_store.seach_by_keywords_like( + f"%{original_word}%", user_name=user_name ) # filter by doc scope diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index bb5a1c552..478fa104f 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -29,3 +29,4 @@ def __init__( self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager + self.DB_IDX_READY = False