diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 3d650c17b..fe46fbe62 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -5,15 +5,15 @@ from datetime import datetime from typing import TYPE_CHECKING, Any -from tenacity import retry, stop_after_attempt, wait_exponential +from tenacity import retry, stop_after_attempt, wait_random_exponential -from memos import log from memos.configs.memory import MemFeedbackConfig from memos.context.context import ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.factory import EmbedderFactory, OllamaEmbedder from memos.graph_dbs.factory import GraphStoreFactory, PolarDBGraphDB from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM +from memos.log import get_logger from memos.mem_feedback.base import BaseMemFeedback from memos.mem_feedback.utils import make_mem_item, should_keep_update, split_into_chunks from memos.mem_reader.factory import MemReaderFactory @@ -48,7 +48,7 @@ "generation": {"en": FEEDBACK_ANSWER_PROMPT, "zh": FEEDBACK_ANSWER_PROMPT_ZH}, } -logger = log.get_logger(__name__) +logger = get_logger(__name__) class MemFeedback(BaseMemFeedback): @@ -83,19 +83,47 @@ def __init__(self, config: MemFeedbackConfig): self.reranker = None self.DB_IDX_READY = False + @require_python_package( + import_name="jieba", + install_command="pip install jieba", + install_link="https://github.com/fxsjy/jieba", + ) + def _tokenize_chinese(self, text): + """split zh jieba""" + import jieba + + tokens = jieba.lcut(text) + tokens = [token.strip() for token in tokens if token.strip()] + return self.stopword_manager.filter_words(tokens) + + @retry(stop=stop_after_attempt(4), wait=wait_random_exponential(multiplier=1, max=10)) + def _embed_once(self, texts): + return self.embedder.embed(texts) + + @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, min=4, max=10)) + def _retry_db_operation(self, operation): + try: + return operation() + except Exception as e: + logger.error( + f"[Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True + ) + raise + def _batch_embed(self, texts: list[str], embed_bs: int = 5): - embed_bs = 5 - texts_embeddings = [] + results = [] + dim = self.embedder.config.embedding_dims + for i in range(0, len(texts), embed_bs): batch = texts[i : i + embed_bs] try: - texts_embeddings.extend(self.embedder.embed(batch)) + results.extend(self._embed_once(batch)) except Exception as e: logger.error( - f"[Feedback Core: process_feedback_core] Embedding batch failed: {e}", - exc_info=True, + f"[Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" ) - return texts_embeddings + results.extend([[0.0] * dim for _ in range(len(batch))]) + return results def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, info: dict): """ @@ -108,7 +136,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i lambda: self.memory_manager.add(to_add_memories, user_name=user_name) ) logger.info( - f"[Feedback Core: _pure_add] Added {len(added_ids)} memories for user {user_name}." + f"[Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." ) return { "record": { @@ -199,7 +227,7 @@ def _single_add_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, mode=async_mode) ) - logger.info(f"[Memory Feedback ADD] {added_ids[0]}") + logger.info(f"[Memory Feedback ADD] memory id: {added_ids[0]}") return {"id": added_ids[0], "text": to_add_memory.memory} def _single_update_operation( @@ -305,10 +333,14 @@ def semantics_feedback( if not current_memories: operations = [{"operation": "ADD"}] + logger.warning( + "[Feedback Core]: There was no recall of the relevant memory, so it was added directly." + ) else: memory_chunks = split_into_chunks(current_memories, max_tokens_per_chunk=500) all_operations = [] + now_time = datetime.now().isoformat() with ContextThreadPoolExecutor(max_workers=10) as executor: future_to_chunk_idx = {} for chunk in memory_chunks: @@ -316,6 +348,7 @@ def semantics_feedback( [f"{item.id}: {item.memory}" for item in chunk] ) prompt = template.format( + now_time=now_time, current_memories=current_memories_str, new_facts=memory_item.memory, chat_history=history_str, @@ -337,7 +370,7 @@ def semantics_feedback( operations = self.standard_operations(all_operations, current_memories) - logger.info(f"[Feedback memory operations]: {operations!s}") + logger.info(f"[Feedback Core Operations]: {operations!s}") if not operations: return {"record": {"add": [], "update": []}} @@ -453,6 +486,7 @@ def _feedback_memory( } def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: list) -> bool: + """Filter the relevant memory items based on info""" if not _info and not memory.metadata.info: return True @@ -463,10 +497,10 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: record.append(info_v == mem_v) return all(record) - def _retrieve(self, query: str, info=None, user_name=None): + def _retrieve(self, query: str, info=None, top_k=100, user_name=None): """Retrieve memory items""" retrieved_mems = self.searcher.search( - query, info=info, user_name=user_name, topk=50, full_recall=True + query, info=info, user_name=user_name, top_k=top_k, full_recall=True ) retrieved_mems = [item[0] for item in retrieved_mems] return retrieved_mems @@ -524,11 +558,19 @@ def _get_llm_response(self, prompt: str, dsl: bool = True) -> dict: else: return response_text except Exception as e: - logger.error(f"[Feedback Core LLM] Exception during chat generation: {e}") + logger.error( + f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}" + ) response_json = None return response_json def standard_operations(self, operations, current_memories): + """ + Regularize the operation design + 1. Map the id to the correct original memory id + 2. If there is an update, skip the memory object of add + 3. If the modified text is too long, skip the update + """ right_ids = [item.id for item in current_memories] right_lower_map = {x.lower(): x for x in right_ids} @@ -582,9 +624,16 @@ def correct_item(data): has_update = any(item.get("operation").lower() == "update" for item in llm_operations) if has_update: filtered_items = [ + item for item in llm_operations if item.get("operation").lower() == "add" + ] + update_items = [ item for item in llm_operations if item.get("operation").lower() != "add" ] - return filtered_items + if filtered_items: + logger.info( + f"[Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" + ) + return update_items else: return llm_operations @@ -683,6 +732,10 @@ def process_keyword_replace( if doc_scope != "NONE": retrieved_memories = self._doc_filter(doc_scope, retrieved_memories) + logger.info( + f"[Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." + ) + if not retrieved_memories: return {"record": {"add": [], "update": []}} @@ -693,14 +746,14 @@ def process_keyword_replace( if original_word in old_mem.memory: mem = old_mem.model_copy(deep=True) mem.memory = mem.memory.replace(original_word, target_word) + if original_word in mem.metadata.tags: + mem.metadata.tags.remove(original_word) if target_word not in mem.metadata.tags: mem.metadata.tags.append(target_word) pick_index.append(i) update_memories.append(mem) + update_memories_embed = self._batch_embed([mem.memory for mem in update_memories]) - update_memories_embed = self._retry_db_operation( - lambda: self._batch_embed([mem.memory for mem in update_memories]) - ) for _i, embed in zip(range(len(update_memories)), update_memories_embed, strict=False): update_memories[_i].metadata.embedding = embed @@ -805,9 +858,7 @@ def check_validity(item): feedback_memories = [] corrected_infos = [item["corrected_info"] for item in valid_feedback] - feedback_memories_embeddings = self._retry_db_operation( - lambda: self._batch_embed(corrected_infos) - ) + feedback_memories_embeddings = self._batch_embed(corrected_infos) for item, embedding in zip( valid_feedback, feedback_memories_embeddings, strict=False @@ -845,8 +896,10 @@ def check_validity(item): info=info, **kwargs, ) + add_memories = mem_record["record"]["add"] + update_memories = mem_record["record"]["update"] logger.info( - f"[Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback memories for user {user_name}." + f"[Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." ) return mem_record @@ -902,42 +955,19 @@ def process_feedback( task_id = kwargs.get("task_id", "default") logger.info( - f"[MemFeedback process] Feedback Completed : user {user_name} | task_id {task_id} | record {record}." + f"[Feedback Core MemFeedback process] Feedback Completed : user {user_name} | task_id {task_id} | record {record}." ) return {"answer": answer, "record": record["record"]} except concurrent.futures.TimeoutError: logger.error( - f"[MemFeedback process] Timeout in sync mode for {user_name}", exc_info=True + f"[Feedback Core MemFeedback process] Timeout in sync mode for {user_name}", + exc_info=True, ) return {"answer": "", "record": {"add": [], "update": []}} except Exception as e: logger.error( - f"[MemFeedback process] Error in concurrent tasks for {user_name}: {e}", + f"[Feedback Core MemFeedback process] Error in concurrent tasks for {user_name}: {e}", exc_info=True, ) return {"answer": "", "record": {"add": [], "update": []}} - - # Helper for DB operations with retry - @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) - def _retry_db_operation(self, operation): - try: - return operation() - except Exception as e: - logger.error( - f"[MemFeedback: _retry_db_operation] DB operation failed: {e}", exc_info=True - ) - raise - - @require_python_package( - import_name="jieba", - install_command="pip install jieba", - install_link="https://github.com/fxsjy/jieba", - ) - def _tokenize_chinese(self, text): - """split zh jieba""" - import jieba - - tokens = jieba.lcut(text) - tokens = [token.strip() for token in tokens if token.strip()] - return self.stopword_manager.filter_words(tokens) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index bc50faab0..a36f4ff3a 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -185,7 +185,7 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]: task_id=feedback_req.task_id, info=feedback_req.info, ) - self.logger.info(f"Feedback memories result: {feedback_result}") + self.logger.info(f"[Feedback memories result:] {feedback_result}") return feedback_result def _get_search_mode(self, mode: str) -> str: diff --git a/src/memos/templates/mem_feedback_prompts.py b/src/memos/templates/mem_feedback_prompts.py index cd0c46a61..bbdb187e2 100644 --- a/src/memos/templates/mem_feedback_prompts.py +++ b/src/memos/templates/mem_feedback_prompts.py @@ -441,6 +441,8 @@ ] }} +**Current time** +{now_time} **Current Memories** {current_memories} @@ -581,6 +583,9 @@ ] }} +**当前时间:** +{now_time} + **当前记忆:** {current_memories}