Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer


if TYPE_CHECKING:
Expand Down Expand Up @@ -142,7 +143,7 @@ def init_server() -> dict[str, Any]:
)

logger.debug("Memory manager initialized")

tokenizer = FastTokenizer()
# Initialize text memory
text_mem = SimpleTreeTextMemory(
llm=llm,
Expand All @@ -153,6 +154,7 @@ def init_server() -> dict[str, Any]:
memory_manager=memory_manager,
config=default_cube_config.text_mem.config,
internet_retriever=internet_retriever,
tokenizer=tokenizer,
)

logger.debug("Text memory initialized")
Expand Down Expand Up @@ -270,7 +272,6 @@ def init_server() -> dict[str, Any]:

online_bot = get_online_bot_function() if dingding_enabled else None
logger.info("DingDing bot is enabled")

# Return all components as a dictionary for easy access and extension
return {
"graph_db": graph_db,
Expand Down
3 changes: 2 additions & 1 deletion src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _fast_search(
"""
target_session_id = search_req.session_id or "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None

plugin = bool(search_req.source is not None and search_req.source == "plugin")
search_results = self.naive_mem_cube.text_mem.search(
query=search_req.query,
user_name=user_context.mem_cube_id,
Expand All @@ -205,6 +205,7 @@ def _fast_search(
"session_id": target_session_id,
"chat_history": search_req.chat_history,
},
plugin=plugin,
)

formatted_memories = [format_memory_item(data) for data in search_results]
Expand Down
1 change: 1 addition & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class APISearchRequest(BaseRequest):
)
include_preference: bool = Field(True, description="Whether to handle preference memory")
pref_top_k: int = Field(6, description="Number of preference results to return")
source: str | None = Field(None, description="Source of the search")


class APIADDRequest(BaseRequest):
Expand Down
109 changes: 109 additions & 0 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,115 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
"""Get the ordered context chain starting from a node."""
raise NotImplementedError

@timed
def search_by_fulltext(
self,
query_words: list[str],
top_k: int = 10,
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
search_filter: dict | None = None,
user_name: str | None = None,
tsvector_field: str = "properties_tsvector_zh",
tsquery_config: str = "jiebaqry",
**kwargs,
) -> list[dict]:
"""
Full-text search functionality using PostgreSQL's full-text search capabilities.

Args:
query_text: query text
top_k: maximum number of results to return
scope: memory type filter (memory_type)
status: status filter, defaults to "activated"
threshold: similarity threshold filter
search_filter: additional property filter conditions
user_name: username filter
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
**kwargs: other parameters (e.g. cube_name)

Returns:
list[dict]: result list containing id and score
"""
# Build WHERE clause dynamically, same as search_by_embedding
where_clauses = []

if scope:
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype"
)
if status:
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype"
)
else:
where_clauses.append(
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
)

# Add user_name filter
user_name = user_name if user_name else self.config.user_name
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
)

# Add search_filter conditions
if search_filter:
for key, value in search_filter.items():
if isinstance(value, str):
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype"
)
else:
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
)

# Add fulltext search condition
# Convert query_text to OR query format: "word1 | word2 | word3"
tsquery_string = " | ".join(query_words)

where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

# Build fulltext search query
query = f"""
SELECT
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
agtype_object_field_text(properties, 'memory') as memory_text,
ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY rank DESC
LIMIT {top_k};
"""

params = [tsquery_string, tsquery_string]

conn = self._get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
output = []
for row in results:
oldid = row[0] # old_id
rank = row[2] # rank score

id_val = str(oldid)
score_val = float(rank)

# Apply threshold filter if specified
if threshold is None or score_val >= threshold:
output.append({"id": id_val, "score": score_val})

return output[:top_k]
finally:
self._return_connection(conn)

@timed
def search_by_embedding(
self,
Expand Down
1 change: 1 addition & 0 deletions src/memos/memories/textual/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
preference: str | None = Field(default=None, description="Preference.")
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
mem_cube_id: str | None = Field(default=None, description="ID of the MemCube.")
score: float | None = Field(default=None, description="Score of the retrieval result.")


class TextualMemoryItem(BaseModel):
Expand Down
29 changes: 17 additions & 12 deletions src/memos/memories/textual/prefer_text_memory/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
response = response.strip().replace("```json", "").replace("```", "").strip()
result = json.loads(response)
result["preference"] = result.pop("implicit_preference")
for d in result:
d["preference"] = d.pop("implicit_preference")
return result
except Exception as e:
logger.error(f"Error extracting implicit preferences: {e}, return None")
Expand Down Expand Up @@ -136,20 +137,24 @@ def _process_single_chunk_implicit(
if not implicit_pref:
return None

vector_info = {
"embedding": self.embedder.embed([implicit_pref["context_summary"]])[0],
}
memories = []
for pref in implicit_pref:
vector_info = {
"embedding": self.embedder.embed([pref["context_summary"]])[0],
}

extract_info = {**basic_info, **implicit_pref, **vector_info, **info}
extract_info = {**basic_info, **pref, **vector_info, **info}

metadata = PreferenceTextualMemoryMetadata(
type=msg_type, preference_type="implicit_preference", **extract_info
)
memory = TextualMemoryItem(
id=extract_info["dialog_id"], memory=implicit_pref["context_summary"], metadata=metadata
)
metadata = PreferenceTextualMemoryMetadata(
type=msg_type, preference_type="implicit_preference", **extract_info
)
memory = TextualMemoryItem(
id=str(uuid.uuid4()), memory=pref["context_summary"], metadata=metadata
)

return memory
memories.append(memory)

return memories

def extract(
self,
Expand Down
40 changes: 30 additions & 10 deletions src/memos/memories/textual/prefer_text_memory/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -34,9 +36,12 @@ def _naive_reranker(
self, query: str, prefs_mem: list[TextualMemoryItem], top_k: int, **kwargs: Any
) -> list[TextualMemoryItem]:
if self.reranker:
prefs_mem = self.reranker.rerank(query, prefs_mem, top_k)
return [item for item, _ in prefs_mem]
return prefs_mem
prefs_mem_reranked = []
prefs_mem_tuple = self.reranker.rerank(query, prefs_mem, top_k)
for item, score in prefs_mem_tuple:
item.metadata.score = score
prefs_mem_reranked.append(item)
return prefs_mem_reranked

def _original_text_reranker(
self,
Expand All @@ -52,11 +57,22 @@ def _original_text_reranker(
prefs_mem_for_reranker = deepcopy(prefs_mem)
for pref_mem, pref in zip(prefs_mem_for_reranker, prefs, strict=False):
pref_mem.memory = pref_mem.memory + "\n" + pref.original_text
prefs_mem_for_reranker = self.reranker.rerank(query, prefs_mem_for_reranker, top_k)
prefs_mem_for_reranker = [item for item, _ in prefs_mem_for_reranker]
reranked_results = self.reranker.rerank(query, prefs_mem_for_reranker, top_k)
prefs_mem_for_reranker = [item for item, _ in reranked_results]
prefs_ids = [item.id for item in prefs_mem_for_reranker]
prefs_dict = {item.id: item for item in prefs_mem}
return [prefs_dict[item_id] for item_id in prefs_ids if item_id in prefs_dict]

# Create mapping from id to score from reranked results
reranked_scores = {item.id: score for item, score in reranked_results}

# Assign scores to the original items
result_items = []
for item_id in prefs_ids:
if item_id in prefs_dict:
original_item = prefs_dict[item_id]
original_item.metadata.score = reranked_scores.get(item_id)
result_items.append(original_item)
return result_items
return prefs_mem

def retrieve(
Expand Down Expand Up @@ -119,9 +135,6 @@ def retrieve(
if pref.payload.get("preference", None)
]

# store explicit id and score, use it after reranker
explicit_id_scores = {item.id: item.score for item in explicit_prefs}

reranker_map = {
"naive": self._naive_reranker,
"original_text": self._original_text_reranker,
Expand All @@ -136,7 +149,14 @@ def retrieve(

# filter explicit mem by score bigger than threshold
explicit_prefs_mem = [
item for item in explicit_prefs_mem if explicit_id_scores.get(item.id, 0) >= 0.0
item
for item in explicit_prefs_mem
if item.metadata.score >= float(os.getenv("PREFERENCE_SEARCH_THRESHOLD", 0.0))
]
implicit_prefs_mem = [
item
for item in implicit_prefs_mem
if item.metadata.score >= float(os.getenv("PREFERENCE_SEARCH_THRESHOLD", 0.0))
]

return explicit_prefs_mem + implicit_prefs_mem
3 changes: 3 additions & 0 deletions src/memos/memories/textual/simple_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from memos.memories.textual.tree import TreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
from memos.reranker.base import BaseReranker


Expand All @@ -35,6 +36,7 @@ def __init__(
config: TreeTextMemoryConfig,
internet_retriever: None = None,
is_reorganize: bool = False,
tokenizer: FastTokenizer | None = None,
):
"""Initialize memory with the given configuration."""
self.config: TreeTextMemoryConfig = config
Expand All @@ -51,6 +53,7 @@ def __init__(
if self.search_strategy and self.search_strategy.get("bm25", False)
else None
)
self.tokenizer = tokenizer
self.reranker = reranker
self.memory_manager: MemoryManager = memory_manager
# Create internet retriever if configured
Expand Down
13 changes: 12 additions & 1 deletion src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, config: TreeTextMemoryConfig):
)
else:
logger.info("No internet retriever configured")
self.tokenizer = None

def add(
self,
Expand Down Expand Up @@ -165,6 +166,7 @@ def search(
moscube: bool = False,
search_filter: dict | None = None,
user_name: str | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
User query -> TaskGoalParser -> MemoryPathResolver ->
Expand Down Expand Up @@ -199,6 +201,7 @@ def search(
moscube=moscube,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
else:
searcher = Searcher(
Expand All @@ -211,9 +214,17 @@ def search(
moscube=moscube,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
return searcher.search(
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
query,
top_k,
info,
mode,
memory_type,
search_filter,
user_name=user_name,
plugin=kwargs.get("plugin", False),
)

def get_relevant_subgraph(
Expand Down
Loading
Loading