diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index fd0dfc7f8..2758c9e32 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -5,6 +5,8 @@ using dependency injection for better modularity and testability. """ +from pydantic import validate_call + from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse from memos.memories.textual.item import ( @@ -13,6 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.types import MessageList class AddHandler(BaseHandler): @@ -60,38 +63,45 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: cube_view = self._build_cube_view(add_req) + @validate_call + def _check_messages(messages: MessageList) -> None: + pass + if add_req.is_feedback: - chat_history = add_req.chat_history - messages = add_req.messages - if chat_history is None: - chat_history = [] - if messages is None: - messages = [] - concatenate_chat = chat_history + messages - - last_user_index = max(i for i, d in enumerate(concatenate_chat) if d["role"] == "user") - feedback_content = concatenate_chat[last_user_index]["content"] - feedback_history = concatenate_chat[:last_user_index] - - feedback_req = APIFeedbackRequest( - user_id=add_req.user_id, - session_id=add_req.session_id, - task_id=add_req.task_id, - history=feedback_history, - feedback_content=feedback_content, - writable_cube_ids=add_req.writable_cube_ids, - async_mode=add_req.async_mode, - ) - process_record = cube_view.feedback_memories(feedback_req) + try: + messages = add_req.messages + _check_messages(messages) - self.logger.info( - f"[FeedbackHandler] Final feedback results count={len(process_record)}" - ) + chat_history = add_req.chat_history if add_req.chat_history else [] + concatenate_chat = chat_history + messages - return MemoryResponse( - message="Memory feedback successfully", - data=[process_record], - ) + last_user_index = max( + i for i, d in enumerate(concatenate_chat) if d["role"] == "user" + ) + feedback_content = concatenate_chat[last_user_index]["content"] + feedback_history = concatenate_chat[:last_user_index] + + feedback_req = APIFeedbackRequest( + user_id=add_req.user_id, + session_id=add_req.session_id, + task_id=add_req.task_id, + history=feedback_history, + feedback_content=feedback_content, + writable_cube_ids=add_req.writable_cube_ids, + async_mode=add_req.async_mode, + ) + process_record = cube_view.feedback_memories(feedback_req) + + self.logger.info( + f"[ADDFeedbackHandler] Final feedback results count={len(process_record)}" + ) + + return MemoryResponse( + message="Memory feedback successfully", + data=[process_record], + ) + except Exception as e: + self.logger.warning(f"[ADDFeedbackHandler] Running error: {e}") results = cube_view.add_memories(add_req) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index ffe736aa3..1c0f68a98 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -642,7 +642,6 @@ class APIFeedbackRequest(BaseRequest): ) feedback_content: str | None = Field(..., description="Feedback content to process") feedback_time: str | None = Field(None, description="Feedback time") - # ==== Multi-cube writing ==== writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube add" ) @@ -650,7 +649,7 @@ class APIFeedbackRequest(BaseRequest): "async", description="feedback mode: sync or async" ) corrected_answer: bool = Field(False, description="Whether need return corrected answer") - # ==== Backward compatibility ==== + # ==== mem_cube_id is NOT enabled==== mem_cube_id: str | None = Field( None, description=( diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 7657ef7e3..0ae4cfdb4 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1455,6 +1455,98 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + @timed + def seach_by_keywords( + self, + query_words: list[str], + scope: str | None = None, + status: str | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Build user_name filter with knowledgebase_ids support (OR relationship) using common method + user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( + user_name=user_name, + knowledgebase_ids=knowledgebase_ids, + default_user_name=self.config.user_name, + ) + + # Add OR condition if we have any user_name conditions + if user_name_conditions: + if len(user_name_conditions) == 1: + where_clauses.append(user_name_conditions[0]) + else: + where_clauses.append(f"({' OR '.join(user_name_conditions)})") + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + where_clauses.extend(filter_conditions) + # Add fulltext search condition + # Convert query_text to OR query format: "word1 | word2 | word3" + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + + params = (tsquery_string,) + logger.info(f"[search_by_fulltext] query: {query}, params: {params}") + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + output.append({"id": id_val}) + + return output + finally: + self._return_connection(conn) + @timed def search_by_fulltext( self, diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 02b737451..eed43d66e 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -10,10 +10,12 @@ from memos import log from memos.configs.memory import MemFeedbackConfig from memos.context.context import ContextThreadPoolExecutor +from memos.dependency import require_python_package from memos.embedders.factory import EmbedderFactory, OllamaEmbedder from memos.graph_dbs.factory import GraphStoreFactory, PolarDBGraphDB from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.mem_feedback.base import BaseMemFeedback +from memos.mem_feedback.utils import should_keep_update, split_into_chunks from memos.mem_reader.factory import MemReaderFactory from memos.mem_reader.simple_struct import detect_lang from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -30,6 +32,8 @@ FEEDBACK_ANSWER_PROMPT_ZH, FEEDBACK_JUDGEMENT_PROMPT, FEEDBACK_JUDGEMENT_PROMPT_ZH, + KEYWORDS_REPLACE, + KEYWORDS_REPLACE_ZH, UPDATE_FORMER_MEMORIES, UPDATE_FORMER_MEMORIES_ZH, ) @@ -37,6 +41,7 @@ FEEDBACK_PROMPT_DICT = { + "if_kw_replace": {"en": KEYWORDS_REPLACE, "zh": KEYWORDS_REPLACE_ZH}, "judge": {"en": FEEDBACK_JUDGEMENT_PROMPT, "zh": FEEDBACK_JUDGEMENT_PROMPT_ZH}, "compare": {"en": UPDATE_FORMER_MEMORIES, "zh": UPDATE_FORMER_MEMORIES_ZH}, "generation": {"en": FEEDBACK_ANSWER_PROMPT, "zh": FEEDBACK_ANSWER_PROMPT_ZH}, @@ -74,6 +79,20 @@ def __init__(self, config: MemFeedbackConfig): ) self.searcher: Searcher = self.memory_manager.searcher + def _batch_embed(self, texts: list[str], embed_bs: int = 5): + embed_bs = 5 + texts_embeddings = [] + for i in range(0, len(texts), embed_bs): + batch = texts[i : i + embed_bs] + try: + texts_embeddings.extend(self.embedder.embed(batch)) + except Exception as e: + logger.error( + f"[Feedback Core: process_feedback_core] Embedding batch failed: {e}", + exc_info=True, + ) + return texts_embeddings + def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, info: dict): """ Directly add new memory @@ -97,6 +116,25 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i } } + def _keyword_replace_judgement(self, feedback_content: str) -> dict | None: + """ + Determine whether it is keyword replacement + """ + lang = detect_lang(feedback_content) + template = FEEDBACK_PROMPT_DICT["if_kw_replace"][lang] + prompt = template.format( + user_feedback=feedback_content, + ) + + judge_res = self._get_llm_response(prompt) + if judge_res: + return judge_res + else: + logger.warning( + "[Feedback Core: _feedback_judgement] feedback judgement failed, return []" + ) + return {} + def _feedback_judgement( self, chat_history: list[MessageDict], feedback_content: str, feedback_time: str = "" ) -> dict | None: @@ -128,7 +166,7 @@ def _single_add_operation( new_memory_item: TextualMemoryItem, user_id: str, user_name: str, - async_mode: str, + async_mode: str = "sync", ) -> dict: """ Individual addition operations @@ -166,7 +204,7 @@ def _single_update_operation( new_memory_item: TextualMemoryItem, user_id: str, user_name: str, - async_mode: str, + async_mode: str = "sync", ) -> dict: """ Individual update operations @@ -231,10 +269,111 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> f"[Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) + def semantics_feedback( + self, + user_id: str, + user_name: str, + memory_item: TextualMemoryItem, + current_memories: list[TextualMemoryItem], + fact_history: str, + ): + lang = detect_lang("".join(memory_item.memory)) + template = FEEDBACK_PROMPT_DICT["compare"][lang] + if current_memories == []: + current_memories = self._retrieve( + memory_item.memory, info={"user_id": user_id}, user_name=user_name + ) + + if not current_memories: + operations = [{"operation": "ADD"}] + else: + memory_chunks = split_into_chunks(current_memories, max_tokens_per_chunk=500) + + all_operations = [] + with ContextThreadPoolExecutor(max_workers=10) as executor: + future_to_chunk_idx = {} + for chunk in memory_chunks: + current_memories_str = "\n".join( + [f"{item.id}: {item.memory}" for item in chunk] + ) + prompt = template.format( + current_memories=current_memories_str, + new_facts=memory_item.memory, + chat_history=fact_history, + ) + + future = executor.submit(self._get_llm_response, prompt) + future_to_chunk_idx[future] = chunk + for future in concurrent.futures.as_completed(future_to_chunk_idx): + try: + chunk_operations = future.result() + if ( + chunk_operations + and "operations" in chunk_operations + and isinstance(chunk_operations["operations"], list) + ): + all_operations.extend(chunk_operations["operations"]) + except Exception as e: + logger.error(f"[Feedback Core: semantics_feedback] Operation failed: {e}") + + operations = self.standard_operations(all_operations, current_memories) + + # TODO based on the operation, change memory_item memory info ; change source info + logger.info(f"[Feedback memory operations]: {operations!s}") + + if not operations: + return {"record": {"add": [], "update": []}} + + add_results = [] + update_results = [] + id_to_item = {item.id: item for item in current_memories} + + with ContextThreadPoolExecutor(max_workers=10) as executor: + future_to_op = {} + for op in operations: + event_type = op.get("operation", "").lower() + + if event_type == "add": + future = executor.submit( + self._single_add_operation, + None, + memory_item, + user_id, + user_name, + ) + future_to_op[future] = ("add", op) + elif event_type == "update": + future = executor.submit( + self._single_update_operation, + id_to_item[op["id"]], + memory_item, + user_id, + user_name, + ) + future_to_op[future] = ("update", op) + + for future in concurrent.futures.as_completed(future_to_op): + result_type, original_op = future_to_op[future] + try: + result = future.result() + if result_type == "add" and result: + add_results.append(result) + elif result_type == "update" and result: + update_results.append(result) + except Exception as e: + logger.error( + f"[Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", + exc_info=True, + ) + if update_results: + updated_ids = [item["archived_id"] for item in update_results] + self._del_working_binding(updated_ids, user_name) + + return {"record": {"add": add_results, "update": update_results}} + def _feedback_memory( self, user_id: str, user_name: str, feedback_memories: list[TextualMemoryItem], **kwargs ) -> dict: - async_mode = kwargs.get("async_mode") retrieved_memory_ids = kwargs.get("retrieved_memory_ids") or [] chat_history = kwargs.get("chat_history", []) feedback_content = kwargs.get("feedback_content", "") @@ -259,90 +398,11 @@ def _feedback_memory( if "mode:fast" not in item["metadata"]["tags"] ] - def _add_or_update( - memory_item: TextualMemoryItem, - current_memories: list[TextualMemoryItem], - fact_history: str, - ): - if current_memories == []: - current_memories = self._retrieve( - memory_item.memory, info={"user_id": user_id}, user_name=user_name - ) - - if current_memories: - lang = detect_lang("".join(memory_item.memory)) - template = FEEDBACK_PROMPT_DICT["compare"][lang] - current_memories_str = "\n".join( - [f"{item.id}: {item.memory}" for item in current_memories] - ) - prompt = template.format( - current_memories=current_memories_str, - new_facts=memory_item.memory, - chat_history=fact_history, - ) - - operations = self._get_llm_response(prompt).get("operations", []) - operations = self._id_dehallucination(operations, current_memories) - else: - operations = [{"operation": "ADD"}] - - # TODO based on the operation, change memory_item memory info ; change source info - logger.info(f"[Feedback memory operations]: {operations!s}") - - if not operations: - return {"record": {"add": [], "update": []}} - - add_results = [] - update_results = [] - id_to_item = {item.id: item for item in current_memories} - with ContextThreadPoolExecutor(max_workers=10) as executor: - future_to_op = {} - for op in operations: - event_type = op.get("operation", "").lower() - - if event_type == "add": - future = executor.submit( - self._single_add_operation, - None, - memory_item, - user_id, - user_name, - async_mode, - ) - future_to_op[future] = ("add", op) - elif event_type == "update": - future = executor.submit( - self._single_update_operation, - id_to_item[op["id"]], - memory_item, - user_id, - user_name, - async_mode, - ) - future_to_op[future] = ("update", op) - - for future in concurrent.futures.as_completed(future_to_op): - result_type, original_op = future_to_op[future] - try: - result = future.result() - if result_type == "add" and result: - add_results.append(result) - elif result_type == "update" and result: - update_results.append(result) - except Exception as e: - logger.error( - f"[Feedback Core: _add_or_update] Operation failed for {original_op}: {e}", - exc_info=True, - ) - if update_results: - updated_ids = [item["archived_id"] for item in update_results] - self._del_working_binding(updated_ids, user_name) - - return {"record": {"add": add_results, "update": update_results}} - with ContextThreadPoolExecutor(max_workers=3) as ex: futures = { - ex.submit(_add_or_update, mem, current_memories, fact_history): i + ex.submit( + self.semantics_feedback, user_id, user_name, mem, current_memories, fact_history + ): i for i, mem in enumerate(feedback_memories) } results = [None] * len(futures) @@ -368,7 +428,10 @@ def _add_or_update( def _retrieve(self, query: str, info=None, user_name=None): """Retrieve memory items""" - retrieved_mems = self.searcher.search(query, info=info, user_name=user_name) + retrieved_mems = self.searcher.search( + query, info=info, user_name=user_name, topk=50, full_recall=True + ) + retrieved_mems = [item[0] for item in retrieved_mems] return retrieved_mems def _vec_query(self, new_memories_embedding: list[float], user_name=None): @@ -430,28 +493,51 @@ def _get_llm_response(self, prompt: str, dsl: bool = True) -> dict: response_json = None return response_json - def _id_dehallucination(self, operations, current_memories): + def standard_operations(self, operations, current_memories): right_ids = [item.id for item in current_memories] right_lower_map = {x.lower(): x for x in right_ids} def correct_item(data): - if data.get("operation", "").lower() != "update": - return data - - original_id = data["id"] - if original_id in right_ids: - return data - - lower_id = original_id.lower() - if lower_id in right_lower_map: - data["id"] = right_lower_map[lower_id] - return data - - matches = difflib.get_close_matches(original_id, right_ids, n=1, cutoff=0.8) - if matches: - data["id"] = matches[0] - return data + try: + assert "operation" in data + if data.get("operation", "").lower() == "add": + return data + + if data.get("operation", "").lower() == "none": + return None + + assert ( + "id" in data + and "text" in data + and "old_memory" in data + and data["operation"].lower() == "update" + ) + if not should_keep_update(data["text"], data["old_memory"]): + logger.warning( + f"[Feedback Core: semantics_feedback] Due to the excessive proportion of changes, skip update: {data}" + ) + return None + + # id dehallucination + original_id = data["id"] + if original_id in right_ids: + return data + + lower_id = original_id.lower() + if lower_id in right_lower_map: + data["id"] = right_lower_map[lower_id] + return data + + matches = difflib.get_close_matches(original_id, right_ids, n=1, cutoff=0.8) + if matches: + data["id"] = matches[0] + return data + except Exception: + logger.error( + f"[Feedback Core: standard_operations] Error processing operation item: {data}", + exc_info=True, + ) return None dehallu_res = [correct_item(item) for item in operations] @@ -475,6 +561,86 @@ def _generate_answer( return self._get_llm_response(prompt, dsl=False) + def process_keyword_replace(self, user_id: str, user_name: str, kwp_judge: dict | None = None): + """ + memory keyword replace process + """ + doc_scope = kwp_judge.get("doc_scope", "NONE") + original_word = kwp_judge.get("original") + target_word = kwp_judge.get("target") + + # retrieve + lang = detect_lang(original_word) + queries = self._tokenize_chinese(original_word) if lang == "zh" else original_word.split() + + must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0] + retrieved_ids = self.graph_store.seach_by_keywords([must_part], user_name=user_name) + if len(retrieved_ids) < 1: + retrieved_ids = self.graph_store.search_by_fulltext( + queries, top_k=100, user_name=user_name + ) + + # filter by doc scope + mem_data = [ + self.graph_store.get_node(item["id"], user_name=user_name) for item in retrieved_ids + ] + retrieved_memories = [TextualMemoryItem(**item) for item in mem_data] + + if doc_scope != "NONE": + retrieved_memories = [ + item + for item in retrieved_memories + if doc_scope in item.metadata.sources # TODO + ] + + if not retrieved_memories: + return {"record": {"add": [], "update": []}} + + # replace keywords + pick_index = [] + update_memories = [] + for i, old_mem in enumerate(retrieved_memories): + if original_word in old_mem.memory: + mem = old_mem.model_copy(deep=True) + mem.memory = mem.memory.replace(original_word, target_word) + if target_word not in mem.metadata.tags: + mem.metadata.tags.append(target_word) + pick_index.append(i) + update_memories.append(mem) + + update_memories_embed = self._retry_db_operation( + lambda: self._batch_embed([mem.memory for mem in update_memories]) + ) + for _i, embed in zip(range(len(update_memories)), update_memories_embed, strict=False): + update_memories[_i].metadata.embedding = embed + + update_results = [] + with ContextThreadPoolExecutor(max_workers=10) as executor: + future_to_info = {} + for new_mem, old_idx in zip(update_memories, pick_index, strict=False): + old_mem = retrieved_memories[old_idx] + + future = executor.submit( + self._single_update_operation, + old_mem, + new_mem, + user_id, + user_name, + ) + future_to_info[future] = old_mem.id + + for future in future_to_info: + try: + result = future.result() + update_results.append(result) + except Exception as e: + mem_id = future_to_info[future][0] + self.logger.error( + f"[Feedback Core DB] Exception during update operation for memory {mem_id}: {e}" + ) + + return {"record": {"add": [], "update": update_results}} + def process_feedback_core( self, user_id: str, @@ -497,19 +663,28 @@ def check_validity(item): and "tags" in item ) + if feedback_content.strip() == "": + return {"record": {"add": [], "update": []}} try: feedback_time = kwargs.get("feedback_time") or datetime.now().isoformat() session_id = kwargs.get("session_id") - if feedback_content.strip() == "": - return {"record": {"add": [], "update": []}} - info = {"user_id": user_id, "user_name": user_name, "session_id": session_id} logger.info( f"[Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" ) + # feedback keywords update + kwp_judge = self._keyword_replace_judgement(feedback_content) + if ( + kwp_judge + and kwp_judge["if_keyword_replace"].lower() == "true" + and kwp_judge.get("original", "NONE") != "NONE" + and kwp_judge.get("target", "NONE") != "NONE" + ): + return self.process_keyword_replace(user_id, user_name, kwp_judge=kwp_judge) + + # llm update memory if not chat_history: return self._pure_add(user_name, feedback_content, feedback_time, info) - else: raw_judge = self._feedback_judgement( chat_history, feedback_content, feedback_time=feedback_time @@ -533,17 +708,9 @@ def check_validity(item): feedback_memories = [] corrected_infos = [item["corrected_info"] for item in valid_feedback] - embed_bs = 5 - feedback_memories_embeddings = [] - for i in range(0, len(corrected_infos), embed_bs): - batch = corrected_infos[i : i + embed_bs] - try: - feedback_memories_embeddings.extend(self.embedder.embed(batch)) - except Exception as e: - logger.error( - f"[Feedback Core: process_feedback_core] Embedding batch failed: {e}", - exc_info=True, - ) + feedback_memories_embeddings = self._retry_db_operation( + lambda: self._batch_embed(corrected_infos) + ) for item, embedding in zip( valid_feedback, feedback_memories_embeddings, strict=False @@ -664,3 +831,16 @@ def _retry_db_operation(self, operation): f"[MemFeedback: _retry_db_operation] DB operation failed: {e}", exc_info=True ) raise + + @require_python_package( + import_name="jieba", + install_command="pip install jieba", + install_link="https://github.com/fxsjy/jieba", + ) + def _tokenize_chinese(self, text): + """split zh jieba""" + import jieba + + tokens = jieba.lcut(text) + tokens = [token.strip() for token in tokens if token.strip()] + return self.stopword_manager.filter_words(tokens) diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index 01132eb97..bb5a1c552 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -5,6 +5,7 @@ from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -27,3 +28,4 @@ def __init__( self.memory_manager = memory_manager self.mem_reader = mem_reader self.searcher = searcher + self.stopword_manager = StopwordManager diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py new file mode 100644 index 000000000..b290993cd --- /dev/null +++ b/src/memos/mem_feedback/utils.py @@ -0,0 +1,86 @@ +from memos.memories.textual.item import TextualMemoryItem + + +def estimate_tokens(text: str) -> int: + """ + Estimate the approximate number of tokens for the text + """ + if not text: + return 0 + + chinese_chars = sum(1 for char in text if "\u4e00" <= char <= "\u9fff") + + english_parts = text.split() + english_words = 0 + for part in english_parts: + has_chinese = any("\u4e00" <= char <= "\u9fff" for char in part) + if not has_chinese and any(c.isalpha() for c in part): + english_words += 1 + + other_chars = len(text) - chinese_chars + + estimated_tokens = int(chinese_chars * 1.5 + english_words * 1.33 + other_chars * 0.5) + + return max(1, estimated_tokens) + + +def should_keep_update(new_text: str, old_text: str) -> bool: + """ + Determine whether the update should be skipped + Rule: + 1. If the length of old_text is less than 50 and the modification ratio is less than 50% => returns True + 2. If the length of old_text is greater than or equal to 50 and the modification ratio is less than 15% => returns True + 3. Return False in other cases + """ + + old_len = estimate_tokens(old_text) + + def calculate_similarity(text1: str, text2: str) -> float: + set1 = set(text1) + set2 = set(text2) + if not set1 and not set2: + return 1.0 + + intersection = len(set1.intersection(set2)) + union = len(set1.union(set2)) + return intersection / union if union > 0 else 0.0 + + similarity = calculate_similarity(old_text, new_text) + change_ratio = 1 - similarity + + if old_len < 50: + return change_ratio < 0.5 + else: + return change_ratio < 0.15 + + +def split_into_chunks(memories: list[TextualMemoryItem], max_tokens_per_chunk=500): + chunks = [] + current_chunk = [] + current_tokens = 0 + + for item in memories: + item_text = f"{item.id}: {item.memory}" + item_tokens = estimate_tokens(item_text) + + if item_tokens > max_tokens_per_chunk: + if current_chunk: + chunks.append(current_chunk) + current_chunk = [] + + chunks.append([item]) + current_tokens = 0 + + elif current_tokens + item_tokens <= max_tokens_per_chunk: + current_chunk.append(item) + current_tokens += item_tokens + else: + if current_chunk: + chunks.append(current_chunk) + current_chunk = [item] + current_tokens = item_tokens + + if current_chunk: + chunks.append(current_chunk) + + return chunks diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index b7956bfec..8067c7f72 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -90,6 +90,11 @@ class TextualMemoryMetadata(BaseModel): model_config = ConfigDict(extra="allow") + covered_history: Any | None = Field( + default=None, + description="Record the memory id covered by the update", + ) + def __str__(self) -> str: """Pretty string representation of the metadata.""" meta = self.model_dump(exclude_none=True) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 9e1e6c240..5a82883c8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -93,12 +93,6 @@ def find_project_root(marker=".git"): return Path(".") -PROJECT_ROOT = find_project_root() -DEFAULT_STOPWORD_FILE = ( - PROJECT_ROOT / "examples" / "data" / "config" / "stopwords.txt" -) # cause time delay - - class StopwordManager: _stopwords = None @@ -109,13 +103,7 @@ def _load_stopwords(cls): return cls._stopwords stopwords = set() - try: - with open(DEFAULT_STOPWORD_FILE, encoding="utf-8") as f: - stopwords = {line.strip() for line in f if line.strip()} - logger.info("Stopwords loaded successfully.") - except Exception as e: - logger.warning(f"Error loading stopwords: {e}, using default stopwords.") - stopwords = cls._load_default_stopwords() + stopwords = cls._load_default_stopwords() cls._stopwords = stopwords return stopwords @@ -370,14 +358,6 @@ def is_stopword(cls, word): cls._load_stopwords() return word in cls._stopwords - @classmethod - def reload_stopwords(cls, file_path=None): - cls._stopwords = None - if file_path: - global DEFAULT_STOPWORD_FILE - DEFAULT_STOPWORD_FILE = file_path - cls._load_stopwords() - class FastTokenizer: def __init__(self, use_jieba=True, use_stopwords=True): diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 830b915c1..035aa3b96 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -173,6 +173,10 @@ def search( user_name=user_name, ) + full_recall = kwargs.get("full_recall", False) + if full_recall: + return retrieved_results + final_results = self.post_retrieve( retrieved_results=retrieved_results, top_k=top_k, diff --git a/src/memos/templates/mem_feedback_prompts.py b/src/memos/templates/mem_feedback_prompts.py index f7f2e8cb4..cd0c46a61 100644 --- a/src/memos/templates/mem_feedback_prompts.py +++ b/src/memos/templates/mem_feedback_prompts.py @@ -1,3 +1,119 @@ +KEYWORDS_REPLACE = """ +**Instruction:** +Please analyze the user's input text to determine if it is a "keyword replacement" request. If yes, follow these steps: + +1. **Identify the request type**: Confirm whether the user is asking to replace a specific word or phrase with another **within a specified scope**. +2. **Extract the modification scope**: Determine the scope where the modification should apply. + - If the user mentions a specific **document, file, or material identifier** (e.g., "in the Q1 operations plan", "in the prospectus numbered BT7868"), extract this description as the document scope. + - **If the user does not explicitly specify any scope, mark the scope as "NONE"**. +3. **Extract the original term (A)**: Identify the original word or phrase the user wants to be replaced. +4. **Extract the target term (B)**: Identify the target word or phrase the user wants to replace it with. + +**Output JSON Format**: +{{ + "if_keyword_replace": "true" | "false", + "doc_scope": "[Extracted specific file or document description]" | "NONE" | null, + "original": "[Extracted original word or phrase A]" | null, + "target": "[Extracted target word or phrase B]" | null +}} +- **If it is NOT a replacement request**, set `if_keyword_replace` to `"false"`, and set the values for `doc_scope`, `original`, and `target` to `null`. +- **If it IS a replacement request**, set `if_keyword_replace` to `"true"` and fill in the remaining fields. If the user did not specify a scope, set `doc_scope` to `"NONE"`. + +**Examples**: + +1. **User Input**: "In the file `User_Agreement.docx`, replace 'Party B' with 'User'." + **Output**: + {{ + "if_keyword_replace": "true", + "doc_scope": "User_Agreement.docx", + "original": "Party B", + "target": "User" + }} + +2. **User Input**: "Change 'Homepage' to 'Front Page'." + **Output**: + {{ + "if_keyword_replace": "true", + "doc_scope": "NONE", + "original": "Homepage", + "target": "Front Page" + }} + +3. **User Input**: "Does this sentence need modification?" + **Output**: + {{ + "if_keyword_replace": "false", + "doc_scope": null, + "original": null, + "target": null + }} + +**User Input** +{user_feedback} + +**Output**: +""" + + +KEYWORDS_REPLACE_ZH = """ +**指令:** +请分析用户输入的文本,判断是否为“关键词替换”需求。 如果是,请按以下步骤处理: + +1. **识别需求类型**:确认用户是否要求将**特定范围**内的某个词或短语替换为另一个词或短语。 +2. **提取修改范围**:确定用户指定的修改生效范围。 + - 如果用户提及了具体的**文档、文件或资料标识**(如“在第一季运营方案”、“编号为BT7868的招股书”),则提取此描述作为文件范围。 + - **如果用户未明确指定任何范围,则范围标记为 "NONE"**。 +3. **提取原始词汇(A)**:找出用户希望被替换的原始词或短语。 +4. **提取目标词汇(B)**:找出用户希望替换成的目标词或短语。 + +**输出JSON格式**: +{{ + "if_keyword_replace": "true" | "false", + "doc_scope": "[提取的具体文件或文档描述]" | "NONE" | null, + "original": "[提取的原始词或短语A]" | null, + "target": "[提取的目标词或短语B]" | null +}} +- **如果不是替换需求**,将 `if_keyword_replace` 设为 `"false"`,并将 `doc_scope`、`original`、`target` 三个键的值都设为 `null`。 +- **如果是替换需求**,将 `if_keyword_replace` 设为 `"true"`,并填充其余字段。如果用户未指定范围,`doc_scope` 设为 `"NONE"`。 + + +**示例**: + +1. **用户输入**:“在`用户协议.docx`这个文件中,把‘乙方’替换为‘用户’。” + **输出**: + {{ + "if_keyword_replace": "true", + "doc_scope": "用户协议.docx", + "original": "乙方", + "target": "用户" + }} + +2. **用户输入**:“把‘主页’改成‘首页’。” + **输出**: + {{ + "if_keyword_replace": "true", + "doc_scope": "NONE", + "original": "主页", + "target": "首页" + }} + +3. **用户输入**:“这个句子需要修改吗?” + **输出**: + {{ + "if_keyword_replace": "false", + "doc_scope": null, + "original": null, + "target": null + }} + + +**用户输入** +{user_feedback} + +**输出**: +""" + + FEEDBACK_JUDGEMENT_PROMPT = """You are a answer quality analysis expert. Please strictly follow the steps and criteria below to analyze the provided "User and Assistant Chat History" and "User Feedback," and fill the final evaluation results into the specified JSON format. Analysis Steps and Criteria: