From 1ede365aa4b3a6ddbeb6a38a2e004b053346fbcc Mon Sep 17 00:00:00 2001 From: "qijin.qq" Date: Wed, 6 May 2026 17:40:23 +0800 Subject: [PATCH] [UniLLaDA] Add UniLLaDA multimodal discrete diffusion pipeline Add UniLLaDA pipeline supporting text-to-image, image understanding, and image editing via block-wise iterative discrete diffusion. New components: - UniLLaDaPipeline: main pipeline (DiffusionPipeline subclass) - LLaDA2UniImageTransformer2DModel: image transformer model - LLaDA2UniFlowMatchEulerScheduler: flow matching scheduler - ImageTokenizer: VQ image encoder helper - Documentation and tests --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/unillada.md | 72 ++ src/diffusers/__init__.py | 8 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformer_llada2uni_image.py | 1082 +++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/unillada/__init__.py | 48 + .../pipelines/unillada/image_tokenizer.py | 617 ++++++++++ .../pipelines/unillada/pipeline_output.py | 38 + .../pipelines/unillada/pipeline_unillada.py | 508 ++++++++ src/diffusers/pipelines/unillada/sigvq.py | 56 + src/diffusers/pipelines/unillada/utils.py | 94 ++ src/diffusers/schedulers/__init__.py | 2 + .../scheduling_llada2uni_flow_match_euler.py | 236 ++++ .../dummy_torch_and_transformers_objects.py | 30 + tests/pipelines/unillada/__init__.py | 0 tests/pipelines/unillada/test_unillada.py | 159 +++ 18 files changed, 2957 insertions(+) create mode 100644 docs/source/en/api/pipelines/unillada.md create mode 100644 src/diffusers/models/transformers/transformer_llada2uni_image.py create mode 100644 src/diffusers/pipelines/unillada/__init__.py create mode 100644 src/diffusers/pipelines/unillada/image_tokenizer.py create mode 100644 src/diffusers/pipelines/unillada/pipeline_output.py create mode 100644 src/diffusers/pipelines/unillada/pipeline_unillada.py create mode 100644 src/diffusers/pipelines/unillada/sigvq.py create mode 100644 src/diffusers/pipelines/unillada/utils.py create mode 100644 src/diffusers/schedulers/scheduling_llada2uni_flow_match_euler.py create mode 100644 tests/pipelines/unillada/__init__.py create mode 100644 tests/pipelines/unillada/test_unillada.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8e8776d4a8c2..2c92a6ee004a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -646,6 +646,8 @@ title: VisualCloze - local: api/pipelines/z_image title: Z-Image + - local: api/pipelines/unillada + title: UniLLaDA title: Image - sections: - local: api/pipelines/llada2 diff --git a/docs/source/en/api/pipelines/unillada.md b/docs/source/en/api/pipelines/unillada.md new file mode 100644 index 000000000000..91e8d5d7961d --- /dev/null +++ b/docs/source/en/api/pipelines/unillada.md @@ -0,0 +1,72 @@ + + +# UniLLaDA + +[UniLLaDA](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) is a unified discrete diffusion language model that supports +text-to-image generation, image understanding, and image editing through block-wise iterative refinement. It extends +the [LLaDA2](./llada2) framework with multimodal capabilities. + +## Usage + +UniLLaDA supports three modes: +- **Text-to-Image**: Generate images from text prompts. +- **Image Understanding**: Answer questions about images. +- **Image Editing**: Edit images based on text instructions. + +### Text-to-Image + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import BlockRefinementScheduler, UniLLaDaPipeline + +model_id = "inclusionAI/LLaDA2.0-Uni" +model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +scheduler = BlockRefinementScheduler() + +pipe = UniLLaDaPipeline(transformer=model, tokenizer=tokenizer, scheduler=scheduler) + +result = pipe(prompt="A cat sitting on a windowsill at sunset") +result.images[0].save("output.png") +``` + +### Image Understanding + +```py +from PIL import Image + +img = Image.open("photo.jpg") +result = pipe(image=img, question="Describe this image in detail.") +print(result.text) +``` + +### Image Editing + +```py +result = pipe(image=img, instruction="Change the background to a beach.") +result.images[0].save("edited.png") +``` + +## UniLLaDaPipeline + +[[autodoc]] UniLLaDaPipeline + - all + - __call__ + +## UniLLaDaPipelineOutput + +[[autodoc]] pipelines.UniLLaDaPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0c6083cafd0a..cf89c5602fe9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -303,6 +303,7 @@ "WanAnimateTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "LLaDA2UniImageTransformer2DModel", "ZImageControlNetModel", "ZImageTransformer2DModel", "attention_backend", @@ -380,6 +381,7 @@ "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", + "LLaDA2UniFlowMatchEulerScheduler", "HeliosDMDScheduler", "HeliosScheduler", "HeunDiscreteScheduler", @@ -611,6 +613,8 @@ "LEditsPPPipelineStableDiffusionXL", "LLaDA2Pipeline", "LLaDA2PipelineOutput", + "UniLLaDaPipeline", + "UniLLaDaPipelineOutput", "LongCatAudioDiTPipeline", "LongCatImageEditPipeline", "LongCatImagePipeline", @@ -1072,6 +1076,7 @@ Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LLaDA2UniImageTransformer2DModel, LongCatAudioDiTTransformer, LongCatAudioDiTVae, LongCatImageTransformer2DModel, @@ -1201,6 +1206,7 @@ KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LCMScheduler, + LLaDA2UniFlowMatchEulerScheduler, LTXEulerAncestralRFScheduler, PNDMScheduler, RePaintScheduler, @@ -1529,6 +1535,8 @@ UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder, + UniLLaDaPipeline, + UniLLaDaPipelineOutput, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc772fcc6d0c..ad5b9e4f5182 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -132,6 +132,7 @@ _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] + _import_structure["transformers.transformer_llada2uni_image"] = ["LLaDA2UniImageTransformer2DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -238,6 +239,7 @@ HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LLaDA2UniImageTransformer2DModel, LongCatAudioDiTTransformer, LongCatImageTransformer2DModel, LTX2VideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index bbd7ecfa911b..4c9cacb0b3ae 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -37,6 +37,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel + from .transformer_llada2uni_image import LLaDA2UniImageTransformer2DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_llada2uni_image.py b/src/diffusers/models/transformers/transformer_llada2uni_image.py new file mode 100644 index 000000000000..5c751cf10e47 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_llada2uni_image.py @@ -0,0 +1,1082 @@ +# Copyright 2025 Ant Group, Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This is a modified version of transformer_z_image.py adapted for UniLLaDA. +# Key changes: +# - Optional flash_attn and dispatch_attention_fn imports +# - Modified attention mask handling for flash_attn compatibility +# - Renamed cap_embedder to semantic_embedder +# - Always create attention masks (no None case) + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.modeling_utils import ModelMixin +from ...models.normalization import RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph + + +try: + from ..attention_dispatch import dispatch_attention_fn + + _HAS_DISPATCH_ATTENTION = True +except ImportError: + _HAS_DISPATCH_ATTENTION = False + +try: + from flash_attn import flash_attn_func + + _HAS_FLASH_ATTN = True +except ImportError: + _HAS_FLASH_ATTN = False + +from ..modeling_outputs import Transformer2DModelOutput + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + freqs_cis: torch.Tensor | None = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to appropriate mask format + if attention_mask is not None and attention_mask.ndim == 2: + if _HAS_FLASH_ATTN: + # flash_attn: mask out inputs directly (matches original implementation) + mask_expanded = attention_mask.unsqueeze(-1).unsqueeze(-1) # (B, S, 1, 1) + query = query * mask_expanded + key = key * mask_expanded + value = value * mask_expanded + elif _HAS_DISPATCH_ATTENTION: + # dispatch_attention_fn expects 4D mask: [batch, 1, 1, seq_len] + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + if _HAS_FLASH_ATTN: + hidden_states = flash_attn_func( + query, + key, + value, + dropout_p=0.0, + causal=False, + ) + elif _HAS_DISPATCH_ATTENTION: + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + else: + raise RuntimeError("Either flash_attn or dispatch_attention_fn is required.") + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +@maybe_allow_in_graph +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: torch.Tensor | None = None, + noise_mask: torch.Tensor | None = None, + adaln_noisy: torch.Tensor | None = None, + adaln_clean: torch.Tensor | None = None, + ): + if self.modulation: + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: list[int] = (16, 56, 56), + axes_lens: list[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class LLaDA2UniImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock"] + _repeated_blocks = ["ZImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["t_embedder", "semantic_embedder"] # precision sensitive layers + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.semantic_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True) + ) + + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify( + self, + x: list[torch.Tensor], + size: list[tuple], + patch_size, + f_patch_size, + x_pos_offsets: list[tuple[int, int]] | None = None, + ) -> list[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: tuple, + pos_start: tuple, + device: torch.device, + noise_mask_val: int | None = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def patchify_and_embed( + self, all_image: list[torch.Tensor], all_cap_feats: list[torch.Tensor], patch_size: int, f_patch_size: int + ): + """Patchify for basic mode: single image per batch item.""" + device = all_image[0].device + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] + + for image, cap_feat in zip(all_image, all_cap_feats): + # Caption + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device + ) + all_cap_out.append(cap_out) + all_cap_pos_ids.append(cap_pos_ids) + all_cap_pad_mask.append(cap_pad_mask) + + # Image + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device + ) + all_img_out.append(img_out) + all_img_size.append(size) + all_img_pos_ids.append(img_pos_ids) + all_img_pad_mask.append(img_pad_mask) + + return ( + all_img_out, + all_cap_out, + all_img_size, + all_img_pos_ids, + all_cap_pos_ids, + all_img_pad_mask, + all_cap_pad_mask, + ) + + def patchify_and_embed_omni( + self, + all_x: list[list[torch.Tensor]], + all_cap_feats: list[list[torch.Tensor]], + all_siglip_feats: list[list[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: list[list[int]], + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] + + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] + + return ( + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_sig_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, + ) + + def _prepare_sequence( + self, + feats: list[torch.Tensor], + pos_ids: list[torch.Tensor], + inner_pad_mask: list[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: list[list[int]] | None = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Pad token — use index assignment to match original implementation exactly + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask — always create to match original behavior + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: list[int], + x_noise_mask: list[list[int]] | None, + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: list[int], + cap_noise_mask: list[list[int]] | None, + siglip: torch.Tensor | None, + siglip_freqs: torch.Tensor | None, + siglip_seqlens: list[int] | None, + siglip_noise_mask: list[list[int]] | None, + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask — always create to match original behavior + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def forward( + self, + x: list[torch.Tensor, list[list[torch.Tensor]]], + t, + cap_feats: list[torch.Tensor, list[list[torch.Tensor]]], + return_dict: bool = True, + controlnet_block_samples: dict[int, torch.Tensor] | None = None, + siglip_feats: list[list[torch.Tensor]] | None = None, + image_noise_mask: list[list[int]] | None = None, + patch_size: int = 2, + f_patch_size: int = 1, + ): + """ + Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine + -> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify + """ + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device + + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None + + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None + + # X embed & refine + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) + + for layer in self.noise_refiner: + x = ( + self._gradient_checkpointing_func( + layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean) + ) + + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.semantic_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) + + for layer in self.context_refiner: + cap_feats = ( + self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(cap_feats, cap_mask, cap_freqs) + ) + + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) + + for layer in self.siglip_refiner: + siglip_feats = ( + self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(siglip_feats, siglip_mask, siglip_freqs) + ) + + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): + unified = ( + self._gradient_checkpointing_func( + layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean) + ) + if controlnet_block_samples is not None and layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) + + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return (x,) if not return_dict else Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c49ad3938cdc..3c4eb261f720 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -320,6 +320,7 @@ ) _import_structure["latte"] = ["LattePipeline"] _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] + _import_structure["unillada"] = ["UniLLaDaPipeline", "UniLLaDaPipelineOutput"] _import_structure["ltx"] = [ "LTXPipeline", "LTXImageToVideoPipeline", @@ -868,6 +869,7 @@ StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline, ) + from .unillada import UniLLaDaPipeline, UniLLaDaPipelineOutput from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import ( WanAnimatePipeline, diff --git a/src/diffusers/pipelines/unillada/__init__.py b/src/diffusers/pipelines/unillada/__init__.py new file mode 100644 index 000000000000..03096ea588af --- /dev/null +++ b/src/diffusers/pipelines/unillada/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {"pipeline_output": ["UniLLaDaPipelineOutput"]} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_unillada"] = ["UniLLaDaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import UniLLaDaPipelineOutput + from .pipeline_unillada import UniLLaDaPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/unillada/image_tokenizer.py b/src/diffusers/pipelines/unillada/image_tokenizer.py new file mode 100644 index 000000000000..f5b78917cf44 --- /dev/null +++ b/src/diffusers/pipelines/unillada/image_tokenizer.py @@ -0,0 +1,617 @@ +# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image tokenizer for UniLLaDA. +Converts PIL images into discrete VQ token IDs via a vision encoder + VQVAE. +""" + +import json +from pathlib import Path +from types import SimpleNamespace + +import PIL.Image +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms.v2 import functional as tvF + + +# ============================================================ +# Config loading +# ============================================================ + + +def load_configs(model_dir: str | Path) -> dict: + with open(Path(model_dir) / "config.json", "r") as f: + return json.load(f) + + +def make_vision_config(raw: dict) -> SimpleNamespace: + vc = raw.get("vision_config", raw) + # Determine best attention implementation + attn_impl = "eager" + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + + attn_impl = "flash_attention_2" + except ImportError: + try: + import torch + + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + attn_impl = "sdpa" + except Exception: + pass + return SimpleNamespace( + hidden_size=vc["hidden_size"], + intermediate_size=vc["intermediate_size"], + num_heads=vc["num_heads"], + depth=vc["depth"], + patch_size=vc["patch_size"], + image_size=vc["image_size"], + in_channels=vc.get("in_channels", 3), + hidden_act=vc.get("hidden_act", "gelu"), + attention_bias=vc.get("attention_bias", True), + attention_dropout=vc.get("attention_dropout", 0.0), + layer_norm_eps=vc.get("layer_norm_eps", 1e-6), + spatial_merge_size=vc.get("spatial_merge_size", 1), + _attn_implementation=attn_impl, + ) + + +def make_vq_config(raw: dict) -> SimpleNamespace: + vq = raw.get("vq_config", raw) + return SimpleNamespace( + num_embeddings=vq["num_embeddings"], + embed_dim=vq["embed_dim"], + latent_channels=vq["latent_channels"], + beta=vq.get("beta", 0.25), + ) + + +# ============================================================ +# Image preprocessing +# ============================================================ + +OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] + + +class ImagePreprocessor: + """Image preprocessor: rescale + normalize. Resizing/cropping is handled externally.""" + + def __init__(self, config_path: str | Path): + config_path = Path(config_path) + if config_path.is_dir(): + config_path = config_path / "preprocessor_config.json" + with open(config_path, "r") as f: + config = json.load(f) + + self.do_rescale = config.get("do_rescale", True) + self.do_normalize = config.get("do_normalize", True) + self.rescale_factor = config.get("rescale_factor", 1.0 / 255.0) + self.image_mean = config.get("image_mean", OPENAI_CLIP_MEAN) + self.image_std = config.get("image_std", OPENAI_CLIP_STD) + self.patch_size = config.get("patch_size", 14) + self.temporal_patch_size = config.get("temporal_patch_size", 2) + self.merge_size = config.get("merge_size", 2) + self.factor = self.patch_size * self.merge_size + + def _pil_to_tensor(self, image): + return tvF.to_dtype(tvF.to_image(image), dtype=torch.float32, scale=False) + + def _rescale_and_normalize(self, images): + if self.do_rescale: + images = images * self.rescale_factor + if self.do_normalize: + mean = torch.tensor(self.image_mean, dtype=images.dtype, device=images.device).view(-1, 1, 1) + std = torch.tensor(self.image_std, dtype=images.dtype, device=images.device).view(-1, 1, 1) + images = (images - mean) / std + return images + + def __call__(self, images): + if isinstance(images, PIL.Image.Image): + images = [images] + + all_patches, all_grids = [], [] + for img in images: + if img.mode != "RGB": + img = img.convert("RGB") + img_tensor = self._pil_to_tensor(img) + height, width = img_tensor.shape[-2:] + rh, rw = height, width + + patches = self._rescale_and_normalize(img_tensor) + if patches.ndim == 3: + patches = patches.unsqueeze(0) + + # Temporal padding + if patches.shape[0] % self.temporal_patch_size != 0: + repeats = patches[-1:].repeat(self.temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=0) + + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h = rh // self.patch_size + grid_w = rw // self.patch_size + channel = patches.shape[1] + + # Reshape into patch tokens + patches = patches.unsqueeze(0).view( + 1, + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten_patches = patches.reshape( + 1, + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ) + all_patches.append(flatten_patches.squeeze(0)) + all_grids.append([grid_t, grid_h, grid_w]) + + return { + "pixel_values": torch.cat(all_patches, dim=0), + "image_grid_thw": torch.tensor(all_grids, dtype=torch.long), + } + + +# ============================================================ +# Vision model components +# ============================================================ + + +def _get_act_fn(name): + mapping = { + "gelu": nn.GELU(), + "relu": nn.ReLU(), + "silu": nn.SiLU(), + "quick_gelu": lambda x: x * torch.sigmoid(1.702 * x), + } + if name in mapping: + return mapping[name] + from transformers.activations import ACT2FN + + return ACT2FN[name] + + +class VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.activation_fn = _get_act_fn(config.hidden_act) + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, x): + return self.fc2(self.activation_fn(self.fc1(x))) + + +class VisionAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = config.attention_dropout + self.is_causal = False + + def forward(self, hidden_states, cu_seqlens, **kwargs): + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + # Try to use the HF attention dispatch (flash_attention_2 / sdpa / eager) + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + attn_impl = getattr(self.config, "_attn_implementation", "eager") + if attn_impl != "eager" and attn_impl in ALL_ATTENTION_FUNCTIONS: + attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] + if "flash" in attn_impl: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(t, lengths.tolist(), dim=2) for t in (query_states, key_states, value_states) + ] + attn_output = torch.cat( + [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ], + dim=1, + ) + attn_output = attn_output.reshape(seq_length, -1).contiguous() + return self.proj(attn_output) + except (ImportError, KeyError, AttributeError): + pass + + # Fallback: try flash_attn directly + try: + from flash_attn import flash_attn_varlen_func + + q = query_states.squeeze(0).transpose(0, 1) + k = key_states.squeeze(0).transpose(0, 1) + v = value_states.squeeze(0).transpose(0, 1) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) + attn_output = attn_output.reshape(seq_length, -1).contiguous() + return self.proj(attn_output) + except ImportError: + pass + + # Final fallback: manual eager attention + q = query_states.squeeze(0) + k = key_states.squeeze(0) + v = value_states.squeeze(0) + lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + outputs = [] + for qc, kc, vc in zip( + torch.split(q, lengths, dim=1), torch.split(k, lengths, dim=1), torch.split(v, lengths, dim=1) + ): + attn = F.softmax(torch.matmul(qc, kc.transpose(-2, -1)) * self.scaling, dim=-1, dtype=torch.float32).to( + qc.dtype + ) + outputs.append(torch.matmul(attn, vc)) + attn_output = torch.cat(outputs, dim=1).transpose(0, 1).reshape(seq_length, -1).contiguous() + return self.proj(attn_output) + + +class VisionPatchEmbed(nn.Module): + def __init__(self, config): + super().__init__() + self.patch_size = config.patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size) + + def forward(self, x): + target_dtype = self.proj.weight.dtype + x = x.view(-1, self.in_channels, self.patch_size, self.patch_size) + return self.proj(x.to(dtype=target_dtype)).view(-1, self.embed_dim) + + +class VisionEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_dim = config.hidden_size + num_patches = (config.image_size // config.patch_size) ** 2 + self.position_embedding = nn.Embedding(num_patches, self.embed_dim) + + def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords): + pos_w = self.position_embedding.weight + hidden_size = pos_w.shape[1] + device = pos_w.device + + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + + orig_size = int(pos_w.shape[0] ** 0.5) + pos_2d = pos_w.view(orig_size, orig_size, hidden_size).permute(2, 0, 1).unsqueeze(0).float() + + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + + adapted = F.grid_sample(pos_2d, grid, mode="bilinear", align_corners=False, padding_mode="border") + adapted = adapted.squeeze(0).squeeze(-1).permute(1, 0).to(pos_w.dtype).to(embeddings.device) + return embeddings + adapted + + +class VisionBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = VisionAttention(config) + self.mlp = VisionMLP(config) + + def forward(self, hidden_states, cu_seqlens, **kwargs): + hidden_states = hidden_states + self.attn(self.norm1(hidden_states), cu_seqlens=cu_seqlens) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class VisionEncoder(nn.Module): + """Vision transformer encoder that produces per-patch features.""" + + def __init__(self, config): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.embeddings = VisionEmbeddings(config) + self.patch_embed = VisionPatchEmbed(config) + self.blocks = nn.ModuleList([VisionBlock(config) for _ in range(config.depth)]) + + @property + def dtype(self): + return self.patch_embed.proj.weight.dtype + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos = hpos.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos = hpos.permute(0, 2, 1, 3).flatten() + + wpos = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos = wpos.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos = wpos.permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos, wpos], dim=-1).repeat(t, 1)) + return torch.cat(pos_ids, dim=0) + + def forward(self, pixel_values, grid_thw): + hidden_states = self.patch_embed(pixel_values) + image_type_ids = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + cu_seqlens = F.pad(cu_seqlens.cumsum(0, dtype=torch.int32), (1, 0), value=0) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + + hidden_states = self.embeddings( + hidden_states, + seqlens, + grid_thw, + image_type_ids[:, 0].to(hidden_states.device), + image_type_ids[:, 1].to(hidden_states.device), + ) + for blk in self.blocks: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens) + return hidden_states + + +# ============================================================ +# VQVAE quantizer +# ============================================================ + + +class VQVAEVectorQuantizer(nn.Module): + def __init__(self, config): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + + def forward(self, hidden_state): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + flat = hidden_state.view(-1, self.embedding_dim) + + flat = F.normalize(flat, p=2, dim=-1) + emb = F.normalize(self.embedding.weight, p=2, dim=-1) + + distances = ( + torch.sum(flat**2, dim=1, keepdim=True) + + torch.sum(emb**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", flat, emb.t()) + ) + return torch.argmin(distances, dim=1) + + +class VQVAE(nn.Module): + def __init__(self, config): + super().__init__() + self.quantize = VQVAEVectorQuantizer(config) + self.quant_conv = nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = nn.Conv2d(config.embed_dim, config.latent_channels, 1) + + def encode(self, hidden_states): + return self.quantize(self.quant_conv(hidden_states)) + + +# ============================================================ +# Weight loading +# ============================================================ + + +def _load_weights(model_dir, visual, vqmodel): + from safetensors.torch import load_file + + model_path = Path(model_dir) + index_file = model_path / "model.safetensors.index.json" + + if index_file.exists(): + with open(index_file) as f: + weight_map = json.load(f)["weight_map"] + needed = {fn for k, fn in weight_map.items() if k.startswith(("model.visual.", "model.vqmodel."))} + else: + needed = {f.name for f in model_path.glob("*.safetensors")} + + visual_sd, vq_sd = {}, {} + for filename in sorted(needed): + filepath = model_path / filename + if not filepath.exists(): + continue + shard = load_file(str(filepath), device="cpu") + for key, value in shard.items(): + if key.startswith("model.visual."): + visual_sd[key[len("model.visual.") :]] = value + elif key.startswith("model.vqmodel."): + vq_sd[key[len("model.vqmodel.") :]] = value + del shard + + visual.load_state_dict(visual_sd, strict=False) + vqmodel.load_state_dict(vq_sd, strict=False) + del visual_sd, vq_sd + + +# ============================================================ +# Main tokenizer class +# ============================================================ + + +class ImageTokenizer: + """ + Standalone image tokenizer that converts PIL images to discrete VQ token IDs. + + Expects the following layout under ``model_path``:: + + model_path/ + └── image_tokenizer/ + ├── config.json # vision_config + vq_config + ├── preprocessor_config.json + └── *.safetensors # visual + vqmodel weights + + Args: + model_path: Root path of the model directory (parent of image_tokenizer/). + device: Torch device. + dtype: Model dtype (default: bfloat16). + """ + + def __init__(self, model_path, device="cuda", dtype=torch.bfloat16): + self.device = torch.device(device) + self.dtype = dtype + + tokenizer_dir = Path(model_path) / "image_tokenizer" + + self.image_processor = ImagePreprocessor(tokenizer_dir) + + raw_config = load_configs(tokenizer_dir) + vision_cfg = make_vision_config(raw_config) + vq_cfg = make_vq_config(raw_config) + + self.visual = VisionEncoder(vision_cfg).to(self.device, self.dtype) + self.vqmodel = VQVAE(vq_cfg).to(self.device, self.dtype) + + _load_weights(str(tokenizer_dir), self.visual, self.vqmodel) + self.visual.eval() + self.vqmodel.eval() + self.spatial_merge_size = vision_cfg.spatial_merge_size + + @staticmethod + def _whiten_transparency(img): + if img.mode == "RGBA": + canvas = PIL.Image.new("RGBA", img.size, (255, 255, 255, 255)) + canvas.alpha_composite(img) + return canvas.convert("RGB") + return img if img.mode == "RGB" else img.convert("RGB") + + def _extract_features(self, pixel_values, image_grid_thw): + with torch.no_grad(): + hidden = self.visual(pixel_values.to(self.device, self.dtype), grid_thw=image_grid_thw.to(self.device)) + split_sizes = (image_grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() + return list(torch.split(hidden, split_sizes)) + + def _quantize(self, hidden_states, image_grid_thw): + hidden_size = hidden_states.shape[-1] + split_sizes = image_grid_thw.prod(dim=-1).tolist() + all_tokens = [] + with torch.no_grad(): + for i, hs in enumerate(torch.split(hidden_states, split_sizes)): + gt, gh, gw = image_grid_thw[i].tolist() + hs = hs.view(gt, gh, gw, hidden_size).permute(0, 3, 1, 2).contiguous() + all_tokens.append(self.vqmodel.encode(hs)) + return torch.cat(all_tokens, dim=0) + + @torch.no_grad() + def encode(self, image: PIL.Image.Image) -> list[int]: + """Encode a single image to VQ token IDs.""" + image = self._whiten_transparency(image) + inputs = self.image_processor([image]) + embeds = self._extract_features(inputs["pixel_values"], inputs["image_grid_thw"]) + tokens = self._quantize(torch.cat(embeds, dim=0), inputs["image_grid_thw"]) + return tokens.flatten().tolist() + + @torch.no_grad() + def encode_batch(self, images: list[PIL.Image.Image]) -> list[list[int]]: + """Encode a batch of images to VQ token IDs.""" + images = [self._whiten_transparency(img) for img in images] + inputs = self.image_processor(images) + pv, grid = inputs["pixel_values"], inputs["image_grid_thw"] + embeds = self._extract_features(pv, grid) + return [self._quantize(e, grid[i : i + 1]).flatten().tolist() for i, e in enumerate(embeds)] + + @torch.no_grad() + def encode_with_info(self, image: PIL.Image.Image) -> dict: + """Encode image and return token IDs with metadata.""" + image = self._whiten_transparency(image) + w, h = image.size + inputs = self.image_processor([image]) + pv, grid = inputs["pixel_values"], inputs["image_grid_thw"] + embeds = self._extract_features(pv, grid) + tl = self._quantize(torch.cat(embeds, dim=0), grid).flatten().tolist() + return { + "pixel_values": pv, + "token_ids": tl, + "grid_thw": tuple(grid[0].tolist()), + "num_tokens": len(tl), + "image_size": (w, h), + } + + @property + def codebook_size(self): + return self.vqmodel.quantize.num_embeddings + + @property + def embed_dim(self): + return self.vqmodel.quantize.embedding_dim diff --git a/src/diffusers/pipelines/unillada/pipeline_output.py b/src/diffusers/pipelines/unillada/pipeline_output.py new file mode 100644 index 000000000000..edd8e1e2d9dd --- /dev/null +++ b/src/diffusers/pipelines/unillada/pipeline_output.py @@ -0,0 +1,38 @@ +# Copyright 2025 Ant Group and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class UniLLaDaPipelineOutput(BaseOutput): + """ + Output class for UniLLaDA pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`, *optional*): + List of denoised PIL images or numpy array. Present for text-to-image and image editing tasks. + text (`str` or `list[str]`, *optional*): + Generated text response. Present for image understanding tasks. + """ + + images: list[PIL.Image.Image] | np.ndarray | None = None + text: str | list[str] | None = None diff --git a/src/diffusers/pipelines/unillada/pipeline_unillada.py b/src/diffusers/pipelines/unillada/pipeline_unillada.py new file mode 100644 index 000000000000..70e5a997383e --- /dev/null +++ b/src/diffusers/pipelines/unillada/pipeline_unillada.py @@ -0,0 +1,508 @@ +# Copyright 2025 Ant Group and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch + +from ...schedulers import BlockRefinementScheduler +from ...utils import logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import UniLLaDaPipelineOutput + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from diffusers import UniLLaDaPipeline, BlockRefinementScheduler + >>> from diffusers.pipelines.unillada.image_tokenizer import ImageTokenizer + + >>> model_id = "inclusionAI/LLaDA2.0-Uni" + >>> model = AutoModelForCausalLM.from_pretrained( + ... model_id, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto" + ... ) + >>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + >>> scheduler = BlockRefinementScheduler() + >>> image_tokenizer = ImageTokenizer(model_path=model_id) + + >>> pipe = UniLLaDaPipeline( + ... transformer=model, tokenizer=tokenizer, scheduler=scheduler, image_tokenizer=image_tokenizer + ... ) + + >>> # Text-to-Image + >>> result = pipe(prompt="A cat sitting on a windowsill at sunset") + >>> result.images[0].save("output.png") + ``` +""" + + +class UniLLaDaPipeline(DiffusionPipeline): + r""" + Pipeline for UniLLaDA — a discrete diffusion LLM supporting text-to-image generation, + image understanding, and image editing via block-wise iterative refinement. + + This pipeline supports three modes determined automatically by the inputs: + - **Text-to-Image**: Provide `prompt` only. + - **Image Understanding**: Provide `image` and `question`. + - **Image Editing**: Provide `image` and `instruction`. + + The model (`transformer`) is expected to be a `transformers`-compatible causal LM with + `generate_image`, `understand_image`, and `edit_image` methods (e.g., loaded with + `AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)`). + + Args: + transformer (`Any`): + The UniLLaDA language model backbone with image generation capabilities. + Expected to have `generate_image`, `understand_image`, and `edit_image` methods. + tokenizer (`Any`): + Tokenizer compatible with the transformer model. + scheduler ([`BlockRefinementScheduler`]): + A scheduler for block-wise refinement during discrete diffusion. + image_tokenizer (`Any`, *optional*): + An image tokenizer for encoding input images to VQ tokens (required for understanding and editing modes). + """ + + transformer: Any + tokenizer: Any + scheduler: BlockRefinementScheduler + image_tokenizer: Any + + _optional_components = ["image_tokenizer"] + model_cpu_offload_seq = "transformer" + + def __init__( + self, + transformer: Any, + tokenizer: Any, + scheduler: BlockRefinementScheduler, + image_tokenizer: Any | None = None, + ): + super().__init__() + self.register_modules( + transformer=transformer, + tokenizer=tokenizer, + scheduler=scheduler, + image_tokenizer=image_tokenizer, + ) + + # ================================================================ + # Image Encoding (for understanding and editing) + # ================================================================ + + def encode_image( + self, + image: PIL.Image.Image, + ) -> tuple[list[int], int, int]: + """ + Encode a PIL image to VQ token IDs with the `image_token_offset` applied. + + Args: + image (`PIL.Image.Image`): + Input PIL image. + + Returns: + `tuple[list[int], int, int]`: Tuple of (token_ids_with_offset, h, w). + """ + if self.image_tokenizer is None: + raise ValueError( + "`image_tokenizer` is required for image understanding and editing modes. " + "Pass it to the pipeline constructor." + ) + + from .utils import generate_crop_size_list, var_center_crop + + crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32) + image = var_center_crop(image, crop_size_list=crop_size_list) + + info = self.image_tokenizer.encode_with_info(image) + # Add image_token_offset as the backbone expects + image_token_offset = getattr(self.transformer.config, "image_token_offset", 0) + image_tokens = [x + image_token_offset for x in info["token_ids"]] + _, h, w = info["grid_thw"] + return image_tokens, h, w + + # ================================================================ + # VQ Token Decoding + # ================================================================ + + @torch.inference_mode() + def decode_tokens_to_image( + self, + token_ids: list[int], + h: int, + w: int, + decode_fn: Callable | None = None, + num_steps: int = 50, + resolution_multiplier: int = 2, + decode_mode: str = "normal", + **decode_kwargs, + ) -> PIL.Image.Image: + """ + Decode VQ token IDs into a PIL Image. + + Args: + token_ids (`list[int]`): + VQ token IDs (without the image_token_offset). + h (`int`): + Semantic grid height. + w (`int`): + Semantic grid width. + decode_fn (`Callable`, *optional*): + Custom decode function. If not provided, the pipeline uses the transformer's + built-in decode method if available. + num_steps (`int`, defaults to 50): + ODE/SDE sampling steps for the decoder. + resolution_multiplier (`int`, defaults to 2): + Upscale factor (2 = 1024px from 512px tokens). + decode_mode (`str`, defaults to `"normal"`): + Decoder mode: `"normal"` (50 steps) or `"decoder-turbo"` (8 steps). + **decode_kwargs: + Additional keyword arguments passed to the decode function. + + Returns: + `PIL.Image.Image`: The decoded image. + """ + if decode_fn is not None: + return decode_fn( + token_ids, + h, + w, + resolution_multiplier=resolution_multiplier, + num_steps=num_steps, + decode_mode=decode_mode, + **decode_kwargs, + ) + + # Fallback: try the transformer's own decode method + if hasattr(self.transformer, "decode_image_tokens"): + return self.transformer.decode_image_tokens( + token_ids, + h, + w, + num_steps=num_steps, + resolution_multiplier=resolution_multiplier, + decode_mode=decode_mode, + **decode_kwargs, + ) + + raise ValueError( + "No decode function available. Pass `decode_fn` to `__call__()` or ensure " + "the transformer model has a `decode_image_tokens` method." + ) + + # ================================================================ + # Input Validation + # ================================================================ + + def check_inputs( + self, + prompt: str | None, + image: PIL.Image.Image | None, + question: str | None, + instruction: str | None, + output_type: str, + ): + """Validate input arguments.""" + has_prompt = prompt is not None + has_image = image is not None + has_question = question is not None + has_instruction = instruction is not None + + if not has_prompt and not has_image: + raise ValueError( + "Invalid input combination. Provide one of:\n" + " - `prompt` only (text-to-image)\n" + " - `image` + `question` (image understanding)\n" + " - `image` + `instruction` (image editing)" + ) + + if has_prompt and (has_image or has_question or has_instruction): + raise ValueError( + "For text-to-image mode, provide `prompt` only without `image`, `question`, or `instruction`." + ) + + if has_image and not has_question and not has_instruction: + raise ValueError( + "When `image` is provided, also provide `question` (understanding) or `instruction` (editing)." + ) + + if has_question and has_instruction: + raise ValueError("Provide either `question` or `instruction`, not both.") + + if output_type not in {"pil", "np", "tokens"}: + raise ValueError(f"`output_type` must be one of 'pil', 'np', 'tokens', got {output_type!r}.") + + # ================================================================ + # Main __call__ method + # ================================================================ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | None = None, + image: PIL.Image.Image | None = None, + question: str | None = None, + instruction: str | None = None, + height: int = 1024, + width: int = 1024, + steps: int = 8, + guidance_scale: float = 2.0, + block_length: int = 32, + cfg_text_scale: float | None = None, + cfg_image_scale: float = 0.0, + decoder_steps: int | None = None, + decode_mode: str = "decoder-turbo", + decode_fn: Callable | None = None, + resolution_multiplier: int = 2, + generator: torch.Generator | None = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> UniLLaDaPipelineOutput | tuple: + r""" + Generate images or text based on the provided inputs. + + The mode is determined automatically by which arguments are provided: + - **Text-to-Image**: Provide `prompt` only. + - **Image Understanding**: Provide `image` + `question`. + - **Image Editing**: Provide `image` + `instruction`. + + Args: + prompt (`str`, *optional*): + Text prompt for text-to-image generation. + image (`PIL.Image.Image`, *optional*): + Input image for understanding or editing. + question (`str`, *optional*): + Question about the image (understanding mode). + instruction (`str`, *optional*): + Editing instruction (editing mode). + height (`int`, defaults to 1024): + Output image height in pixels (text-to-image only). + width (`int`, defaults to 1024): + Output image width in pixels (text-to-image only). + steps (`int`, defaults to 8): + Number of block diffusion steps for the LLM backbone. + guidance_scale (`float`, defaults to 2.0): + CFG scale for the LLM backbone during token generation. + block_length (`int`, defaults to 32): + Block size for the LLM block diffusion. + cfg_text_scale (`float`, *optional*): + CFG scale for text in editing mode. Defaults to `guidance_scale`. + cfg_image_scale (`float`, defaults to 0.0): + CFG scale for image in editing mode. + decoder_steps (`int`, *optional*): + Number of decoder diffusion steps. Defaults to 8 for turbo, 50 for normal. + decode_mode (`str`, defaults to `"decoder-turbo"`): + Decoder mode: `"decoder-turbo"` (fast, ~8 steps) or `"normal"` (quality, ~50 steps). + decode_fn (`Callable`, *optional*): + Custom decode function for converting VQ tokens to images. If not provided, + the transformer's built-in decode method is used. + resolution_multiplier (`int`, defaults to 2): + Upscale factor (2 = 1024px from 512px tokens). + generator (`torch.Generator`, *optional*): + Random generator for reproducibility. + output_type (`str`, defaults to `"pil"`): + Output format: `"pil"`, `"np"`, or `"tokens"`. + return_dict (`bool`, defaults to `True`): + Whether to return a [`UniLLaDaPipelineOutput`] or a tuple. + + Returns: + [`UniLLaDaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, a [`UniLLaDaPipelineOutput`] is returned, otherwise a tuple is returned. + + Examples: + """ + # 1. Validate inputs + self.check_inputs( + prompt=prompt, + image=image, + question=question, + instruction=instruction, + output_type=output_type, + ) + + # 2. Determine default decoder steps + if decoder_steps is None: + decoder_steps = 8 if decode_mode == "decoder-turbo" else 50 + + # 3. Route to the appropriate mode + if image is not None and question is not None: + return self._understand_image( + image=image, + question=question, + steps=steps, + output_type=output_type, + return_dict=return_dict, + ) + elif image is not None and instruction is not None: + return self._edit_image( + image=image, + instruction=instruction, + steps=steps, + block_length=block_length, + cfg_text_scale=cfg_text_scale if cfg_text_scale is not None else guidance_scale, + cfg_image_scale=cfg_image_scale, + decoder_steps=decoder_steps, + decode_mode=decode_mode, + decode_fn=decode_fn, + resolution_multiplier=resolution_multiplier, + output_type=output_type, + return_dict=return_dict, + ) + else: + return self._generate_image( + prompt=prompt, + height=height, + width=width, + steps=steps, + cfg_scale=guidance_scale, + block_length=block_length, + decoder_steps=decoder_steps, + decode_mode=decode_mode, + decode_fn=decode_fn, + resolution_multiplier=resolution_multiplier, + output_type=output_type, + return_dict=return_dict, + ) + + # ================================================================ + # Mode implementations + # ================================================================ + + def _generate_image( + self, + prompt: str, + height: int, + width: int, + steps: int, + cfg_scale: float, + block_length: int, + decoder_steps: int, + decode_mode: str, + decode_fn: Callable | None, + resolution_multiplier: int, + output_type: str, + return_dict: bool, + ) -> UniLLaDaPipelineOutput | tuple: + """Text-to-image generation.""" + result = self.transformer.generate_image( + prompt, + image_h=height, + image_w=width, + steps=steps, + cfg_scale=cfg_scale, + block_length=block_length, + ) + + if output_type == "tokens": + if not return_dict: + return (None, str(result)) + return UniLLaDaPipelineOutput(images=None, text=str(result)) + + image = self.decode_tokens_to_image( + result["token_ids"], + result["h"], + result["w"], + decode_fn=decode_fn, + num_steps=decoder_steps, + resolution_multiplier=resolution_multiplier, + decode_mode=decode_mode, + ) + + if output_type == "np": + image = np.array(image) + + if not return_dict: + return ([image],) + return UniLLaDaPipelineOutput(images=[image]) + + def _understand_image( + self, + image: PIL.Image.Image, + question: str, + steps: int, + output_type: str, + return_dict: bool, + ) -> UniLLaDaPipelineOutput | tuple: + """Image understanding.""" + image_tokens, h, w = self.encode_image(image) + + response = self.transformer.understand_image(image_tokens, h, w, question, steps=steps) + + if not return_dict: + return (response,) + return UniLLaDaPipelineOutput(text=response) + + def _edit_image( + self, + image: PIL.Image.Image, + instruction: str, + steps: int, + block_length: int, + cfg_text_scale: float, + cfg_image_scale: float, + decoder_steps: int, + decode_mode: str, + decode_fn: Callable | None, + resolution_multiplier: int, + output_type: str, + return_dict: bool, + ) -> UniLLaDaPipelineOutput | tuple: + """Image editing.""" + image_tokens, h, w = self.encode_image(image) + + result = self.transformer.edit_image( + image_tokens, + h, + w, + instruction, + steps=steps, + block_length=block_length, + cfg_text_scale=cfg_text_scale, + cfg_image_scale=cfg_image_scale, + ) + + if output_type == "tokens": + if not return_dict: + return (None, str(result)) + return UniLLaDaPipelineOutput(images=None, text=str(result)) + + edited_image = self.decode_tokens_to_image( + result["token_ids"], + result["h"], + result["w"], + decode_fn=decode_fn, + num_steps=decoder_steps, + resolution_multiplier=resolution_multiplier, + decode_mode=decode_mode, + ) + + if output_type == "np": + edited_image = np.array(edited_image) + + if not return_dict: + return ([edited_image],) + return UniLLaDaPipelineOutput(images=[edited_image]) + + +__all__ = ["UniLLaDaPipeline"] diff --git a/src/diffusers/pipelines/unillada/sigvq.py b/src/diffusers/pipelines/unillada/sigvq.py new file mode 100644 index 000000000000..27af099eeee0 --- /dev/null +++ b/src/diffusers/pipelines/unillada/sigvq.py @@ -0,0 +1,56 @@ +"""SigVQ: Semantic token embedding extractor for the image decoder.""" + +import torch +import torch.nn as nn + + +class _LinearWrapper(nn.Module): + """Wraps nn.Linear inside a .proj attribute to match diffusers checkpoint key format.""" + + def __init__(self, in_features, out_features): + super().__init__() + self.proj = nn.Linear(in_features, out_features) + + def forward(self, x): + return self.proj(x) + + +class _FeedForward(nn.Module): + """SiLU feed-forward matching diffusers key layout: net.0.proj / net.1 / net.2""" + + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.net = nn.Sequential( + _LinearWrapper(dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class SigVQ(nn.Module): + """ + Lightweight semantic token extractor. + Maps discrete VQ token IDs to continuous feature vectors via embedding + projection. + + Args: + vocab_size: VQ codebook size (default: 16384). + inner_dim: Feature dimension (default: 4096). + """ + + def __init__(self, vocab_size: int = 16384, inner_dim: int = 4096): + super().__init__() + self.prior_token_embedding = nn.Embedding(vocab_size, inner_dim) + self.prior_projector = _FeedForward(dim=inner_dim, hidden_dim=inner_dim) + self.requires_grad_(False) + + def forward(self, token_ids: torch.Tensor) -> torch.Tensor: + """ + Args: + token_ids: (batch, seq_len) discrete token indices. + Returns: + (batch, seq_len, inner_dim) projected feature vectors. + """ + return self.prior_projector(self.prior_token_embedding(token_ids)) diff --git a/src/diffusers/pipelines/unillada/utils.py b/src/diffusers/pipelines/unillada/utils.py new file mode 100644 index 000000000000..b45c7d11c7ae --- /dev/null +++ b/src/diffusers/pipelines/unillada/utils.py @@ -0,0 +1,94 @@ +# Copyright 2025 Ant Group and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image preprocessing utilities for UniLLaDA pipeline.""" + +from __future__ import annotations + +import math + +import PIL.Image + + +def generate_crop_size_list( + max_num_patches: int, + patch_size: int = 32, + min_size: int = 256, +) -> list[tuple[int, int]]: + """ + Generate a list of valid (height, width) crop sizes. + + Args: + max_num_patches (`int`): + Maximum number of patches (e.g., (512 // 32) ** 2 = 256). + patch_size (`int`, defaults to 32): + Patch size in pixels. + min_size (`int`, defaults to 256): + Minimum image dimension. + + Returns: + `list[tuple[int, int]]`: Sorted list of (height, width) pairs. + """ + crop_sizes = [] + for h_patches in range(1, max_num_patches + 1): + for w_patches in range(1, max_num_patches + 1): + if h_patches * w_patches <= max_num_patches: + h = h_patches * patch_size + w = w_patches * patch_size + if h >= min_size and w >= min_size: + crop_sizes.append((h, w)) + crop_sizes.sort(key=lambda x: x[0] * x[1]) + return crop_sizes + + +def var_center_crop( + image: PIL.Image.Image, + crop_size_list: list[tuple[int, int]], +) -> PIL.Image.Image: + """ + Center-crop an image to the best matching size from `crop_size_list`, + preserving aspect ratio as much as possible. + + Args: + image (`PIL.Image.Image`): + Input image. + crop_size_list (`list[tuple[int, int]]`): + List of valid (height, width) crop sizes. + + Returns: + `PIL.Image.Image`: Cropped and resized image. + """ + w, h = image.size + aspect_ratio = w / h + + # Find best matching crop size + best_size = min( + crop_size_list, + key=lambda s: abs(s[1] / s[0] - aspect_ratio), + ) + + target_h, target_w = best_size + + # Resize to cover target size while maintaining aspect ratio + scale = max(target_w / w, target_h / h) + new_w = int(math.ceil(w * scale)) + new_h = int(math.ceil(h * scale)) + image = image.resize((new_w, new_h), PIL.Image.LANCZOS) + + # Center crop + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + image = image.crop((left, top, left + target_w, top + target_h)) + + return image diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b1f75bed7dc5..711a548299bb 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -62,6 +62,7 @@ _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] + _import_structure["scheduling_llada2uni_flow_match_euler"] = ["LLaDA2UniFlowMatchEulerScheduler"] _import_structure["scheduling_helios"] = ["HeliosScheduler"] _import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] @@ -175,6 +176,7 @@ from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_lcm import LCMScheduler + from .scheduling_llada2uni_flow_match_euler import LLaDA2UniFlowMatchEulerScheduler from .scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler diff --git a/src/diffusers/schedulers/scheduling_llada2uni_flow_match_euler.py b/src/diffusers/schedulers/scheduling_llada2uni_flow_match_euler.py new file mode 100644 index 000000000000..4fa0da700b56 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_llada2uni_flow_match_euler.py @@ -0,0 +1,236 @@ +# Copyright 2025 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Flow Matching Euler ODE Scheduler for UniLLaDA. + +Implements linear interpolation path (ICPlan) with velocity prediction and Euler ODE integration. +Supports both standard ODE sampling (50 steps) and DDPM-style re-noising (8 steps, decoder-turbo). +""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class LLaDA2UniFlowMatchEulerSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.Tensor + + +class LLaDA2UniFlowMatchEulerScheduler(SchedulerMixin, ConfigMixin): + """ + Flow Matching scheduler using Euler ODE integration with linear interpolation path for UniLLaDA. + + This scheduler implements the flow matching framework with: + - Linear path: x_t = t * x_1 + (1 - t) * x_0 + - Velocity prediction: v_t = x_1 - x_0 + - Euler ODE integration + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model (not used during inference). + shift_factor (`float`, defaults to 6.0): + Time shifting factor for improved sampling at high resolutions. + use_dynamic_shifting (`bool`, defaults to `True`): + Whether to apply dynamic time shifting based on sequence length. + stochastic_ratio (`float`, defaults to 0.0): + Ratio of stochastic (DDPM-style) vs deterministic (ODE) sampling. + 0.0 = pure ODE, 1.0 = pure DDPM re-noising (decoder-turbo mode). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift_factor: float = 6.0, + use_dynamic_shifting: bool = True, + stochastic_ratio: float = 0.0, + ): + self.num_train_timesteps = num_train_timesteps + self.shift_factor = shift_factor + self.use_dynamic_shifting = use_dynamic_shifting + self.stochastic_ratio = stochastic_ratio + + # Will be set in set_timesteps + self.timesteps = None + self.num_inference_steps = None + self._step_index = None + + def time_shift(self, t: torch.Tensor, seq_len: int) -> torch.Tensor: + """Apply time shifting for improved high-resolution sampling.""" + if not self.use_dynamic_shifting: + return t + + # Dynamic shifting based on sequence length + base_shift = 0.5 + max_shift = 1.15 + mu = base_shift + (max_shift - base_shift) * (seq_len - 256) / (4096 - 256) + mu = max(base_shift, min(max_shift, mu)) + + # Shift formula (original uses t=0:clean, t=1:noise; we use reverse) + t_shifted = 1 - t + t_shifted = math.exp(mu) / (math.exp(mu) + (1 / (t_shifted + 1e-10) - 1) ** 1.0) + t_shifted = 1 - t_shifted + + return t_shifted + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + seq_len: Optional[int] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + seq_len (`int`, *optional*): + Sequence length for dynamic time shifting. If None, uses default shifting. + """ + self.num_inference_steps = num_inference_steps + + # Linear timesteps from 0 to 1 + timesteps = torch.linspace(0, 1, num_inference_steps, device=device) + + # Apply time shifting if enabled + if self.use_dynamic_shifting and seq_len is not None: + timesteps = torch.tensor( + [self.time_shift(t.item(), seq_len) for t in timesteps], + device=device, + ) + elif self.shift_factor > 0: + # Apply fixed shifting + timesteps = timesteps / (timesteps + self.shift_factor - self.shift_factor * timesteps) + + self.timesteps = timesteps + self._step_index = 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor], + sample: torch.Tensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LLaDA2UniFlowMatchEulerSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model (velocity prediction). + timestep (`float` or `torch.Tensor`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator for stochastic sampling. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_utils.LLaDA2UniFlowMatchEulerSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.LLaDA2UniFlowMatchEulerSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.LLaDA2UniFlowMatchEulerSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.timesteps is None: + raise ValueError("Timesteps must be set before calling step(). Call set_timesteps() first.") + + # Get current and next timestep + step_idx = self._step_index + if step_idx == len(self.timesteps) - 1: + # Last step + dt = 0.0 + prev_sample = sample + model_output * dt + else: + t_curr = self.timesteps[step_idx] + t_next = self.timesteps[step_idx + 1] + dt = (t_next - t_curr).item() + + # Euler step: x_{t+dt} = x_t + v_t * dt + prev_sample = sample + model_output * dt + + # Add stochastic noise if stochastic_ratio > 0 (decoder-turbo mode) + if self.stochastic_ratio > 0 and step_idx < len(self.timesteps) - 1: + noise = torch.randn_like(sample, generator=generator) + # Scale noise by stochastic ratio and timestep + noise_scale = self.stochastic_ratio * math.sqrt(abs(dt)) + prev_sample = prev_sample + noise * noise_scale + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return LLaDA2UniFlowMatchEulerSchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + """ + Add noise to the original samples according to the flow matching forward process. + + Args: + original_samples (`torch.Tensor`): + The original samples (x_1). + noise (`torch.Tensor`): + The noise to add (x_0). + timesteps (`torch.Tensor`): + The timesteps (t) for each sample. + + Returns: + `torch.Tensor`: The noisy samples x_t = t * x_1 + (1 - t) * x_0. + """ + # Ensure timesteps are in [0, 1] + timesteps = timesteps.float() + if timesteps.dim() == 0: + timesteps = timesteps.unsqueeze(0) + + # Reshape timesteps to broadcast correctly + while timesteps.dim() < original_samples.dim(): + timesteps = timesteps.unsqueeze(-1) + + # Linear interpolation: x_t = t * x_1 + (1 - t) * x_0 + noisy_samples = timesteps * original_samples + (1 - timesteps) * noise + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6511345e9511..3c85fd6e2f7e 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2432,6 +2432,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class UniLLaDaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniLLaDaPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LongCatAudioDiTPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/unillada/__init__.py b/tests/pipelines/unillada/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/unillada/test_unillada.py b/tests/pipelines/unillada/test_unillada.py new file mode 100644 index 000000000000..9d0c05a8e239 --- /dev/null +++ b/tests/pipelines/unillada/test_unillada.py @@ -0,0 +1,159 @@ +import unittest + +import torch + +from diffusers import BlockRefinementScheduler, UniLLaDaPipeline +from diffusers.pipelines.unillada.pipeline_output import UniLLaDaPipelineOutput + + +class _DummyConfig: + image_token_offset = 100 + + +class _DummyGenerateResult(dict): + """Mimic the dict returned by transformer.generate_image().""" + + pass + + +class _DummyTransformer: + """Mock transformer that mimics the UniLLaDA backbone interface.""" + + config = _DummyConfig() + _hf_hook = None # needed for DiffusionPipeline + + def generate_image(self, prompt, image_h=1024, image_w=1024, steps=8, cfg_scale=2.0, block_length=32): + return {"token_ids": list(range(16)), "h": 4, "w": 4} + + def understand_image(self, image_tokens, h, w, question, steps=8): + return "This is a test response." + + def edit_image( + self, image_tokens, h, w, instruction, steps=8, block_length=32, cfg_text_scale=2.0, cfg_image_scale=0.0 + ): + return {"token_ids": list(range(16)), "h": h, "w": w} + + # Required for register_modules + @property + def device(self): + return torch.device("cpu") + + @property + def dtype(self): + return torch.float32 + + +class _DummyTokenizer: + """Mock tokenizer.""" + + eos_token_id = 2 + mask_token_id = 31 + _hf_hook = None + + +class _DummyImageTokenizer: + """Mock image tokenizer.""" + + _hf_hook = None + + def encode_with_info(self, image): + return {"token_ids": list(range(16)), "grid_thw": (1, 4, 4)} + + +def _make_pipeline(with_image_tokenizer=False): + transformer = _DummyTransformer() + tokenizer = _DummyTokenizer() + scheduler = BlockRefinementScheduler() + image_tokenizer = _DummyImageTokenizer() if with_image_tokenizer else None + return UniLLaDaPipeline( + transformer=transformer, + tokenizer=tokenizer, + scheduler=scheduler, + image_tokenizer=image_tokenizer, + ) + + +def _dummy_decode_fn(token_ids, h, w, **kwargs): + """Return a small dummy PIL image.""" + import PIL.Image + + return PIL.Image.new("RGB", (64, 64), color=(128, 128, 128)) + + +class UniLLaDaPipelineTest(unittest.TestCase): + def test_text_to_image(self): + pipe = _make_pipeline() + out = pipe( + prompt="A test prompt", + decode_fn=_dummy_decode_fn, + output_type="pil", + ) + self.assertIsInstance(out, UniLLaDaPipelineOutput) + self.assertIsNotNone(out.images) + self.assertEqual(len(out.images), 1) + + def test_text_to_image_tokens(self): + pipe = _make_pipeline() + out = pipe( + prompt="A test prompt", + output_type="tokens", + ) + self.assertIsNone(out.images) + self.assertIsNotNone(out.text) + + def test_image_understanding(self): + import PIL.Image + + pipe = _make_pipeline(with_image_tokenizer=True) + img = PIL.Image.new("RGB", (256, 256)) + out = pipe(image=img, question="What is this?") + self.assertIsInstance(out, UniLLaDaPipelineOutput) + self.assertIsNotNone(out.text) + self.assertEqual(out.text, "This is a test response.") + + def test_image_editing(self): + import PIL.Image + + pipe = _make_pipeline(with_image_tokenizer=True) + img = PIL.Image.new("RGB", (256, 256)) + out = pipe( + image=img, + instruction="Make it red", + decode_fn=_dummy_decode_fn, + output_type="pil", + ) + self.assertIsNotNone(out.images) + self.assertEqual(len(out.images), 1) + + def test_invalid_input_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe() # No inputs + + def test_invalid_output_type_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe(prompt="test", output_type="invalid") + + def test_understanding_without_image_tokenizer_raises(self): + import PIL.Image + + pipe = _make_pipeline(with_image_tokenizer=False) + img = PIL.Image.new("RGB", (256, 256)) + with self.assertRaises(ValueError): + pipe(image=img, question="What is this?") + + def test_return_dict_false(self): + pipe = _make_pipeline() + out = pipe( + prompt="A test prompt", + decode_fn=_dummy_decode_fn, + output_type="pil", + return_dict=False, + ) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 1) + + +if __name__ == "__main__": + unittest.main()