Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/memos/embedders/universal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from memos.configs.embedder import UniversalAPIEmbedderConfig
from memos.embedders.base import BaseEmbedder
from memos.log import get_logger
from memos.utils import timed
from memos.utils import timed_with_status


logger = get_logger(__name__)
Expand All @@ -30,8 +30,7 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
else:
raise ValueError(f"Embeddings unsupported provider: {self.provider}")

@timed(
log=True,
@timed_with_status(
log_prefix="model_timed_embedding",
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
)
Expand Down
6 changes: 3 additions & 3 deletions src/memos/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from memos.llms.utils import remove_thinking_tags
from memos.log import get_logger
from memos.types import MessageList
from memos.utils import timed
from memos.utils import timed_with_status


logger = get_logger(__name__)
Expand All @@ -28,7 +28,7 @@ def __init__(self, config: OpenAILLMConfig):
)
logger.info("OpenAI LLM instance initialized")

@timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
response = self.client.chat.completions.create(
Expand All @@ -55,7 +55,7 @@ def generate(self, messages: MessageList, **kwargs) -> str:
return reasoning_content + response_content
return response_content

@timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
"""Stream response from OpenAI LLM with optional reasoning support."""
if kwargs.get("tools"):
Expand Down
106 changes: 51 additions & 55 deletions src/memos/reranker/http_bge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import requests

from memos.log import get_logger
from memos.utils import timed
from memos.utils import timed_with_status

from .base import BaseReranker
from .concat import concat_original_source
Expand Down Expand Up @@ -119,8 +119,12 @@ def __init__(
self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys)
self._warned_missing_keys: set[str] = set()

@timed(
log=True, log_prefix="model_timed_rerank", log_extra_args={"model_name_or_path": "reranker"}
@timed_with_status(
log_prefix="model_timed_rerank",
log_extra_args={"model_name_or_path": "reranker"},
fallback=lambda exc, self, query, graph_results, top_k, *a, **kw: [
(item, 0.0) for item in graph_results[:top_k]
],
)
def rerank(
self,
Expand Down Expand Up @@ -150,6 +154,7 @@ def rerank(
list[tuple[TextualMemoryItem, float]]
Re-ranked items with scores, sorted descending by score.
"""

if not graph_results:
return []

Expand All @@ -173,63 +178,54 @@ def rerank(
headers = {"Content-Type": "application/json", **self.headers_extra}
payload = {"model": self.model, "query": query, "documents": documents}

try:
# Make the HTTP request to the reranker service
resp = requests.post(
self.reranker_url, headers=headers, json=payload, timeout=self.timeout
)
resp.raise_for_status()
data = resp.json()

scored_items: list[tuple[TextualMemoryItem, float]] = []

if "results" in data:
# Format:
# dict("results": [{"index": int, "relevance_score": float},
# ...])
rows = data.get("results", [])
for r in rows:
idx = r.get("index")
# The returned index refers to 'documents' (i.e., our 'pairs' order),
# so we must map it back to the original graph_results index.
if isinstance(idx, int) and 0 <= idx < len(graph_results):
raw_score = float(r.get("relevance_score", r.get("score", 0.0)))
item = graph_results[idx]
# generic boost
score = self._apply_boost_generic(item, raw_score, search_priority)
scored_items.append((item, score))

scored_items.sort(key=lambda x: x[1], reverse=True)
return scored_items[: min(top_k, len(scored_items))]

elif "data" in data:
# Format: {"data": [{"score": float}, ...]} aligned by list order
rows = data.get("data", [])
# Build a list of scores aligned with our 'documents' (pairs)
score_list = [float(r.get("score", 0.0)) for r in rows]

if len(score_list) < len(graph_results):
score_list += [0.0] * (len(graph_results) - len(score_list))
elif len(score_list) > len(graph_results):
score_list = score_list[: len(graph_results)]

scored_items = []
for item, raw_score in zip(graph_results, score_list, strict=False):
# Make the HTTP request to the reranker service
resp = requests.post(self.reranker_url, headers=headers, json=payload, timeout=self.timeout)
resp.raise_for_status()
data = resp.json()

scored_items: list[tuple[TextualMemoryItem, float]] = []

if "results" in data:
# Format:
# dict("results": [{"index": int, "relevance_score": float},
# ...])
rows = data.get("results", [])
for r in rows:
idx = r.get("index")
# The returned index refers to 'documents' (i.e., our 'pairs' order),
# so we must map it back to the original graph_results index.
if isinstance(idx, int) and 0 <= idx < len(graph_results):
raw_score = float(r.get("relevance_score", r.get("score", 0.0)))
item = graph_results[idx]
# generic boost
score = self._apply_boost_generic(item, raw_score, search_priority)
scored_items.append((item, score))

scored_items.sort(key=lambda x: x[1], reverse=True)
return scored_items[: min(top_k, len(scored_items))]
scored_items.sort(key=lambda x: x[1], reverse=True)
return scored_items[: min(top_k, len(scored_items))]

elif "data" in data:
# Format: {"data": [{"score": float}, ...]} aligned by list order
rows = data.get("data", [])
# Build a list of scores aligned with our 'documents' (pairs)
score_list = [float(r.get("score", 0.0)) for r in rows]

if len(score_list) < len(graph_results):
score_list += [0.0] * (len(graph_results) - len(score_list))
elif len(score_list) > len(graph_results):
score_list = score_list[: len(graph_results)]

else:
# Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs
# Note: we use 'pairs' to keep alignment with valid (string) docs.
return [(item, 0.0) for item in graph_results[:top_k]]
scored_items = []
for item, raw_score in zip(graph_results, score_list, strict=False):
score = self._apply_boost_generic(item, raw_score, search_priority)
scored_items.append((item, score))

except Exception as e:
# Network error, timeout, JSON decode error, etc.
# Degrade gracefully by returning first top_k valid docs with 0.0 score.
logger.error(f"[HTTPBGEReranker] request failed: {e}")
scored_items.sort(key=lambda x: x[1], reverse=True)
return scored_items[: min(top_k, len(scored_items))]

else:
# Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs
# Note: we use 'pairs' to keep alignment with valid (string) docs.
return [(item, 0.0) for item in graph_results[:top_k]]

def _get_attr_or_key(self, obj: Any, key: str) -> Any:
Expand Down
90 changes: 67 additions & 23 deletions src/memos/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import time

from memos.log import get_logger
Expand All @@ -6,47 +7,90 @@
logger = get_logger(__name__)


def timed(func=None, *, log=True, log_prefix="", log_args=None, log_extra_args=None):
def timed_with_status(
func=None,
*,
log_prefix="",
log_args=None,
log_extra_args=None,
fallback=None,
):
"""
Parameters:
- log: enable timing logs (default True)
- log_prefix: prefix; falls back to function name
- log_args: names to include in logs (str or list/tuple of str).
Value priority: kwargs → args[0].config.<name> (if available).
Non-string items are ignored.

Examples:
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path", "temperature"])
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["temperature"])
- @timed() # defaults
- log_extra_args: extra arguments to include in logs (dict).
"""

if isinstance(log_args, str):
effective_log_args = [log_args]
else:
effective_log_args = list(log_args) if log_args else []

def decorator(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = fn(*args, **kwargs)
elapsed_ms = (time.perf_counter() - start) * 1000.0
ctx_str = ""
ctx_parts = []
exc_type = None
result = None
success_flag = False

if log is not True:
try:
result = fn(*args, **kwargs)
success_flag = True
return result
except Exception as e:
exc_type = type(e)
success_flag = False

if fallback is not None and callable(fallback):
result = fallback(e, *args, **kwargs)
return result
finally:
elapsed_ms = (time.perf_counter() - start) * 1000.0

if log_args:
for key in log_args:
ctx_parts = []
for key in effective_log_args:
val = kwargs.get(key)
ctx_parts.append(f"{key}={val}")
ctx_str = f" [{', '.join(ctx_parts)}]"

if log_extra_args:
ctx_parts.extend([f"{key}={val}" for key, val in log_extra_args.items()])
if log_extra_args:
ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items())

ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else ""

status = "SUCCESS" if success_flag else "FAILED"
status_info = f", status: {status}"

if not success_flag and exc_type is not None:
status_info += f", error: {exc_type.__name__}"

msg = (
f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} "
f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}"
)

logger.info(msg)

return wrapper

if func is None:
return decorator
return decorator(func)

if ctx_parts:
ctx_str = f" [{', '.join(ctx_parts)}]"

logger.info(
f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms, args: {ctx_str}"
)
def timed(func=None, *, log=True, log_prefix=""):
def decorator(fn):
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = fn(*args, **kwargs)
elapsed_ms = (time.perf_counter() - start) * 1000.0

if log is not True:
return result

logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")

return result

Expand Down
Loading