diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index a8a6f8b7b..1bd83eae7 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -7,6 +7,9 @@ from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APIADDRequest, MemoryResponse +from memos.memories.textual.item import ( + list_all_fields, +) from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView @@ -44,6 +47,13 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ self.logger.info(f"[AddHandler] Add Req is: {add_req}") + if add_req.info: + exclude_fields = list_all_fields() + info_len = len(add_req.info) + add_req.info = {k: v for k, v in add_req.info.items() if k not in exclude_fields} + if len(add_req.info) < info_len: + self.logger.warning(f"[AddHandler] info fields can not contain {exclude_fields}.") + cube_view = self._build_cube_view(add_req) results = cube_view.add_memories(add_req) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index c47a3cf83..689e2b16b 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -4,7 +4,7 @@ This module handles retrieving all memories or specific subgraphs based on queries. """ -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal from memos.api.handlers.formatters_handler import format_memory_item from memos.api.product_models import ( @@ -24,6 +24,10 @@ ) +if TYPE_CHECKING: + from memos.memories.textual.preference import TextualMemoryItem + + logger = get_logger(__name__) @@ -161,17 +165,20 @@ def handle_get_subgraph( def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse: # TODO: Implement get memory with filter memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"] - filter_params: dict[str, Any] = {} - if get_mem_req.user_id is not None: - filter_params["user_id"] = get_mem_req.user_id - if get_mem_req.mem_cube_id is not None: - filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) + preferences: list[TextualMemoryItem] = [] + if get_mem_req.include_preference: + filter_params: dict[str, Any] = {} + if get_mem_req.user_id is not None: + filter_params["user_id"] = get_mem_req.user_id + if get_mem_req.mem_cube_id is not None: + filter_params["mem_cube_id"] = get_mem_req.mem_cube_id + preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) + preferences = [format_memory_item(mem) for mem in preferences] return GetMemoryResponse( message="Memories retrieved successfully", data={ "text_mem": memories, - "pref_mem": [format_memory_item(mem) for mem in preferences], + "pref_mem": preferences, }, ) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index ea5f8d136..2f2e9ea54 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -6,7 +6,7 @@ # Import message types from core types module from memos.log import get_logger -from memos.types import MessageDict, MessagesType, PermissionDict, SearchMode +from memos.types import MessageList, MessagesType, PermissionDict, SearchMode logger = get_logger(__name__) @@ -72,40 +72,57 @@ class ChatRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") - mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") readable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can read for multi-cube chat" ) writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube chat" ) - history: list[MessageDict] | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") - internet_search: bool = Field(True, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") - threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") - filter: dict[str, Any] | None = Field(None, description="Filter for the memory") model_name_or_path: str | None = Field(None, description="Model name to use for chat") max_tokens: int | None = Field(None, description="Max tokens to generate") temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + + # ==== Filter conditions ==== + filter: dict[str, Any] | None = Field( + None, + description=""" + Filter for the memory, example: + { + "`and` or `or`": [ + {"id": "uuid-xxx"}, + {"created_at": {"gt": "2024-01-01"}}, + ] + } + """, + ) + + # ==== Extended capabilities ==== + internet_search: bool = Field(True, description="Whether to use internet search") + threshold: float = Field(0.5, description="Threshold for filtering references") + + # ==== Backward compatibility ==== + mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") moscube: bool = Field( False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" ) class ChatCompleteRequest(BaseRequest): - """Request model for chat operations.""" + """Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest.""" user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") - history: list[MessageDict] | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base prompt to use for chat") top_k: int = Field(10, description="Number of results to return") @@ -191,7 +208,7 @@ class MemoryCreateRequest(BaseRequest): """Request model for creating memories.""" user_id: str = Field(..., description="User ID") - messages: list[MessageDict] | None = Field(None, description="List of messages to store.") + messages: MessagesType | None = Field(None, description="List of messages to store.") memory_content: str | None = Field(None, description="Memory content to store") doc_path: str | None = Field(None, description="Path to document to store") mem_cube_id: str | None = Field(None, description="Cube ID") @@ -269,7 +286,15 @@ class APISearchRequest(BaseRequest): # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( None, - description=("Filter for the memory"), + description=""" + Filter for the memory, example: + { + "`and` or `or`": [ + {"id": "uuid-xxx"}, + {"created_at": {"gt": "2024-01-01"}}, + ] + } + """, ) # ==== Extended capabilities ==== @@ -291,7 +316,7 @@ class APISearchRequest(BaseRequest): ) # ==== Context ==== - chat_history: MessagesType | None = Field( + chat_history: MessageList | None = Field( None, description=( "Historical chat messages used internally by algorithms. " @@ -421,7 +446,7 @@ class APIADDRequest(BaseRequest): ) # ==== Chat history ==== - chat_history: MessagesType | None = Field( + chat_history: MessageList | None = Field( None, description=( "Historical chat messages used internally by algorithms. " @@ -540,31 +565,49 @@ class APIChatCompleteRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") - mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") readable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can read for multi-cube chat" ) writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube chat" ) - history: list[MessageDict] | None = Field(None, description="Chat history") - internet_search: bool = Field(False, description="Whether to use internet search") - system_prompt: str | None = Field(None, description="Base system prompt to use for chat") + history: MessageList | None = Field(None, description="Chat history") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") + system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") - threshold: float = Field(0.5, description="Threshold for filtering references") - session_id: str | None = Field( - "default_session", description="Session ID for soft-filtering memories" - ) + session_id: str | None = Field(None, description="Session ID for soft-filtering memories") include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") - filter: dict[str, Any] | None = Field(None, description="Filter for the memory") model_name_or_path: str | None = Field(None, description="Model name to use for chat") max_tokens: int | None = Field(None, description="Max tokens to generate") temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + # ==== Filter conditions ==== + filter: dict[str, Any] | None = Field( + None, + description=""" + Filter for the memory, example: + { + "`and` or `or`": [ + {"id": "uuid-xxx"}, + {"created_at": {"gt": "2024-01-01"}}, + ] + } + """, + ) + + # ==== Extended capabilities ==== + internet_search: bool = Field(True, description="Whether to use internet search") + threshold: float = Field(0.5, description="Threshold for filtering references") + + # ==== Backward compatibility ==== + mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") + moscube: bool = Field( + False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" + ) + class AddStatusRequest(BaseRequest): """Request model for checking add status.""" @@ -594,7 +637,7 @@ class SuggestionRequest(BaseRequest): user_id: str = Field(..., description="User ID") mem_cube_id: str = Field(..., description="Cube ID") language: Literal["zh", "en"] = Field("zh", description="Language for suggestions") - message: list[MessageDict] | None = Field(None, description="List of messages to store.") + message: MessagesType | None = Field(None, description="List of messages to store.") # ─── MemOS Client Response Models ────────────────────────────────────────────── diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 3845f37d0..29ce49d90 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -22,6 +22,8 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.parsers.factory import ParserFactory from memos.templates.mem_reader_prompts import ( + CUSTOM_TAGS_INSTRUCTION, + CUSTOM_TAGS_INSTRUCTION_ZH, SIMPLE_STRUCT_DOC_READER_PROMPT, SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, SIMPLE_STRUCT_MEM_READER_EXAMPLE, @@ -41,6 +43,7 @@ "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, }, "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, + "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, } try: @@ -121,11 +124,15 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder embedding = embedder.embed([value])[0] + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id", ""), - session_id=info.get("session_id", ""), + user_id=user_id, + session_id=session_id, memory_type="LongTermMemory", status="activated", tags=tags, @@ -136,6 +143,7 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder background="", confidence=0.99, type="fact", + info=info_, ), ) except Exception as e: @@ -183,11 +191,15 @@ def _make_memory_item( confidence: float = 0.99, ) -> TextualMemoryItem: """construct memory item""" + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id", ""), - session_id=info.get("session_id", ""), + user_id=user_id, + session_id=session_id, memory_type=memory_type, status="activated", tags=tags or [], @@ -198,14 +210,23 @@ def _make_memory_item( background=background, confidence=confidence, type=type_, + info=info_, ), ) - def _get_llm_response(self, mem_str: str) -> dict: + def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict: lang = detect_lang(mem_str) template = PROMPT_DICT["chat"][lang] examples = PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) + + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt) + if self.config.remove_prompt_example: prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] @@ -274,6 +295,9 @@ def _iter_chat_windows(self, scene_data_info, max_tokens=None, overlap=200): def _process_chat_data(self, scene_data_info, info, **kwargs): mode = kwargs.get("mode", "fine") windows = list(self._iter_chat_windows(scene_data_info)) + custom_tags = info.pop( + "custom_tags", None + ) # msut pop here, avoid add to info, only used in sync fine mode if mode == "fast": logger.debug("Using unified Fast Mode") @@ -304,7 +328,7 @@ def _build_fast_node(w): logger.debug("Using unified Fine Mode") chat_read_nodes = [] for w in windows: - resp = self._get_llm_response(w["text"]) + resp = self._get_llm_response(w["text"], custom_tags) for m in resp.get("memory list", []): try: memory_type = ( @@ -326,9 +350,12 @@ def _build_fast_node(w): logger.error(f"[ChatFine] parse error: {e}") return chat_read_nodes - def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): + def _process_transfer_chat_data( + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None + ): raw_memory = raw_node.memory - response_json = self._get_llm_response(raw_memory) + response_json = self._get_llm_response(raw_memory, custom_tags) + chat_read_nodes = [] for memory_i_raw in response_json.get("memory list", []): try: @@ -342,6 +369,7 @@ def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): node_i = self._make_memory_item( value=memory_i_raw.get("value", ""), info={ + **(raw_node.metadata.info or {}), "user_id": raw_node.metadata.user_id, "session_id": raw_node.metadata.session_id, }, @@ -429,7 +457,10 @@ def get_memory( return memory_list def fine_transfer_simple_mem( - self, input_memories: list[TextualMemoryItem], type: str + self, + input_memories: list[TextualMemoryItem], + type: str, + custom_tags: list[str] | None = None, ) -> list[list[TextualMemoryItem]]: if not input_memories: return [] @@ -446,7 +477,7 @@ def fine_transfer_simple_mem( # Process Q&A pairs concurrently with context propagation with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(processing_func, scene_data_info) + executor.submit(processing_func, scene_data_info, custom_tags) for scene_data_info in input_memories ] for future in concurrent.futures.as_completed(futures): @@ -539,11 +570,18 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): if mode == "fast": raise NotImplementedError chunks = self.chunker.chunk(scene_data_info["text"]) + custom_tags = info.pop("custom_tags", None) messages = [] for chunk in chunks: lang = detect_lang(chunk.text) template = PROMPT_DICT["doc"][lang] prompt = template.replace("{chunk_text}", chunk.text) + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) message = [{"role": "user", "content": prompt}] messages.append(message) @@ -578,7 +616,9 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): logger.error(f"[DocReader] Future task failed: {e}") return doc_nodes - def _process_transfer_doc_data(self, raw_node: TextualMemoryItem): + def _process_transfer_doc_data( + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None + ): raise NotImplementedError def parse_json_result(self, response_text: str) -> dict: diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py index 1fc21461e..21be8bc39 100644 --- a/src/memos/mem_reader/strategy_struct.py +++ b/src/memos/mem_reader/strategy_struct.py @@ -8,6 +8,8 @@ from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang from memos.parsers.factory import ParserFactory from memos.templates.mem_reader_prompts import ( + CUSTOM_TAGS_INSTRUCTION, + CUSTOM_TAGS_INSTRUCTION_ZH, SIMPLE_STRUCT_DOC_READER_PROMPT, SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, SIMPLE_STRUCT_MEM_READER_EXAMPLE, @@ -28,6 +30,7 @@ "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, }, "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, + "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, } @@ -38,11 +41,19 @@ def __init__(self, config: StrategyStructMemReaderConfig): super().__init__(config) self.chat_chunker = config.chat_chunker["config"] - def _get_llm_response(self, mem_str: str) -> dict: + def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict: lang = detect_lang(mem_str) template = STRATEGY_PROMPT_DICT["chat"][lang] examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) + + custom_tags_prompt = ( + STRATEGY_PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt) + if self.config.remove_prompt_example: # TODO unused prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index f18bfd715..d7c3e65f1 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -367,6 +367,7 @@ def process_message(message: ScheduleMessageItem): mem_cube = self.current_mem_cube content = message.content user_name = message.user_name + info = message.info or {} # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -390,6 +391,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, text_mem=text_mem, user_name=user_name, + custom_tags=info.get("custom_tags", None), ) logger.info( @@ -414,6 +416,7 @@ def _process_memories_with_reader( mem_cube_id: str, text_mem: TreeTextMemory, user_name: str, + custom_tags: list[str] | None = None, ) -> None: """ Process memories using mem_reader for enhanced memory processing. @@ -423,6 +426,7 @@ def _process_memories_with_reader( user_id: User ID mem_cube_id: Memory cube ID text_mem: Text memory instance + custom_tags: Optional list of custom tags for memory processing """ try: # Get the mem_reader from the parent MOSCore @@ -466,6 +470,7 @@ def _process_memories_with_reader( processed_memories = self.mem_reader.fine_transfer_simple_mem( memory_items, type="chat", + custom_tags=custom_tags, ) except Exception as e: logger.warning(f"{e}: Fail to transfer mem: {memory_items}") @@ -756,6 +761,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id content = message.content messages_list = json.loads(content) + info = message.info or {} logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") @@ -778,7 +784,12 @@ def process_message(message: ScheduleMessageItem): pref_memories = pref_mem.get_memory( messages_list, type="chat", - info={"user_id": user_id, "session_id": session_id, "mem_cube_id": mem_cube_id}, + info={ + **info, + "user_id": user_id, + "session_id": session_id, + "mem_cube_id": mem_cube_id, + }, ) # Add pref_mem to vector db pref_ids = pref_mem.add(pref_memories) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 9c79fc42a..2bd6ef1ef 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -46,6 +46,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): default="", description="user name / display name (optional)", ) + info: dict | None = Field(default=None, description="user custom info") # Pydantic V2 model configuration model_config = ConfigDict( diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index d71a86d2e..f56b2028d 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -56,7 +56,9 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: [message["role"] + ":" + message["content"] for message in messages] ) - prompt = SIMPLE_STRUCT_MEM_READER_PROMPT.replace("${conversation}", str_messages) + prompt = SIMPLE_STRUCT_MEM_READER_PROMPT.replace("${conversation}", str_messages).replace( + "${custom_tags_prompt}", "" + ) messages = [{"role": "user", "content": prompt}] response_text = self.extractor_llm.generate(messages) response_json = self.parse_json_result(response_text) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index e7595443d..fccd75bfd 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -83,6 +83,10 @@ class TextualMemoryMetadata(BaseModel): default_factory=lambda: datetime.now().isoformat(), description="The timestamp of the last modification to the memory. Useful for tracking memory freshness or change history. Format: ISO 8601.", ) + info: dict | None = Field( + default=None, + description="Arbitrary key-value pairs for additional metadata.", + ) model_config = ConfigDict(extra="allow") @@ -267,3 +271,17 @@ def _coerce_metadata(cls, v: Any): def __str__(self) -> str: """Pretty string representation of the memory item.""" return f"" + + +def list_all_fields() -> list[str]: + """List all possible fields of the TextualMemoryItem model.""" + top = list(TextualMemoryItem.model_fields.keys()) + meta_models = [ + TextualMemoryMetadata, + TreeNodeTextualMemoryMetadata, + SearchedTreeNodeTextualMemoryMetadata, + PreferenceTextualMemoryMetadata, + ] + meta_all = sorted(set().union(*[set(m.model_fields.keys()) for m in meta_models])) + + return top + meta_all diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 1924880ad..4501dfee3 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -432,6 +432,7 @@ def _schedule_memory_tasks( content=json.dumps(mem_ids), timestamp=datetime.utcnow(), user_name=self.cube_id, + info=add_req.info, ) self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read]) self.logger.info( @@ -504,6 +505,7 @@ def _process_pref_mem( [add_req.messages], type="chat", info={ + **(add_req.info or {}), "user_id": add_req.user_id, "session_id": target_session_id, "mem_cube_id": self.cube_id, @@ -555,6 +557,8 @@ def _process_text_mem( [add_req.messages], type="chat", info={ + **(add_req.info or {}), + "custom_tags": add_req.custom_tags, "user_id": add_req.user_id, "session_id": target_session_id, }, diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index ec6812743..3223e4694 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -39,6 +39,8 @@ - The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input conversation. **如果输入是中文,请输出中文** - Keep `memory_type` in English. +${custom_tags_prompt} + Example: Conversation: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. @@ -132,6 +134,8 @@ - `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** - `memory_type` 保持英文。 +${custom_tags_prompt} + 示例: 对话: user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 @@ -212,6 +216,8 @@ - The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input document summaries. **如果输入是中文,请输出中文** - Keep `memory_type` in English. +{custom_tags_prompt} + Document chunk: {chunk_text} @@ -250,6 +256,8 @@ - `key`、`value`、`tags` 字段必须与输入文档摘要的主要语言一致。**如果输入是中文,请输出中文** - `memory_type` 保持英文。 +{custom_tags_prompt} + 文档片段: {chunk_text} @@ -341,3 +349,13 @@ } """ + + +CUSTOM_TAGS_INSTRUCTION = """Output tags can refer to the following tags: +{custom_tags} +You can choose tags from the above list that are relevant to the memory. Additionally, you can freely add tags based on the content of the memory.""" + + +CUSTOM_TAGS_INSTRUCTION_ZH = """输出tags可以参考下列标签: +{custom_tags} +你可以选择与memory相关的在上述列表中可以加入tags,同时你可以根据memory的内容自由添加tags。""" diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py index ba4a00d0a..21421e30b 100644 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -61,6 +61,7 @@ Language rules: - The `key`, `value`, `tags`, `summary` and `memory_type` fields must be in English. +${custom_tags_prompt} Example: Conversations: @@ -157,6 +158,7 @@ 语言规则: - `key`、`value`、`tags`、`summary` 、`memory_type` 字段必须输出中文 +${custom_tags_prompt} 示例1: 对话: diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index 9babdc096..2b7206c74 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -36,6 +36,7 @@ "MessagesType", "Permission", "PermissionDict", + "RawMessageList", "SearchMode", "UserContext", "UserID", @@ -49,7 +50,7 @@ # Message structure class MessageDict(TypedDict, total=False): - """Typed dictionary for chat message dictionaries.""" + """Typed dictionary for chat message dictionaries, will (Deprecate), use ChatCompletionMessageParam instead.""" role: MessageRole content: str