diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index e74e50614..60bae15a5 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -4,7 +4,7 @@ from memos.configs.embedder import UniversalAPIEmbedderConfig from memos.embedders.base import BaseEmbedder from memos.log import get_logger -from memos.utils import timed +from memos.utils import timed_with_status logger = get_logger(__name__) @@ -30,8 +30,7 @@ def __init__(self, config: UniversalAPIEmbedderConfig): else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") - @timed( - log=True, + @timed_with_status( log_prefix="model_timed_embedding", log_extra_args={"model_name_or_path": "text-embedding-3-large"}, ) diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index f4ebf45c7..35a9c7117 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -12,7 +12,7 @@ from memos.llms.utils import remove_thinking_tags from memos.log import get_logger from memos.types import MessageList -from memos.utils import timed +from memos.utils import timed_with_status logger = get_logger(__name__) @@ -28,7 +28,7 @@ def __init__(self, config: OpenAILLMConfig): ) logger.info("OpenAI LLM instance initialized") - @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) + @timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from OpenAI LLM, optionally overriding generation params.""" response = self.client.chat.completions.create( @@ -55,7 +55,7 @@ def generate(self, messages: MessageList, **kwargs) -> str: return reasoning_content + response_content return response_content - @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) + @timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from OpenAI LLM with optional reasoning support.""" if kwargs.get("tools"): diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 29f41e38f..4e9054f1e 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -9,7 +9,7 @@ import requests from memos.log import get_logger -from memos.utils import timed +from memos.utils import timed_with_status from .base import BaseReranker from .concat import concat_original_source @@ -119,8 +119,12 @@ def __init__( self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) self._warned_missing_keys: set[str] = set() - @timed( - log=True, log_prefix="model_timed_rerank", log_extra_args={"model_name_or_path": "reranker"} + @timed_with_status( + log_prefix="model_timed_rerank", + log_extra_args={"model_name_or_path": "reranker"}, + fallback=lambda exc, self, query, graph_results, top_k, *a, **kw: [ + (item, 0.0) for item in graph_results[:top_k] + ], ) def rerank( self, @@ -150,6 +154,7 @@ def rerank( list[tuple[TextualMemoryItem, float]] Re-ranked items with scores, sorted descending by score. """ + if not graph_results: return [] @@ -173,63 +178,54 @@ def rerank( headers = {"Content-Type": "application/json", **self.headers_extra} payload = {"model": self.model, "query": query, "documents": documents} - try: - # Make the HTTP request to the reranker service - resp = requests.post( - self.reranker_url, headers=headers, json=payload, timeout=self.timeout - ) - resp.raise_for_status() - data = resp.json() - - scored_items: list[tuple[TextualMemoryItem, float]] = [] - - if "results" in data: - # Format: - # dict("results": [{"index": int, "relevance_score": float}, - # ...]) - rows = data.get("results", []) - for r in rows: - idx = r.get("index") - # The returned index refers to 'documents' (i.e., our 'pairs' order), - # so we must map it back to the original graph_results index. - if isinstance(idx, int) and 0 <= idx < len(graph_results): - raw_score = float(r.get("relevance_score", r.get("score", 0.0))) - item = graph_results[idx] - # generic boost - score = self._apply_boost_generic(item, raw_score, search_priority) - scored_items.append((item, score)) - - scored_items.sort(key=lambda x: x[1], reverse=True) - return scored_items[: min(top_k, len(scored_items))] - - elif "data" in data: - # Format: {"data": [{"score": float}, ...]} aligned by list order - rows = data.get("data", []) - # Build a list of scores aligned with our 'documents' (pairs) - score_list = [float(r.get("score", 0.0)) for r in rows] - - if len(score_list) < len(graph_results): - score_list += [0.0] * (len(graph_results) - len(score_list)) - elif len(score_list) > len(graph_results): - score_list = score_list[: len(graph_results)] - - scored_items = [] - for item, raw_score in zip(graph_results, score_list, strict=False): + # Make the HTTP request to the reranker service + resp = requests.post(self.reranker_url, headers=headers, json=payload, timeout=self.timeout) + resp.raise_for_status() + data = resp.json() + + scored_items: list[tuple[TextualMemoryItem, float]] = [] + + if "results" in data: + # Format: + # dict("results": [{"index": int, "relevance_score": float}, + # ...]) + rows = data.get("results", []) + for r in rows: + idx = r.get("index") + # The returned index refers to 'documents' (i.e., our 'pairs' order), + # so we must map it back to the original graph_results index. + if isinstance(idx, int) and 0 <= idx < len(graph_results): + raw_score = float(r.get("relevance_score", r.get("score", 0.0))) + item = graph_results[idx] + # generic boost score = self._apply_boost_generic(item, raw_score, search_priority) scored_items.append((item, score)) - scored_items.sort(key=lambda x: x[1], reverse=True) - return scored_items[: min(top_k, len(scored_items))] + scored_items.sort(key=lambda x: x[1], reverse=True) + return scored_items[: min(top_k, len(scored_items))] + + elif "data" in data: + # Format: {"data": [{"score": float}, ...]} aligned by list order + rows = data.get("data", []) + # Build a list of scores aligned with our 'documents' (pairs) + score_list = [float(r.get("score", 0.0)) for r in rows] + + if len(score_list) < len(graph_results): + score_list += [0.0] * (len(graph_results) - len(score_list)) + elif len(score_list) > len(graph_results): + score_list = score_list[: len(graph_results)] - else: - # Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs - # Note: we use 'pairs' to keep alignment with valid (string) docs. - return [(item, 0.0) for item in graph_results[:top_k]] + scored_items = [] + for item, raw_score in zip(graph_results, score_list, strict=False): + score = self._apply_boost_generic(item, raw_score, search_priority) + scored_items.append((item, score)) - except Exception as e: - # Network error, timeout, JSON decode error, etc. - # Degrade gracefully by returning first top_k valid docs with 0.0 score. - logger.error(f"[HTTPBGEReranker] request failed: {e}") + scored_items.sort(key=lambda x: x[1], reverse=True) + return scored_items[: min(top_k, len(scored_items))] + + else: + # Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs + # Note: we use 'pairs' to keep alignment with valid (string) docs. return [(item, 0.0) for item in graph_results[:top_k]] def _get_attr_or_key(self, obj: Any, key: str) -> Any: diff --git a/src/memos/utils.py b/src/memos/utils.py index 6671d88b7..a29eaf99d 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -1,3 +1,4 @@ +import functools import time from memos.log import get_logger @@ -6,47 +7,90 @@ logger = get_logger(__name__) -def timed(func=None, *, log=True, log_prefix="", log_args=None, log_extra_args=None): +def timed_with_status( + func=None, + *, + log_prefix="", + log_args=None, + log_extra_args=None, + fallback=None, +): """ Parameters: - log: enable timing logs (default True) - log_prefix: prefix; falls back to function name - log_args: names to include in logs (str or list/tuple of str). - Value priority: kwargs → args[0].config. (if available). - Non-string items are ignored. - - Examples: - - @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path", "temperature"]) - - @timed(log=True, log_prefix="OpenAI LLM", log_args=["temperature"]) - - @timed() # defaults + - log_extra_args: extra arguments to include in logs (dict). """ + if isinstance(log_args, str): + effective_log_args = [log_args] + else: + effective_log_args = list(log_args) if log_args else [] + def decorator(fn): + @functools.wraps(fn) def wrapper(*args, **kwargs): start = time.perf_counter() - result = fn(*args, **kwargs) - elapsed_ms = (time.perf_counter() - start) * 1000.0 - ctx_str = "" - ctx_parts = [] + exc_type = None + result = None + success_flag = False - if log is not True: + try: + result = fn(*args, **kwargs) + success_flag = True return result + except Exception as e: + exc_type = type(e) + success_flag = False + + if fallback is not None and callable(fallback): + result = fallback(e, *args, **kwargs) + return result + finally: + elapsed_ms = (time.perf_counter() - start) * 1000.0 - if log_args: - for key in log_args: + ctx_parts = [] + for key in effective_log_args: val = kwargs.get(key) ctx_parts.append(f"{key}={val}") - ctx_str = f" [{', '.join(ctx_parts)}]" - if log_extra_args: - ctx_parts.extend([f"{key}={val}" for key, val in log_extra_args.items()]) + if log_extra_args: + ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items()) + + ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else "" + + status = "SUCCESS" if success_flag else "FAILED" + status_info = f", status: {status}" + + if not success_flag and exc_type is not None: + status_info += f", error: {exc_type.__name__}" + + msg = ( + f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} " + f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}" + ) + + logger.info(msg) + + return wrapper + + if func is None: + return decorator + return decorator(func) - if ctx_parts: - ctx_str = f" [{', '.join(ctx_parts)}]" - logger.info( - f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms, args: {ctx_str}" - ) +def timed(func=None, *, log=True, log_prefix=""): + def decorator(fn): + def wrapper(*args, **kwargs): + start = time.perf_counter() + result = fn(*args, **kwargs) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + + if log is not True: + return result + + logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms") return result