diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 5926bbb8e713..158ebd6fb3a7 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -106,7 +106,8 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + x1 = hidden_states[..., 0::2] + x2 = hidden_states[..., 1::2] cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states)