diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index f39ffaa58..79a5d9ea6 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -30,7 +30,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig): else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") - @timed(log=True, log_prefix="model_timed_embedding") + @timed( + log=True, + log_prefix="model_timed_embedding", + log_extra_args={"model_name_or_path": "text-embedding-3-large"}, + ) def embed(self, texts: list[str]) -> list[list[float]]: if self.provider == "openai" or self.provider == "azure": try: diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 19d7a60fe..c45038e9d 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -28,7 +28,7 @@ def __init__(self, config: OpenAILLMConfig): ) logger.info("OpenAI LLM instance initialized") - @timed(log=True, log_prefix="OpenAI LLM") + @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from OpenAI LLM, optionally overriding generation params.""" response = self.client.chat.completions.create( diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 764b53032..29f41e38f 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -119,7 +119,9 @@ def __init__( self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) self._warned_missing_keys: set[str] = set() - @timed(log=True, log_prefix="model_timed_rerank") + @timed( + log=True, log_prefix="model_timed_rerank", log_extra_args={"model_name_or_path": "reranker"} + ) def rerank( self, query: str, diff --git a/src/memos/utils.py b/src/memos/utils.py index 4b1a59834..6671d88b7 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -6,20 +6,48 @@ logger = get_logger(__name__) -def timed(func=None, *, log=True, log_prefix=""): - """Decorator to measure and optionally log time of retrieval steps. - - Can be used as @timed or @timed(log=True) +def timed(func=None, *, log=True, log_prefix="", log_args=None, log_extra_args=None): + """ + Parameters: + - log: enable timing logs (default True) + - log_prefix: prefix; falls back to function name + - log_args: names to include in logs (str or list/tuple of str). + Value priority: kwargs → args[0].config. (if available). + Non-string items are ignored. + + Examples: + - @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path", "temperature"]) + - @timed(log=True, log_prefix="OpenAI LLM", log_args=["temperature"]) + - @timed() # defaults """ def decorator(fn): def wrapper(*args, **kwargs): start = time.perf_counter() result = fn(*args, **kwargs) - elapsed = time.perf_counter() - start - elapsed_ms = elapsed * 1000.0 - if log: - logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms") + elapsed_ms = (time.perf_counter() - start) * 1000.0 + ctx_str = "" + ctx_parts = [] + + if log is not True: + return result + + if log_args: + for key in log_args: + val = kwargs.get(key) + ctx_parts.append(f"{key}={val}") + ctx_str = f" [{', '.join(ctx_parts)}]" + + if log_extra_args: + ctx_parts.extend([f"{key}={val}" for key, val in log_extra_args.items()]) + + if ctx_parts: + ctx_str = f" [{', '.join(ctx_parts)}]" + + logger.info( + f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms, args: {ctx_str}" + ) + return result return wrapper