diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index ba401e7fdef1..42ee66f619ed 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -352,7 +352,7 @@ def __call__(self, ids: torch.Tensor): result = [] for i in range(len(self.axes_dims)): index = ids[:, i] - result.append(self.freqs_cis[i][index]) + result.append(torch.complex(self.freqs_cis[i].real[index], self.freqs_cis[i].imag[index])) return torch.cat(result, dim=-1)