From 1a22018ac150dcdb7d2cb36c19be0d98167a42cd Mon Sep 17 00:00:00 2001 From: Zhibin Mo <97496981+luren55@users.noreply.github.com> Date: Thu, 7 May 2026 16:42:02 +0800 Subject: [PATCH 1/3] z-image support npu --- .../transformers/transformer_z_image.py | 72 ++++++++++++++----- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index ba401e7fdef1..7b3d016e6e58 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -24,6 +24,7 @@ from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm +from ...utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput @@ -322,37 +323,72 @@ def __init__( self.axes_lens = axes_lens assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" self.freqs_cis = None + self.freqs_real = None + self.freqs_imag = None @staticmethod def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): with torch.device("cpu"): - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis + if is_torch_npu_available: + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + else: + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + return freqs_cis def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + if is_torch_npu_available: + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) return torch.cat(result, dim=-1) From 3d12f3baf519c2d5d30ba38914ae08688cf7e4cb Mon Sep 17 00:00:00 2001 From: Zhibin Mo <97496981+luren55@users.noreply.github.com> Date: Fri, 8 May 2026 10:27:51 +0800 Subject: [PATCH 2/3] update a general solution for supporting both NPU and GPU --- .../transformers/transformer_z_image.py | 72 +++++-------------- 1 file changed, 18 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 7b3d016e6e58..42ee66f619ed 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -24,7 +24,6 @@ from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import RMSNorm -from ...utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput @@ -323,72 +322,37 @@ def __init__( self.axes_lens = axes_lens assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" self.freqs_cis = None - self.freqs_real = None - self.freqs_imag = None @staticmethod def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): with torch.device("cpu"): - if is_torch_npu_available: - freqs_real_list = [] - freqs_imag_list = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_real = torch.cos(freqs) - freqs_imag = torch.sin(freqs) - freqs_real_list.append(freqs_real.to(torch.float32)) - freqs_imag_list.append(freqs_imag.to(torch.float32)) - - return freqs_real_list, freqs_imag_list - else: - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - return freqs_cis + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device - if is_torch_npu_available: - if self.freqs_real is None or self.freqs_imag is None: - freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_real = [fr.to(device) for fr in freqs_real] - self.freqs_imag = [fi.to(device) for fi in freqs_imag] - else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_real[0].device != device: - self.freqs_real = [fr.to(device) for fr in freqs_real] - self.freqs_imag = [fi.to(device) for fi in freqs_imag] - - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - real_part = self.freqs_real[i][index] - imag_part = self.freqs_imag[i][index] - complex_part = torch.complex(real_part, imag_part) - result.append(complex_part) + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] else: - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(torch.complex(self.freqs_cis[i].real[index], self.freqs_cis[i].imag[index])) return torch.cat(result, dim=-1) From bc2b732b833efe54c7bb0a8bcce844fae0d080a2 Mon Sep 17 00:00:00 2001 From: Matrix_K Date: Sun, 10 May 2026 15:43:46 +0000 Subject: [PATCH 3/3] fix-copies --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 85fa0d365547..84db3d14cd97 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -329,7 +329,7 @@ def __call__(self, ids: torch.Tensor): result = [] for i in range(len(self.axes_dims)): index = ids[:, i] - result.append(self.freqs_cis[i][index]) + result.append(torch.complex(self.freqs_cis[i].real[index], self.freqs_cis[i].imag[index])) return torch.cat(result, dim=-1)