From 832b4d26dd5480d7c11d42d622fb902ef7df33db Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 17 Mar 2026 16:08:55 +0800 Subject: [PATCH 1/2] fix mask in SP --- src/diffusers/models/attention_dispatch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5b1f831ed060..cfb796d05af6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -813,6 +813,9 @@ def _native_attention_forward_op( if return_lse: raise ValueError("Native attention does not support return_lse=True") + if attn_mask is not None and attn_mask.dim() == 2: + attn_mask = attn_mask[:, None, None, :] + # used for backward pass if _save_ctx: ctx.save_for_backward(query, key, value) From 1e535f0d6de63debb7569cce70117ae2a6969f2d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 17 Mar 2026 18:03:34 +0800 Subject: [PATCH 2/2] change the modification to qwen specific --- src/diffusers/models/attention_dispatch.py | 3 --- src/diffusers/models/transformers/transformer_qwenimage.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index cfb796d05af6..5b1f831ed060 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -813,9 +813,6 @@ def _native_attention_forward_op( if return_lse: raise ValueError("Native attention does not support return_lse=True") - if attn_mask is not None and attn_mask.dim() == 2: - attn_mask = attn_mask[:, None, None, :] - # used for backward pass if _save_ctx: ctx.save_for_backward(query, key, value) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a54cb3b8e092..a76e4dbc93b3 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -934,6 +934,7 @@ def forward( batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + joint_attention_mask = joint_attention_mask[:, None, None, :] block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks):