Skip to content

AttentionModuleMixin.set_attention_backend does not download hub kernels #13284

@Marius-Graml

Description

@Marius-Graml

Describe the bug

Hellooo :),

I believe there is a bug with per-module attention backend setting.

Currently, set_attention_backend() works correctly when called on a top-level model (e.g. pipe.transformer.set_attention_backend("sage_hub")), but fails silently when called on individual attention submodules. This means it is not possible to apply a hub-based attention backend (like sage_hub) to only specific transformer blocks of a model.

If this is intended behavior, please feel free to close this issue.

Root Cause

There are two set_attention_backend() methods in diffusers:

  1. ModelMixin.set_attention_backend (modeling_utils.py:586): Validates the backend, calls _check_attention_backend_requirements() and _maybe_download_kernel_for_backend(), sets processor._attention_backend on all child attention modules, and updates the global active backend.

  2. AttentionModuleMixin.set_attention_backend (attention.py:161): Only validates the backend name and sets self.processor._attention_backend. It does not call _maybe_download_kernel_for_backend().

Because the per-module method skips the kernel download, hub-based backends (e.g. sage_hub) are never actually loaded. The backend name is set on the processor, but kernel_fn remains None. No error is raised at this point, the failure only occurs later at inference time when dispatch_attention_fn tries to invoke the missing kernel.

Reproduction

Just copy-paste into an empty notebook:

import torch
from diffusers import WanPipeline
from diffusers.models.attention import AttentionModuleMixin

pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# This works: ModelMixin.set_attention_backend downloads the hub kernel internally (uncomment line below for running the functional code)
# pipe.transformer.set_attention_backend("sage_hub")

# This does NOT work: AttentionModuleMixin.set_attention_backend skips the kernel download
for name, module in pipe.transformer.named_modules():
    if isinstance(module, AttentionModuleMixin):
        module.set_attention_backend("sage_hub")

pipe("a cat walking in the snow", num_inference_steps=2) # Here, the error is thrown since the kernel is None as not downloaded

# Root cause: compare the two set_attention_backend implementations
from diffusers.models.attention_dispatch import _HUB_KERNELS_REGISTRY, AttentionBackendName

kernel_fn = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
print(f"sage_hub kernel_fn after submodule-level set_attention_backend: {kernel_fn}")
# None, because AttentionModuleMixin.set_attention_backend (attention.py:161)
# only does `self.processor._attention_backend = backend`
# without calling `_maybe_download_kernel_for_backend(backend)`

Logs

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 9
      6     if isinstance(module, AttentionModuleMixin):
      7         module.set_attention_backend("sage_hub")
----> 9 pipe("a cat walking in the snow", num_inference_steps=2)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    120 @functools.wraps(func)
    121 def decorate_context(*args, **kwargs):
    122     # pyrefly: ignore [bad-context-manager]
    123     with ctx_factory():
--> 124         return func(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py:608, in WanPipeline.__call__(self, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, guidance_scale_2, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    605     timestep = t.expand(latents.shape[0])
    607 with current_model.cache_context("cond"):
--> 608     noise_pred = current_model(
    609         hidden_states=latent_model_input,
    610         timestep=timestep,
    611         encoder_hidden_states=prompt_embeds,
    612         attention_kwargs=attention_kwargs,
    613         return_dict=False,
    614     )[0]
    616 if self.do_classifier_free_guidance:
    617     with current_model.cache_context("uncond"):

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/utils/peft_utils.py:315, in apply_lora_scale.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    311     scale_lora_layers(self, lora_scale)
    313 try:
    314     # Execute the forward pass
--> 315     result = forward_fn(self, *args, **kwargs)
    316     return result
    317 finally:
    318     # Always unscale, even if forward pass raises an exception

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:677, in WanTransformer3DModel.forward(self, hidden_states, timestep, encoder_hidden_states, encoder_hidden_states_image, return_dict, attention_kwargs)
    675 else:
    676     for block in self.blocks:
--> 677         hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
    679 # 5. Output norm, projection & unpatchify
    680 if temb.ndim == 3:
    681     # batch_size, seq_len, inner_dim (wan 2.2 ti2v)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:489, in WanTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, rotary_emb)
    487 # 1. Self-attention
    488 norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
--> 489 attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
    490 hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
    492 # 2. Cross-attention

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:281, in WanAttention.forward(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
    273 def forward(
    274     self,
    275     hidden_states: torch.Tensor,
   (...)    279     **kwargs,
    280 ) -> torch.Tensor:
--> 281     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:143, in WanAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, rotary_emb)
    140     hidden_states_img = hidden_states_img.flatten(2, 3)
    141     hidden_states_img = hidden_states_img.type_as(query)
--> 143 hidden_states = dispatch_attention_fn(
    144     query,
    145     key,
    146     value,
    147     attn_mask=attention_mask,
    148     dropout_p=0.0,
    149     is_causal=False,
    150     backend=self._attention_backend,
    151     # Reference: https://github.com/huggingface/diffusers/pull/12909
    152     parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
    153 )
    154 hidden_states = hidden_states.flatten(2, 3)
    155 hidden_states = hidden_states.type_as(query)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:432, in dispatch_attention_fn(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, attention_kwargs, backend, parallel_config)
    428         check(**kwargs)
    430 kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
--> 432 return backend_fn(**kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:3277, in _sage_attention_hub(query, key, value, attn_mask, is_causal, scale, return_lse, _parallel_config)
   3275 func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
   3276 if _parallel_config is None:
-> 3277     out = func(
   3278         q=query,
   3279         k=key,
   3280         v=value,
   3281         tensor_layout="NHD",
   3282         is_causal=is_causal,
   3283         sm_scale=scale,
   3284         return_lse=return_lse,
   3285     )
   3286     if return_lse:
   3287         out, lse, *_ = out

TypeError: 'NoneType' object is not callable

System Info

  • 🤗 Diffusers version: 0.37.0
  • Platform: Linux-6.8.0-106-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.11.15
  • PyTorch version (GPU?): 2.10.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.2
  • Transformers version: 4.57.6
  • Accelerate version: 1.13.0
  • PEFT version: 0.18.1
  • Bitsandbytes version: 0.49.2
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA H100 PCIe, 81559 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions