From bc8817604afce8e2cf4c612897fd59e5e0242082 Mon Sep 17 00:00:00 2001 From: cbensimon Date: Thu, 19 Mar 2026 17:30:58 +0000 Subject: [PATCH] [core] Add export-safe LRU cache helper --- src/diffusers/hooks/context_parallel.py | 5 ++--- src/diffusers/models/attention_dispatch.py | 4 ++-- .../transformers/transformer_qwenimage.py | 9 ++++---- src/diffusers/utils/torch_utils.py | 21 +++++++++++++++++++ 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6130be2b8290..f6ab623a1865 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import functools import inspect from dataclasses import dataclass from typing import Type @@ -32,7 +31,7 @@ gather_size_by_comm, ) from ..utils import get_logger -from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module +from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module from .hooks import HookRegistry, ModelHook @@ -327,7 +326,7 @@ def unshard_anything( return tensor -@functools.lru_cache(maxsize=64) +@lru_cache_unless_export(maxsize=64) def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]: gather_shapes = [] for i in range(world_size): diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5b1f831ed060..b4318be5d405 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -49,7 +49,7 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -from ..utils.torch_utils import maybe_allow_in_graph +from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph from ._modeling_parallel import gather_size_by_comm @@ -575,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) -@functools.lru_cache(maxsize=128) +@lru_cache_unless_export(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, seq_len_q: int, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a54cb3b8e092..c5419b9f107e 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import math from math import prod from typing import Any @@ -25,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, deprecate, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -307,7 +306,7 @@ def forward( return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=128) + @lru_cache_unless_export(maxsize=128) def _compute_video_freqs( self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None ) -> torch.Tensor: @@ -428,7 +427,7 @@ def forward( return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=None) + @lru_cache_unless_export(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs @@ -450,7 +449,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() - @functools.lru_cache(maxsize=None) + @lru_cache_unless_export(maxsize=None) def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7f4cb3e12766..f342a0f36349 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -19,11 +19,16 @@ import functools import os +from typing import Callable, ParamSpec, TypeVar from . import logging from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +T = TypeVar("T") +P = ParamSpec("P") + + if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift @@ -333,5 +338,21 @@ def disable_full_determinism(): torch.use_deterministic_algorithms(False) +@functools.wraps(functools.lru_cache) +def lru_cache_unless_export(maxsize=128, typed=False): + def outer_wrapper(fn: Callable[P, T]): + cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn) + + @functools.wraps(fn) + def inner_wrapper(*args: P.args, **kwargs: P.kwargs): + if torch.compiler.is_exporting(): + return fn(*args, **kwargs) + return cached(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + if is_torch_available(): torch_device = get_device()