diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 35a9c7117..1d180eebd 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -28,7 +28,12 @@ def __init__(self, config: OpenAILLMConfig): ) logger.info("OpenAI LLM instance initialized") - @timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) + @timed_with_status( + log_prefix="OpenAI LLM", + log_extra_args=lambda self, messages, **kwargs: { + "model_name_or_path": kwargs.get("model_name_or_path", self.config.model_name_or_path) + }, + ) def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from OpenAI LLM, optionally overriding generation params.""" response = self.client.chat.completions.create( @@ -55,7 +60,12 @@ def generate(self, messages: MessageList, **kwargs) -> str: return reasoning_content + response_content return response_content - @timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"]) + @timed_with_status( + log_prefix="OpenAI LLM", + log_extra_args=lambda self, messages, **kwargs: { + "model_name_or_path": self.config.model_name_or_path + }, + ) def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from OpenAI LLM with optional reasoning support.""" if kwargs.get("tools"): diff --git a/src/memos/utils.py b/src/memos/utils.py index e4945b7d3..d787b7ae2 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -22,10 +22,10 @@ def timed_with_status( Parameters: - log: enable timing logs (default True) - log_prefix: prefix; falls back to function name - - log_args: names to include in logs (str or list/tuple of str). - - log_extra_args: extra arguments to include in logs (dict). If it contains - key "time_threshold", use its value (in seconds) as the logging threshold; otherwise - fall back to DEFAULT_TIME_BAR. + - log_args: names to include in logs (str or list/tuple of str), values are taken from kwargs by name. + - log_extra_args: + - can be a dict: fixed contextual fields that are always attached to logs; + - or a callable: like `fn(*args, **kwargs) -> dict`, used to dynamically generate contextual fields at runtime. """ if isinstance(log_args, str): @@ -56,12 +56,24 @@ def wrapper(*args, **kwargs): elapsed_ms = (time.perf_counter() - start) * 1000.0 ctx_parts = [] + # 1) Collect parameters from kwargs by name for key in effective_log_args: val = kwargs.get(key) ctx_parts.append(f"{key}={val}") - if log_extra_args: - ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items()) + # 2) Support log_extra_args as dict or callable, so we can dynamically + # extract values from self or other runtime context + extra_items = {} + try: + if callable(log_extra_args): + extra_items = log_extra_args(*args, **kwargs) or {} + elif isinstance(log_extra_args, dict): + extra_items = log_extra_args + except Exception as e: + logger.warning(f"[TIMER_WITH_STATUS] log_extra_args callback error: {e!r}") + + if extra_items: + ctx_parts.extend(f"{key}={val}" for key, val in extra_items.items()) ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else "" @@ -75,15 +87,8 @@ def wrapper(*args, **kwargs): f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} " f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}" ) - threshold_ms = DEFAULT_TIME_BAR * 1000.0 - if log_extra_args and "time_threshold" in log_extra_args: - try: - threshold_ms = float(log_extra_args["time_threshold"]) * 1000.0 - except Exception: - threshold_ms = DEFAULT_TIME_BAR * 1000.0 - if elapsed_ms >= threshold_ms: - logger.info(msg) + logger.info(msg) return wrapper @@ -92,7 +97,7 @@ def wrapper(*args, **kwargs): return decorator(func) -def timed(func=None, *, log=True, log_prefix=""): +def timed(func=None, *, log=False, log_prefix=""): def decorator(fn): def wrapper(*args, **kwargs): start = time.perf_counter()