Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/memos/api/handlers/add_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.product_models import APIADDRequest, MemoryResponse
from memos.memories.textual.item import (
list_all_fields,
)
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
Expand Down Expand Up @@ -44,6 +47,13 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
"""
self.logger.info(f"[AddHandler] Add Req is: {add_req}")

if add_req.info:
exclude_fields = list_all_fields()
info_len = len(add_req.info)
add_req.info = {k: v for k, v in add_req.info.items() if k not in exclude_fields}
if len(add_req.info) < info_len:
self.logger.warning(f"[AddHandler] info fields can not contain {exclude_fields}.")

cube_view = self._build_cube_view(add_req)

results = cube_view.add_memories(add_req)
Expand Down
23 changes: 15 additions & 8 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
This module handles retrieving all memories or specific subgraphs based on queries.
"""

from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

from memos.api.handlers.formatters_handler import format_memory_item
from memos.api.product_models import (
Expand All @@ -24,6 +24,10 @@
)


if TYPE_CHECKING:
from memos.memories.textual.preference import TextualMemoryItem


logger = get_logger(__name__)


Expand Down Expand Up @@ -161,17 +165,20 @@ def handle_get_subgraph(
def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse:
# TODO: Implement get memory with filter
memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"]
filter_params: dict[str, Any] = {}
if get_mem_req.user_id is not None:
filter_params["user_id"] = get_mem_req.user_id
if get_mem_req.mem_cube_id is not None:
filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params)
preferences: list[TextualMemoryItem] = []
if get_mem_req.include_preference:
filter_params: dict[str, Any] = {}
if get_mem_req.user_id is not None:
filter_params["user_id"] = get_mem_req.user_id
if get_mem_req.mem_cube_id is not None:
filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params)
preferences = [format_memory_item(mem) for mem in preferences]
return GetMemoryResponse(
message="Memories retrieved successfully",
data={
"text_mem": memories,
"pref_mem": [format_memory_item(mem) for mem in preferences],
"pref_mem": preferences,
},
)

Expand Down
87 changes: 65 additions & 22 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Import message types from core types module
from memos.log import get_logger
from memos.types import MessageDict, MessagesType, PermissionDict, SearchMode
from memos.types import MessageList, MessagesType, PermissionDict, SearchMode


logger = get_logger(__name__)
Expand Down Expand Up @@ -72,40 +72,57 @@ class ChatRequest(BaseRequest):

user_id: str = Field(..., description="User ID")
query: str = Field(..., description="Chat query message")
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
readable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can read for multi-cube chat"
)
writable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can write for multi-cube chat"
)
history: list[MessageDict] | None = Field(None, description="Chat history")
history: MessageList | None = Field(None, description="Chat history")
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
internet_search: bool = Field(True, description="Whether to use internet search")
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
threshold: float = Field(0.5, description="Threshold for filtering references")
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
include_preference: bool = Field(True, description="Whether to handle preference memory")
pref_top_k: int = Field(6, description="Number of preference results to return")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
model_name_or_path: str | None = Field(None, description="Model name to use for chat")
max_tokens: int | None = Field(None, description="Max tokens to generate")
temperature: float | None = Field(None, description="Temperature for sampling")
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")

# ==== Filter conditions ====
filter: dict[str, Any] | None = Field(
None,
description="""
Filter for the memory, example:
{
"`and` or `or`": [
{"id": "uuid-xxx"},
{"created_at": {"gt": "2024-01-01"}},
]
}
""",
)

# ==== Extended capabilities ====
internet_search: bool = Field(True, description="Whether to use internet search")
threshold: float = Field(0.5, description="Threshold for filtering references")

# ==== Backward compatibility ====
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
moscube: bool = Field(
False, description="(Deprecated) Whether to use legacy MemOSCube pipeline"
)


class ChatCompleteRequest(BaseRequest):
"""Request model for chat operations."""
"""Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest."""

