diff --git a/diffsynth_engine/configs/__init__.py b/diffsynth_engine/configs/__init__.py index 9474d9c..7f6b2c6 100644 --- a/diffsynth_engine/configs/__init__.py +++ b/diffsynth_engine/configs/__init__.py @@ -1,7 +1,9 @@ from .base import PipelineConfig from .qwen_image import QwenImagePipelineConfig +from .z_image import ZImagePipelineConfig __all__ = [ "PipelineConfig", "QwenImagePipelineConfig", + "ZImagePipelineConfig", ] diff --git a/diffsynth_engine/configs/z_image.py b/diffsynth_engine/configs/z_image.py new file mode 100644 index 0000000..364cf09 --- /dev/null +++ b/diffsynth_engine/configs/z_image.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +from diffsynth_engine.configs.base import PipelineConfig + + +@dataclass +class ZImagePipelineConfig(PipelineConfig): + pass diff --git a/diffsynth_engine/models/z_image/__init__.py b/diffsynth_engine/models/z_image/__init__.py new file mode 100644 index 0000000..0c5097c --- /dev/null +++ b/diffsynth_engine/models/z_image/__init__.py @@ -0,0 +1,5 @@ +from .transformer_z_image import ZImageTransformer2DModel + +__all__ = [ + "ZImageTransformer2DModel", +] diff --git a/diffsynth_engine/models/z_image/transformer_z_image.py b/diffsynth_engine/models/z_image/transformer_z_image.py new file mode 100644 index 0000000..9f127de --- /dev/null +++ b/diffsynth_engine/models/z_image/transformer_z_image.py @@ -0,0 +1,1050 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_z_image.py + +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import RMSNorm +from torch.nn.utils.rnn import pad_sequence + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.forward_context import get_forward_context +from diffsynth_engine.layers.attention import USPAttention +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 + + +def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Apply rotary positional embeddings using complex number multiplication.""" + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + +def select_per_token(noisy_val, clean_val, noise_mask, seq_len): + """Select per-token values based on noise mask for omni mode modulation.""" + mask = noise_mask[:, :seq_len].unsqueeze(-1).float() + return noisy_val * mask + clean_val * (1.0 - mask) + + +def _linear_with_batch_padding(module: nn.Module, input_tensor: torch.Tensor, min_batch_size: int = 2) -> torch.Tensor: + """ + Run a Linear (or Sequential containing Linear) with batch padding to align CUDA kernel paths. + + When bsz=1, PyTorch 2.10.0's addmm routes through cublas::gemm (_badd kernel) + instead of cublasLt (_bias kernel) because mat1_sizes[0]==1 fails the + isInputCompliesAddmmCudaLt check. By padding to bsz>=2, we ensure the same + cublasLt path regardless of batch size, eliminating BF16 accumulation differences. + """ + if input_tensor.shape[0] < min_batch_size: + pad_count = min_batch_size - input_tensor.shape[0] + padded = torch.cat([input_tensor, input_tensor[-1:].expand(pad_count, *input_tensor.shape[1:])]) + result = module(padded) + return result[: input_tensor.shape[0]] + return module(input_tensor) + + +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations via sinusoidal frequency encoding + MLP.""" + + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZImageSingleStreamAttention(nn.Module): + """ + Single-stream self-attention for Z-Image, replacing diffusers Attention + ZSingleStreamAttnProcessor. + Uses USPAttention for the core attention computation. + """ + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + qk_norm: bool = True, + norm_eps: float = 1e-5, + ): + super().__init__() + self.heads = n_heads + self.head_dim = dim // n_heads + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(dim, dim, bias=False) + self.to_v = nn.Linear(dim, dim, bias=False) + + # Keep to_out as ModuleList to match diffusers weight names: attention.to_out.0.weight + self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)]) + + # QK normalization + if qk_norm: + self.norm_q = RMSNorm(self.head_dim, eps=norm_eps) + self.norm_k = RMSNorm(self.head_dim, eps=norm_eps) + else: + self.norm_q = None + self.norm_k = None + + # USPAttention for attention computation with sequence parallel support + forward_context = get_forward_context() + self.usp_attn = USPAttention( + num_heads=n_heads, + head_size=self.head_dim, + attn_type=forward_context.attn_type, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # [B, S, H*D] -> [B, S, H, D] + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + # Apply QK normalization + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Apply RoPE (complex form) + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # Expand 2D bool mask [B, S] to 4D [B, 1, 1, S] for SDPA + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # USPAttention: [B, S, H, D] -> [B, S, H, D] + hidden_states = self.usp_attn(query, key, value, attn_mask=attention_mask) + + # [B, S, H, D] -> [B, S, H*D] + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = self.to_out[0](hidden_states) + return output + + +class FeedForward(nn.Module): + """SwiGLU-style feed-forward network (w1/w2/w3 structure).""" + + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ZImageTransformerBlock(nn.Module): + """ + Single-stream transformer block with optional adaLN modulation. + Supports both global modulation (standard) and per-token modulation (omni mode). + """ + + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation: bool = True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + self.attention = ZImageSingleStreamAttention( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + qk_norm=qk_norm, + norm_eps=1e-5, + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, + ): + if self.modulation: + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens (omni mode) + mod_noisy = _linear_with_batch_padding(self.adaLN_modulation, adaln_noisy) + mod_clean = _linear_with_batch_padding(self.adaLN_modulation, adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens + mod = _linear_with_batch_padding(self.adaLN_modulation, adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # No modulation (used by context_refiner) + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + """Final output layer with adaLN modulation.""" + + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation (omni mode) + scale_noisy = 1.0 + _linear_with_batch_padding(self.adaLN_modulation, c_noisy) + scale_clean = 1.0 + _linear_with_batch_padding(self.adaLN_modulation, c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + _linear_with_batch_padding(self.adaLN_modulation, c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale + x = self.linear(x) + return x + + +class RopeEmbedder: + """3D Rotary Position Embedding using complex number representation.""" + + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for d, e in zip(dim, end): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) + freqs_cis.append(freqs_cis_i) + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageTransformer2DModel(DiffusionModel): + """Z-Image Transformer model adapted for DiffSynth-Engine.""" + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + siglip_feat_dim=None, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=(32, 48, 48), + axes_lens=(1024, 512, 512), + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + with torch.device("cpu"): + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + + # Pad tokens forced on CPU for meta device compatibility + with torch.device("cpu"): + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + pad_len = SEQ_MULTI_OF + cu_len += pad_len + else: + img_F, img_H, img_W = size[i][j] + ori_len = (img_F // pF) * (img_H // pH) * (img_W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(img_F // pF, img_H // pH, img_W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, img_F, img_H, img_W) + ) + cu_len += ori_len + pad_len + result.append(x_item) + return result + else: + for i in range(bsz): + img_F, img_H, img_W = size[i] + ori_len = (img_F // pF) * (img_H // pH) * (img_W // pW) + x[i] = ( + x[i][:ori_len] + .view(img_F // pF, img_H // pH, img_W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, img_F, img_H, img_W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, img_F, img_H, img_W = image.size() + F_tokens, H_tokens, W_tokens = img_F // pF, img_H // pH, img_W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (img_F, img_H, img_W), (F_tokens, H_tokens, W_tokens) + + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def patchify_and_embed( + self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int + ): + """Patchify for basic mode: single image per batch item.""" + device = all_image[0].device + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] + + for image, cap_feat in zip(all_image, all_cap_feats): + # Caption + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device + ) + all_cap_out.append(cap_out) + all_cap_pos_ids.append(cap_pos_ids) + all_cap_pad_mask.append(cap_pad_mask) + + # Image + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device + ) + all_img_out.append(img_out) + all_img_size.append(size) + all_img_pos_ids.append(img_pos_ids) + all_img_pad_mask.append(img_pad_mask) + + return ( + all_img_out, + all_cap_out, + all_img_size, + all_img_pos_ids, + all_cap_pos_ids, + all_img_pad_mask, + all_cap_pad_mask, + ) + + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]], + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] + + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] + + return ( + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_sig_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, + ) + + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Apply pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask (not used by USPAttention, kept for interface compatibility) + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """ + Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask (kept for interface compatibility, not used by USPAttention) + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def forward( + self, + x: Union[List[torch.Tensor], List[List[torch.Tensor]]], + t, + cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]], + return_dict: bool = True, + controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, + siglip_feats: Optional[List[List[torch.Tensor]]] = None, + image_noise_mask: Optional[List[List[int]]] = None, + patch_size: int = 2, + f_patch_size: int = 1, + ): + """ + Forward pass of the Z-Image Transformer. + + Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine + -> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify + """ + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device + + if omni_mode: + t_noisy = _linear_with_batch_padding(self.t_embedder, t * self.t_scale).type_as(x[0][-1]) + t_clean = _linear_with_batch_padding(self.t_embedder, torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + adaln_input = _linear_with_batch_padding(self.t_embedder, t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None + + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None + + # X embed & refine + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) + + for layer in self.noise_refiner: + x = layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean) + + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) + + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs) + + # Siglip embed & refine (omni mode only) + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) + + for layer in self.siglip_refiner: + siglip_feats = layer(siglip_feats, siglip_mask, siglip_freqs) + + # Build unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Sequence parallel shard before main transformer layers + original_unified_seq_len = unified.shape[1] + (unified, unified_freqs, unified_noise_tensor) = sequence_parallel_shard( + (unified, unified_freqs, unified_noise_tensor), seq_dims=(1, 1, 1) + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): + unified = layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean) + if controlnet_block_samples is not None and layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + + # Sequence parallel unshard after main transformer layers + (unified,) = sequence_parallel_unshard((unified,), seq_dims=(1,), seq_lens=(original_unified_seq_len,)) + + # Final layer + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) + + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return (x,) if not return_dict else Transformer2DModelOutput(sample=x) diff --git a/diffsynth_engine/pipelines/z_image/__init__.py b/diffsynth_engine/pipelines/z_image/__init__.py new file mode 100644 index 0000000..4171e75 --- /dev/null +++ b/diffsynth_engine/pipelines/z_image/__init__.py @@ -0,0 +1,5 @@ +from .pipeline_z_image import ZImagePipeline + +__all__ = [ + "ZImagePipeline", +] diff --git a/diffsynth_engine/pipelines/z_image/pipeline_z_image.py b/diffsynth_engine/pipelines/z_image/pipeline_z_image.py new file mode 100644 index 0000000..e499bc0 --- /dev/null +++ b/diffsynth_engine/pipelines/z_image/pipeline_z_image.py @@ -0,0 +1,771 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/z_image/pipeline_z_image.py + +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from accelerate import init_empty_weights +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + +from diffsynth_engine.configs.z_image import ZImagePipelineConfig +from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.layers.attention import get_attn_backend +from diffsynth_engine.models.z_image import ZImageTransformer2DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights + +logger = logging.get_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImagePipeline(Pipeline): + r""" + The Z-Image pipeline for text-to-image generation, adapted for DiffSynth-Engine. + + Changes from the original diffusers implementation: + - Inherits from Pipeline (DiffSynth-Engine) instead of DiffusionPipeline + - Removed ZImageLoraLoaderMixin (LoRA loading support) + - Removed FromSingleFileMixin (single-file model loading) + - Removed register_modules() — components are assigned directly + - Removed model_cpu_offload_seq — CPU offload sequence declaration (DiffusionPipeline feature) + - Removed _execution_device property — replaced with self.device + - Removed maybe_free_model_hooks() — model offload hooks (DiffusionPipeline feature) + - Removed replace_example_docstring decorator + - Reimplemented from_pretrained as classmethod with model_path_or_config pattern + - Added set_forward_context for transformer initialization and inference + - Added _build_attn_metadata for attention metadata construction + - Added _predict_noise_with_cfg for CFG-parallel denoising support + - Added attn_backend initialization for DiffSynth-Engine attention system + + Args: + pipeline_config (`ZImagePipelineConfig`): + Configuration for the pipeline. + scheduler (`FlowMatchEulerDiscreteScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder (`PreTrainedModel`): + Text encoder model for encoding prompts into embeddings. + tokenizer (`AutoTokenizer`): + Tokenizer for the text encoder. + transformer (`ZImageTransformer2DModel`): + Conditional Transformer architecture to denoise the encoded image latents. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + pipeline_config: ZImagePipelineConfig, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__(pipeline_config) + + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.transformer = transformer + self.scheduler = scheduler + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if self.vae is not None else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + head_dim = transformer.config.dim // transformer.config.n_heads + self.attn_backend = get_attn_backend( + head_size=head_dim, + attn_type=pipeline_config.attn_type, + ) + + @classmethod + def from_pretrained(cls, model_path_or_config: str | ZImagePipelineConfig): + """ + Load a ZImagePipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a ZImagePipelineConfig instance. + + Returns: + ZImagePipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = ZImagePipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + # Load transformer + transformer = cls.init_transformer(pipeline_config) + + # Load scheduler + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(pipeline_config) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + # Initialize pipeline + return cls( + pipeline_config=pipeline_config, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + + @staticmethod + def init_transformer(pipeline_config: ZImagePipelineConfig, empty_weights: bool = False): + logger.info("Initializing transformer...") + with set_forward_context(attn_type=pipeline_config.attn_type): + if empty_weights: + with init_empty_weights(): + config_dict = ZImageTransformer2DModel.load_config( + pipeline_config.model_path, + subfolder="transformer", + local_files_only=True, + ) + model = ZImageTransformer2DModel.from_config(config_dict) + else: + model = ZImageTransformer2DModel.from_pretrained( + pipeline_config.model_path, + subfolder="transformer", + device=pipeline_config.device, + dtype=pipeline_config.model_dtype, + ) + return model + + @staticmethod + def init_text_encoder(pipeline_config: ZImagePipelineConfig, empty_weights: bool = False): + logger.info("Initializing text encoder...") + with init_empty_weights(): + config = AutoConfig.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + model = AutoModelForCausalLM.from_config(config) + + if empty_weights: + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="text_encoder", + device=pipeline_config.device, + dtype=pipeline_config.text_encoder_dtype, + ) + if key_mapping := getattr(model, "_checkpoint_conversion_mapping", None): + state_dict = fix_state_dict_key(state_dict, key_mapping) + model.load_state_dict(state_dict, strict=False, assign=True) + model.tie_weights() + model.to(device=pipeline_config.device) + return model + + @staticmethod + def init_vae(pipeline_config: ZImagePipelineConfig, empty_weights: bool = False): + logger.info("Initializing VAE...") + with init_empty_weights(): + config_dict = AutoencoderKL.load_config( + pipeline_config.model_path, + subfolder="vae", + local_files_only=True, + ) + model = AutoencoderKL.from_config(config_dict) + + if empty_weights: + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="vae", + device=pipeline_config.device, + dtype=pipeline_config.vae_dtype, + ) + model.load_state_dict(state_dict, strict=True, assign=True) + model.to(device=pipeline_config.device) + return model + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self.device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: List[torch.FloatTensor], + negative_prompt_embeds: List[torch.FloatTensor], + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + cfg_normalization: bool, + use_cfg_parallel: bool, + actual_batch_size: int, + ): + """ + Predict noise with optional classifier-free guidance and CFG parallelism. + + Args: + latents: Current noisy latents, shape (batch, channels, height, width). + timestep: Current timestep tensor, shape (batch,). + prompt_embeds: List of positive prompt embeddings (variable length per item). + negative_prompt_embeds: List of negative prompt embeddings (variable length per item). + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + cfg_normalization: Whether to apply CFG renormalization. + use_cfg_parallel: Whether to use CFG parallelism across devices. + actual_batch_size: The actual batch size (batch_size * num_images_per_prompt). + + Returns: + noise_pred: The predicted noise tensor. + """ + if not apply_cfg: + # No CFG: single forward pass + latent_model_input = latents.to(self.pipeline_config.model_dtype) + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + with set_forward_context(attn_metadata=attn_metadata): + model_out_list = self.transformer(latent_model_input_list, timestep, prompt_embeds, return_dict=False)[ + 0 + ] + + noise_pred = torch.stack([tensor.float() for tensor in model_out_list], dim=0) + return noise_pred + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not model_parallel_is_initialized(): + raise RuntimeError("Model parallel groups must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + latents_typed = latents.to(self.pipeline_config.model_dtype) + latents_typed = latents_typed.unsqueeze(2) + + pos_out_list = [torch.zeros_like(latents_typed[0])] * actual_batch_size + neg_out_list = [torch.zeros_like(latents_typed[0])] * actual_batch_size + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + latent_model_input_list = list(latents_typed.unbind(dim=0)) + with set_forward_context(attn_metadata=attn_metadata): + pos_out_list = self.transformer(latent_model_input_list, timestep, prompt_embeds, return_dict=False)[0] + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + latent_model_input_list = list(latents_typed.unbind(dim=0)) + with set_forward_context(attn_metadata=attn_metadata): + neg_out_list = self.transformer( + latent_model_input_list, timestep, negative_prompt_embeds, return_dict=False + )[0] + + # All-reduce for CFG parallel + pos_out = torch.stack([tensor.float() for tensor in pos_out_list], dim=0) + neg_out = torch.stack([tensor.float() for tensor in neg_out_list], dim=0) + + if use_cfg_parallel: + pos_out = cfg_group.all_reduce(pos_out) + neg_out = cfg_group.all_reduce(neg_out) + + # Apply CFG + noise_pred_list = [] + for j in range(actual_batch_size): + pos = pos_out[j] + neg = neg_out[j] + pred = pos + guidance_scale * (pos - neg) + + # Renormalization + if cfg_normalization: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + if new_pos_norm > ori_pos_norm: + pred = pred * (ori_pos_norm / new_pos_norm) + + noise_pred_list.append(pred) + + noise_pred = torch.stack(noise_pred_list, dim=0) + return noise_pred + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale for classifier-free guidance. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply CFG renormalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + Time-aware truncation for CFG. When normalized time exceeds this value, CFG is disabled. + negative_prompt (`str` or `List[str]`, *optional*): + The negative prompt or prompts. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + Random generator(s) for deterministic generation. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `ZImagePipelineOutput` instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + Tensor inputs for the callback function. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for the prompt. + + Returns: + `ZImagePipelineOutput` or `tuple`: Generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self.device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # Encode prompts + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware CFG truncation (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle CFG truncation + current_guidance_scale = self.guidance_scale + if self.do_classifier_free_guidance and cfg_truncation is not None and float(cfg_truncation) <= 1: + if t_norm > cfg_truncation: + current_guidance_scale = 0.0 + + # Determine whether to apply CFG this step + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + attn_metadata=attn_metadata, + apply_cfg=apply_cfg, + guidance_scale=current_guidance_scale, + cfg_normalization=cfg_normalization, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + actual_batch_size=actual_batch_size, + ) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/examples/z_image/z_image.py b/examples/z_image/z_image.py new file mode 100644 index 0000000..16649e7 --- /dev/null +++ b/examples/z_image/z_image.py @@ -0,0 +1,21 @@ +import torch + +from diffsynth_engine.pipelines.z_image import ZImagePipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Tongyi-MAI/Z-Image") + pipe = ZImagePipeline.from_pretrained(model_path_or_config=model_path) + prompt = "两名年轻亚裔女性紧密站在一起,背景为朴素的灰色纹理墙面,可能是室内地毯地面。左侧女性留着长卷发,身穿藏青色毛衣,左袖有奶油色褶皱装饰,内搭白色立领衬衫,下身白色裤子;佩戴小巧金色耳钉,双臂交叉于背后。右侧女性留直肩长发,身穿奶油色卫衣,胸前印有“Tun the tables”字样,下方为“New ideas”,搭配白色裤子;佩戴银色小环耳环,双臂交叉于胸前。两人均面带微笑直视镜头。照片,自然光照明,柔和阴影,以藏青、奶油白为主的中性色调,休闲时尚摄影,中等景深,面部和上半身对焦清晰,姿态放松,表情友好,室内环境,地毯地面,纯色背景。" + negative_prompt = "" + image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=1280, + width=720, + cfg_normalization=False, + num_inference_steps=50, + guidance_scale=4, + generator=torch.Generator("cuda").manual_seed(42), + ).images[0] + image.save("z_image_example.png") diff --git a/tests/data/expect/qwen_image/qwen_image.png b/tests/data/expect/qwen_image/qwen_image.png index c414ab5..7d271af 100644 Binary files a/tests/data/expect/qwen_image/qwen_image.png and b/tests/data/expect/qwen_image/qwen_image.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_edit.png b/tests/data/expect/qwen_image/qwen_image_edit.png index c99ac34..2719f41 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_edit.png and b/tests/data/expect/qwen_image/qwen_image_edit.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2509.png b/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2509.png index 33b24d5..9a62851 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2509.png and b/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2509.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2511.png b/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2511.png index 495a440..88b75b8 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2511.png and b/tests/data/expect/qwen_image/qwen_image_edit_plus_multi_2511.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2509.png b/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2509.png index 4c4e818..e01eb0a 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2509.png and b/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2509.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2511.png b/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2511.png index a4332b4..e0da9c8 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2511.png and b/tests/data/expect/qwen_image/qwen_image_edit_plus_single_2511.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_layered_0.png b/tests/data/expect/qwen_image/qwen_image_layered_0.png index e807b24..4357af8 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_layered_0.png and b/tests/data/expect/qwen_image/qwen_image_layered_0.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_layered_1.png b/tests/data/expect/qwen_image/qwen_image_layered_1.png index 82d31a0..b1968f8 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_layered_1.png and b/tests/data/expect/qwen_image/qwen_image_layered_1.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_layered_2.png b/tests/data/expect/qwen_image/qwen_image_layered_2.png index 35cb34a..5251a7b 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_layered_2.png and b/tests/data/expect/qwen_image/qwen_image_layered_2.png differ diff --git a/tests/data/expect/qwen_image/qwen_image_layered_3.png b/tests/data/expect/qwen_image/qwen_image_layered_3.png index d7b0902..e3c46cb 100644 Binary files a/tests/data/expect/qwen_image/qwen_image_layered_3.png and b/tests/data/expect/qwen_image/qwen_image_layered_3.png differ diff --git a/tests/data/expect/z_image/z_image.png b/tests/data/expect/z_image/z_image.png new file mode 100644 index 0000000..62abe3b Binary files /dev/null and b/tests/data/expect/z_image/z_image.png differ diff --git a/tests/data/expect/z_image/z_image_turbo.png b/tests/data/expect/z_image/z_image_turbo.png new file mode 100644 index 0000000..1f48afd Binary files /dev/null and b/tests/data/expect/z_image/z_image_turbo.png differ diff --git a/tests/test_pipelines/test_qwen_image_edit_plus_2509.py b/tests/test_pipelines/test_qwen_image_edit_plus_2509.py index df451d0..03f98a4 100644 --- a/tests/test_pipelines/test_qwen_image_edit_plus_2509.py +++ b/tests/test_pipelines/test_qwen_image_edit_plus_2509.py @@ -52,7 +52,7 @@ def test_multi_image_edit(self): generator=torch.Generator(device="cpu").manual_seed(42), ) image = output.images[0] - self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_edit_plus_multi_2509.png", threshold=0.98) + self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_edit_plus_multi_2509.png", threshold=0.99) if __name__ == "__main__": diff --git a/tests/test_pipelines/test_z_image.py b/tests/test_pipelines/test_z_image.py new file mode 100644 index 0000000..03e2c25 --- /dev/null +++ b/tests/test_pipelines/test_z_image.py @@ -0,0 +1,38 @@ +import unittest + +import torch + +from diffsynth_engine.pipelines.z_image import ZImagePipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import ImageTestCase + + +class TestZImagePipeline(ImageTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Tongyi-MAI/Z-Image") + cls.pipe = ZImagePipeline.from_pretrained(model_path_or_config=model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_txt2img(self): + prompt = "两名年轻亚裔女性紧密站在一起,背景为朴素的灰色纹理墙面,可能是室内地毯地面。左侧女性留着长卷发,身穿藏青色毛衣,左袖有奶油色褶皱装饰,内搭白色立领衬衫,下身白色裤子;佩戴小巧金色耳钉,双臂交叉于背后。右侧女性留直肩长发,身穿奶油色卫衣,胸前印有“Tun the tables”字样,下方为“New ideas”,搭配白色裤子;佩戴银色小环耳环,双臂交叉于胸前。两人均面带微笑直视镜头。照片,自然光照明,柔和阴影,以藏青、奶油白为主的中性色调,休闲时尚摄影,中等景深,面部和上半身对焦清晰,姿态放松,表情友好,室内环境,地毯地面,纯色背景。" + negative_prompt = "" + output = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=1280, + width=720, + cfg_normalization=False, + num_inference_steps=50, + guidance_scale=4, + generator=torch.Generator("cuda").manual_seed(42), + ) + image = output.images[0] + self.assertImageEqualAndSaveFailed(image, "z_image/z_image.png", threshold=0.99) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_z_image_turbo.py b/tests/test_pipelines/test_z_image_turbo.py new file mode 100644 index 0000000..69f70fe --- /dev/null +++ b/tests/test_pipelines/test_z_image_turbo.py @@ -0,0 +1,38 @@ +import unittest + +import torch + +from diffsynth_engine.pipelines.z_image import ZImagePipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import ImageTestCase + + +class TestZImagePipeline(ImageTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Tongyi-MAI/Z-Image-Turbo") + cls.pipe = ZImagePipeline.from_pretrained(model_path_or_config=model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_txt2img(self): + prompt = "两名年轻亚裔女性紧密站在一起,背景为朴素的灰色纹理墙面,可能是室内地毯地面。左侧女性留着长卷发,身穿藏青色毛衣,左袖有奶油色褶皱装饰,内搭白色立领衬衫,下身白色裤子;佩戴小巧金色耳钉,双臂交叉于背后。右侧女性留直肩长发,身穿奶油色卫衣,胸前印有“Tun the tables”字样,下方为“New ideas”,搭配白色裤子;佩戴银色小环耳环,双臂交叉于胸前。两人均面带微笑直视镜头。照片,自然光照明,柔和阴影,以藏青、奶油白为主的中性色调,休闲时尚摄影,中等景深,面部和上半身对焦清晰,姿态放松,表情友好,室内环境,地毯地面,纯色背景。" + negative_prompt = "" + output = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=1280, + width=720, + cfg_normalization=False, + num_inference_steps=50, + guidance_scale=4, + generator=torch.Generator("cuda").manual_seed(42), + ) + image = output.images[0] + self.assertImageEqualAndSaveFailed(image, "z_image/z_image_turbo.png", threshold=0.96) + + +if __name__ == "__main__": + unittest.main()