-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Description
Describe the bug
Hellooo :),
I believe there is a bug with per-module attention backend setting.
Currently, set_attention_backend() works correctly when called on a top-level model (e.g. pipe.transformer.set_attention_backend("sage_hub")), but fails silently when called on individual attention submodules. This means it is not possible to apply a hub-based attention backend (like sage_hub) to only specific transformer blocks of a model.
If this is intended behavior, please feel free to close this issue.
Root Cause
There are two set_attention_backend() methods in diffusers:
-
ModelMixin.set_attention_backend(modeling_utils.py:586): Validates the backend, calls_check_attention_backend_requirements()and_maybe_download_kernel_for_backend(), setsprocessor._attention_backendon all child attention modules, and updates the global active backend. -
AttentionModuleMixin.set_attention_backend(attention.py:161): Only validates the backend name and setsself.processor._attention_backend. It does not call_maybe_download_kernel_for_backend().
Because the per-module method skips the kernel download, hub-based backends (e.g. sage_hub) are never actually loaded. The backend name is set on the processor, but kernel_fn remains None. No error is raised at this point, the failure only occurs later at inference time when dispatch_attention_fn tries to invoke the missing kernel.
Reproduction
Just copy-paste into an empty notebook:
import torch
from diffusers import WanPipeline
from diffusers.models.attention import AttentionModuleMixin
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# This works: ModelMixin.set_attention_backend downloads the hub kernel internally (uncomment line below for running the functional code)
# pipe.transformer.set_attention_backend("sage_hub")
# This does NOT work: AttentionModuleMixin.set_attention_backend skips the kernel download
for name, module in pipe.transformer.named_modules():
if isinstance(module, AttentionModuleMixin):
module.set_attention_backend("sage_hub")
pipe("a cat walking in the snow", num_inference_steps=2) # Here, the error is thrown since the kernel is None as not downloaded
# Root cause: compare the two set_attention_backend implementations
from diffusers.models.attention_dispatch import _HUB_KERNELS_REGISTRY, AttentionBackendName
kernel_fn = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
print(f"sage_hub kernel_fn after submodule-level set_attention_backend: {kernel_fn}")
# None, because AttentionModuleMixin.set_attention_backend (attention.py:161)
# only does `self.processor._attention_backend = backend`
# without calling `_maybe_download_kernel_for_backend(backend)`
Logs
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 9
6 if isinstance(module, AttentionModuleMixin):
7 module.set_attention_backend("sage_hub")
----> 9 pipe("a cat walking in the snow", num_inference_steps=2)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
120 @functools.wraps(func)
121 def decorate_context(*args, **kwargs):
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py:608, in WanPipeline.__call__(self, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, guidance_scale_2, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
605 timestep = t.expand(latents.shape[0])
607 with current_model.cache_context("cond"):
--> 608 noise_pred = current_model(
609 hidden_states=latent_model_input,
610 timestep=timestep,
611 encoder_hidden_states=prompt_embeds,
612 attention_kwargs=attention_kwargs,
613 return_dict=False,
614 )[0]
616 if self.do_classifier_free_guidance:
617 with current_model.cache_context("uncond"):
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/utils/peft_utils.py:315, in apply_lora_scale.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
311 scale_lora_layers(self, lora_scale)
313 try:
314 # Execute the forward pass
--> 315 result = forward_fn(self, *args, **kwargs)
316 return result
317 finally:
318 # Always unscale, even if forward pass raises an exception
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:677, in WanTransformer3DModel.forward(self, hidden_states, timestep, encoder_hidden_states, encoder_hidden_states_image, return_dict, attention_kwargs)
675 else:
676 for block in self.blocks:
--> 677 hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
679 # 5. Output norm, projection & unpatchify
680 if temb.ndim == 3:
681 # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:489, in WanTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, rotary_emb)
487 # 1. Self-attention
488 norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
--> 489 attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
490 hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
492 # 2. Cross-attention
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:281, in WanAttention.forward(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
273 def forward(
274 self,
275 hidden_states: torch.Tensor,
(...) 279 **kwargs,
280 ) -> torch.Tensor:
--> 281 return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:143, in WanAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, rotary_emb)
140 hidden_states_img = hidden_states_img.flatten(2, 3)
141 hidden_states_img = hidden_states_img.type_as(query)
--> 143 hidden_states = dispatch_attention_fn(
144 query,
145 key,
146 value,
147 attn_mask=attention_mask,
148 dropout_p=0.0,
149 is_causal=False,
150 backend=self._attention_backend,
151 # Reference: https://github.com/huggingface/diffusers/pull/12909
152 parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
153 )
154 hidden_states = hidden_states.flatten(2, 3)
155 hidden_states = hidden_states.type_as(query)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:432, in dispatch_attention_fn(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, attention_kwargs, backend, parallel_config)
428 check(**kwargs)
430 kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
--> 432 return backend_fn(**kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:3277, in _sage_attention_hub(query, key, value, attn_mask, is_causal, scale, return_lse, _parallel_config)
3275 func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
3276 if _parallel_config is None:
-> 3277 out = func(
3278 q=query,
3279 k=key,
3280 v=value,
3281 tensor_layout="NHD",
3282 is_causal=is_causal,
3283 sm_scale=scale,
3284 return_lse=return_lse,
3285 )
3286 if return_lse:
3287 out, lse, *_ = out
TypeError: 'NoneType' object is not callableSystem Info
- 🤗 Diffusers version: 0.37.0
- Platform: Linux-6.8.0-106-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.11.15
- PyTorch version (GPU?): 2.10.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.2
- Transformers version: 4.57.6
- Accelerate version: 1.13.0
- PEFT version: 0.18.1
- Bitsandbytes version: 0.49.2
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H100 PCIe, 81559 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
No response