user_id: str = Field(..., description="User ID")
query: str = Field(..., description="Chat query message")
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
history: list[MessageDict] | None = Field(None, description="Chat history")
history: MessageList | None = Field(None, description="Chat history")
internet_search: bool = Field(False, description="Whether to use internet search")
system_prompt: str | None = Field(None, description="Base prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
Expand Down Expand Up @@ -191,7 +208,7 @@ class MemoryCreateRequest(BaseRequest):
"""Request model for creating memories."""

user_id: str = Field(..., description="User ID")
messages: list[MessageDict] | None = Field(None, description="List of messages to store.")
messages: MessagesType | None = Field(None, description="List of messages to store.")
memory_content: str | None = Field(None, description="Memory content to store")
doc_path: str | None = Field(None, description="Path to document to store")
mem_cube_id: str | None = Field(None, description="Cube ID")
Expand Down Expand Up @@ -269,7 +286,15 @@ class APISearchRequest(BaseRequest):
# TODO: maybe add detailed description later
filter: dict[str, Any] | None = Field(
None,
description=("Filter for the memory"),
description="""
Filter for the memory, example:
{
"`and` or `or`": [
{"id": "uuid-xxx"},
{"created_at": {"gt": "2024-01-01"}},
]
}
""",
)

# ==== Extended capabilities ====
Expand All @@ -291,7 +316,7 @@ class APISearchRequest(BaseRequest):
)

# ==== Context ====
chat_history: MessagesType | None = Field(
chat_history: MessageList | None = Field(
None,
description=(
"Historical chat messages used internally by algorithms. "
Expand Down Expand Up @@ -421,7 +446,7 @@ class APIADDRequest(BaseRequest):
)

# ==== Chat history ====
chat_history: MessagesType | None = Field(
chat_history: MessageList | None = Field(
None,
description=(
"Historical chat messages used internally by algorithms. "
Expand Down Expand Up @@ -540,31 +565,49 @@ class APIChatCompleteRequest(BaseRequest):

user_id: str = Field(..., description="User ID")
query: str = Field(..., description="Chat query message")
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
readable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can read for multi-cube chat"
)
writable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can write for multi-cube chat"
)
history: list[MessageDict] | None = Field(None, description="Chat history")
internet_search: bool = Field(False, description="Whether to use internet search")
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
history: MessageList | None = Field(None, description="Chat history")
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
threshold: float = Field(0.5, description="Threshold for filtering references")
session_id: str | None = Field(
"default_session", description="Session ID for soft-filtering memories"
)
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
include_preference: bool = Field(True, description="Whether to handle preference memory")
pref_top_k: int = Field(6, description="Number of preference results to return")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
model_name_or_path: str | None = Field(None, description="Model name to use for chat")
max_tokens: int | None = Field(None, description="Max tokens to generate")
temperature: float | None = Field(None, description="Temperature for sampling")
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")

# ==== Filter conditions ====
filter: dict[str, Any] | None = Field(
None,
description="""
Filter for the memory, example:
{
"`and` or `or`": [
{"id": "uuid-xxx"},
{"created_at": {"gt": "2024-01-01"}},
]
}
""",
)

# ==== Extended capabilities ====
internet_search: bool = Field(True, description="Whether to use internet search")
threshold: float = Field(0.5, description="Threshold for filtering references")

# ==== Backward compatibility ====
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
moscube: bool = Field(
False, description="(Deprecated) Whether to use legacy MemOSCube pipeline"
)


class AddStatusRequest(BaseRequest):
"""Request model for checking add status."""
Expand Down Expand Up @@ -594,7 +637,7 @@ class SuggestionRequest(BaseRequest):
user_id: str = Field(..., description="User ID")
mem_cube_id: str = Field(..., description="Cube ID")
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
message: list[MessageDict] | None = Field(None, description="List of messages to store.")
message: MessagesType | None = Field(None, description="List of messages to store.")


# ─── MemOS Client Response Models ──────────────────────────────────────────────
Expand Down
Loading
Loading