Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer


if TYPE_CHECKING:
Expand Down Expand Up @@ -142,7 +143,7 @@ def init_server() -> dict[str, Any]:
)

logger.debug("Memory manager initialized")

tokenizer = FastTokenizer()
# Initialize text memory
text_mem = SimpleTreeTextMemory(
llm=llm,
Expand All @@ -153,6 +154,7 @@ def init_server() -> dict[str, Any]:
memory_manager=memory_manager,
config=default_cube_config.text_mem.config,
internet_retriever=internet_retriever,
tokenizer=tokenizer,
)

logger.debug("Text memory initialized")
Expand Down Expand Up @@ -270,7 +272,6 @@ def init_server() -> dict[str, Any]:

online_bot = get_online_bot_function() if dingding_enabled else None
logger.info("DingDing bot is enabled")

# Return all components as a dictionary for easy access and extension
return {
"graph_db": graph_db,
Expand Down
3 changes: 2 additions & 1 deletion src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _fast_search(
"""
target_session_id = search_req.session_id or "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None

plugin = bool(search_req.info is not None and search_req.info.get("origin_model"))
search_results = self.naive_mem_cube.text_mem.search(
query=search_req.query,
user_name=user_context.mem_cube_id,
Expand All @@ -205,6 +205,7 @@ def _fast_search(
"session_id": target_session_id,
"chat_history": search_req.chat_history,
},
plugin=plugin,
)

formatted_memories = [format_memory_item(data) for data in search_results]
Expand Down
1 change: 1 addition & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class APISearchRequest(BaseRequest):
)
include_preference: bool = Field(True, description="Whether to handle preference memory")
pref_top_k: int = Field(6, description="Number of preference results to return")
info: dict | None = Field(None, description="Info for search")


class APIADDRequest(BaseRequest):
Expand Down
109 changes: 109 additions & 0 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,115 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
"""Get the ordered context chain starting from a node."""
raise NotImplementedError

@timed
def search_by_fulltext(
self,
query_words: list[str],
top_k: int = 10,
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
search_filter: dict | None = None,
user_name: str | None = None,
tsvector_field: str = "properties_tsvector_zh",
tsquery_config: str = "jiebaqry",
**kwargs,
) -> list[dict]:
"""
Full-text search functionality using PostgreSQL's full-text search capabilities.

Args:
query_text: query text
top_k: maximum number of results to return
scope: memory type filter (memory_type)
status: status filter, defaults to "activated"
threshold: similarity threshold filter
search_filter: additional property filter conditions
user_name: username filter
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
**kwargs: other parameters (e.g. cube_name)

Returns:
list[dict]: result list containing id and score
"""
# Build WHERE clause dynamically, same as search_by_embedding
where_clauses = []

if scope:
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype"
)
if status:
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype"
)
else:
where_clauses.append(
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
)

# Add user_name filter
user_name = user_name if user_name else self.config.user_name
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
)

# Add search_filter conditions
if search_filter:
for key, value in search_filter.items():
if isinstance(value, str):
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype"
)
else:
where_clauses.append(
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
)

# Add fulltext search condition
# Convert query_text to OR query format: "word1 | word2 | word3"
tsquery_string = " | ".join(query_words)

where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

# Build fulltext search query
query = f"""
SELECT
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
agtype_object_field_text(properties, 'memory') as memory_text,
ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY rank DESC
LIMIT {top_k};
"""

params = [tsquery_string, tsquery_string]

conn = self._get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
output = []
for row in results:
oldid = row[0] # old_id
rank = row[2] # rank score

id_val = str(oldid)
score_val = float(rank)

# Apply threshold filter if specified
if threshold is None or score_val >= threshold:
output.append({"id": id_val, "score": score_val})

return output[:top_k]
finally:
self._return_connection(conn)

@timed
def search_by_embedding(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/memos/memories/textual/simple_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from memos.memories.textual.tree import TreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
from memos.reranker.base import BaseReranker


Expand All @@ -35,6 +36,7 @@ def __init__(
config: TreeTextMemoryConfig,
internet_retriever: None = None,
is_reorganize: bool = False,
tokenizer: FastTokenizer | None = None,
):
"""Initialize memory with the given configuration."""
self.config: TreeTextMemoryConfig = config
Expand All @@ -51,6 +53,7 @@ def __init__(
if self.search_strategy and self.search_strategy.get("bm25", False)
else None
)
self.tokenizer = tokenizer
self.reranker = reranker
self.memory_manager: MemoryManager = memory_manager
# Create internet retriever if configured
Expand Down
13 changes: 12 additions & 1 deletion src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, config: TreeTextMemoryConfig):
)
else:
logger.info("No internet retriever configured")
self.tokenizer = None

def add(
self,
Expand Down Expand Up @@ -165,6 +166,7 @@ def search(
moscube: bool = False,
search_filter: dict | None = None,
user_name: str | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
User query -> TaskGoalParser -> MemoryPathResolver ->
Expand Down Expand Up @@ -199,6 +201,7 @@ def search(
moscube=moscube,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
else:
searcher = Searcher(
Expand All @@ -211,9 +214,17 @@ def search(
moscube=moscube,
search_strategy=self.search_strategy,
manual_close_internet=manual_close_internet,
tokenizer=self.tokenizer,
)
return searcher.search(
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
query,
top_k,
info,
mode,
memory_type,
search_filter,
user_name=user_name,
plugin=kwargs.get("plugin", False),
)

def get_relevant_subgraph(
Expand Down
22 changes: 21 additions & 1 deletion src/memos/memories/textual/tree_text_memory/retrieve/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,26 @@ def retrieve_from_cube(

return list(combined.values())

def retrieve_from_mixed(
self,
top_k: int,
memory_scope: str | None = None,
query_embedding: list[list[float]] | None = None,
search_filter: dict | None = None,
user_name: str | None = None,
use_fast_graph: bool = False,
) -> list[TextualMemoryItem]:
"""Retrieve from mixed and memory"""
vector_results = self._vector_recall(
query_embedding or [],
memory_scope,
top_k,
search_filter=search_filter,
user_name=user_name,
) # Merge and deduplicate by ID
combined = {item.id: item for item in vector_results}
return list(combined.values())

def _graph_recall(
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs
) -> list[TextualMemoryItem]:
Expand Down Expand Up @@ -270,7 +290,7 @@ def _vector_recall(
query_embedding: list[list[float]],
memory_scope: str,
top_k: int = 20,
max_num: int = 5,
max_num: int = 20,
status: str = "activated",
cube_name: str | None = None,
search_filter: dict | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pathlib import Path

import numpy as np

from memos.dependency import require_python_package
from memos.log import get_logger

Expand Down Expand Up @@ -376,3 +378,28 @@ def detect_lang(text):
return "en"
except Exception:
return "en"


def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8):
assert len(sentences) == len(similarity_matrix)

num_sentence = len(sentences)
selected_sentences = []
selected_indices = []
for i in range(num_sentence):
can_add = True
for j in selected_indices:
if similarity_matrix[i][j] > bar:
can_add = False
break
if can_add:
selected_sentences.append(i)
selected_indices.append(i)
return selected_sentences, selected_indices


def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
x_normalized = embeddings / norms
similarity_matrix = np.dot(x_normalized, x_normalized.T)
return similarity_matrix
Loading
Loading