Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
Expand All @@ -32,7 +31,7 @@
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook


Expand Down Expand Up @@ -327,7 +326,7 @@ def unshard_anything(
return tensor


@functools.lru_cache(maxsize=64)
@lru_cache_unless_export(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import maybe_allow_in_graph
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm


Expand Down Expand Up @@ -575,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)


@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,
Expand Down
9 changes: 4 additions & 5 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import math
from math import prod
from typing import Any
Expand All @@ -25,7 +24,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
Expand Down Expand Up @@ -307,7 +306,7 @@ def forward(

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
Expand Down Expand Up @@ -428,7 +427,7 @@ def forward(

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
Expand All @@ -450,7 +449,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()

@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@

import functools
import os
from typing import Callable, ParamSpec, TypeVar

from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version


T = TypeVar("T")
P = ParamSpec("P")


if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
Expand Down Expand Up @@ -333,5 +338,21 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)


@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to specify typed?


@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check if is_exporting exists with hastattr?

return fn(*args, **kwargs)
return cached(*args, **kwargs)

return inner_wrapper

return outer_wrapper


if is_torch_available():
torch_device = get_device()
Loading