From 828d9daa465cff9b9dcc7ace08e3422d4437ca37 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Sun, 8 Mar 2026 20:51:08 +0530 Subject: [PATCH 01/12] Preserve torch init return contract under no_init_weights --- src/diffusers/models/modeling_utils.py | 5 ++++- tests/models/test_modeling_utils.py | 29 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_modeling_utils.py diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0901840679e3..a8b5e7e1783c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -215,7 +215,10 @@ def no_init_weights(): """ def _skip_init(*args, **kwargs): - pass + # Preserve the `torch.nn.init.*` return contract so third-party model + # constructors that chain on the returned tensor still work under + # `no_init_weights()`. + return args[0] if len(args) > 0 else None for name, init_func in TORCH_INIT_FUNCTIONS.items(): setattr(torch.nn.init, name, _skip_init) diff --git a/tests/models/test_modeling_utils.py b/tests/models/test_modeling_utils.py new file mode 100644 index 000000000000..57e69fddcb39 --- /dev/null +++ b/tests/models/test_modeling_utils.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from diffusers.models.modeling_utils import no_init_weights + + +def test_no_init_weights_preserves_torch_init_return_contract(): + tensor = torch.empty(2, 3) + + with no_init_weights(): + truncated = torch.nn.init.trunc_normal_(tensor) + zeroed = torch.nn.init.zeros_(tensor) + + assert truncated is tensor + assert zeroed is tensor From df855b739890f1f56c0591b39c2eeebcdf8053be Mon Sep 17 00:00:00 2001 From: plugyawn Date: Sun, 8 Mar 2026 20:51:18 +0530 Subject: [PATCH 02/12] Add Stage-2 RAE DiT model, pipeline, and tooling --- examples/research_projects/rae_dit/README.md | 103 +++ .../rae_dit/compare_stage2_sample.py | 188 ++++++ .../rae_dit/train_rae_dit.py | 603 ++++++++++++++++++ .../rae_dit/verify_stage2_parity.py | 190 ++++++ scripts/convert_rae_stage2_to_diffusers.py | 548 ++++++++++++++++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_rae_dit.py | 538 ++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/rae_dit/__init__.py | 18 + .../pipelines/rae_dit/pipeline_rae_dit.py | 244 +++++++ .../test_models_rae_dit_transformer2d.py | 160 +++++ tests/pipelines/rae_dit/__init__.py | 1 + .../rae_dit/test_pipeline_rae_dit.py | 238 +++++++ 15 files changed, 2840 insertions(+) create mode 100644 examples/research_projects/rae_dit/README.md create mode 100644 examples/research_projects/rae_dit/compare_stage2_sample.py create mode 100644 examples/research_projects/rae_dit/train_rae_dit.py create mode 100644 examples/research_projects/rae_dit/verify_stage2_parity.py create mode 100644 scripts/convert_rae_stage2_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_rae_dit.py create mode 100644 src/diffusers/pipelines/rae_dit/__init__.py create mode 100644 src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py create mode 100644 tests/models/transformers/test_models_rae_dit_transformer2d.py create mode 100644 tests/pipelines/rae_dit/__init__.py create mode 100644 tests/pipelines/rae_dit/test_pipeline_rae_dit.py diff --git a/examples/research_projects/rae_dit/README.md b/examples/research_projects/rae_dit/README.md new file mode 100644 index 000000000000..59179740df4f --- /dev/null +++ b/examples/research_projects/rae_dit/README.md @@ -0,0 +1,103 @@ +# Training RAEDiT Stage 2 + +This folder contains the minimal Stage-2 follow-up for the RAE integration: training `RAEDiTTransformer2DModel` on top of a frozen `AutoencoderRAE`. + +It is intentionally placed under `examples/research_projects/rae_dit/` rather than the top-level `examples/` trainers because this is still an experimental follow-up to the new RAE support. + +## What this mirrors + +The scaffold is deliberately composed from existing `diffusers` patterns instead of introducing a new training style: + +- `examples/research_projects/autoencoder_rae/train_autoencoder_rae.py` + for ImageFolder loading, RAE-specific preprocessing, and the experimental research-project placement. +- `examples/dreambooth/train_dreambooth_flux.py` + for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks. +- `examples/flux-control/train_control_flux.py` + for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers. + +## Current scope + +This is a minimal full-finetuning scaffold, not a paper-complete training stack. It currently does the following: + +- loads a frozen pretrained `AutoencoderRAE` +- encodes RGB images to normalized Stage-1 latents on the fly +- trains only the Stage-2 `RAEDiTTransformer2DModel` +- uses `FlowMatchEulerDiscreteScheduler` with the same shifted-sigma schedule shape already used elsewhere in `diffusers` +- consumes ImageFolder class ids as `class_labels` +- saves the trained transformer under `output_dir/transformer` +- saves the scheduler config under `output_dir/scheduler` +- writes `id2label.json` from the ImageFolder class mapping + +It intentionally does not yet include: + +- a latent-caching path +- validation image generation inside the script +- autoguidance or the broader upstream transport stack +- exact upstream distributed training/runtime features + +## Parity check + +`verify_stage2_parity.py` compares a converted diffusers transformer against the upstream `DiTwDDTHead` with the same published checkpoint and synthetic latent inputs. This is the quickest way to confirm that a conversion still matches upstream numerically before opening or updating a PR. + +Example: + +```bash +python examples/research_projects/rae_dit/verify_stage2_parity.py \ + --upstream_repo_path /path/to/RAE \ + --config_path /path/to/RAE/configs/stage2/sampling/ImageNet256/DiTDHXL-DINOv2-B.yaml \ + --checkpoint_path /path/to/stage2_model.pt \ + --converted_transformer_path /path/to/diffusers-transformer +``` + +## Dataset format + +The script expects an `ImageFolder`-compatible dataset: + +```text +train_data_dir/ + n01440764/ + img_0001.jpeg + n01443537/ + img_0002.jpeg +``` + +The folder names define the class labels used during Stage-2 training. + +## Quickstart + +```bash +accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ + --pretrained_rae_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/rae-dit \ + --resolution 256 \ + --train_batch_size 8 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate 1e-4 \ + --lr_scheduler cosine \ + --lr_warmup_steps 1000 \ + --max_train_steps 200000 \ + --mixed_precision bf16 \ + --report_to wandb \ + --allow_tf32 +``` + +If you already have a converted or partially trained Stage-2 checkpoint, resume from it with: + +```bash +accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ + --pretrained_rae_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \ + --pretrained_transformer_model_name_or_path /path/to/previous/transformer \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/rae-dit-finetune \ + --resolution 256 \ + --train_batch_size 8 \ + --max_train_steps 50000 +``` + +## Notes + +- The script derives a default flow shift from the latent dimensionality as `sqrt(latent_dim / time_shift_base)`, matching the upstream Stage-2 heuristic at a high level. +- The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiTTransformer2DModel` operates on 2D latent feature maps. +- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents, validation sampling through the pipeline, and broader parity tooling. diff --git a/examples/research_projects/rae_dit/compare_stage2_sample.py b/examples/research_projects/rae_dit/compare_stage2_sample.py new file mode 100644 index 000000000000..7c84adf2a059 --- /dev/null +++ b/examples/research_projects/rae_dit/compare_stage2_sample.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import sys +from pathlib import Path +from typing import Any + +import torch +import yaml +from PIL import Image, ImageDraw + +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler +from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Create a visual side-by-side sample comparison between upstream and diffusers Stage-2 RAE DiT.") + parser.add_argument("--upstream_repo_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--converted_transformer_path", type=str, required=True) + parser.add_argument("--vae_model_name_or_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--class_label", type=int, default=207, help="ImageNet class id to sample.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num_inference_steps", type=int, default=25) + parser.add_argument("--device", type=str, default=None) + return parser.parse_args() + + +def _resolve_section(config: dict[str, Any], *keys: str) -> dict[str, Any]: + for key in keys: + section = config.get(key) + if isinstance(section, dict): + return section + raise KeyError(f"Could not find any of {keys} in config.") + + +def _maybe_strip_common_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: + if len(state_dict) > 0 and all(key.startswith(prefix) for key in state_dict): + return {key[len(prefix) :]: value for key, value in state_dict.items()} + return state_dict + + +def unwrap_state_dict(maybe_wrapped: dict[str, Any], prefer_ema: bool = True) -> dict[str, Any]: + state_dict: dict[str, Any] | Any = maybe_wrapped + if isinstance(state_dict, dict): + candidate_keys = ["ema", "model", "state_dict"] if prefer_ema else ["model", "ema", "state_dict"] + for key in candidate_keys: + if key in state_dict and isinstance(state_dict[key], dict): + state_dict = state_dict[key] + break + + if not isinstance(state_dict, dict): + raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") + + state_dict = dict(state_dict) + for prefix in ("module.", "model.", "model.module."): + state_dict = _maybe_strip_common_prefix(state_dict, prefix) + return state_dict + + +def load_checkpoint(checkpoint_path: Path) -> dict[str, Any]: + if checkpoint_path.suffix.lower() == ".safetensors": + import safetensors.torch + + return safetensors.torch.load_file(checkpoint_path) + + return torch.load(checkpoint_path, map_location="cpu") + + +def latent_to_pil(image: torch.Tensor) -> Image.Image: + array = image.detach().cpu().clamp(0, 1).permute(1, 2, 0).mul(255).round().byte().numpy() + return Image.fromarray(array) + + +def draw_label(image: Image.Image, text: str) -> Image.Image: + canvas = Image.new("RGB", (image.width, image.height + 24), color="white") + canvas.paste(image, (0, 24)) + draw = ImageDraw.Draw(canvas) + draw.text((8, 4), text, fill="black") + return canvas + + +def main(): + args = parse_args() + upstream_repo_path = Path(args.upstream_repo_path).expanduser().resolve() + config_path = Path(args.config_path).expanduser().resolve() + checkpoint_path = Path(args.checkpoint_path).expanduser().resolve() + converted_transformer_path = Path(args.converted_transformer_path).expanduser().resolve() + output_path = Path(args.output_path).expanduser().resolve() + + sys.path.insert(0, str(upstream_repo_path / "src")) + from stage2.models.DDT import DiTwDDTHead + + with config_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) + + stage2 = _resolve_section(config, "stage_2", "stage2") + stage2_params = stage2.get("params", {}) + misc = _resolve_section(config, "misc") + latent_size = misc["latent_size"] + shift = math.sqrt(int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096))) + num_train_timesteps = int(_resolve_section(config, "transport").get("params", {}).get("num_train_timesteps", 1000)) + + device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) + state_dict = unwrap_state_dict(load_checkpoint(checkpoint_path), prefer_ema=True) + + upstream_model = DiTwDDTHead(**stage2_params) + upstream_model.load_state_dict(state_dict, strict=True) + upstream_model.to(device=device, dtype=torch.float32) + upstream_model.eval() + + hf_model = RAEDiTTransformer2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) + hf_model.to(device=device, dtype=torch.float32) + hf_model.eval() + + vae = AutoencoderRAE.from_pretrained(args.vae_model_name_or_path, low_cpu_mem_usage=False) + vae.to(device=device, dtype=torch.float32) + vae.eval() + + generator = torch.Generator(device=device).manual_seed(args.seed) + latents_init = torch.randn( + (1, int(stage2_params["in_channels"]), int(stage2_params["input_size"]), int(stage2_params["input_size"])), + generator=generator, + device=device, + dtype=torch.float32, + ) + class_labels = torch.tensor([args.class_label], device=device, dtype=torch.long) + + def run_sample(model, latents): + latents = latents.clone() + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=num_train_timesteps, + shift=shift, + stochastic_sampling=False, + ) + scheduler.set_timesteps(args.num_inference_steps, device=device) + with torch.no_grad(): + for timestep in scheduler.timesteps: + timestep_input = timestep.expand(latents.shape[0]) / scheduler.config.num_train_timesteps + if isinstance(model, DiTwDDTHead): + model_output = model(latents, timestep_input, class_labels) + else: + model_output = model(hidden_states=latents, timestep=timestep_input, class_labels=class_labels).sample + latents = scheduler.step(model_output, timestep, latents).prev_sample + return vae.decode(latents).sample.clamp(0, 1) + + upstream_image = run_sample(upstream_model, latents_init)[0] + diffusers_image = run_sample(hf_model, latents_init)[0] + abs_diff = (upstream_image - diffusers_image).abs() + diff_vis = (abs_diff / max(abs_diff.max().item(), 1e-8)).clamp(0, 1) + + max_abs_error = abs_diff.max().item() + mean_abs_error = abs_diff.mean().item() + print(f"max_abs_error={max_abs_error:.8f}") + print(f"mean_abs_error={mean_abs_error:.8f}") + + upstream_pil = draw_label(latent_to_pil(upstream_image), "Upstream") + diffusers_pil = draw_label(latent_to_pil(diffusers_image), "Diffusers") + diff_pil = draw_label(latent_to_pil(diff_vis), "Abs Diff") + + canvas = Image.new("RGB", (upstream_pil.width * 3, upstream_pil.height), color="white") + canvas.paste(upstream_pil, (0, 0)) + canvas.paste(diffusers_pil, (upstream_pil.width, 0)) + canvas.paste(diff_pil, (upstream_pil.width * 2, 0)) + output_path.parent.mkdir(parents=True, exist_ok=True) + canvas.save(output_path) + print(output_path) + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py new file mode 100644 index 000000000000..899e557cc533 --- /dev/null +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -0,0 +1,603 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import math +import os +import shutil +from pathlib import Path + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import ImageFolder +from tqdm.auto import tqdm + +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTTransformer2DModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.utils import check_min_version +from diffusers.utils.torch_utils import is_compiled_module + + +check_min_version("0.38.0.dev0") + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Minimal stage-2 trainer for RAEDiTTransformer2DModel.") + parser.add_argument( + "--train_data_dir", + type=str, + required=True, + help="Path to an ImageFolder-style dataset root. Class folder names define label ids.", + ) + parser.add_argument("--output_dir", type=str, default="rae-dit", help="Directory to save checkpoints/model.") + parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.") + parser.add_argument("--report_to", type=str, default="tensorboard", help="Tracker to use with Accelerate.") + parser.add_argument("--seed", type=int, default=None, help="Seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=None, + help="Training image resolution. Defaults to the loaded RAE image size.", + ) + parser.add_argument("--center_crop", action="store_true", help="Use center crop instead of random crop.") + parser.add_argument("--random_flip", action="store_true", help="Apply random horizontal flips during training.") + parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size per device.") + parser.add_argument("--dataloader_num_workers", type=int, default=4, help="Number of dataloader workers.") + parser.add_argument("--num_train_epochs", type=int, default=10, help="Training epochs if max steps is not set.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total training steps. Overrides num_train_epochs when provided.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of update steps to accumulate before optimizer step.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Enable gradient checkpointing on the transformer.", + ) + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Gradient clipping norm.") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by world size, accumulation steps, and batch size.", + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="AdamW beta1.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="AdamW beta2.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="AdamW weight decay.") + parser.add_argument("--adam_epsilon", type=float, default=1e-8, help="AdamW epsilon.") + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help='Scheduler type. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"].', + ) + parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Scheduler warmup steps.") + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help="Override Accelerate mixed precision mode.", + ) + parser.add_argument("--allow_tf32", action="store_true", help="Enable TF32 matmul on Ampere+ GPUs.") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=1000, + help="Save Accelerate checkpoints every N optimizer steps.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help="Maximum number of checkpoint folders to keep.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help='Checkpoint path or "latest" to resume from the latest checkpoint in output_dir.', + ) + parser.add_argument( + "--pretrained_rae_model_name_or_path", + type=str, + required=True, + help="Path or Hub id for the pretrained stage-1 AutoencoderRAE.", + ) + parser.add_argument( + "--pretrained_transformer_model_name_or_path", + type=str, + default=None, + help="Optional path or Hub id for a pretrained RAEDiT transformer checkpoint.", + ) + parser.add_argument("--patch_size", type=int, default=1, help="Latent patch size for the Stage-2 transformer.") + parser.add_argument("--encoder_hidden_size", type=int, default=1152, help="Encoder token width.") + parser.add_argument("--decoder_hidden_size", type=int, default=2048, help="Decoder token width.") + parser.add_argument("--encoder_num_layers", type=int, default=28, help="Number of encoder blocks.") + parser.add_argument("--decoder_num_layers", type=int, default=2, help="Number of decoder blocks.") + parser.add_argument("--encoder_num_attention_heads", type=int, default=16, help="Encoder attention heads.") + parser.add_argument("--decoder_num_attention_heads", type=int, default=16, help="Decoder attention heads.") + parser.add_argument("--mlp_ratio", type=float, default=4.0, help="MLP expansion ratio.") + parser.add_argument( + "--class_dropout_prob", + type=float, + default=0.1, + help="Class dropout probability for classifier-free guidance readiness.", + ) + parser.add_argument( + "--num_classes", + type=int, + default=None, + help="Number of class labels. Defaults to the number of ImageFolder classes.", + ) + parser.add_argument("--use_qknorm", action="store_true", help="Enable QK norm in attention.") + parser.add_argument("--use_swiglu", action=argparse.BooleanOptionalAction, default=True, help="Use SwiGLU MLPs.") + parser.add_argument("--use_rope", action=argparse.BooleanOptionalAction, default=True, help="Use rotary embeddings.") + parser.add_argument( + "--use_rmsnorm", + action=argparse.BooleanOptionalAction, + default=True, + help="Use RMSNorm instead of LayerNorm.", + ) + parser.add_argument("--wo_shift", action="store_true", help="Disable AdaLN shift modulation.") + parser.add_argument( + "--use_pos_embed", + action=argparse.BooleanOptionalAction, + default=True, + help="Use fixed sin-cos positional embeddings on the encoder stream.", + ) + parser.add_argument( + "--num_train_timesteps", + type=int, + default=1000, + help="Number of flow-matching training timesteps.", + ) + parser.add_argument( + "--flow_shift", + type=float, + default=None, + help="Explicit flow-matching shift. If omitted, it is derived from the latent size.", + ) + parser.add_argument( + "--time_shift_base", + type=float, + default=4096.0, + help="Base latent dimensionality used to derive the default flow shift.", + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help='Weighting scheme for flow-matching timestep sampling and loss weighting.', + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="Mean used when the logit-normal weighting scheme is selected.", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="Std used when the logit-normal weighting scheme is selected.", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Mode weighting scale used when weighting_scheme=mode.", + ) + return parser.parse_args() + + +def build_transforms(args): + image_transforms = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + if args.random_flip: + image_transforms.append(transforms.RandomHorizontalFlip()) + image_transforms.append(transforms.ToTensor()) + return transforms.Compose(image_transforms) + + +def collate_fn(examples): + pixel_values = torch.stack([example[0] for example in examples]).float() + class_labels = torch.tensor([example[1] for example in examples], dtype=torch.long) + return {"pixel_values": pixel_values, "class_labels": class_labels} + + +def get_latent_spec(autoencoder: AutoencoderRAE) -> tuple[int, int]: + if not autoencoder.config.reshape_to_2d: + raise ValueError("Stage-2 RAE DiT training expects `AutoencoderRAE.reshape_to_2d=True`.") + + latent_channels = int(autoencoder.config.encoder_hidden_size) + latent_size = int(autoencoder.config.encoder_input_size // autoencoder.config.encoder_patch_size) + return latent_channels, latent_size + + +def resolve_flow_shift(args, latent_channels: int, latent_size: int) -> float: + if args.flow_shift is not None: + return float(args.flow_shift) + + latent_dim = latent_channels * latent_size * latent_size + return math.sqrt(latent_dim / float(args.time_shift_base)) + + +def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None): + if device is None: + device = timesteps.device + + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device=device) + timesteps = timesteps.to(device=device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def unwrap_model(accelerator, model): + model = accelerator.unwrap_model(model) + return model._orig_mod if is_compiled_module(model) else model + + +def maybe_prune_checkpoints(output_dir: str, checkpoints_total_limit: int | None): + if checkpoints_total_limit is None: + return + + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + if len(checkpoints) < checkpoints_total_limit: + return + + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + for checkpoint in checkpoints[:num_to_remove]: + shutil.rmtree(os.path.join(output_dir, checkpoint)) + + +def main(): + args = parse_args() + + logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + project_config=accelerator_project_config, + log_with=args.report_to, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if args.seed is not None: + set_seed(args.seed) + + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + autoencoder = AutoencoderRAE.from_pretrained(args.pretrained_rae_model_name_or_path) + autoencoder.requires_grad_(False) + autoencoder.eval() + + latent_channels, latent_size = get_latent_spec(autoencoder) + if args.resolution is None: + args.resolution = int(autoencoder.config.image_size) + + dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args)) + inferred_num_classes = len(dataset.classes) + num_classes = inferred_num_classes if args.num_classes is None else int(args.num_classes) + if num_classes < inferred_num_classes: + raise ValueError( + f"`--num_classes` ({num_classes}) must be >= the number of dataset classes ({inferred_num_classes})." + ) + + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=True, + drop_last=True, + ) + + if args.pretrained_transformer_model_name_or_path is not None: + transformer = RAEDiTTransformer2DModel.from_pretrained(args.pretrained_transformer_model_name_or_path) + if transformer.config.in_channels != latent_channels or transformer.config.sample_size != latent_size: + raise ValueError( + "Loaded transformer latent shape does not match the selected AutoencoderRAE. " + f"Expected channels={latent_channels}, size={latent_size}; got " + f"channels={transformer.config.in_channels}, size={transformer.config.sample_size}." + ) + if transformer.config.num_classes < num_classes: + raise ValueError( + f"Loaded transformer supports {transformer.config.num_classes} classes but dataset requires {num_classes}." + ) + else: + transformer = RAEDiTTransformer2DModel( + sample_size=latent_size, + patch_size=args.patch_size, + in_channels=latent_channels, + hidden_size=(args.encoder_hidden_size, args.decoder_hidden_size), + depth=(args.encoder_num_layers, args.decoder_num_layers), + num_heads=(args.encoder_num_attention_heads, args.decoder_num_attention_heads), + mlp_ratio=args.mlp_ratio, + class_dropout_prob=args.class_dropout_prob, + num_classes=num_classes, + use_qknorm=args.use_qknorm, + use_swiglu=args.use_swiglu, + use_rope=args.use_rope, + use_rmsnorm=args.use_rmsnorm, + wo_shift=args.wo_shift, + use_pos_embed=args.use_pos_embed, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + flow_shift = resolve_flow_shift(args, latent_channels=latent_channels, latent_size=latent_size) + noise_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=args.num_train_timesteps, + shift=flow_shift, + ) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + optimizer = torch.optim.AdamW( + transformer.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + def save_model_hook(models, weights, output_dir): + if not accelerator.is_main_process: + return + + for model in models: + if isinstance(unwrap_model(accelerator, model), RAEDiTTransformer2DModel): + unwrap_model(accelerator, model).save_pretrained(os.path.join(output_dir, "transformer")) + else: + raise ValueError(f"Unexpected model type during save: {type(model)}") + + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + model = models.pop() + target_model = unwrap_model(accelerator, model) + if not isinstance(target_model, RAEDiTTransformer2DModel): + raise ValueError(f"Unexpected model type during load: {type(model)}") + + load_model = RAEDiTTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") + target_model.register_to_config(**load_model.config) + target_model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + autoencoder.to(accelerator.device, dtype=weight_dtype) + + if overrode_max_train_steps: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers( + "train_rae_dit", + config={ + **vars(args), + "latent_channels": latent_channels, + "latent_size": latent_size, + "flow_shift": flow_shift, + "inferred_num_classes": inferred_num_classes, + }, + ) + with open(os.path.join(args.output_dir, "id2label.json"), "w", encoding="utf-8") as f: + json.dump({idx: label for idx, label in enumerate(dataset.classes)}, f, indent=2, sort_keys=True) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + logger.info("***** Running stage-2 RAE DiT training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num classes = {inferred_num_classes}") + logger.info(f" RAE latent shape = ({latent_channels}, {latent_size}, {latent_size})") + logger.info(f" Flow shift = {flow_shift:.4f}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size = {total_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + path = checkpoints[-1] if checkpoints else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(transformer): + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=weight_dtype, non_blocking=True) + class_labels = batch["class_labels"].to(device=accelerator.device, non_blocking=True) + + with torch.no_grad(): + latents = autoencoder.encode(pixel_values).latent + + noise = torch.randn_like(latents) + batch_size = latents.shape[0] + + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=batch_size, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + device=latents.device, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps.to(device=latents.device)[indices] + + sigmas = get_sigmas( + noise_scheduler, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + device=latents.device, + ) + noisy_latents = noise_scheduler.scale_noise(latents, timesteps, noise) + + model_pred = transformer( + hidden_states=noisy_latents, + timestep=timesteps / noise_scheduler.config.num_train_timesteps, + class_labels=class_labels, + ).sample + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + target = noise - latents + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + dim=1, + ) + loss = loss.mean() + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + maybe_prune_checkpoints(args.output_dir, args.checkpoints_total_limit) + accelerator.wait_for_everyone() + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + if accelerator.is_main_process: + noise_scheduler.save_pretrained(os.path.join(save_path, "scheduler")) + logger.info(f"Saved state to {save_path}") + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped_transformer = unwrap_model(accelerator, transformer) + unwrapped_transformer.save_pretrained(os.path.join(args.output_dir, "transformer")) + noise_scheduler.save_pretrained(os.path.join(args.output_dir, "scheduler")) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/rae_dit/verify_stage2_parity.py b/examples/research_projects/rae_dit/verify_stage2_parity.py new file mode 100644 index 000000000000..29899628d8bd --- /dev/null +++ b/examples/research_projects/rae_dit/verify_stage2_parity.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import sys +from pathlib import Path +from typing import Any + +import torch +import yaml + +from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Compare a converted RAEDiT checkpoint against the upstream Stage-2 model.") + parser.add_argument("--upstream_repo_path", type=str, required=True, help="Path to the cloned upstream RAE repo.") + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Path to the upstream Stage-2 YAML config used for the published checkpoint.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the upstream Stage-2 checkpoint (.pt or .safetensors).", + ) + parser.add_argument( + "--converted_transformer_path", + type=str, + required=True, + help="Path to the converted diffusers transformer directory.", + ) + parser.add_argument("--device", type=str, default=None, help="Torch device to use. Defaults to cuda if available.") + parser.add_argument("--seed", type=int, default=0, help="Random seed for the synthetic parity inputs.") + parser.add_argument("--batch_size", type=int, default=2, help="Batch size for the parity run.") + parser.add_argument("--rtol", type=float, default=1e-4, help="Relative tolerance for parity.") + parser.add_argument("--atol", type=float, default=1e-5, help="Absolute tolerance for parity.") + return parser.parse_args() + + +def _resolve_section(config: dict[str, Any], *keys: str) -> dict[str, Any]: + for key in keys: + section = config.get(key) + if isinstance(section, dict): + return section + raise KeyError(f"Could not find any of {keys} in config.") + + +def _maybe_strip_common_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: + if len(state_dict) > 0 and all(key.startswith(prefix) for key in state_dict): + return {key[len(prefix) :]: value for key, value in state_dict.items()} + return state_dict + + +def unwrap_state_dict(maybe_wrapped: dict[str, Any], prefer_ema: bool = True) -> dict[str, Any]: + state_dict: dict[str, Any] | Any = maybe_wrapped + + if isinstance(state_dict, dict): + candidate_keys = ["ema", "model", "state_dict"] if prefer_ema else ["model", "ema", "state_dict"] + for key in candidate_keys: + if key in state_dict and isinstance(state_dict[key], dict): + state_dict = state_dict[key] + break + + if not isinstance(state_dict, dict): + raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") + + state_dict = dict(state_dict) + for prefix in ("module.", "model.", "model.module."): + state_dict = _maybe_strip_common_prefix(state_dict, prefix) + return state_dict + + +def load_checkpoint(checkpoint_path: Path) -> dict[str, Any]: + if checkpoint_path.suffix.lower() == ".safetensors": + import safetensors.torch + + return safetensors.torch.load_file(checkpoint_path) + + return torch.load(checkpoint_path, map_location="cpu") + + +def build_inputs( + batch_size: int, + in_channels: int, + sample_size: int, + num_classes: int, + shift: float, + seed: int, + device: torch.device, +): + generator = torch.Generator(device=device).manual_seed(seed) + clean_latents = torch.randn( + (batch_size, in_channels, sample_size, sample_size), generator=generator, device=device, dtype=torch.float32 + ) + noise = torch.randn(clean_latents.shape, generator=generator, device=device, dtype=torch.float32) + + # Use a spread of normalized timesteps inside the open interval to avoid any + # boundary-case special handling around t=0 or t=1. + timesteps = torch.linspace(0.2, 0.8, steps=batch_size, device=device, dtype=torch.float32) + sigma = shift * timesteps / (1 + (shift - 1) * timesteps) + sigma = sigma.view(-1, 1, 1, 1) + + noised_latents = (1.0 - sigma) * clean_latents + sigma * noise + class_labels = torch.arange(batch_size, device=device, dtype=torch.long) % num_classes + return noised_latents, timesteps, class_labels + + +def main(): + args = parse_args() + + upstream_repo_path = Path(args.upstream_repo_path).expanduser().resolve() + sys.path.insert(0, str(upstream_repo_path / "src")) + + from stage2.models.DDT import DiTwDDTHead + + config_path = Path(args.config_path).expanduser().resolve() + checkpoint_path = Path(args.checkpoint_path).expanduser().resolve() + converted_transformer_path = Path(args.converted_transformer_path).expanduser().resolve() + + with config_path.open("r", encoding="utf-8") as handle: + config = yaml.safe_load(handle) + + stage2 = _resolve_section(config, "stage_2", "stage2") + stage2_params = stage2.get("params", {}) + misc = _resolve_section(config, "misc") + latent_size = misc["latent_size"] + shift = math.sqrt(int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096))) + + device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) + state_dict = unwrap_state_dict(load_checkpoint(checkpoint_path), prefer_ema=True) + + upstream_model = DiTwDDTHead(**stage2_params) + upstream_model.load_state_dict(state_dict, strict=True) + upstream_model.to(device=device, dtype=torch.float32) + upstream_model.eval() + + hf_model = RAEDiTTransformer2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) + hf_model.to(device=device, dtype=torch.float32) + hf_model.eval() + + noised_latents, timesteps, class_labels = build_inputs( + batch_size=args.batch_size, + in_channels=int(stage2_params["in_channels"]), + sample_size=int(stage2_params["input_size"]), + num_classes=int(stage2_params.get("num_classes", misc.get("num_classes", 1000))), + shift=shift, + seed=args.seed, + device=device, + ) + + with torch.no_grad(): + upstream_output = upstream_model(noised_latents, timesteps, class_labels) + hf_output = hf_model(hidden_states=noised_latents, timestep=timesteps, class_labels=class_labels).sample + + abs_error = (upstream_output - hf_output).abs() + max_abs_error = abs_error.max().item() + mean_abs_error = abs_error.mean().item() + + print(f"device={device}") + print(f"shape={tuple(hf_output.shape)}") + print(f"max_abs_error={max_abs_error:.8f}") + print(f"mean_abs_error={mean_abs_error:.8f}") + + if not torch.allclose(upstream_output, hf_output, atol=args.atol, rtol=args.rtol): + raise AssertionError( + f"Parity failed: max_abs_error={max_abs_error:.8f}, mean_abs_error={mean_abs_error:.8f}, " + f"expected atol={args.atol}, rtol={args.rtol}" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/convert_rae_stage2_to_diffusers.py b/scripts/convert_rae_stage2_to_diffusers.py new file mode 100644 index 000000000000..c51011d2286c --- /dev/null +++ b/scripts/convert_rae_stage2_to_diffusers.py @@ -0,0 +1,548 @@ +import argparse +import json +import math +from pathlib import Path +from typing import Any + +import torch +import yaml +from huggingface_hub import HfApi, hf_hub_download + +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline +from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel + + +DEFAULT_NUM_TRAIN_TIMESTEPS = 1000 +DEFAULT_SHIFT_BASE = 4096 +DEFAULT_TRANSFORMER_SUBFOLDER = "transformer" +DEFAULT_GUIDANCE_SUBFOLDER = "guidance_transformer" +DEFAULT_SCHEDULER_SUBFOLDER = "scheduler" + + +class RepoAccessor: + def __init__(self, repo_or_path: str, cache_dir: str | None = None): + self.repo_or_path = repo_or_path + self.cache_dir = cache_dir + self.local_root: Path | None = None + self.repo_id: str | None = None + self.repo_files: set[str] | None = None + + root = Path(repo_or_path) + if root.exists() and root.is_dir(): + self.local_root = root + else: + self.repo_id = repo_or_path + self.repo_files = set(HfApi().list_repo_files(repo_or_path)) + + def exists(self, relative_path: str) -> bool: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + return (self.local_root / relative_path).is_file() + return relative_path in self.repo_files + + def fetch(self, relative_path: str) -> Path: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + path = self.local_root / relative_path + if not path.is_file(): + raise FileNotFoundError(f"File not found: {path}") + return path + + downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir) + return Path(downloaded) + + +def read_yaml(accessor: RepoAccessor, relative_path: str) -> dict[str, Any]: + with resolve_input_path(accessor, relative_path).open() as handle: + config = yaml.safe_load(handle) + + if not isinstance(config, dict): + raise ValueError(f"Expected YAML object at `{relative_path}` to decode to a dictionary.") + + return config + + +def _get_nested(mapping: dict[str, Any], path: str) -> Any: + value: Any = mapping + for part in path.split("."): + if not isinstance(value, dict) or part not in value: + raise KeyError(f"Missing `{path}` in checkpoint/config object.") + value = value[part] + return value + + +def _resolve_section(config: dict[str, Any], *keys: str) -> dict[str, Any]: + for key in keys: + section = config.get(key) + if isinstance(section, dict): + return section + raise KeyError(f"Could not find any of {keys} in config.") + + +def _normalize_pair(value: int | list[int] | tuple[int, ...], field_name: str) -> tuple[int, int]: + if isinstance(value, (list, tuple)): + if len(value) != 2: + raise ValueError(f"`{field_name}` must have length 2 when provided as a sequence, but got {value}.") + return int(value[0]), int(value[1]) + + scalar = int(value) + return scalar, scalar + + +def _normalize_patch_size(value: int | list[int] | tuple[int, ...]) -> int | tuple[int, int]: + stage1_patch_size, stage2_patch_size = _normalize_pair(value, "patch_size") + if stage1_patch_size == stage2_patch_size: + return stage1_patch_size + return stage1_patch_size, stage2_patch_size + + +def _maybe_strip_common_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: + if len(state_dict) > 0 and all(key.startswith(prefix) for key in state_dict): + return {key[len(prefix) :]: value for key, value in state_dict.items()} + return state_dict + + +def unwrap_state_dict( + maybe_wrapped: dict[str, Any], + checkpoint_key: str | None = None, + prefer_ema: bool = True, +) -> dict[str, Any]: + if checkpoint_key: + state_dict = _get_nested(maybe_wrapped, checkpoint_key) + else: + state_dict = maybe_wrapped + if isinstance(state_dict, dict): + candidate_keys = ["ema", "model", "state_dict"] if prefer_ema else ["model", "ema", "state_dict"] + for key in candidate_keys: + if key in state_dict and isinstance(state_dict[key], dict): + state_dict = state_dict[key] + break + + if not isinstance(state_dict, dict): + raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") + + state_dict = dict(state_dict) + for prefix in ("module.", "model.", "model.module."): + state_dict = _maybe_strip_common_prefix(state_dict, prefix) + return state_dict + + +def load_checkpoint(checkpoint_path: Path) -> dict[str, Any]: + suffix = checkpoint_path.suffix.lower() + if suffix == ".safetensors": + import safetensors.torch + + return safetensors.torch.load_file(checkpoint_path) + + return torch.load(checkpoint_path, map_location="cpu") + + +def build_transformer_config(stage2_params: dict[str, Any], misc: dict[str, Any]) -> dict[str, Any]: + hidden_size = _normalize_pair(stage2_params["hidden_size"], "hidden_size") + depth = _normalize_pair(stage2_params["depth"], "depth") + num_heads = _normalize_pair(stage2_params["num_heads"], "num_heads") + patch_size = _normalize_patch_size(stage2_params.get("patch_size", 1)) + + input_size = int(stage2_params["input_size"]) + in_channels = int(stage2_params["in_channels"]) + + latent_size = misc.get("latent_size") + if latent_size is not None: + if len(latent_size) != 3: + raise ValueError(f"`misc.latent_size` should have length 3, but got {latent_size}.") + latent_channels, latent_height, latent_width = [int(dim) for dim in latent_size] + if latent_channels != in_channels: + raise ValueError( + f"`misc.latent_size[0]` ({latent_channels}) does not match `stage_2.params.in_channels` ({in_channels})." + ) + if latent_height != input_size or latent_width != input_size: + raise ValueError( + f"`misc.latent_size[1:]` ({latent_height}, {latent_width}) does not match `stage_2.params.input_size` ({input_size})." + ) + + return { + "sample_size": input_size, + "patch_size": patch_size, + "in_channels": in_channels, + "hidden_size": hidden_size, + "depth": depth, + "num_heads": num_heads, + "mlp_ratio": float(stage2_params.get("mlp_ratio", 4.0)), + "class_dropout_prob": float(stage2_params.get("class_dropout_prob", 0.1)), + "num_classes": int(stage2_params.get("num_classes", misc.get("num_classes", 1000))), + "use_qknorm": bool(stage2_params.get("use_qknorm", False)), + "use_swiglu": bool(stage2_params.get("use_swiglu", True)), + "use_rope": bool(stage2_params.get("use_rope", True)), + "use_rmsnorm": bool(stage2_params.get("use_rmsnorm", True)), + "wo_shift": bool(stage2_params.get("wo_shift", False)), + "use_pos_embed": bool(stage2_params.get("use_pos_embed", True)), + } + + +def build_scheduler_config(config: dict[str, Any]) -> tuple[FlowMatchEulerDiscreteScheduler, dict[str, Any]]: + transport = _resolve_section(config, "transport") + misc = _resolve_section(config, "misc") + + transport_params = transport.get("params", {}) + latent_size = misc.get("latent_size", None) + if latent_size is None: + raise KeyError("Config must define `misc.latent_size` for scheduler conversion.") + + shift_dim = int(misc.get("time_dist_shift_dim", math.prod(int(dim) for dim in latent_size))) + shift_base = int(misc.get("time_dist_shift_base", DEFAULT_SHIFT_BASE)) + shift = math.sqrt(shift_dim / shift_base) + + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=int(transport_params.get("num_train_timesteps", DEFAULT_NUM_TRAIN_TIMESTEPS)), + shift=shift, + stochastic_sampling=False, + ) + metadata = { + "num_train_timesteps": scheduler.config.num_train_timesteps, + "shift": scheduler.config.shift, + "path_type": transport_params.get("path_type", "Linear"), + "prediction": transport_params.get("prediction", "velocity"), + "time_dist_type": transport_params.get("time_dist_type", "uniform"), + } + return scheduler, metadata + + +def convert_transformer_state_dict( + transformer_config: dict[str, Any], + checkpoint_path: Path, + checkpoint_key: str | None, + prefer_ema: bool, + output_dir: Path, + safe_serialization: bool, + verify_load: bool, + component_name: str, +) -> dict[str, Any]: + raw_checkpoint = load_checkpoint(checkpoint_path) + state_dict = unwrap_state_dict(raw_checkpoint, checkpoint_key=checkpoint_key, prefer_ema=prefer_ema) + + with torch.device("meta"): + model = RAEDiTTransformer2DModel(**transformer_config) + + load_result = model.load_state_dict(state_dict, strict=False, assign=True) + missing_keys = set(load_result.missing_keys) + unexpected_keys = set(load_result.unexpected_keys) + + allowed_missing = { + "pos_embed", + "enc_feat_rope.freqs_cos", + "enc_feat_rope.freqs_sin", + "dec_feat_rope.freqs_cos", + "dec_feat_rope.freqs_sin", + } + missing_keys -= allowed_missing + + if unexpected_keys: + raise RuntimeError( + f"Unexpected keys while converting {component_name}: {sorted(unexpected_keys)[:20]}" + + (" ..." if len(unexpected_keys) > 20 else "") + ) + if missing_keys: + raise RuntimeError( + f"Missing keys while converting {component_name}: {sorted(missing_keys)[:20]}" + + (" ..." if len(missing_keys) > 20 else "") + ) + + output_dir.mkdir(parents=True, exist_ok=True) + model.save_pretrained(output_dir, safe_serialization=safe_serialization) + + if verify_load: + reloaded = RAEDiTTransformer2DModel.from_pretrained(output_dir, low_cpu_mem_usage=False) + if not isinstance(reloaded, RAEDiTTransformer2DModel): + raise RuntimeError(f"Verification failed for {component_name}: reloaded object is not RAEDiTTransformer2DModel.") + + return { + "checkpoint_path": str(checkpoint_path), + "checkpoint_key": checkpoint_key, + "prefer_ema": prefer_ema, + "config": transformer_config, + "num_parameters": sum(t.numel() for t in state_dict.values() if isinstance(t, torch.Tensor)), + } + + +def write_metadata(output_path: Path, metadata: dict[str, Any]) -> None: + with (output_path / "conversion_metadata.json").open("w") as handle: + json.dump(metadata, handle, indent=2) + + +def resolve_input_path(accessor: RepoAccessor, path: str) -> Path: + candidates = [path] + if path.startswith("models/"): + candidates.append(path[len("models/") :]) + + for candidate in candidates: + local_path = Path(candidate) + if local_path.is_file(): + return local_path + + try: + return accessor.fetch(candidate) + except FileNotFoundError: + continue + + raise FileNotFoundError(f"Could not resolve `{path}` from `{accessor.repo_or_path}`.") + + +def resolve_checkpoint_path( + accessor: RepoAccessor, + configured_path: str | None, + override_path: str | None, + description: str, +) -> Path | None: + path = override_path or configured_path + if path is None: + return None + try: + return resolve_input_path(accessor, path) + except FileNotFoundError as error: + raise FileNotFoundError(f"{description} not found: {path}") from error + + +def convert(args: argparse.Namespace) -> None: + weights_accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir) + config_accessor = RepoAccessor(args.config_repo_or_path, cache_dir=args.cache_dir) if args.config_repo_or_path else weights_accessor + config = read_yaml(config_accessor, args.config_path) + + stage2 = _resolve_section(config, "stage_2", "stage2") + stage2_params = stage2.get("params", {}) + misc = _resolve_section(config, "misc") + guidance = _resolve_section(config, "guidance") + + transformer_config = build_transformer_config(stage2_params, misc) + checkpoint_path = resolve_checkpoint_path( + weights_accessor, + configured_path=stage2.get("ckpt"), + override_path=args.checkpoint_path, + description="Stage-2 checkpoint", + ) + if checkpoint_path is None: + raise ValueError("Could not resolve a Stage-2 checkpoint. Pass `--checkpoint_path` or provide `stage_2.ckpt` in config.") + + scheduler, scheduler_metadata = build_scheduler_config(config) + sampler = _resolve_section(config, "sampler") + + guidance_config = guidance.get("guidance_model") + guidance_checkpoint_path = None + guidance_transformer_config = None + if guidance_config is not None and not args.skip_guidance_model: + guidance_transformer_config = build_transformer_config(guidance_config["params"], misc) + guidance_checkpoint_path = resolve_checkpoint_path( + weights_accessor, + configured_path=guidance_config.get("ckpt"), + override_path=args.guidance_checkpoint_path, + description="Guidance checkpoint", + ) + + metadata = { + "source": { + "weights_repo_or_path": args.repo_or_path, + "config_repo_or_path": args.config_repo_or_path, + "config_path": args.config_path, + "vae_model_name_or_path": args.vae_model_name_or_path, + }, + "scheduler": scheduler_metadata, + "sampler": { + "mode": sampler.get("mode", "ODE"), + "params": sampler.get("params", {}), + }, + "guidance": { + "method": guidance.get("method", "cfg"), + "scale": float(guidance.get("scale", 1.0)), + "t_min": float(guidance.get("t-min", guidance.get("t_min", 0.0))), + "t_max": float(guidance.get("t-max", guidance.get("t_max", 1.0))), + }, + "misc": misc, + } + + print(f"Using config: {args.config_path}") + print(f"Using Stage-2 checkpoint: {checkpoint_path}") + print(f"Derived scheduler shift: {scheduler.config.shift:.6f}") + if metadata["sampler"]["mode"] != "ODE" or metadata["sampler"]["params"].get("sampling_method", "euler") != "euler": + print( + "Warning: upstream sampler is not the public ODE/Euler path. The saved scheduler still uses " + "FlowMatchEulerDiscreteScheduler for diffusers V1 compatibility." + ) + if guidance_checkpoint_path is not None: + print(f"Using guidance checkpoint: {guidance_checkpoint_path}") + elif guidance_config is not None: + print("Guidance model found in config but not converting it (missing checkpoint or `--skip_guidance_model`).") + + if args.dry_run: + print(json.dumps(metadata, indent=2)) + print(json.dumps({"transformer_config": transformer_config}, indent=2)) + if guidance_transformer_config is not None: + print(json.dumps({"guidance_transformer_config": guidance_transformer_config}, indent=2)) + return + + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + transformer_output_dir = output_path / args.transformer_subfolder + metadata["transformer"] = convert_transformer_state_dict( + transformer_config=transformer_config, + checkpoint_path=checkpoint_path, + checkpoint_key=args.checkpoint_key, + prefer_ema=not args.disable_ema, + output_dir=transformer_output_dir, + safe_serialization=args.safe_serialization, + verify_load=args.verify_load, + component_name="transformer", + ) + + if guidance_checkpoint_path is not None and guidance_transformer_config is not None: + guidance_output_dir = output_path / args.guidance_subfolder + metadata["guidance_transformer"] = convert_transformer_state_dict( + transformer_config=guidance_transformer_config, + checkpoint_path=guidance_checkpoint_path, + checkpoint_key=args.guidance_checkpoint_key, + prefer_ema=not args.disable_ema, + output_dir=guidance_output_dir, + safe_serialization=args.safe_serialization, + verify_load=args.verify_load, + component_name="guidance_transformer", + ) + + scheduler_output_dir = output_path / args.scheduler_subfolder + scheduler.save_pretrained(scheduler_output_dir) + + if args.vae_model_name_or_path is not None: + vae = AutoencoderRAE.from_pretrained(args.vae_model_name_or_path) + transformer = RAEDiTTransformer2DModel.from_pretrained(transformer_output_dir, low_cpu_mem_usage=False) + scheduler_for_pipe = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_output_dir) + + id2label = None + if args.id2label_json_path is not None: + with Path(args.id2label_json_path).expanduser().open("r", encoding="utf-8") as handle: + id2label = json.load(handle) + + pipe = RAEDiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler_for_pipe, id2label=id2label) + pipe.save_pretrained(output_path, safe_serialization=args.safe_serialization) + metadata["pipeline"] = {"saved": True, "id2label_json_path": args.id2label_json_path} + + write_metadata(output_path, metadata) + + print(f"Saved transformer to: {transformer_output_dir}") + print(f"Saved scheduler to: {scheduler_output_dir}") + if "guidance_transformer" in metadata: + print(f"Saved guidance model to: {output_path / args.guidance_subfolder}") + if "pipeline" in metadata: + print(f"Saved pipeline to: {output_path}") + print(f"Saved metadata to: {output_path / 'conversion_metadata.json'}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert upstream RAE Stage-2 checkpoints to diffusers format") + parser.add_argument( + "--repo_or_path", + type=str, + required=True, + help="Hub repo id or local directory containing the upstream Stage-2 weights.", + ) + parser.add_argument( + "--config_repo_or_path", + type=str, + default=None, + help="Optional separate hub repo id or local directory containing the upstream YAML configs. Defaults to `repo_or_path`.", + ) + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Relative path to the upstream Stage-2 YAML config inside `config_repo_or_path`, or a direct local file path.", + ) + parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted components.") + parser.add_argument( + "--vae_model_name_or_path", + type=str, + default=None, + help="Optional diffusers AutoencoderRAE checkpoint to bundle into a full RAEDiTPipeline export.", + ) + parser.add_argument( + "--id2label_json_path", + type=str, + default=None, + help="Optional JSON mapping of class ids to label strings for the saved pipeline.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Optional Stage-2 checkpoint override. Interpreted relative to `repo_or_path` unless it is already local.", + ) + parser.add_argument( + "--guidance_checkpoint_path", + type=str, + default=None, + help="Optional autoguidance checkpoint override.", + ) + parser.add_argument( + "--checkpoint_key", + type=str, + default=None, + help="Optional dotted key path inside the Stage-2 checkpoint payload. By default the converter auto-prefers `ema` then `model`.", + ) + parser.add_argument( + "--guidance_checkpoint_key", + type=str, + default=None, + help="Optional dotted key path inside the guidance checkpoint payload.", + ) + parser.add_argument( + "--transformer_subfolder", + type=str, + default=DEFAULT_TRANSFORMER_SUBFOLDER, + help="Subfolder name used for the converted primary transformer.", + ) + parser.add_argument( + "--guidance_subfolder", + type=str, + default=DEFAULT_GUIDANCE_SUBFOLDER, + help="Subfolder name used for the converted autoguidance transformer.", + ) + parser.add_argument( + "--scheduler_subfolder", + type=str, + default=DEFAULT_SCHEDULER_SUBFOLDER, + help="Subfolder name used for the saved scheduler config.", + ) + parser.add_argument("--cache_dir", type=str, default=None, help="Optional Hugging Face Hub cache directory.") + parser.add_argument( + "--disable_ema", + action="store_true", + help="Do not prefer `ema` when the checkpoint stores both `ema` and `model` weights.", + ) + parser.add_argument( + "--skip_guidance_model", + action="store_true", + help="Do not convert `guidance.guidance_model` even if it is present in the config.", + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + help="Save converted transformer weights as safetensors.", + ) + parser.add_argument( + "--verify_load", + action="store_true", + help="Reload each saved transformer with `from_pretrained(low_cpu_mem_usage=False)` after conversion.", + ) + parser.add_argument( + "--dry_run", + action="store_true", + help="Only resolve paths and print the derived config/metadata without saving anything.", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert(args) + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d6d557a4c224..0d4097065f71 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,6 +258,7 @@ "PixArtTransformer2DModel", "PriorTransformer", "PRXTransformer2DModel", + "RAEDiTTransformer2DModel", "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", @@ -328,6 +329,7 @@ "DDPMPipeline", "DiffusionPipeline", "DiTPipeline", + "RAEDiTPipeline", "ImagePipelineOutput", "KarrasVePipeline", "LDMPipeline", @@ -1033,6 +1035,7 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, + RAEDiTTransformer2DModel, QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, @@ -1101,6 +1104,7 @@ DDPMPipeline, DiffusionPipeline, DiTPipeline, + RAEDiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e4bc95fdf884..30a4f541165f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -117,6 +117,7 @@ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] + _import_structure["transformers.transformer_rae_dit"] = ["RAEDiTTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -237,6 +238,7 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, + RAEDiTTransformer2DModel, QwenImageTransformer2DModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..1006a758801f 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -43,6 +43,7 @@ from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel + from .transformer_rae_dit import RAEDiTTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_rae_dit.py b/src/diffusers/models/transformers/transformer_rae_dit.py new file mode 100644 index 000000000000..ea2b4d223128 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_rae_dit.py @@ -0,0 +1,538 @@ +from __future__ import annotations + +from math import pi, sqrt + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..embeddings import PatchEmbed, get_2d_sincos_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + + +def _repeat_to_length(hidden_states: torch.Tensor, target_length: int) -> torch.Tensor: + if hidden_states.shape[1] == target_length: + return hidden_states + + if target_length % hidden_states.shape[1] != 0: + raise ValueError( + f"Cannot repeat sequence of length {hidden_states.shape[1]} to match target length {target_length}." + ) + + return hidden_states.repeat_interleave(target_length // hidden_states.shape[1], dim=1) + + +def _ddt_modulate(hidden_states: torch.Tensor, shift: torch.Tensor | None, scale: torch.Tensor) -> torch.Tensor: + if shift is None: + shift = torch.zeros_like(scale) + + shift = _repeat_to_length(shift, hidden_states.shape[1]) + scale = _repeat_to_length(scale, hidden_states.shape[1]) + return hidden_states * (1 + scale) + shift + + +def _ddt_gate(hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + gate = _repeat_to_length(gate, hidden_states.shape[1]) + return hidden_states * gate + + +def _to_pair(value: int | tuple[int, int] | list[int], name: str) -> tuple[int, int]: + if isinstance(value, int): + return value, value + + if len(value) != 2: + raise ValueError(f"`{name}` must be an int or a pair, but got {value}.") + + return int(value[0]), int(value[1]) + + +def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view(*hidden_states.shape[:-1], -1, 2) + first, second = hidden_states.unbind(dim=-1) + hidden_states = torch.stack((-second, first), dim=-1) + return hidden_states.flatten(-2) + + +class _ApproximateGELUMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.act = nn.GELU(approximate="tanh") + self.fc2 = nn.Linear(intermediate_size, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SwiGLUFFN(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.w12 = nn.Linear(hidden_size, 2 * intermediate_size) + self.w3 = nn.Linear(intermediate_size, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.w12(hidden_states) + hidden_states_1, hidden_states_2 = hidden_states.chunk(2, dim=-1) + hidden_states = F.silu(hidden_states_1) * hidden_states_2 + hidden_states = self.w3(hidden_states) + return hidden_states + + +class GaussianFourierEmbedding(nn.Module): + def __init__(self, hidden_size: int, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.embedding_size = embedding_size + self.scale = scale + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.mlp = nn.Sequential( + nn.Linear(embedding_size * 2, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, timestep: torch.Tensor) -> torch.Tensor: + timestep = timestep.to(self.W.dtype) + hidden_states = timestep[:, None] * self.W[None, :] * 2 * pi + hidden_states = torch.cat([torch.sin(hidden_states), torch.cos(hidden_states)], dim=-1) + hidden_states = self.mlp(hidden_states) + return hidden_states + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float): + super().__init__() + self.num_classes = num_classes + self.dropout_prob = dropout_prob + self.embedding_table = nn.Embedding(num_classes + int(dropout_prob > 0), hidden_size) + + def token_drop(self, class_labels: torch.LongTensor, force_drop_ids: torch.Tensor | None = None) -> torch.LongTensor: + if force_drop_ids is None: + drop_ids = torch.rand(class_labels.shape[0], device=class_labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + return torch.where(drop_ids, self.num_classes, class_labels) + + def forward( + self, + class_labels: torch.LongTensor, + train: bool, + force_drop_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + if (train and self.dropout_prob > 0) or (force_drop_ids is not None): + class_labels = self.token_drop(class_labels, force_drop_ids=force_drop_ids) + return self.embedding_table(class_labels) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__(self, dim: int, pt_seq_len: int, ft_seq_len: int | None = None, theta: float = 10000.0): + super().__init__() + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim)) + positions = torch.arange(ft_seq_len, dtype=torch.float32) / ft_seq_len * pt_seq_len + freqs = torch.einsum("n,d->nd", positions, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = torch.cat( + [ + freqs[:, None, :].expand(ft_seq_len, ft_seq_len, -1), + freqs[None, :, :].expand(ft_seq_len, ft_seq_len, -1), + ], + dim=-1, + ) + + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1])) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1])) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + _, _, sequence_length, _ = hidden_states.shape + base_sequence_length = self.freqs_cos.shape[0] + repeat = sequence_length // base_sequence_length + + freqs_cos = self.freqs_cos + freqs_sin = self.freqs_sin + if repeat != 1: + freqs_cos = freqs_cos.repeat_interleave(repeat, dim=0) + freqs_sin = freqs_sin.repeat_interleave(repeat, dim=0) + + return hidden_states * freqs_cos + _rotate_half(hidden_states) * freqs_sin + + +class NormAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qk_norm: bool = False, + use_rmsnorm: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}") + + self.num_heads = num_heads + self.head_dim = dim // num_heads + + norm_cls = RMSNorm if use_rmsnorm else nn.LayerNorm + norm_kwargs = {"eps": 1e-6} if use_rmsnorm else {"elementwise_affine": True, "eps": 1e-6} + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_cls(self.head_dim, **norm_kwargs) if qk_norm else nn.Identity() + self.k_norm = norm_cls(self.head_dim, **norm_kwargs) if qk_norm else nn.Identity() + self.proj = nn.Linear(dim, dim) + + def forward(self, hidden_states: torch.Tensor, rope: VisionRotaryEmbeddingFast | None = None) -> torch.Tensor: + batch_size, sequence_length, channels = hidden_states.shape + qkv = self.qkv(hidden_states) + qkv = qkv.view(batch_size, sequence_length, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + query, key, value = qkv.unbind(0) + + query = self.q_norm(query) + key = self.k_norm(key) + + if rope is not None: + query = rope(query) + key = rope(key) + + query = query.to(dtype=value.dtype) + key = key.to(dtype=value.dtype) + + hidden_states = F.scaled_dot_product_attention(query, key, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, channels) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class RAEDiTBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_qknorm: bool = False, + use_swiglu: bool = True, + use_rmsnorm: bool = True, + wo_shift: bool = False, + ): + super().__init__() + + if use_rmsnorm: + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + else: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = NormAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=use_qknorm, + use_rmsnorm=use_rmsnorm, + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if use_swiglu: + self.mlp = SwiGLUFFN(hidden_size, int(2 * mlp_hidden_dim / 3)) + else: + self.mlp = _ApproximateGELUMLP(hidden_size, mlp_hidden_dim) + + self.wo_shift = wo_shift + if wo_shift: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 4 * hidden_size, bias=True)) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward( + self, + hidden_states: torch.Tensor, + conditioning: torch.Tensor, + feat_rope: VisionRotaryEmbeddingFast | None = None, + ) -> torch.Tensor: + if conditioning.ndim < hidden_states.ndim: + conditioning = conditioning.unsqueeze(1) + + if self.wo_shift: + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(conditioning).chunk(4, dim=-1) + shift_msa = None + shift_mlp = None + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(conditioning).chunk( + 6, dim=-1 + ) + + hidden_states = hidden_states + _ddt_gate( + self.attn(_ddt_modulate(self.norm1(hidden_states), shift_msa, scale_msa), rope=feat_rope), gate_msa + ) + hidden_states = hidden_states + _ddt_gate( + self.mlp(_ddt_modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)), + gate_mlp, + ) + return hidden_states + + +class RAEDiTFinalLayer(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, use_rmsnorm: bool = True): + super().__init__() + + if use_rmsnorm: + self.norm_final = RMSNorm(hidden_size, eps=1e-6) + else: + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + if conditioning.ndim < hidden_states.ndim: + conditioning = conditioning.unsqueeze(1) + + shift, scale = self.adaLN_modulation(conditioning).chunk(2, dim=-1) + hidden_states = _ddt_modulate(self.norm_final(hidden_states), shift, scale) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class RAEDiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + Stage-2 latent diffusion transformer used by the RAE paper. + + The architecture mirrors the upstream two-stream `DiTwDDTHead` design: + an encoder path first builds conditioning tokens from the latent input, + then a decoder path denoises the latent tokens conditioned on those + encoded tokens. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 16, + patch_size: int | tuple[int, int] | list[int] = 1, + in_channels: int = 768, + hidden_size: int | tuple[int, int] | list[int] = (1152, 2048), + depth: int | tuple[int, int] | list[int] = (28, 2), + num_heads: int | tuple[int, int] | list[int] = (16, 16), + mlp_ratio: float = 4.0, + class_dropout_prob: float = 0.1, + num_classes: int = 1000, + use_qknorm: bool = False, + use_swiglu: bool = True, + use_rope: bool = True, + use_rmsnorm: bool = True, + wo_shift: bool = False, + use_pos_embed: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = in_channels + self.gradient_checkpointing = False + + encoder_hidden_size, decoder_hidden_size = _to_pair(hidden_size, "hidden_size") + encoder_num_layers, decoder_num_layers = _to_pair(depth, "depth") + encoder_num_attention_heads, decoder_num_attention_heads = _to_pair(num_heads, "num_heads") + + self.encoder_hidden_size = encoder_hidden_size + self.decoder_hidden_size = decoder_hidden_size + self.num_encoder_blocks = encoder_num_layers + self.num_decoder_blocks = decoder_num_layers + self.num_blocks = encoder_num_layers + decoder_num_layers + self.use_rope = use_rope + self.use_pos_embed = use_pos_embed + + self.s_patch_size, self.x_patch_size = _to_pair(patch_size, "patch_size") + + self.s_channel_per_token = in_channels * self.s_patch_size * self.s_patch_size + self.x_channel_per_token = in_channels * self.x_patch_size * self.x_patch_size + + self.s_embedder = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=self.s_patch_size, + in_channels=in_channels, + embed_dim=encoder_hidden_size, + bias=True, + pos_embed_type=None, + ) + self.x_embedder = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=self.x_patch_size, + in_channels=in_channels, + embed_dim=decoder_hidden_size, + bias=True, + pos_embed_type=None, + ) + + self.s_projector = ( + nn.Linear(encoder_hidden_size, decoder_hidden_size) if encoder_hidden_size != decoder_hidden_size else nn.Identity() + ) + self.t_embedder = GaussianFourierEmbedding(encoder_hidden_size) + self.y_embedder = LabelEmbedder(num_classes, encoder_hidden_size, class_dropout_prob) + self.final_layer = RAEDiTFinalLayer( + decoder_hidden_size, + out_channels=self.x_channel_per_token, + use_rmsnorm=use_rmsnorm, + ) + + num_patches = self.s_embedder.height * self.s_embedder.width + if use_pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, encoder_hidden_size), requires_grad=False) + self.x_pos_embed = None + else: + self.register_parameter("pos_embed", None) + self.x_pos_embed = None + + if use_rope: + encoder_rope_dim = encoder_hidden_size // encoder_num_attention_heads // 2 + decoder_rope_dim = decoder_hidden_size // decoder_num_attention_heads // 2 + encoder_side = int(sqrt(num_patches)) + decoder_side = int(sqrt(self.x_embedder.height * self.x_embedder.width)) + self.enc_feat_rope = VisionRotaryEmbeddingFast(encoder_rope_dim, pt_seq_len=encoder_side) + self.dec_feat_rope = VisionRotaryEmbeddingFast(decoder_rope_dim, pt_seq_len=decoder_side) + else: + self.enc_feat_rope = None + self.dec_feat_rope = None + + self.blocks = nn.ModuleList( + [ + RAEDiTBlock( + hidden_size=encoder_hidden_size if index < encoder_num_layers else decoder_hidden_size, + num_heads=encoder_num_attention_heads if index < encoder_num_layers else decoder_num_attention_heads, + mlp_ratio=mlp_ratio, + use_qknorm=use_qknorm, + use_swiglu=use_swiglu, + use_rmsnorm=use_rmsnorm, + wo_shift=wo_shift, + ) + for index in range(self.num_blocks) + ] + ) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + nn.init.xavier_uniform_(self.x_embedder.proj.weight.view(self.x_embedder.proj.weight.shape[0], -1)) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.xavier_uniform_(self.s_embedder.proj.weight.view(self.s_embedder.proj.weight.shape[0], -1)) + nn.init.constant_(self.s_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + if self.use_pos_embed: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt") + self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0)) + + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, hidden_states: torch.Tensor) -> torch.Tensor: + channels = self.in_channels + patch_size = self.x_embedder.patch_size + if isinstance(patch_size, tuple): + patch_size = patch_size[0] + + height = width = int(sqrt(hidden_states.shape[1])) + if height * width != hidden_states.shape[1]: + raise ValueError("Sequence length must form a square grid for unpatchify.") + + hidden_states = hidden_states.reshape(hidden_states.shape[0], height, width, patch_size, patch_size, channels) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.reshape(hidden_states.shape[0], channels, height * patch_size, width * patch_size) + return hidden_states + + def _run_block( + self, + block: RAEDiTBlock, + hidden_states: torch.Tensor, + conditioning: torch.Tensor, + feat_rope: VisionRotaryEmbeddingFast | None, + ) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + def custom_forward(hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + return block(hidden_states, conditioning, feat_rope=feat_rope) + + return self._gradient_checkpointing_func(custom_forward, hidden_states, conditioning) + return block(hidden_states, conditioning, feat_rope=feat_rope) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor | None = None, + class_labels: torch.LongTensor | None = None, + conditioning_hidden_states: torch.Tensor | None = None, + return_dict: bool = True, + ) -> Transformer2DModelOutput | tuple[torch.Tensor]: + if timestep is None: + raise ValueError("`timestep` must be provided.") + if class_labels is None: + raise ValueError("`class_labels` must be provided.") + + timestep = timestep.reshape(-1).to(hidden_states.device) + class_labels = class_labels.reshape(-1).to(hidden_states.device) + + timestep_emb = self.t_embedder(timestep) + class_emb = self.y_embedder(class_labels, self.training) + conditioning = F.silu(timestep_emb + class_emb) + + if conditioning_hidden_states is None: + conditioning_hidden_states = self.s_embedder(hidden_states) + if self.use_pos_embed: + conditioning_hidden_states = conditioning_hidden_states + self.pos_embed + + for block_idx in range(self.num_encoder_blocks): + conditioning_hidden_states = self._run_block( + self.blocks[block_idx], + conditioning_hidden_states, + conditioning, + self.enc_feat_rope, + ) + + conditioning_hidden_states = F.silu(timestep_emb.unsqueeze(1) + conditioning_hidden_states) + + conditioning_hidden_states = self.s_projector(conditioning_hidden_states) + + hidden_states = self.x_embedder(hidden_states) + if self.use_pos_embed and self.x_pos_embed is not None: + hidden_states = hidden_states + self.x_pos_embed + + for block_idx in range(self.num_encoder_blocks, self.num_blocks): + hidden_states = self._run_block( + self.blocks[block_idx], + hidden_states, + conditioning_hidden_states, + self.dec_feat_rope, + ) + + hidden_states = self.final_layer(hidden_states, conditioning_hidden_states) + hidden_states = self.unpatchify(hidden_states) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8007035338b0..b1d3a95adf46 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -147,6 +147,7 @@ "FluxKontextInpaintPipeline", ] _import_structure["prx"] = ["PRXPipeline"] + _import_structure["rae_dit"] = ["RAEDiTPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -772,6 +773,7 @@ from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .prx import PRXPipeline + from .rae_dit import RAEDiTPipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/rae_dit/__init__.py b/src/diffusers/pipelines/rae_dit/__init__.py new file mode 100644 index 000000000000..adb404b485b1 --- /dev/null +++ b/src/diffusers/pipelines/rae_dit/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_rae_dit": ["RAEDiTPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_rae_dit import RAEDiTPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py new file mode 100644 index 000000000000..7446231a572b --- /dev/null +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import torch + +from ...models import AutoencoderRAE +from ...models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class RAEDiTPipeline(DiffusionPipeline): + r""" + Pipeline for class-conditioned image generation in RAE latent space. + + Parameters: + transformer ([`RAEDiTTransformer2DModel`]): + Class-conditioned latent transformer used for Stage-2 denoising in RAE latent space. + vae ([`AutoencoderRAE`]): + Representation autoencoder used to decode latent samples back to RGB images. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + Flow-matching scheduler used to integrate the latent denoising trajectory. + """ + + model_cpu_offload_seq = "transformer->vae" + + def __init__( + self, + transformer: RAEDiTTransformer2DModel, + vae: AutoencoderRAE, + scheduler: FlowMatchEulerDiscreteScheduler, + id2label: dict[int, str] | None = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.strip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + self._guidance_scale = 1.0 + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1.0 + + def get_label_ids(self, label: str | list[str]) -> list[int]: + r""" + Map ImageNet-style label strings to class ids. + """ + + if not isinstance(label, list): + label = [label] + + for label_name in label: + if label_name not in self.labels: + raise ValueError( + f"{label_name} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[label_name] for label_name in label] + + def _prepare_class_labels( + self, + class_labels: int | list[int] | torch.Tensor, + num_images_per_prompt: int, + device: torch.device, + ) -> torch.LongTensor: + class_labels = torch.as_tensor(class_labels, device=device, dtype=torch.long).reshape(-1) + + if num_images_per_prompt > 1: + class_labels = class_labels.repeat_interleave(num_images_per_prompt) + + return class_labels + + def _prepare_latents( + self, + batch_size: int, + latent_channels: int, + latent_size: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None, + ) -> torch.Tensor: + shape = (batch_size, latent_channels, latent_size, latent_size) + + if latents is None: + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + latents = latents.to(device=device, dtype=dtype) + if latents.shape != shape: + raise ValueError(f"Expected `latents` to have shape {shape}, but got {tuple(latents.shape)}.") + + return latents + + def _prepare_timesteps( + self, timestep: torch.Tensor | float, batch_size: int, sample: torch.Tensor + ) -> torch.Tensor: + if not torch.is_tensor(timestep): + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) + elif timestep.ndim == 0: + timestep = timestep[None].to(sample.device) + else: + timestep = timestep.to(sample.device) + + return timestep.expand(batch_size) + + @torch.no_grad() + def __call__( + self, + class_labels: int | list[int] | torch.Tensor, + guidance_scale: float = 1.0, + guidance_start: float = 0.0, + guidance_end: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + num_inference_steps: int = 50, + output_type: str = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + class_labels (`int`, `list[int]`, or `torch.Tensor`): + The class ids for the images to generate. + guidance_scale (`float`, *optional*, defaults to `1.0`): + Classifier-free guidance scale. Guidance is enabled when `guidance_scale > 1`. + guidance_start (`float`, *optional*, defaults to `0.0`): + Lower bound of the normalized timestep interval in which classifier-free guidance is active. + guidance_end (`float`, *optional*, defaults to `1.0`): + Upper bound of the normalized timestep interval in which classifier-free guidance is active. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + Number of images to generate per class label. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Random generator used for latent sampling. + latents (`torch.Tensor`, *optional*): + Pre-generated latent noise tensor of shape `(batch, channels, height, width)`. + num_inference_steps (`int`, *optional*, defaults to `50`): + Number of denoising steps. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format. Choose from `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`ImagePipelineOutput`] instead of a tuple. + """ + + if num_images_per_prompt < 1: + raise ValueError(f"`num_images_per_prompt` must be >= 1, but got {num_images_per_prompt}.") + if guidance_scale < 1.0: + raise ValueError(f"`guidance_scale` must be >= 1.0, but got {guidance_scale}.") + if not 0.0 <= guidance_start <= guidance_end <= 1.0: + raise ValueError( + f"`guidance_start` and `guidance_end` must satisfy 0 <= guidance_start <= guidance_end <= 1, but got " + f"{guidance_start} and {guidance_end}." + ) + if output_type not in {"latent", "np", "pil", "pt"}: + raise ValueError(f"Unsupported `output_type`: {output_type}.") + if guidance_scale > 1.0 and self.transformer.config.class_dropout_prob <= 0: + raise ValueError( + "Classifier-free guidance requires `transformer.config.class_dropout_prob > 0` so a null class token exists." + ) + + self._guidance_scale = guidance_scale + + device = self._execution_device + dtype = self.transformer.dtype + + class_labels = self._prepare_class_labels(class_labels, num_images_per_prompt=num_images_per_prompt, device=device) + batch_size = class_labels.shape[0] + + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + latents = self._prepare_latents( + batch_size=batch_size, + latent_channels=latent_channels, + latent_size=latent_size, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + self._num_timesteps = len(self.scheduler.timesteps) + + for timestep in self.progress_bar(self.scheduler.timesteps): + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents, latents], dim=0) + null_class_labels = torch.full( + (batch_size,), + self.transformer.config.num_classes, + device=device, + dtype=class_labels.dtype, + ) + class_labels_input = torch.cat([class_labels, null_class_labels], dim=0) + else: + latent_model_input = latents + class_labels_input = class_labels + + timestep_input = self._prepare_timesteps(timestep, latent_model_input.shape[0], latent_model_input) + timestep_input = timestep_input / self.scheduler.config.num_train_timesteps + model_output = self.transformer( + latent_model_input, + timestep=timestep_input, + class_labels=class_labels_input, + ).sample + + if self.do_classifier_free_guidance: + cond_model_output, uncond_model_output = model_output.chunk(2, dim=0) + guided_model_output = uncond_model_output + guidance_scale * (cond_model_output - uncond_model_output) + guidance_mask = ((timestep_input[:batch_size] >= guidance_start) & (timestep_input[:batch_size] <= guidance_end)) + guidance_mask = guidance_mask.view(-1, *([1] * (cond_model_output.ndim - 1))) + model_output = torch.where(guidance_mask, guided_model_output, cond_model_output) + + latents = self.scheduler.step(model_output, timestep, latents).prev_sample + + if output_type == "latent": + output = latents + else: + images = self.vae.decode(latents).sample.clamp(0, 1) + if output_type == "pt": + output = images + else: + output = images.cpu().permute(0, 2, 3, 1).float().numpy() + if output_type == "pil": + output = self.numpy_to_pil(output) + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(images=output) diff --git a/tests/models/transformers/test_models_rae_dit_transformer2d.py b/tests/models/transformers/test_models_rae_dit_transformer2d.py new file mode 100644 index 000000000000..9023c01b70b1 --- /dev/null +++ b/tests/models/transformers/test_models_rae_dit_transformer2d.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import RAEDiTTransformer2DModel + +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +def _initialize_non_zero_stage2_head(model: RAEDiTTransformer2DModel): + torch.manual_seed(0) + + for block in model.blocks: + block.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + block.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + + model.final_layer.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) + + +class RAEDiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = RAEDiTTransformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + in_channels = 8 + sample_size = 4 + scheduler_num_train_steps = 1000 + num_class_labels = 10 + + hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device) + timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device) + class_labels = torch.randint(0, num_class_labels, size=(batch_size,)).to(torch_device) + + return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_labels} + + @property + def input_shape(self): + return (8, 4, 4) + + @property + def output_shape(self): + return (8, 4, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 4, + "patch_size": 1, + "in_channels": 8, + "hidden_size": (32, 64), + "depth": (2, 1), + "num_heads": (4, 4), + "mlp_ratio": 2.0, + "class_dropout_prob": 0.0, + "num_classes": 10, + "use_qknorm": True, + "use_swiglu": True, + "use_rope": True, + "use_rmsnorm": True, + "wo_shift": False, + "use_pos_embed": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output( + expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape + ) + + def test_output_with_precomputed_conditioning_hidden_states(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + _initialize_non_zero_stage2_head(model) + + batch_size = inputs_dict[self.main_input_name].shape[0] + num_patches = (init_dict["sample_size"] // init_dict["patch_size"]) ** 2 + conditioning_hidden_states = floats_tensor((batch_size, num_patches, init_dict["hidden_size"][0])).to( + torch_device + ) + + with torch.no_grad(): + output = model(**inputs_dict, conditioning_hidden_states=conditioning_hidden_states).sample + + self.assertEqual(output.shape, inputs_dict[self.main_input_name].shape) + + def test_precomputed_conditioning_matches_internal_encoder_path(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + _initialize_non_zero_stage2_head(model) + + hidden_states = inputs_dict["hidden_states"] + timesteps = inputs_dict["timestep"] + class_labels = inputs_dict["class_labels"] + + with torch.no_grad(): + timestep_emb = model.t_embedder(timesteps.reshape(-1).to(torch_device)) + class_emb = model.y_embedder(class_labels.reshape(-1).to(torch_device), train=False) + conditioning = torch.nn.functional.silu(timestep_emb + class_emb) + + conditioning_hidden_states = model.s_embedder(hidden_states) + if model.use_pos_embed: + conditioning_hidden_states = conditioning_hidden_states + model.pos_embed + + for block_idx in range(model.num_encoder_blocks): + conditioning_hidden_states = model.blocks[block_idx]( + conditioning_hidden_states, + conditioning, + feat_rope=model.enc_feat_rope, + ) + + conditioning_hidden_states = torch.nn.functional.silu( + timestep_emb.unsqueeze(1) + conditioning_hidden_states + ) + + output_internal = model(**inputs_dict).sample + output_precomputed = model( + **inputs_dict, + conditioning_hidden_states=conditioning_hidden_states, + ).sample + + self.assertTrue(torch.allclose(output_internal, output_precomputed, atol=1e-5, rtol=1e-4)) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"RAEDiTTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) + + @unittest.skip("RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative.") + def test_layerwise_casting_inference(self): + pass diff --git a/tests/pipelines/rae_dit/__init__.py b/tests/pipelines/rae_dit/__init__.py new file mode 100644 index 000000000000..0acff376f3a3 --- /dev/null +++ b/tests/pipelines/rae_dit/__init__.py @@ -0,0 +1 @@ +# Copyright 2026 HuggingFace Inc. diff --git a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py new file mode 100644 index 000000000000..3c64c90a36d6 --- /dev/null +++ b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import numpy as np +import torch +import torch.nn.functional as F + +import diffusers.models.autoencoders.autoencoder_rae as _rae_module +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline +from diffusers.models.autoencoders.autoencoder_rae import _ENCODER_FORWARD_FNS, _build_encoder +from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import ( + CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS, + CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +def _initialize_non_zero_stage2_head(model: RAEDiTTransformer2DModel): + torch.manual_seed(0) + + for block in model.blocks: + block.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + block.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + + model.final_layer.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) + + +class _TinyTestEncoderModule(torch.nn.Module): + def __init__(self, hidden_size: int = 8, patch_size: int = 4, **kwargs): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + + def forward(self, images: torch.Tensor) -> torch.Tensor: + pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size) + tokens = pooled.flatten(2).transpose(1, 2).contiguous() + return tokens.repeat(1, 1, self.hidden_size) + + +def _tiny_test_encoder_forward(model, images): + return model(images) + + +def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers): + return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size) + + +_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward +_original_build_encoder = _build_encoder + + +def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers): + if encoder_type == "tiny_test": + return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers) + return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers) + + +_rae_module._build_encoder = _patched_build_encoder + + +class RAEDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = RAEDiTPipeline + params = CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS + batch_params = CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS + test_attention_slicing = False + test_xformers_attention = False + + @classmethod + def tearDownClass(cls): + _rae_module._build_encoder = _original_build_encoder + _ENCODER_FORWARD_FNS.pop("tiny_test", None) + super().tearDownClass() + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = RAEDiTTransformer2DModel( + sample_size=2, + patch_size=1, + in_channels=8, + hidden_size=(16, 16), + depth=(1, 1), + num_heads=(2, 2), + mlp_ratio=2.0, + class_dropout_prob=0.1, + num_classes=4, + use_qknorm=False, + use_swiglu=True, + use_rope=True, + use_rmsnorm=True, + use_pos_embed=True, + ) + _initialize_non_zero_stage2_head(transformer) + + vae = AutoencoderRAE( + encoder_type="tiny_test", + encoder_hidden_size=8, + encoder_patch_size=4, + encoder_num_hidden_layers=1, + decoder_hidden_size=16, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=2, + decoder_intermediate_size=32, + patch_size=2, + encoder_input_size=8, + image_size=4, + num_channels=3, + encoder_norm_mean=[0.5, 0.5, 0.5], + encoder_norm_std=[0.5, 0.5, 0.5], + noise_tau=0.0, + reshape_to_2d=True, + scaling_factor=1.0, + ) + scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0) + + return {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler} + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + return { + "class_labels": [1], + "generator": generator, + "guidance_scale": 1.0, + "num_inference_steps": 2, + "output_type": "np", + } + + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(torch_device))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + output_loaded = pipe_loaded(**self.get_dummy_inputs(torch_device))[0] + max_diff = np.abs(output_loaded - output).max() + self.assertLess(max_diff, expected_max_difference) + + def test_inference(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + image = pipe(**self.get_dummy_inputs("cpu")).images + image_slice = image[0, -2:, -2:, -1] + + self.assertEqual(image.shape, (1, 4, 4, 3)) + expected_slice = np.array([0.78739226, 0.79371649, 0.56565261, 0.78660309]) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-4) + + def test_inference_classifier_free_guidance(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs("cpu") + inputs.update({"guidance_scale": 2.0}) + + image = pipe(**inputs).images + self.assertEqual(image.shape, (1, 4, 4, 3)) + self.assertTrue(np.isfinite(image).all()) + + no_guidance = pipe(**self.get_dummy_inputs("cpu")).images + self.assertGreater(np.abs(image - no_guidance).max(), 1e-6) + + def test_guidance_interval_can_disable_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + base = pipe(**self.get_dummy_inputs("cpu")).images + + inputs = self.get_dummy_inputs("cpu") + inputs.pop("guidance_scale") + cfg_disabled = pipe( + **inputs, + guidance_scale=2.0, + guidance_start=0.25, + guidance_end=0.75, + ).images + + self.assertLessEqual(np.abs(base - cfg_disabled).max(), 1e-5) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-4) + + def test_latent_output(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs("cpu") + inputs.pop("output_type") + latents = pipe(**inputs, output_type="latent").images + self.assertEqual(latents.shape, (1, 8, 2, 2)) + self.assertTrue(torch.isfinite(latents).all().item()) + + def test_get_label_ids(self): + pipe = self.pipeline_class( + **self.get_dummy_components(), + id2label={ + 0: "zero", + 1: "one, first", + }, + ) + self.assertEqual(pipe.get_label_ids("first"), [1]) + self.assertEqual(pipe.get_label_ids(["zero", "one"]), [0, 1]) From d88d1a61306230ced0159c17ad8afeed144d27ff Mon Sep 17 00:00:00 2001 From: plugyawn Date: Sun, 8 Mar 2026 21:57:22 +0530 Subject: [PATCH 03/12] Fix RAE DiT review regressions --- .../rae_dit/train_rae_dit.py | 21 ++++++--- .../transformers/transformer_rae_dit.py | 21 ++++++++- .../pipelines/rae_dit/pipeline_rae_dit.py | 8 +++- .../test_models_rae_dit_transformer2d.py | 44 +++++++++++++++++++ .../rae_dit/test_pipeline_rae_dit.py | 17 +++++++ 5 files changed, 102 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index 899e557cc533..9785bacdd167 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -485,23 +485,29 @@ def load_model_hook(models, input_dir): global_step = 0 first_epoch = 0 initial_global_step = 0 + resume_step = 0 if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) + path = args.resume_from_checkpoint + if not os.path.isdir(path): + path = os.path.join(args.output_dir, os.path.basename(path)) else: checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - path = checkpoints[-1] if checkpoints else None + path = os.path.join(args.output_dir, checkpoints[-1]) if checkpoints else None - if path is None: + if path is None or not os.path.isdir(path): logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + accelerator.load_state(path) + global_step = int(os.path.basename(path).split("-")[1]) initial_global_step = global_step + resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) progress_bar = tqdm( range(0, args.max_train_steps), @@ -512,8 +518,11 @@ def load_model_hook(models, input_dir): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() + epoch_dataloader = train_dataloader + if args.resume_from_checkpoint and epoch == first_epoch and resume_step > 0: + epoch_dataloader = accelerator.skip_first_batches(train_dataloader, num_batches=resume_step) - for step, batch in enumerate(train_dataloader): + for batch in epoch_dataloader: with accelerator.accumulate(transformer): pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=weight_dtype, non_blocking=True) class_labels = batch["class_labels"].to(device=accelerator.device, non_blocking=True) diff --git a/src/diffusers/models/transformers/transformer_rae_dit.py b/src/diffusers/models/transformers/transformer_rae_dit.py index ea2b4d223128..f9e5996ee735 100644 --- a/src/diffusers/models/transformers/transformer_rae_dit.py +++ b/src/diffusers/models/transformers/transformer_rae_dit.py @@ -22,7 +22,26 @@ def _repeat_to_length(hidden_states: torch.Tensor, target_length: int) -> torch. f"Cannot repeat sequence of length {hidden_states.shape[1]} to match target length {target_length}." ) - return hidden_states.repeat_interleave(target_length // hidden_states.shape[1], dim=1) + if hidden_states.shape[1] == 1: + return hidden_states.expand(-1, target_length, -1) + + source_side = int(sqrt(hidden_states.shape[1])) + target_side = int(sqrt(target_length)) + if ( + source_side * source_side == hidden_states.shape[1] + and target_side * target_side == target_length + and target_side % source_side == 0 + ): + scale = target_side // source_side + batch_size, _, channels = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, source_side, source_side, channels) + hidden_states = hidden_states.repeat_interleave(scale, dim=1).repeat_interleave(scale, dim=2) + return hidden_states.reshape(batch_size, target_length, channels) + + raise ValueError( + "Cannot expand conditioning tokens without preserving their 2D layout: " + f"source length {hidden_states.shape[1]} is incompatible with target length {target_length}." + ) def _ddt_modulate(hidden_states: torch.Tensor, shift: torch.Tensor | None, scale: torch.Tensor) -> torch.Tensor: diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py index 7446231a572b..c9ad3ec52282 100644 --- a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -33,10 +33,14 @@ def __init__( ): super().__init__() self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + serialized_id2label = None + if id2label is not None: + serialized_id2label = {str(key): value for key, value in id2label.items()} + self.register_to_config(id2label=serialized_id2label) self.labels = {} - if id2label is not None: - for key, value in id2label.items(): + if self.config.id2label is not None: + for key, value in self.config.id2label.items(): for label in value.split(","): self.labels[label.strip()] = int(key) self.labels = dict(sorted(self.labels.items())) diff --git a/tests/models/transformers/test_models_rae_dit_transformer2d.py b/tests/models/transformers/test_models_rae_dit_transformer2d.py index 9023c01b70b1..baa23e0288a7 100644 --- a/tests/models/transformers/test_models_rae_dit_transformer2d.py +++ b/tests/models/transformers/test_models_rae_dit_transformer2d.py @@ -18,6 +18,7 @@ import torch from diffusers import RAEDiTTransformer2DModel +from diffusers.models.transformers.transformer_rae_dit import _repeat_to_length from ...testing_utils import enable_full_determinism, floats_tensor, torch_device from ..test_modeling_common import ModelTesterMixin @@ -148,6 +149,49 @@ def test_precomputed_conditioning_matches_internal_encoder_path(self): self.assertTrue(torch.allclose(output_internal, output_precomputed, atol=1e-5, rtol=1e-4)) + def test_repeat_to_length_preserves_2d_layout(self): + hidden_states = torch.tensor([[[1.0], [2.0], [3.0], [4.0]]]) + + repeated = _repeat_to_length(hidden_states, target_length=16) + + expected = torch.tensor( + [ + [ + [1.0], + [1.0], + [2.0], + [2.0], + [1.0], + [1.0], + [2.0], + [2.0], + [3.0], + [3.0], + [4.0], + [4.0], + [3.0], + [3.0], + [4.0], + [4.0], + ] + ] + ) + self.assertTrue(torch.equal(repeated, expected)) + + def test_repeat_to_length_broadcasts_global_conditioning(self): + hidden_states = torch.tensor([[[1.0, 2.0]]]) + + repeated = _repeat_to_length(hidden_states, target_length=4) + + expected = torch.tensor([[[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]]]) + self.assertTrue(torch.equal(repeated, expected)) + + def test_repeat_to_length_rejects_incompatible_multi_token_layouts(self): + hidden_states = torch.randn(1, 2, 4) + + with self.assertRaises(ValueError): + _repeat_to_length(hidden_states, target_length=8) + def test_gradient_checkpointing_is_applied(self): expected_set = {"RAEDiTTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py index 3c64c90a36d6..7d1c40cfb377 100644 --- a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py +++ b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py @@ -236,3 +236,20 @@ def test_get_label_ids(self): ) self.assertEqual(pipe.get_label_ids("first"), [1]) self.assertEqual(pipe.get_label_ids(["zero", "one"]), [0, 1]) + + def test_save_load_preserves_label_ids(self): + pipe = self.pipeline_class( + **self.get_dummy_components(), + id2label={ + 0: "zero", + 1: "one, first", + }, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + + self.assertEqual(pipe_loaded.config.id2label, {"0": "zero", "1": "one, first"}) + self.assertEqual(pipe_loaded.get_label_ids("first"), [1]) + self.assertEqual(pipe_loaded.get_label_ids(["zero", "one"]), [0, 1]) From 38826eb4a49ad999c961c302aa9c14371e279464 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Sun, 8 Mar 2026 22:05:07 +0530 Subject: [PATCH 04/12] Add RAE DiT resume-order verifier --- .../research_projects/rae_dit/test_rae_dit.py | 49 +++++ .../rae_dit/train_rae_dit.py | 39 +++- .../rae_dit/verify_train_resume.py | 189 ++++++++++++++++++ 3 files changed, 270 insertions(+), 7 deletions(-) create mode 100644 examples/research_projects/rae_dit/test_rae_dit.py create mode 100644 examples/research_projects/rae_dit/verify_train_resume.py diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py new file mode 100644 index 000000000000..f2fd7b0a408f --- /dev/null +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class RAEDiT(ExamplesTestsAccelerate): + def test_verify_train_resume(self): + test_args = """ + examples/research_projects/rae_dit/verify_train_resume.py + --seed 123 + --resolution 16 + --num_samples 6 + --train_batch_size 1 + --gradient_accumulation_steps 2 + --max_train_steps 3 + --resume_global_step 1 + """.split() + + output = run_command(self._launch_args + test_args, return_stdout=True) + + self.assertIn("baseline_trace=", output) + self.assertIn("resumed_trace=", output) + self.assertIn("resume batch order verified", output) diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index 9785bacdd167..e97356a64234 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -238,6 +238,21 @@ def collate_fn(examples): return {"pixel_values": pixel_values, "class_labels": class_labels} +def compute_resume_offsets( + global_step: int, num_update_steps_per_epoch: int, gradient_accumulation_steps: int +) -> tuple[int, int]: + first_epoch = global_step // num_update_steps_per_epoch + resume_global_step = global_step * gradient_accumulation_steps + resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps) + return first_epoch, resume_step + + +def should_skip_resumed_batch( + should_resume: bool, epoch: int, first_epoch: int, step: int, resume_step: int +) -> bool: + return should_resume and epoch == first_epoch and step < resume_step + + def get_latent_spec(autoencoder: AutoencoderRAE) -> tuple[int, int]: if not autoencoder.config.reshape_to_2d: raise ValueError("Stage-2 RAE DiT training expects `AutoencoderRAE.reshape_to_2d=True`.") @@ -505,9 +520,11 @@ def load_model_hook(models, input_dir): accelerator.load_state(path) global_step = int(os.path.basename(path).split("-")[1]) initial_global_step = global_step - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + first_epoch, resume_step = compute_resume_offsets( + global_step=global_step, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) progress_bar = tqdm( range(0, args.max_train_steps), @@ -518,11 +535,19 @@ def load_model_hook(models, input_dir): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - epoch_dataloader = train_dataloader - if args.resume_from_checkpoint and epoch == first_epoch and resume_step > 0: - epoch_dataloader = accelerator.skip_first_batches(train_dataloader, num_batches=resume_step) + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + for step, batch in enumerate(train_dataloader): + if should_skip_resumed_batch( + should_resume=args.resume_from_checkpoint is not None, + epoch=epoch, + first_epoch=first_epoch, + step=step, + resume_step=resume_step, + ): + continue - for batch in epoch_dataloader: with accelerator.accumulate(transformer): pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=weight_dtype, non_blocking=True) class_labels = batch["class_labels"].to(device=accelerator.device, non_blocking=True) diff --git a/examples/research_projects/rae_dit/verify_train_resume.py b/examples/research_projects/rae_dit/verify_train_resume.py new file mode 100644 index 000000000000..34b1945dcf54 --- /dev/null +++ b/examples/research_projects/rae_dit/verify_train_resume.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import math +import tempfile +from pathlib import Path +from types import SimpleNamespace + +from accelerate import Accelerator +from accelerate.utils import set_seed +from PIL import Image +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder + +from train_rae_dit import build_transforms, collate_fn, compute_resume_offsets, should_skip_resumed_batch + + +def parse_args(): + parser = argparse.ArgumentParser(description="Verify Stage-2 RAE DiT mid-epoch resume batch ordering.") + parser.add_argument("--seed", type=int, default=123, help="Seed used for the shuffled dataloader.") + parser.add_argument("--resolution", type=int, default=16, help="Synthetic image resolution.") + parser.add_argument("--num_samples", type=int, default=6, help="Number of unique samples/classes to create.") + parser.add_argument("--train_batch_size", type=int, default=1, help="Microbatch size used by the trace harness.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=2, + help="Gradient accumulation steps used to derive the mid-epoch checkpoint position.", + ) + parser.add_argument("--max_train_steps", type=int, default=3, help="Total optimizer steps to trace.") + parser.add_argument( + "--resume_global_step", + type=int, + default=1, + help="Optimizer step at which the synthetic run resumes from a checkpoint.", + ) + return parser.parse_args() + + +def create_unique_class_dataset(dataset_dir: Path, resolution: int, num_samples: int): + for sample_idx in range(num_samples): + class_dir = dataset_dir / f"class_{sample_idx:02d}" + class_dir.mkdir(parents=True, exist_ok=True) + color = ( + (40 * sample_idx) % 256, + (80 * sample_idx) % 256, + (120 * sample_idx) % 256, + ) + image = Image.new("RGB", (resolution, resolution), color=color) + image.save(class_dir / f"sample_{sample_idx}.png") + + +def collect_class_label_trace( + dataset_dir: Path, + *, + seed: int, + resolution: int, + train_batch_size: int, + gradient_accumulation_steps: int, + max_train_steps: int, + resume_global_step: int = 0, +) -> list[int]: + accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + set_seed(seed) + + transform_args = SimpleNamespace(resolution=resolution, center_crop=True, random_flip=False) + dataset = ImageFolder(dataset_dir, transform=build_transforms(transform_args)) + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=train_batch_size, + num_workers=0, + pin_memory=True, + drop_last=True, + ) + train_dataloader = accelerator.prepare(train_dataloader) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + first_epoch = 0 + resume_step = 0 + should_resume = resume_global_step > 0 + + if should_resume: + first_epoch, resume_step = compute_resume_offsets( + global_step=resume_global_step, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + + expected_microbatches = (max_train_steps - resume_global_step) * gradient_accumulation_steps + trace = [] + + for epoch in range(first_epoch, num_train_epochs): + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + for step, batch in enumerate(train_dataloader): + if should_skip_resumed_batch( + should_resume=should_resume, + epoch=epoch, + first_epoch=first_epoch, + step=step, + resume_step=resume_step, + ): + continue + + trace.extend(batch["class_labels"].tolist()) + if len(trace) >= expected_microbatches: + return trace + + raise AssertionError( + f"Expected to record {expected_microbatches} microbatches, but only collected {len(trace)}." + ) + + +def main(): + args = parse_args() + + if args.resume_global_step >= args.max_train_steps: + raise ValueError( + f"`resume_global_step` ({args.resume_global_step}) must be < `max_train_steps` ({args.max_train_steps})." + ) + microbatches_per_epoch = args.num_samples // args.train_batch_size + required_microbatches = args.max_train_steps * args.gradient_accumulation_steps + if microbatches_per_epoch < required_microbatches: + raise ValueError( + "The verifier keeps the proof inside a single epoch. Increase `--num_samples` or decrease " + "`--train_batch_size`, `--gradient_accumulation_steps`, or `--max_train_steps`." + ) + + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) / "trace-dataset" + create_unique_class_dataset(dataset_dir, resolution=args.resolution, num_samples=args.num_samples) + + baseline_trace = collect_class_label_trace( + dataset_dir, + seed=args.seed, + resolution=args.resolution, + train_batch_size=args.train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + max_train_steps=args.max_train_steps, + ) + resumed_trace = collect_class_label_trace( + dataset_dir, + seed=args.seed, + resolution=args.resolution, + train_batch_size=args.train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + max_train_steps=args.max_train_steps, + resume_global_step=args.resume_global_step, + ) + + consumed_microbatches = args.resume_global_step * args.gradient_accumulation_steps + expected_resumed_trace = baseline_trace[consumed_microbatches:] + + print(f"baseline_trace={json.dumps(baseline_trace)}") + print(f"consumed_trace={json.dumps(baseline_trace[:consumed_microbatches])}") + print(f"resumed_trace={json.dumps(resumed_trace)}") + + if resumed_trace != expected_resumed_trace: + raise AssertionError( + "Resumed batch order does not match the uninterrupted run tail. " + f"Expected {expected_resumed_trace}, got {resumed_trace}." + ) + + print("resume batch order verified") + + +if __name__ == "__main__": + main() From 5847b0798e83b7adf91f301d897a697975f64564 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Sun, 8 Mar 2026 23:10:54 +0530 Subject: [PATCH 05/12] Add RAE DiT training smoke test --- .../research_projects/rae_dit/test_rae_dit.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py index f2fd7b0a408f..d320d562386c 100644 --- a/examples/research_projects/rae_dit/test_rae_dit.py +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -16,6 +16,11 @@ import logging import os import sys +import tempfile + +from PIL import Image + +from diffusers import AutoencoderRAE sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -30,6 +35,80 @@ class RAEDiT(ExamplesTestsAccelerate): + def _create_tiny_rae(self, tmpdir): + model = AutoencoderRAE( + encoder_type="mae", + encoder_hidden_size=64, + encoder_patch_size=4, + encoder_num_hidden_layers=1, + encoder_input_size=16, + patch_size=4, + image_size=16, + num_channels=3, + decoder_hidden_size=64, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=4, + decoder_intermediate_size=128, + encoder_norm_mean=[0.5, 0.5, 0.5], + encoder_norm_std=[0.5, 0.5, 0.5], + noise_tau=0.0, + reshape_to_2d=True, + scaling_factor=1.0, + ) + output_dir = os.path.join(tmpdir, "tiny-rae") + model.save_pretrained(output_dir, safe_serialization=False) + return output_dir + + def _create_dataset(self, tmpdir, resolution=16): + dataset_dir = os.path.join(tmpdir, "dataset") + for class_idx in range(2): + class_dir = os.path.join(dataset_dir, f"class_{class_idx:02d}") + os.makedirs(class_dir, exist_ok=True) + for image_idx in range(2): + color = ( + (50 * class_idx + 20 * image_idx) % 256, + (80 * class_idx + 30 * image_idx) % 256, + (110 * class_idx + 40 * image_idx) % 256, + ) + image = Image.new("RGB", (resolution, resolution), color=color) + image.save(os.path.join(class_dir, f"sample_{image_idx}.png")) + return dataset_dir + + def test_train_rae_dit_smoke(self): + with tempfile.TemporaryDirectory() as tmpdir: + rae_dir = self._create_tiny_rae(tmpdir) + dataset_dir = self._create_dataset(tmpdir) + + test_args = f""" + examples/research_projects/rae_dit/train_rae_dit.py + --pretrained_rae_model_name_or_path {rae_dir} + --train_data_dir {dataset_dir} + --output_dir {tmpdir} + --resolution 16 + --center_crop + --train_batch_size 1 + --dataloader_num_workers 0 + --max_train_steps 2 + --gradient_accumulation_steps 1 + --learning_rate 1e-3 + --lr_scheduler constant + --lr_warmup_steps 0 + --encoder_hidden_size 32 + --decoder_hidden_size 32 + --encoder_num_layers 1 + --decoder_num_layers 1 + --encoder_num_attention_heads 4 + --decoder_num_attention_heads 4 + --mlp_ratio 2.0 + --num_train_timesteps 10 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "id2label.json"))) + def test_verify_train_resume(self): test_args = """ examples/research_projects/rae_dit/verify_train_resume.py From 5ec84b0500ee60318dc156481fe8e32975a033ed Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 09:10:59 +0530 Subject: [PATCH 06/12] Sync RAE DiT stack with diffusers quality checks --- .../rae_dit/compare_stage2_sample.py | 12 ++++++-- .../rae_dit/train_rae_dit.py | 16 +++++----- .../rae_dit/verify_stage2_parity.py | 8 +++-- .../rae_dit/verify_train_resume.py | 5 +--- scripts/convert_rae_stage2_to_diffusers.py | 19 +++++++++--- src/diffusers/__init__.py | 8 ++--- src/diffusers/models/__init__.py | 4 +-- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_rae_dit.py | 21 +++++++++---- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/rae_dit/pipeline_rae_dit.py | 8 +++-- src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++++++++++++ .../test_models_rae_dit_transformer2d.py | 4 ++- 13 files changed, 103 insertions(+), 36 deletions(-) diff --git a/examples/research_projects/rae_dit/compare_stage2_sample.py b/examples/research_projects/rae_dit/compare_stage2_sample.py index 7c84adf2a059..03ad194d8677 100644 --- a/examples/research_projects/rae_dit/compare_stage2_sample.py +++ b/examples/research_projects/rae_dit/compare_stage2_sample.py @@ -29,7 +29,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Create a visual side-by-side sample comparison between upstream and diffusers Stage-2 RAE DiT.") + parser = argparse.ArgumentParser( + description="Create a visual side-by-side sample comparison between upstream and diffusers Stage-2 RAE DiT." + ) parser.add_argument("--upstream_repo_path", type=str, required=True) parser.add_argument("--config_path", type=str, required=True) parser.add_argument("--checkpoint_path", type=str, required=True) @@ -115,7 +117,9 @@ def main(): stage2_params = stage2.get("params", {}) misc = _resolve_section(config, "misc") latent_size = misc["latent_size"] - shift = math.sqrt(int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096))) + shift = math.sqrt( + int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096)) + ) num_train_timesteps = int(_resolve_section(config, "transport").get("params", {}).get("num_train_timesteps", 1000)) device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) @@ -157,7 +161,9 @@ def run_sample(model, latents): if isinstance(model, DiTwDDTHead): model_output = model(latents, timestep_input, class_labels) else: - model_output = model(hidden_states=latents, timestep=timestep_input, class_labels=class_labels).sample + model_output = model( + hidden_states=latents, timestep=timestep_input, class_labels=class_labels + ).sample latents = scheduler.step(model_output, timestep, latents).prev_sample return vae.decode(latents).sample.clamp(0, 1) diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index e97356a64234..d2fdcca5f110 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -161,7 +161,9 @@ def parse_args(): ) parser.add_argument("--use_qknorm", action="store_true", help="Enable QK norm in attention.") parser.add_argument("--use_swiglu", action=argparse.BooleanOptionalAction, default=True, help="Use SwiGLU MLPs.") - parser.add_argument("--use_rope", action=argparse.BooleanOptionalAction, default=True, help="Use rotary embeddings.") + parser.add_argument( + "--use_rope", action=argparse.BooleanOptionalAction, default=True, help="Use rotary embeddings." + ) parser.add_argument( "--use_rmsnorm", action=argparse.BooleanOptionalAction, @@ -198,7 +200,7 @@ def parse_args(): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help='Weighting scheme for flow-matching timestep sampling and loss weighting.', + help="Weighting scheme for flow-matching timestep sampling and loss weighting.", ) parser.add_argument( "--logit_mean", @@ -247,9 +249,7 @@ def compute_resume_offsets( return first_epoch, resume_step -def should_skip_resumed_batch( - should_resume: bool, epoch: int, first_epoch: int, step: int, resume_step: int -) -> bool: +def should_skip_resumed_batch(should_resume: bool, epoch: int, first_epoch: int, step: int, resume_step: int) -> bool: return should_resume and epoch == first_epoch and step < resume_step @@ -484,7 +484,7 @@ def load_model_hook(models, input_dir): }, ) with open(os.path.join(args.output_dir, "id2label.json"), "w", encoding="utf-8") as f: - json.dump({idx: label for idx, label in enumerate(dataset.classes)}, f, indent=2, sort_keys=True) + json.dump(dict(enumerate(dataset.classes)), f, indent=2, sort_keys=True) total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running stage-2 RAE DiT training *****") @@ -549,7 +549,9 @@ def load_model_hook(models, input_dir): continue with accelerator.accumulate(transformer): - pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=weight_dtype, non_blocking=True) + pixel_values = batch["pixel_values"].to( + device=accelerator.device, dtype=weight_dtype, non_blocking=True + ) class_labels = batch["class_labels"].to(device=accelerator.device, non_blocking=True) with torch.no_grad(): diff --git a/examples/research_projects/rae_dit/verify_stage2_parity.py b/examples/research_projects/rae_dit/verify_stage2_parity.py index 29899628d8bd..db1570b66c44 100644 --- a/examples/research_projects/rae_dit/verify_stage2_parity.py +++ b/examples/research_projects/rae_dit/verify_stage2_parity.py @@ -27,7 +27,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Compare a converted RAEDiT checkpoint against the upstream Stage-2 model.") + parser = argparse.ArgumentParser( + description="Compare a converted RAEDiT checkpoint against the upstream Stage-2 model." + ) parser.add_argument("--upstream_repo_path", type=str, required=True, help="Path to the cloned upstream RAE repo.") parser.add_argument( "--config_path", @@ -142,7 +144,9 @@ def main(): stage2_params = stage2.get("params", {}) misc = _resolve_section(config, "misc") latent_size = misc["latent_size"] - shift = math.sqrt(int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096))) + shift = math.sqrt( + int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096)) + ) device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) state_dict = unwrap_state_dict(load_checkpoint(checkpoint_path), prefer_ema=True) diff --git a/examples/research_projects/rae_dit/verify_train_resume.py b/examples/research_projects/rae_dit/verify_train_resume.py index 34b1945dcf54..62fc9cd8c413 100644 --- a/examples/research_projects/rae_dit/verify_train_resume.py +++ b/examples/research_projects/rae_dit/verify_train_resume.py @@ -28,7 +28,6 @@ from PIL import Image from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder - from train_rae_dit import build_transforms, collate_fn, compute_resume_offsets, should_skip_resumed_batch @@ -127,9 +126,7 @@ def collect_class_label_trace( if len(trace) >= expected_microbatches: return trace - raise AssertionError( - f"Expected to record {expected_microbatches} microbatches, but only collected {len(trace)}." - ) + raise AssertionError(f"Expected to record {expected_microbatches} microbatches, but only collected {len(trace)}.") def main(): diff --git a/scripts/convert_rae_stage2_to_diffusers.py b/scripts/convert_rae_stage2_to_diffusers.py index c51011d2286c..376321db20e5 100644 --- a/scripts/convert_rae_stage2_to_diffusers.py +++ b/scripts/convert_rae_stage2_to_diffusers.py @@ -253,7 +253,9 @@ def convert_transformer_state_dict( if verify_load: reloaded = RAEDiTTransformer2DModel.from_pretrained(output_dir, low_cpu_mem_usage=False) if not isinstance(reloaded, RAEDiTTransformer2DModel): - raise RuntimeError(f"Verification failed for {component_name}: reloaded object is not RAEDiTTransformer2DModel.") + raise RuntimeError( + f"Verification failed for {component_name}: reloaded object is not RAEDiTTransformer2DModel." + ) return { "checkpoint_path": str(checkpoint_path), @@ -304,7 +306,11 @@ def resolve_checkpoint_path( def convert(args: argparse.Namespace) -> None: weights_accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir) - config_accessor = RepoAccessor(args.config_repo_or_path, cache_dir=args.cache_dir) if args.config_repo_or_path else weights_accessor + config_accessor = ( + RepoAccessor(args.config_repo_or_path, cache_dir=args.cache_dir) + if args.config_repo_or_path + else weights_accessor + ) config = read_yaml(config_accessor, args.config_path) stage2 = _resolve_section(config, "stage_2", "stage2") @@ -320,7 +326,9 @@ def convert(args: argparse.Namespace) -> None: description="Stage-2 checkpoint", ) if checkpoint_path is None: - raise ValueError("Could not resolve a Stage-2 checkpoint. Pass `--checkpoint_path` or provide `stage_2.ckpt` in config.") + raise ValueError( + "Could not resolve a Stage-2 checkpoint. Pass `--checkpoint_path` or provide `stage_2.ckpt` in config." + ) scheduler, scheduler_metadata = build_scheduler_config(config) sampler = _resolve_section(config, "sampler") @@ -361,7 +369,10 @@ def convert(args: argparse.Namespace) -> None: print(f"Using config: {args.config_path}") print(f"Using Stage-2 checkpoint: {checkpoint_path}") print(f"Derived scheduler shift: {scheduler.config.shift:.6f}") - if metadata["sampler"]["mode"] != "ODE" or metadata["sampler"]["params"].get("sampling_method", "euler") != "euler": + if ( + metadata["sampler"]["mode"] != "ODE" + or metadata["sampler"]["params"].get("sampling_method", "euler") != "euler" + ): print( "Warning: upstream sampler is not the public ODE/Euler path. The saved scheduler still uses " "FlowMatchEulerDiscreteScheduler for diffusers V1 compatibility." diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0d4097065f71..1dd26e855e48 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,10 +258,10 @@ "PixArtTransformer2DModel", "PriorTransformer", "PRXTransformer2DModel", - "RAEDiTTransformer2DModel", "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", + "RAEDiTTransformer2DModel", "SanaControlNetModel", "SanaTransformer2DModel", "SanaVideoTransformer3DModel", @@ -329,12 +329,12 @@ "DDPMPipeline", "DiffusionPipeline", "DiTPipeline", - "RAEDiTPipeline", "ImagePipelineOutput", "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", "PNDMPipeline", + "RAEDiTPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", @@ -1035,10 +1035,10 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, - RAEDiTTransformer2DModel, QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, + RAEDiTTransformer2DModel, SanaControlNetModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, @@ -1104,12 +1104,12 @@ DDPMPipeline, DiffusionPipeline, DiTPipeline, - RAEDiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, PNDMPipeline, + RAEDiTPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 30a4f541165f..10c1ed81c8a8 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -117,8 +117,8 @@ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] - _import_structure["transformers.transformer_rae_dit"] = ["RAEDiTTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] + _import_structure["transformers.transformer_rae_dit"] = ["RAEDiTTransformer2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] @@ -238,8 +238,8 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, - RAEDiTTransformer2DModel, QwenImageTransformer2DModel, + RAEDiTTransformer2DModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 1006a758801f..70df11c26ba4 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -43,8 +43,8 @@ from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel - from .transformer_rae_dit import RAEDiTTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel + from .transformer_rae_dit import RAEDiTTransformer2DModel from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_rae_dit.py b/src/diffusers/models/transformers/transformer_rae_dit.py index f9e5996ee735..44963a4acbcb 100644 --- a/src/diffusers/models/transformers/transformer_rae_dit.py +++ b/src/diffusers/models/transformers/transformer_rae_dit.py @@ -130,7 +130,9 @@ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float): self.dropout_prob = dropout_prob self.embedding_table = nn.Embedding(num_classes + int(dropout_prob > 0), hidden_size) - def token_drop(self, class_labels: torch.LongTensor, force_drop_ids: torch.Tensor | None = None) -> torch.LongTensor: + def token_drop( + self, class_labels: torch.LongTensor, force_drop_ids: torch.Tensor | None = None + ) -> torch.LongTensor: if force_drop_ids is None: drop_ids = torch.rand(class_labels.shape[0], device=class_labels.device) < self.dropout_prob else: @@ -394,7 +396,9 @@ def __init__( ) self.s_projector = ( - nn.Linear(encoder_hidden_size, decoder_hidden_size) if encoder_hidden_size != decoder_hidden_size else nn.Identity() + nn.Linear(encoder_hidden_size, decoder_hidden_size) + if encoder_hidden_size != decoder_hidden_size + else nn.Identity() ) self.t_embedder = GaussianFourierEmbedding(encoder_hidden_size) self.y_embedder = LabelEmbedder(num_classes, encoder_hidden_size, class_dropout_prob) @@ -427,7 +431,9 @@ def __init__( [ RAEDiTBlock( hidden_size=encoder_hidden_size if index < encoder_num_layers else decoder_hidden_size, - num_heads=encoder_num_attention_heads if index < encoder_num_layers else decoder_num_attention_heads, + num_heads=encoder_num_attention_heads + if index < encoder_num_layers + else decoder_num_attention_heads, mlp_ratio=mlp_ratio, use_qknorm=use_qknorm, use_swiglu=use_swiglu, @@ -458,7 +464,9 @@ def _basic_init(module): nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) if self.use_pos_embed: - pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt") + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt" + ) self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0)) for block in self.blocks: @@ -482,7 +490,9 @@ def unpatchify(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.reshape(hidden_states.shape[0], height, width, patch_size, patch_size, channels) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - hidden_states = hidden_states.reshape(hidden_states.shape[0], channels, height * patch_size, width * patch_size) + hidden_states = hidden_states.reshape( + hidden_states.shape[0], channels, height * patch_size, width * patch_size + ) return hidden_states def _run_block( @@ -493,6 +503,7 @@ def _run_block( feat_rope: VisionRotaryEmbeddingFast | None, ) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: + def custom_forward(hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: return block(hidden_states, conditioning, feat_rope=feat_rope) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b1d3a95adf46..7a280b9b6a04 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -773,7 +773,6 @@ from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .prx import PRXPipeline - from .rae_dit import RAEDiTPipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, @@ -785,6 +784,7 @@ QwenImageLayeredPipeline, QwenImagePipeline, ) + from .rae_dit import RAEDiTPipeline from .sana import ( SanaControlNetPipeline, SanaPipeline, diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py index c9ad3ec52282..136087afe2ae 100644 --- a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -180,7 +180,9 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype - class_labels = self._prepare_class_labels(class_labels, num_images_per_prompt=num_images_per_prompt, device=device) + class_labels = self._prepare_class_labels( + class_labels, num_images_per_prompt=num_images_per_prompt, device=device + ) batch_size = class_labels.shape[0] latent_size = self.transformer.config.sample_size @@ -223,7 +225,9 @@ def __call__( if self.do_classifier_free_guidance: cond_model_output, uncond_model_output = model_output.chunk(2, dim=0) guided_model_output = uncond_model_output + guidance_scale * (cond_model_output - uncond_model_output) - guidance_mask = ((timestep_input[:batch_size] >= guidance_start) & (timestep_input[:batch_size] <= guidance_end)) + guidance_mask = (timestep_input[:batch_size] >= guidance_start) & ( + timestep_input[:batch_size] <= guidance_end + ) guidance_mask = guidance_mask.view(-1, *([1] * (cond_model_output.ndim - 1))) model_output = torch.where(guidance_mask, guided_model_output, cond_model_output) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3425cc8d2b61..3dd01849b12b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1541,6 +1541,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RAEDiTTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SanaControlNetModel(metaclass=DummyObject): _backends = ["torch"] @@ -2413,6 +2428,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RAEDiTPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class RePaintPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/transformers/test_models_rae_dit_transformer2d.py b/tests/models/transformers/test_models_rae_dit_transformer2d.py index baa23e0288a7..a39e7a4f67e6 100644 --- a/tests/models/transformers/test_models_rae_dit_transformer2d.py +++ b/tests/models/transformers/test_models_rae_dit_transformer2d.py @@ -199,6 +199,8 @@ def test_gradient_checkpointing_is_applied(self): def test_effective_gradient_checkpointing(self): super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) - @unittest.skip("RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative.") + @unittest.skip( + "RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative." + ) def test_layerwise_casting_inference(self): pass From 33c57ceed233f18248b6ddd5a98323c490e89752 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 09:30:59 +0530 Subject: [PATCH 07/12] Add RAE DiT API docs --- docs/source/en/_toctree.yml | 4 ++ .../en/api/models/rae_dit_transformer2d.md | 32 ++++++++++ docs/source/en/api/pipelines/rae_dit.md | 59 +++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 docs/source/en/api/models/rae_dit_transformer2d.md create mode 100644 docs/source/en/api/pipelines/rae_dit.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e0b7af4898b2..33b9cdbaeeb5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -388,6 +388,8 @@ title: PriorTransformer - local: api/models/qwenimage_transformer2d title: QwenImageTransformer2DModel + - local: api/models/rae_dit_transformer2d + title: RAEDiTTransformer2DModel - local: api/models/sana_transformer2d title: SanaTransformer2DModel - local: api/models/sana_video_transformer3d @@ -604,6 +606,8 @@ title: PRX - local: api/pipelines/qwenimage title: QwenImage + - local: api/pipelines/rae_dit + title: RAE DiT - local: api/pipelines/sana title: Sana - local: api/pipelines/sana_sprint diff --git a/docs/source/en/api/models/rae_dit_transformer2d.md b/docs/source/en/api/models/rae_dit_transformer2d.md new file mode 100644 index 000000000000..c9361c9e8894 --- /dev/null +++ b/docs/source/en/api/models/rae_dit_transformer2d.md @@ -0,0 +1,32 @@ + + +# RAEDiTTransformer2DModel + +The `RAEDiTTransformer2DModel` is the Stage-2 latent diffusion transformer introduced in +[Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690). + +Unlike DiT models that operate on VAE latents, this transformer denoises the latent space learned by +[`AutoencoderRAE`](./autoencoder_rae). It is designed to be used with [`FlowMatchEulerDiscreteScheduler`] and +decoded back to RGB with [`AutoencoderRAE`]. + +## Loading a pretrained transformer + +```python +from diffusers import RAEDiTTransformer2DModel + +transformer = RAEDiTTransformer2DModel.from_pretrained("path/to/converted-stage2-transformer") +``` + +## RAEDiTTransformer2DModel + +[[autodoc]] RAEDiTTransformer2DModel diff --git a/docs/source/en/api/pipelines/rae_dit.md b/docs/source/en/api/pipelines/rae_dit.md new file mode 100644 index 000000000000..53be9f125d43 --- /dev/null +++ b/docs/source/en/api/pipelines/rae_dit.md @@ -0,0 +1,59 @@ + + +# RAE DiT + +[Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) introduces a +two-stage recipe: first train a representation autoencoder (RAE), then train a diffusion transformer on the resulting +latent space. + +[`RAEDiTPipeline`] implements the Stage-2 class-conditional generator in Diffusers. It combines: + +- [`RAEDiTTransformer2DModel`] for latent denoising +- [`FlowMatchEulerDiscreteScheduler`] for the denoising trajectory +- [`AutoencoderRAE`] for decoding latent samples to RGB images + +> [!TIP] +> [`RAEDiTPipeline`] expects a Stage-2 checkpoint converted to Diffusers format together with a compatible +> [`AutoencoderRAE`] checkpoint. + +## Loading a converted pipeline + +```python +import torch +from diffusers import RAEDiTPipeline + +pipe = RAEDiTPipeline.from_pretrained( + "path/to/converted-rae-dit-imagenet256", + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe(class_labels=[207], num_inference_steps=25).images[0] +image.save("golden_retriever.png") +``` + +If the converted pipeline includes an `id2label` mapping, you can also look up class ids by name: + +```python +class_id = pipe.get_label_ids("golden retriever")[0] +image = pipe(class_labels=[class_id], num_inference_steps=25).images[0] +``` + +## RAEDiTPipeline + +[[autodoc]] RAEDiTPipeline + - all + - __call__ + +## ImagePipelineOutput + +[[autodoc]] pipelines.ImagePipelineOutput From 8b744986e85d4356b2e8447909ac2e41b82f22ef Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 10:14:42 +0530 Subject: [PATCH 08/12] Rename RAEDiTTransformer2DModel to RAEDiT2DModel --- docs/source/en/_toctree.yml | 2 +- docs/source/en/api/models/rae_dit_transformer2d.md | 12 ++++++------ docs/source/en/api/pipelines/rae_dit.md | 2 +- examples/research_projects/rae_dit/README.md | 6 +++--- .../rae_dit/compare_stage2_sample.py | 4 ++-- .../research_projects/rae_dit/train_rae_dit.py | 14 +++++++------- .../rae_dit/verify_stage2_parity.py | 4 ++-- scripts/convert_rae_stage2_to_diffusers.py | 14 ++++++-------- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/__init__.py | 4 ++-- src/diffusers/models/transformers/__init__.py | 2 +- .../models/transformers/transformer_rae_dit.py | 2 +- .../pipelines/rae_dit/pipeline_rae_dit.py | 6 +++--- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../test_models_rae_dit_transformer2d.py | 10 +++++----- tests/pipelines/rae_dit/test_pipeline_rae_dit.py | 6 +++--- 16 files changed, 46 insertions(+), 48 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 33b9cdbaeeb5..5a47e710220d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -389,7 +389,7 @@ - local: api/models/qwenimage_transformer2d title: QwenImageTransformer2DModel - local: api/models/rae_dit_transformer2d - title: RAEDiTTransformer2DModel + title: RAEDiT2DModel - local: api/models/sana_transformer2d title: SanaTransformer2DModel - local: api/models/sana_video_transformer3d diff --git a/docs/source/en/api/models/rae_dit_transformer2d.md b/docs/source/en/api/models/rae_dit_transformer2d.md index c9361c9e8894..df171db8ef77 100644 --- a/docs/source/en/api/models/rae_dit_transformer2d.md +++ b/docs/source/en/api/models/rae_dit_transformer2d.md @@ -10,9 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# RAEDiTTransformer2DModel +# RAEDiT2DModel -The `RAEDiTTransformer2DModel` is the Stage-2 latent diffusion transformer introduced in +The `RAEDiT2DModel` is the Stage-2 latent diffusion transformer introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690). Unlike DiT models that operate on VAE latents, this transformer denoises the latent space learned by @@ -22,11 +22,11 @@ decoded back to RGB with [`AutoencoderRAE`]. ## Loading a pretrained transformer ```python -from diffusers import RAEDiTTransformer2DModel +from diffusers import RAEDiT2DModel -transformer = RAEDiTTransformer2DModel.from_pretrained("path/to/converted-stage2-transformer") +transformer = RAEDiT2DModel.from_pretrained("path/to/converted-stage2-transformer") ``` -## RAEDiTTransformer2DModel +## RAEDiT2DModel -[[autodoc]] RAEDiTTransformer2DModel +[[autodoc]] RAEDiT2DModel diff --git a/docs/source/en/api/pipelines/rae_dit.md b/docs/source/en/api/pipelines/rae_dit.md index 53be9f125d43..613da874928b 100644 --- a/docs/source/en/api/pipelines/rae_dit.md +++ b/docs/source/en/api/pipelines/rae_dit.md @@ -18,7 +18,7 @@ latent space. [`RAEDiTPipeline`] implements the Stage-2 class-conditional generator in Diffusers. It combines: -- [`RAEDiTTransformer2DModel`] for latent denoising +- [`RAEDiT2DModel`] for latent denoising - [`FlowMatchEulerDiscreteScheduler`] for the denoising trajectory - [`AutoencoderRAE`] for decoding latent samples to RGB images diff --git a/examples/research_projects/rae_dit/README.md b/examples/research_projects/rae_dit/README.md index 59179740df4f..7e6be56004d2 100644 --- a/examples/research_projects/rae_dit/README.md +++ b/examples/research_projects/rae_dit/README.md @@ -1,6 +1,6 @@ # Training RAEDiT Stage 2 -This folder contains the minimal Stage-2 follow-up for the RAE integration: training `RAEDiTTransformer2DModel` on top of a frozen `AutoencoderRAE`. +This folder contains the minimal Stage-2 follow-up for the RAE integration: training `RAEDiT2DModel` on top of a frozen `AutoencoderRAE`. It is intentionally placed under `examples/research_projects/rae_dit/` rather than the top-level `examples/` trainers because this is still an experimental follow-up to the new RAE support. @@ -21,7 +21,7 @@ This is a minimal full-finetuning scaffold, not a paper-complete training stack. - loads a frozen pretrained `AutoencoderRAE` - encodes RGB images to normalized Stage-1 latents on the fly -- trains only the Stage-2 `RAEDiTTransformer2DModel` +- trains only the Stage-2 `RAEDiT2DModel` - uses `FlowMatchEulerDiscreteScheduler` with the same shifted-sigma schedule shape already used elsewhere in `diffusers` - consumes ImageFolder class ids as `class_labels` - saves the trained transformer under `output_dir/transformer` @@ -99,5 +99,5 @@ accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ ## Notes - The script derives a default flow shift from the latent dimensionality as `sqrt(latent_dim / time_shift_base)`, matching the upstream Stage-2 heuristic at a high level. -- The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiTTransformer2DModel` operates on 2D latent feature maps. +- The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiT2DModel` operates on 2D latent feature maps. - This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents, validation sampling through the pipeline, and broader parity tooling. diff --git a/examples/research_projects/rae_dit/compare_stage2_sample.py b/examples/research_projects/rae_dit/compare_stage2_sample.py index 03ad194d8677..5bada734b5e2 100644 --- a/examples/research_projects/rae_dit/compare_stage2_sample.py +++ b/examples/research_projects/rae_dit/compare_stage2_sample.py @@ -25,7 +25,7 @@ from PIL import Image, ImageDraw from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler -from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel +from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel def parse_args() -> argparse.Namespace: @@ -130,7 +130,7 @@ def main(): upstream_model.to(device=device, dtype=torch.float32) upstream_model.eval() - hf_model = RAEDiTTransformer2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) + hf_model = RAEDiT2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) hf_model.to(device=device, dtype=torch.float32) hf_model.eval() diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index d2fdcca5f110..427626544e74 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -31,7 +31,7 @@ from torchvision.datasets import ImageFolder from tqdm.auto import tqdm -from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTTransformer2DModel +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiT2DModel from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 from diffusers.utils import check_min_version @@ -44,7 +44,7 @@ def parse_args(): - parser = argparse.ArgumentParser(description="Minimal stage-2 trainer for RAEDiTTransformer2DModel.") + parser = argparse.ArgumentParser(description="Minimal stage-2 trainer for RAEDiT2DModel.") parser.add_argument( "--train_data_dir", type=str, @@ -362,7 +362,7 @@ def main(): ) if args.pretrained_transformer_model_name_or_path is not None: - transformer = RAEDiTTransformer2DModel.from_pretrained(args.pretrained_transformer_model_name_or_path) + transformer = RAEDiT2DModel.from_pretrained(args.pretrained_transformer_model_name_or_path) if transformer.config.in_channels != latent_channels or transformer.config.sample_size != latent_size: raise ValueError( "Loaded transformer latent shape does not match the selected AutoencoderRAE. " @@ -374,7 +374,7 @@ def main(): f"Loaded transformer supports {transformer.config.num_classes} classes but dataset requires {num_classes}." ) else: - transformer = RAEDiTTransformer2DModel( + transformer = RAEDiT2DModel( sample_size=latent_size, patch_size=args.patch_size, in_channels=latent_channels, @@ -432,7 +432,7 @@ def save_model_hook(models, weights, output_dir): return for model in models: - if isinstance(unwrap_model(accelerator, model), RAEDiTTransformer2DModel): + if isinstance(unwrap_model(accelerator, model), RAEDiT2DModel): unwrap_model(accelerator, model).save_pretrained(os.path.join(output_dir, "transformer")) else: raise ValueError(f"Unexpected model type during save: {type(model)}") @@ -444,10 +444,10 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() target_model = unwrap_model(accelerator, model) - if not isinstance(target_model, RAEDiTTransformer2DModel): + if not isinstance(target_model, RAEDiT2DModel): raise ValueError(f"Unexpected model type during load: {type(model)}") - load_model = RAEDiTTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") + load_model = RAEDiT2DModel.from_pretrained(input_dir, subfolder="transformer") target_model.register_to_config(**load_model.config) target_model.load_state_dict(load_model.state_dict()) del load_model diff --git a/examples/research_projects/rae_dit/verify_stage2_parity.py b/examples/research_projects/rae_dit/verify_stage2_parity.py index db1570b66c44..0948c9ecbae0 100644 --- a/examples/research_projects/rae_dit/verify_stage2_parity.py +++ b/examples/research_projects/rae_dit/verify_stage2_parity.py @@ -23,7 +23,7 @@ import torch import yaml -from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel +from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel def parse_args() -> argparse.Namespace: @@ -156,7 +156,7 @@ def main(): upstream_model.to(device=device, dtype=torch.float32) upstream_model.eval() - hf_model = RAEDiTTransformer2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) + hf_model = RAEDiT2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) hf_model.to(device=device, dtype=torch.float32) hf_model.eval() diff --git a/scripts/convert_rae_stage2_to_diffusers.py b/scripts/convert_rae_stage2_to_diffusers.py index 376321db20e5..660956eac718 100644 --- a/scripts/convert_rae_stage2_to_diffusers.py +++ b/scripts/convert_rae_stage2_to_diffusers.py @@ -9,7 +9,7 @@ from huggingface_hub import HfApi, hf_hub_download from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline -from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel +from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel DEFAULT_NUM_TRAIN_TIMESTEPS = 1000 @@ -221,7 +221,7 @@ def convert_transformer_state_dict( state_dict = unwrap_state_dict(raw_checkpoint, checkpoint_key=checkpoint_key, prefer_ema=prefer_ema) with torch.device("meta"): - model = RAEDiTTransformer2DModel(**transformer_config) + model = RAEDiT2DModel(**transformer_config) load_result = model.load_state_dict(state_dict, strict=False, assign=True) missing_keys = set(load_result.missing_keys) @@ -251,11 +251,9 @@ def convert_transformer_state_dict( model.save_pretrained(output_dir, safe_serialization=safe_serialization) if verify_load: - reloaded = RAEDiTTransformer2DModel.from_pretrained(output_dir, low_cpu_mem_usage=False) - if not isinstance(reloaded, RAEDiTTransformer2DModel): - raise RuntimeError( - f"Verification failed for {component_name}: reloaded object is not RAEDiTTransformer2DModel." - ) + reloaded = RAEDiT2DModel.from_pretrained(output_dir, low_cpu_mem_usage=False) + if not isinstance(reloaded, RAEDiT2DModel): + raise RuntimeError(f"Verification failed for {component_name}: reloaded object is not RAEDiT2DModel.") return { "checkpoint_path": str(checkpoint_path), @@ -422,7 +420,7 @@ def convert(args: argparse.Namespace) -> None: if args.vae_model_name_or_path is not None: vae = AutoencoderRAE.from_pretrained(args.vae_model_name_or_path) - transformer = RAEDiTTransformer2DModel.from_pretrained(transformer_output_dir, low_cpu_mem_usage=False) + transformer = RAEDiT2DModel.from_pretrained(transformer_output_dir, low_cpu_mem_usage=False) scheduler_for_pipe = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_output_dir) id2label = None diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1dd26e855e48..63cc9e34d6f3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -261,7 +261,7 @@ "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", - "RAEDiTTransformer2DModel", + "RAEDiT2DModel", "SanaControlNetModel", "SanaTransformer2DModel", "SanaVideoTransformer3DModel", @@ -1038,7 +1038,7 @@ QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, - RAEDiTTransformer2DModel, + RAEDiT2DModel, SanaControlNetModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 10c1ed81c8a8..c33e1d8222aa 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -118,7 +118,7 @@ _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] - _import_structure["transformers.transformer_rae_dit"] = ["RAEDiTTransformer2DModel"] + _import_structure["transformers.transformer_rae_dit"] = ["RAEDiT2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] @@ -239,7 +239,7 @@ PriorTransformer, PRXTransformer2DModel, QwenImageTransformer2DModel, - RAEDiTTransformer2DModel, + RAEDiT2DModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 70df11c26ba4..5a174d5e1b07 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -44,7 +44,7 @@ from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel - from .transformer_rae_dit import RAEDiTTransformer2DModel + from .transformer_rae_dit import RAEDiT2DModel from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_rae_dit.py b/src/diffusers/models/transformers/transformer_rae_dit.py index 44963a4acbcb..ac2341c7def8 100644 --- a/src/diffusers/models/transformers/transformer_rae_dit.py +++ b/src/diffusers/models/transformers/transformer_rae_dit.py @@ -322,7 +322,7 @@ def forward(self, hidden_states: torch.Tensor, conditioning: torch.Tensor) -> to return hidden_states -class RAEDiTTransformer2DModel(ModelMixin, ConfigMixin): +class RAEDiT2DModel(ModelMixin, ConfigMixin): r""" Stage-2 latent diffusion transformer used by the RAE paper. diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py index 136087afe2ae..d1823f8a6107 100644 --- a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -3,7 +3,7 @@ import torch from ...models import AutoencoderRAE -from ...models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel +from ...models.transformers.transformer_rae_dit import RAEDiT2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -14,7 +14,7 @@ class RAEDiTPipeline(DiffusionPipeline): Pipeline for class-conditioned image generation in RAE latent space. Parameters: - transformer ([`RAEDiTTransformer2DModel`]): + transformer ([`RAEDiT2DModel`]): Class-conditioned latent transformer used for Stage-2 denoising in RAE latent space. vae ([`AutoencoderRAE`]): Representation autoencoder used to decode latent samples back to RGB images. @@ -26,7 +26,7 @@ class RAEDiTPipeline(DiffusionPipeline): def __init__( self, - transformer: RAEDiTTransformer2DModel, + transformer: RAEDiT2DModel, vae: AutoencoderRAE, scheduler: FlowMatchEulerDiscreteScheduler, id2label: dict[int, str] | None = None, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3dd01849b12b..ff89538605fc 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1541,7 +1541,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class RAEDiTTransformer2DModel(metaclass=DummyObject): +class RAEDiT2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/transformers/test_models_rae_dit_transformer2d.py b/tests/models/transformers/test_models_rae_dit_transformer2d.py index a39e7a4f67e6..aebc09fbd7b5 100644 --- a/tests/models/transformers/test_models_rae_dit_transformer2d.py +++ b/tests/models/transformers/test_models_rae_dit_transformer2d.py @@ -17,7 +17,7 @@ import torch -from diffusers import RAEDiTTransformer2DModel +from diffusers import RAEDiT2DModel from diffusers.models.transformers.transformer_rae_dit import _repeat_to_length from ...testing_utils import enable_full_determinism, floats_tensor, torch_device @@ -27,7 +27,7 @@ enable_full_determinism() -def _initialize_non_zero_stage2_head(model: RAEDiTTransformer2DModel): +def _initialize_non_zero_stage2_head(model: RAEDiT2DModel): torch.manual_seed(0) for block in model.blocks: @@ -40,8 +40,8 @@ def _initialize_non_zero_stage2_head(model: RAEDiTTransformer2DModel): model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) -class RAEDiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase): - model_class = RAEDiTTransformer2DModel +class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = RAEDiT2DModel main_input_name = "hidden_states" @property @@ -193,7 +193,7 @@ def test_repeat_to_length_rejects_incompatible_multi_token_layouts(self): _repeat_to_length(hidden_states, target_length=8) def test_gradient_checkpointing_is_applied(self): - expected_set = {"RAEDiTTransformer2DModel"} + expected_set = {"RAEDiT2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_effective_gradient_checkpointing(self): diff --git a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py index 7d1c40cfb377..ae5935d25d07 100644 --- a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py +++ b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py @@ -23,7 +23,7 @@ import diffusers.models.autoencoders.autoencoder_rae as _rae_module from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline from diffusers.models.autoencoders.autoencoder_rae import _ENCODER_FORWARD_FNS, _build_encoder -from diffusers.models.transformers.transformer_rae_dit import RAEDiTTransformer2DModel +from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel from ...testing_utils import enable_full_determinism, torch_device from ..pipeline_params import ( @@ -36,7 +36,7 @@ enable_full_determinism() -def _initialize_non_zero_stage2_head(model: RAEDiTTransformer2DModel): +def _initialize_non_zero_stage2_head(model: RAEDiT2DModel): torch.manual_seed(0) for block in model.blocks: @@ -97,7 +97,7 @@ def tearDownClass(cls): def get_dummy_components(self): torch.manual_seed(0) - transformer = RAEDiTTransformer2DModel( + transformer = RAEDiT2DModel( sample_size=2, patch_size=1, in_channels=8, From dc437f9975086e65be7494cd4d3f17a39b7c2874 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 10:51:52 +0530 Subject: [PATCH 09/12] Fix RAE DiT review regressions --- .../research_projects/rae_dit/test_rae_dit.py | 21 +++++ .../rae_dit/train_rae_dit.py | 77 ++++++++++++------- scripts/convert_rae_stage2_to_diffusers.py | 13 +++- .../pipelines/rae_dit/pipeline_rae_dit.py | 9 ++- tests/others/test_rae_dit_conversion.py | 26 +++++++ .../rae_dit/test_pipeline_rae_dit.py | 25 +++++- 6 files changed, 138 insertions(+), 33 deletions(-) create mode 100644 tests/others/test_rae_dit_conversion.py diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py index d320d562386c..77892db69663 100644 --- a/examples/research_projects/rae_dit/test_rae_dit.py +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -17,14 +17,17 @@ import os import sys import tempfile +from types import SimpleNamespace from PIL import Image from diffusers import AutoencoderRAE +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 +from train_rae_dit import maybe_load_resumed_scheduler # noqa: E402 logging.basicConfig(level=logging.DEBUG) @@ -126,3 +129,21 @@ def test_verify_train_resume(self): self.assertIn("baseline_trace=", output) self.assertIn("resumed_trace=", output) self.assertIn("resume batch order verified", output) + + def test_maybe_load_resumed_scheduler_prefers_checkpoint_config(self): + args = SimpleNamespace(num_train_timesteps=999, flow_shift=2.5) + + with tempfile.TemporaryDirectory() as tmpdir: + scheduler_dir = os.path.join(tmpdir, "scheduler") + FlowMatchEulerDiscreteScheduler(num_train_timesteps=10, shift=7.0).save_pretrained(scheduler_dir) + + restored = maybe_load_resumed_scheduler( + args=args, + checkpoint_path=tmpdir, + noise_scheduler=FlowMatchEulerDiscreteScheduler(num_train_timesteps=999, shift=2.5), + ) + + self.assertEqual(restored.config.num_train_timesteps, 10) + self.assertEqual(restored.config.shift, 7.0) + self.assertEqual(args.num_train_timesteps, 10) + self.assertEqual(args.flow_shift, 7.0) diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index 427626544e74..c42b731c76c3 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -253,6 +253,22 @@ def should_skip_resumed_batch(should_resume: bool, epoch: int, first_epoch: int, return should_resume and epoch == first_epoch and step < resume_step +def maybe_load_resumed_scheduler( + args, checkpoint_path: str | None, noise_scheduler: FlowMatchEulerDiscreteScheduler +) -> FlowMatchEulerDiscreteScheduler: + if checkpoint_path is None: + return noise_scheduler + + scheduler_path = os.path.join(checkpoint_path, "scheduler") + if not os.path.isdir(scheduler_path): + return noise_scheduler + + restored_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_path) + args.num_train_timesteps = int(restored_scheduler.config.num_train_timesteps) + args.flow_shift = float(restored_scheduler.config.shift) + return restored_scheduler + + def get_latent_spec(autoencoder: AutoencoderRAE) -> tuple[int, int]: if not autoencoder.config.reshape_to_2d: raise ValueError("Stage-2 RAE DiT training expects `AutoencoderRAE.reshape_to_2d=True`.") @@ -472,6 +488,38 @@ def load_model_hook(models, input_dir): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + resume_step = 0 + resume_path = None + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + resume_path = args.resume_from_checkpoint + if not os.path.isdir(resume_path): + resume_path = os.path.join(args.output_dir, os.path.basename(resume_path)) + else: + checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + resume_path = os.path.join(args.output_dir, checkpoints[-1]) if checkpoints else None + + if resume_path is None or not os.path.isdir(resume_path): + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + else: + noise_scheduler = maybe_load_resumed_scheduler(args, resume_path, noise_scheduler) + flow_shift = float(noise_scheduler.config.shift) + accelerator.print(f"Resuming from checkpoint {resume_path}") + accelerator.load_state(resume_path) + global_step = int(os.path.basename(resume_path).split("-")[1]) + initial_global_step = global_step + first_epoch, resume_step = compute_resume_offsets( + global_step=global_step, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + if accelerator.is_main_process: accelerator.init_trackers( "train_rae_dit", @@ -497,35 +545,6 @@ def load_model_hook(models, input_dir): logger.info(f" Total train batch size = {total_batch_size}") logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - initial_global_step = 0 - resume_step = 0 - - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = args.resume_from_checkpoint - if not os.path.isdir(path): - path = os.path.join(args.output_dir, os.path.basename(path)) - else: - checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - path = os.path.join(args.output_dir, checkpoints[-1]) if checkpoints else None - - if path is None or not os.path.isdir(path): - logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") - args.resume_from_checkpoint = None - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(path) - global_step = int(os.path.basename(path).split("-")[1]) - initial_global_step = global_step - first_epoch, resume_step = compute_resume_offsets( - global_step=global_step, - num_update_steps_per_epoch=num_update_steps_per_epoch, - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, diff --git a/scripts/convert_rae_stage2_to_diffusers.py b/scripts/convert_rae_stage2_to_diffusers.py index 660956eac718..c6927e5eedf3 100644 --- a/scripts/convert_rae_stage2_to_diffusers.py +++ b/scripts/convert_rae_stage2_to_diffusers.py @@ -122,7 +122,7 @@ def unwrap_state_dict( raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") state_dict = dict(state_dict) - for prefix in ("module.", "model.", "model.module."): + for prefix in ("model.module.", "model.", "module."): state_dict = _maybe_strip_common_prefix(state_dict, prefix) return state_dict @@ -421,6 +421,9 @@ def convert(args: argparse.Namespace) -> None: if args.vae_model_name_or_path is not None: vae = AutoencoderRAE.from_pretrained(args.vae_model_name_or_path) transformer = RAEDiT2DModel.from_pretrained(transformer_output_dir, low_cpu_mem_usage=False) + guidance_transformer = None + if "guidance_transformer" in metadata: + guidance_transformer = RAEDiT2DModel.from_pretrained(guidance_output_dir, low_cpu_mem_usage=False) scheduler_for_pipe = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_output_dir) id2label = None @@ -428,7 +431,13 @@ def convert(args: argparse.Namespace) -> None: with Path(args.id2label_json_path).expanduser().open("r", encoding="utf-8") as handle: id2label = json.load(handle) - pipe = RAEDiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler_for_pipe, id2label=id2label) + pipe = RAEDiTPipeline( + transformer=transformer, + guidance_transformer=guidance_transformer, + vae=vae, + scheduler=scheduler_for_pipe, + id2label=id2label, + ) pipe.save_pretrained(output_path, safe_serialization=args.safe_serialization) metadata["pipeline"] = {"saved": True, "id2label_json_path": args.id2label_json_path} diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py index d1823f8a6107..fc11f2dad424 100644 --- a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -22,6 +22,7 @@ class RAEDiTPipeline(DiffusionPipeline): Flow-matching scheduler used to integrate the latent denoising trajectory. """ + _optional_components = ["guidance_transformer"] model_cpu_offload_seq = "transformer->vae" def __init__( @@ -29,10 +30,16 @@ def __init__( transformer: RAEDiT2DModel, vae: AutoencoderRAE, scheduler: FlowMatchEulerDiscreteScheduler, + guidance_transformer: RAEDiT2DModel | None = None, id2label: dict[int, str] | None = None, ): super().__init__() - self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + self.register_modules( + transformer=transformer, + guidance_transformer=guidance_transformer, + vae=vae, + scheduler=scheduler, + ) serialized_id2label = None if id2label is not None: serialized_id2label = {str(key): value for key, value in id2label.items()} diff --git a/tests/others/test_rae_dit_conversion.py b/tests/others/test_rae_dit_conversion.py new file mode 100644 index 000000000000..29ca175d4455 --- /dev/null +++ b/tests/others/test_rae_dit_conversion.py @@ -0,0 +1,26 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from scripts.convert_rae_stage2_to_diffusers import unwrap_state_dict + + +def test_unwrap_state_dict_strips_supported_prefixes(): + tensor = torch.randn(1) + + assert unwrap_state_dict({"model.module.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} + assert unwrap_state_dict({"model.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} + assert unwrap_state_dict({"module.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} diff --git a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py index ae5935d25d07..6e16b0051dfe 100644 --- a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py +++ b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile import unittest @@ -136,7 +137,12 @@ def get_dummy_components(self): ) scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0) - return {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler} + return { + "transformer": transformer.eval(), + "guidance_transformer": None, + "vae": vae.eval(), + "scheduler": scheduler, + } def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): @@ -253,3 +259,20 @@ def test_save_load_preserves_label_ids(self): self.assertEqual(pipe_loaded.config.id2label, {"0": "zero", "1": "one, first"}) self.assertEqual(pipe_loaded.get_label_ids("first"), [1]) self.assertEqual(pipe_loaded.get_label_ids(["zero", "one"]), [0, 1]) + + def test_save_load_preserves_guidance_transformer(self): + components = self.get_dummy_components() + guidance_transformer = RAEDiT2DModel(**components["transformer"].config) + _initialize_non_zero_stage2_head(guidance_transformer) + components["guidance_transformer"] = guidance_transformer + + pipe = self.pipeline_class(**components) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "guidance_transformer"))) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + + self.assertIsNotNone(pipe_loaded.guidance_transformer) + self.assertIsInstance(pipe_loaded.guidance_transformer, RAEDiT2DModel) + self.assertEqual(pipe_loaded.guidance_transformer.config.sample_size, guidance_transformer.config.sample_size) From fe21820a5b0c28629eb5e2c2fd29724298b6ff3a Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 10:58:39 +0530 Subject: [PATCH 10/12] Remove RAE DiT validation helper scripts from PR --- examples/research_projects/rae_dit/README.md | 16 +- .../rae_dit/compare_stage2_sample.py | 194 ------------------ .../research_projects/rae_dit/test_rae_dit.py | 136 ++++++++++-- .../rae_dit/verify_stage2_parity.py | 194 ------------------ .../rae_dit/verify_train_resume.py | 186 ----------------- 5 files changed, 119 insertions(+), 607 deletions(-) delete mode 100644 examples/research_projects/rae_dit/compare_stage2_sample.py delete mode 100644 examples/research_projects/rae_dit/verify_stage2_parity.py delete mode 100644 examples/research_projects/rae_dit/verify_train_resume.py diff --git a/examples/research_projects/rae_dit/README.md b/examples/research_projects/rae_dit/README.md index 7e6be56004d2..e47c09919fa8 100644 --- a/examples/research_projects/rae_dit/README.md +++ b/examples/research_projects/rae_dit/README.md @@ -35,20 +35,6 @@ It intentionally does not yet include: - autoguidance or the broader upstream transport stack - exact upstream distributed training/runtime features -## Parity check - -`verify_stage2_parity.py` compares a converted diffusers transformer against the upstream `DiTwDDTHead` with the same published checkpoint and synthetic latent inputs. This is the quickest way to confirm that a conversion still matches upstream numerically before opening or updating a PR. - -Example: - -```bash -python examples/research_projects/rae_dit/verify_stage2_parity.py \ - --upstream_repo_path /path/to/RAE \ - --config_path /path/to/RAE/configs/stage2/sampling/ImageNet256/DiTDHXL-DINOv2-B.yaml \ - --checkpoint_path /path/to/stage2_model.pt \ - --converted_transformer_path /path/to/diffusers-transformer -``` - ## Dataset format The script expects an `ImageFolder`-compatible dataset: @@ -100,4 +86,4 @@ accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ - The script derives a default flow shift from the latent dimensionality as `sqrt(latent_dim / time_shift_base)`, matching the upstream Stage-2 heuristic at a high level. - The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiT2DModel` operates on 2D latent feature maps. -- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents, validation sampling through the pipeline, and broader parity tooling. +- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents and validation sampling through the pipeline. diff --git a/examples/research_projects/rae_dit/compare_stage2_sample.py b/examples/research_projects/rae_dit/compare_stage2_sample.py deleted file mode 100644 index 5bada734b5e2..000000000000 --- a/examples/research_projects/rae_dit/compare_stage2_sample.py +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import math -import sys -from pathlib import Path -from typing import Any - -import torch -import yaml -from PIL import Image, ImageDraw - -from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler -from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Create a visual side-by-side sample comparison between upstream and diffusers Stage-2 RAE DiT." - ) - parser.add_argument("--upstream_repo_path", type=str, required=True) - parser.add_argument("--config_path", type=str, required=True) - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--converted_transformer_path", type=str, required=True) - parser.add_argument("--vae_model_name_or_path", type=str, required=True) - parser.add_argument("--output_path", type=str, required=True) - parser.add_argument("--class_label", type=int, default=207, help="ImageNet class id to sample.") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--num_inference_steps", type=int, default=25) - parser.add_argument("--device", type=str, default=None) - return parser.parse_args() - - -def _resolve_section(config: dict[str, Any], *keys: str) -> dict[str, Any]: - for key in keys: - section = config.get(key) - if isinstance(section, dict): - return section - raise KeyError(f"Could not find any of {keys} in config.") - - -def _maybe_strip_common_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: - if len(state_dict) > 0 and all(key.startswith(prefix) for key in state_dict): - return {key[len(prefix) :]: value for key, value in state_dict.items()} - return state_dict - - -def unwrap_state_dict(maybe_wrapped: dict[str, Any], prefer_ema: bool = True) -> dict[str, Any]: - state_dict: dict[str, Any] | Any = maybe_wrapped - if isinstance(state_dict, dict): - candidate_keys = ["ema", "model", "state_dict"] if prefer_ema else ["model", "ema", "state_dict"] - for key in candidate_keys: - if key in state_dict and isinstance(state_dict[key], dict): - state_dict = state_dict[key] - break - - if not isinstance(state_dict, dict): - raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") - - state_dict = dict(state_dict) - for prefix in ("module.", "model.", "model.module."): - state_dict = _maybe_strip_common_prefix(state_dict, prefix) - return state_dict - - -def load_checkpoint(checkpoint_path: Path) -> dict[str, Any]: - if checkpoint_path.suffix.lower() == ".safetensors": - import safetensors.torch - - return safetensors.torch.load_file(checkpoint_path) - - return torch.load(checkpoint_path, map_location="cpu") - - -def latent_to_pil(image: torch.Tensor) -> Image.Image: - array = image.detach().cpu().clamp(0, 1).permute(1, 2, 0).mul(255).round().byte().numpy() - return Image.fromarray(array) - - -def draw_label(image: Image.Image, text: str) -> Image.Image: - canvas = Image.new("RGB", (image.width, image.height + 24), color="white") - canvas.paste(image, (0, 24)) - draw = ImageDraw.Draw(canvas) - draw.text((8, 4), text, fill="black") - return canvas - - -def main(): - args = parse_args() - upstream_repo_path = Path(args.upstream_repo_path).expanduser().resolve() - config_path = Path(args.config_path).expanduser().resolve() - checkpoint_path = Path(args.checkpoint_path).expanduser().resolve() - converted_transformer_path = Path(args.converted_transformer_path).expanduser().resolve() - output_path = Path(args.output_path).expanduser().resolve() - - sys.path.insert(0, str(upstream_repo_path / "src")) - from stage2.models.DDT import DiTwDDTHead - - with config_path.open("r", encoding="utf-8") as handle: - config = yaml.safe_load(handle) - - stage2 = _resolve_section(config, "stage_2", "stage2") - stage2_params = stage2.get("params", {}) - misc = _resolve_section(config, "misc") - latent_size = misc["latent_size"] - shift = math.sqrt( - int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096)) - ) - num_train_timesteps = int(_resolve_section(config, "transport").get("params", {}).get("num_train_timesteps", 1000)) - - device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) - state_dict = unwrap_state_dict(load_checkpoint(checkpoint_path), prefer_ema=True) - - upstream_model = DiTwDDTHead(**stage2_params) - upstream_model.load_state_dict(state_dict, strict=True) - upstream_model.to(device=device, dtype=torch.float32) - upstream_model.eval() - - hf_model = RAEDiT2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) - hf_model.to(device=device, dtype=torch.float32) - hf_model.eval() - - vae = AutoencoderRAE.from_pretrained(args.vae_model_name_or_path, low_cpu_mem_usage=False) - vae.to(device=device, dtype=torch.float32) - vae.eval() - - generator = torch.Generator(device=device).manual_seed(args.seed) - latents_init = torch.randn( - (1, int(stage2_params["in_channels"]), int(stage2_params["input_size"]), int(stage2_params["input_size"])), - generator=generator, - device=device, - dtype=torch.float32, - ) - class_labels = torch.tensor([args.class_label], device=device, dtype=torch.long) - - def run_sample(model, latents): - latents = latents.clone() - scheduler = FlowMatchEulerDiscreteScheduler( - num_train_timesteps=num_train_timesteps, - shift=shift, - stochastic_sampling=False, - ) - scheduler.set_timesteps(args.num_inference_steps, device=device) - with torch.no_grad(): - for timestep in scheduler.timesteps: - timestep_input = timestep.expand(latents.shape[0]) / scheduler.config.num_train_timesteps - if isinstance(model, DiTwDDTHead): - model_output = model(latents, timestep_input, class_labels) - else: - model_output = model( - hidden_states=latents, timestep=timestep_input, class_labels=class_labels - ).sample - latents = scheduler.step(model_output, timestep, latents).prev_sample - return vae.decode(latents).sample.clamp(0, 1) - - upstream_image = run_sample(upstream_model, latents_init)[0] - diffusers_image = run_sample(hf_model, latents_init)[0] - abs_diff = (upstream_image - diffusers_image).abs() - diff_vis = (abs_diff / max(abs_diff.max().item(), 1e-8)).clamp(0, 1) - - max_abs_error = abs_diff.max().item() - mean_abs_error = abs_diff.mean().item() - print(f"max_abs_error={max_abs_error:.8f}") - print(f"mean_abs_error={mean_abs_error:.8f}") - - upstream_pil = draw_label(latent_to_pil(upstream_image), "Upstream") - diffusers_pil = draw_label(latent_to_pil(diffusers_image), "Diffusers") - diff_pil = draw_label(latent_to_pil(diff_vis), "Abs Diff") - - canvas = Image.new("RGB", (upstream_pil.width * 3, upstream_pil.height), color="white") - canvas.paste(upstream_pil, (0, 0)) - canvas.paste(diffusers_pil, (upstream_pil.width, 0)) - canvas.paste(diff_pil, (upstream_pil.width * 2, 0)) - output_path.parent.mkdir(parents=True, exist_ok=True) - canvas.save(output_path) - print(output_path) - - -if __name__ == "__main__": - main() diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py index 77892db69663..aa22299026cb 100644 --- a/examples/research_projects/rae_dit/test_rae_dit.py +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -14,12 +14,18 @@ # limitations under the License. import logging +import math import os import sys import tempfile +from pathlib import Path from types import SimpleNamespace +from accelerate import Accelerator +from accelerate.utils import set_seed from PIL import Image +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder from diffusers import AutoencoderRAE from diffusers.schedulers import FlowMatchEulerDiscreteScheduler @@ -27,7 +33,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 -from train_rae_dit import maybe_load_resumed_scheduler # noqa: E402 +from train_rae_dit import ( # noqa: E402 + build_transforms, + collate_fn, + compute_resume_offsets, + maybe_load_resumed_scheduler, + should_skip_resumed_batch, +) logging.basicConfig(level=logging.DEBUG) @@ -37,6 +49,78 @@ logger.addHandler(stream_handler) +def _create_unique_class_dataset(dataset_dir: Path, resolution: int, num_samples: int): + for sample_idx in range(num_samples): + class_dir = dataset_dir / f"class_{sample_idx:02d}" + class_dir.mkdir(parents=True, exist_ok=True) + color = ((40 * sample_idx) % 256, (80 * sample_idx) % 256, (120 * sample_idx) % 256) + image = Image.new("RGB", (resolution, resolution), color=color) + image.save(class_dir / f"sample_{sample_idx}.png") + + +def _collect_class_label_trace( + dataset_dir: Path, + *, + seed: int, + resolution: int, + train_batch_size: int, + gradient_accumulation_steps: int, + max_train_steps: int, + resume_global_step: int = 0, +) -> list[int]: + accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + set_seed(seed) + + transform_args = SimpleNamespace(resolution=resolution, center_crop=True, random_flip=False) + dataset = ImageFolder(dataset_dir, transform=build_transforms(transform_args)) + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=train_batch_size, + num_workers=0, + pin_memory=True, + drop_last=True, + ) + train_dataloader = accelerator.prepare(train_dataloader) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + first_epoch = 0 + resume_step = 0 + should_resume = resume_global_step > 0 + + if should_resume: + first_epoch, resume_step = compute_resume_offsets( + global_step=resume_global_step, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + + expected_microbatches = (max_train_steps - resume_global_step) * gradient_accumulation_steps + trace = [] + + for epoch in range(first_epoch, num_train_epochs): + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + for step, batch in enumerate(train_dataloader): + if should_skip_resumed_batch( + should_resume=should_resume, + epoch=epoch, + first_epoch=first_epoch, + step=step, + resume_step=resume_step, + ): + continue + + trace.extend(batch["class_labels"].tolist()) + if len(trace) >= expected_microbatches: + return trace + + raise AssertionError(f"Expected to record {expected_microbatches} microbatches, but only collected {len(trace)}.") + + class RAEDiT(ExamplesTestsAccelerate): def _create_tiny_rae(self, tmpdir): model = AutoencoderRAE( @@ -112,23 +196,39 @@ def test_train_rae_dit_smoke(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "id2label.json"))) - def test_verify_train_resume(self): - test_args = """ - examples/research_projects/rae_dit/verify_train_resume.py - --seed 123 - --resolution 16 - --num_samples 6 - --train_batch_size 1 - --gradient_accumulation_steps 2 - --max_train_steps 3 - --resume_global_step 1 - """.split() - - output = run_command(self._launch_args + test_args, return_stdout=True) - - self.assertIn("baseline_trace=", output) - self.assertIn("resumed_trace=", output) - self.assertIn("resume batch order verified", output) + def test_resume_batch_order_matches_uninterrupted_tail(self): + seed = 123 + resolution = 16 + num_samples = 6 + train_batch_size = 1 + gradient_accumulation_steps = 2 + max_train_steps = 3 + resume_global_step = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) / "trace-dataset" + _create_unique_class_dataset(dataset_dir, resolution=resolution, num_samples=num_samples) + + baseline_trace = _collect_class_label_trace( + dataset_dir, + seed=seed, + resolution=resolution, + train_batch_size=train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + max_train_steps=max_train_steps, + ) + resumed_trace = _collect_class_label_trace( + dataset_dir, + seed=seed, + resolution=resolution, + train_batch_size=train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + max_train_steps=max_train_steps, + resume_global_step=resume_global_step, + ) + + consumed_microbatches = resume_global_step * gradient_accumulation_steps + self.assertEqual(resumed_trace, baseline_trace[consumed_microbatches:]) def test_maybe_load_resumed_scheduler_prefers_checkpoint_config(self): args = SimpleNamespace(num_train_timesteps=999, flow_shift=2.5) diff --git a/examples/research_projects/rae_dit/verify_stage2_parity.py b/examples/research_projects/rae_dit/verify_stage2_parity.py deleted file mode 100644 index 0948c9ecbae0..000000000000 --- a/examples/research_projects/rae_dit/verify_stage2_parity.py +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import math -import sys -from pathlib import Path -from typing import Any - -import torch -import yaml - -from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Compare a converted RAEDiT checkpoint against the upstream Stage-2 model." - ) - parser.add_argument("--upstream_repo_path", type=str, required=True, help="Path to the cloned upstream RAE repo.") - parser.add_argument( - "--config_path", - type=str, - required=True, - help="Path to the upstream Stage-2 YAML config used for the published checkpoint.", - ) - parser.add_argument( - "--checkpoint_path", - type=str, - required=True, - help="Path to the upstream Stage-2 checkpoint (.pt or .safetensors).", - ) - parser.add_argument( - "--converted_transformer_path", - type=str, - required=True, - help="Path to the converted diffusers transformer directory.", - ) - parser.add_argument("--device", type=str, default=None, help="Torch device to use. Defaults to cuda if available.") - parser.add_argument("--seed", type=int, default=0, help="Random seed for the synthetic parity inputs.") - parser.add_argument("--batch_size", type=int, default=2, help="Batch size for the parity run.") - parser.add_argument("--rtol", type=float, default=1e-4, help="Relative tolerance for parity.") - parser.add_argument("--atol", type=float, default=1e-5, help="Absolute tolerance for parity.") - return parser.parse_args() - - -def _resolve_section(config: dict[str, Any], *keys: str) -> dict[str, Any]: - for key in keys: - section = config.get(key) - if isinstance(section, dict): - return section - raise KeyError(f"Could not find any of {keys} in config.") - - -def _maybe_strip_common_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: - if len(state_dict) > 0 and all(key.startswith(prefix) for key in state_dict): - return {key[len(prefix) :]: value for key, value in state_dict.items()} - return state_dict - - -def unwrap_state_dict(maybe_wrapped: dict[str, Any], prefer_ema: bool = True) -> dict[str, Any]: - state_dict: dict[str, Any] | Any = maybe_wrapped - - if isinstance(state_dict, dict): - candidate_keys = ["ema", "model", "state_dict"] if prefer_ema else ["model", "ema", "state_dict"] - for key in candidate_keys: - if key in state_dict and isinstance(state_dict[key], dict): - state_dict = state_dict[key] - break - - if not isinstance(state_dict, dict): - raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") - - state_dict = dict(state_dict) - for prefix in ("module.", "model.", "model.module."): - state_dict = _maybe_strip_common_prefix(state_dict, prefix) - return state_dict - - -def load_checkpoint(checkpoint_path: Path) -> dict[str, Any]: - if checkpoint_path.suffix.lower() == ".safetensors": - import safetensors.torch - - return safetensors.torch.load_file(checkpoint_path) - - return torch.load(checkpoint_path, map_location="cpu") - - -def build_inputs( - batch_size: int, - in_channels: int, - sample_size: int, - num_classes: int, - shift: float, - seed: int, - device: torch.device, -): - generator = torch.Generator(device=device).manual_seed(seed) - clean_latents = torch.randn( - (batch_size, in_channels, sample_size, sample_size), generator=generator, device=device, dtype=torch.float32 - ) - noise = torch.randn(clean_latents.shape, generator=generator, device=device, dtype=torch.float32) - - # Use a spread of normalized timesteps inside the open interval to avoid any - # boundary-case special handling around t=0 or t=1. - timesteps = torch.linspace(0.2, 0.8, steps=batch_size, device=device, dtype=torch.float32) - sigma = shift * timesteps / (1 + (shift - 1) * timesteps) - sigma = sigma.view(-1, 1, 1, 1) - - noised_latents = (1.0 - sigma) * clean_latents + sigma * noise - class_labels = torch.arange(batch_size, device=device, dtype=torch.long) % num_classes - return noised_latents, timesteps, class_labels - - -def main(): - args = parse_args() - - upstream_repo_path = Path(args.upstream_repo_path).expanduser().resolve() - sys.path.insert(0, str(upstream_repo_path / "src")) - - from stage2.models.DDT import DiTwDDTHead - - config_path = Path(args.config_path).expanduser().resolve() - checkpoint_path = Path(args.checkpoint_path).expanduser().resolve() - converted_transformer_path = Path(args.converted_transformer_path).expanduser().resolve() - - with config_path.open("r", encoding="utf-8") as handle: - config = yaml.safe_load(handle) - - stage2 = _resolve_section(config, "stage_2", "stage2") - stage2_params = stage2.get("params", {}) - misc = _resolve_section(config, "misc") - latent_size = misc["latent_size"] - shift = math.sqrt( - int(misc.get("time_dist_shift_dim", math.prod(latent_size))) / int(misc.get("time_dist_shift_base", 4096)) - ) - - device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) - state_dict = unwrap_state_dict(load_checkpoint(checkpoint_path), prefer_ema=True) - - upstream_model = DiTwDDTHead(**stage2_params) - upstream_model.load_state_dict(state_dict, strict=True) - upstream_model.to(device=device, dtype=torch.float32) - upstream_model.eval() - - hf_model = RAEDiT2DModel.from_pretrained(converted_transformer_path, low_cpu_mem_usage=False) - hf_model.to(device=device, dtype=torch.float32) - hf_model.eval() - - noised_latents, timesteps, class_labels = build_inputs( - batch_size=args.batch_size, - in_channels=int(stage2_params["in_channels"]), - sample_size=int(stage2_params["input_size"]), - num_classes=int(stage2_params.get("num_classes", misc.get("num_classes", 1000))), - shift=shift, - seed=args.seed, - device=device, - ) - - with torch.no_grad(): - upstream_output = upstream_model(noised_latents, timesteps, class_labels) - hf_output = hf_model(hidden_states=noised_latents, timestep=timesteps, class_labels=class_labels).sample - - abs_error = (upstream_output - hf_output).abs() - max_abs_error = abs_error.max().item() - mean_abs_error = abs_error.mean().item() - - print(f"device={device}") - print(f"shape={tuple(hf_output.shape)}") - print(f"max_abs_error={max_abs_error:.8f}") - print(f"mean_abs_error={mean_abs_error:.8f}") - - if not torch.allclose(upstream_output, hf_output, atol=args.atol, rtol=args.rtol): - raise AssertionError( - f"Parity failed: max_abs_error={max_abs_error:.8f}, mean_abs_error={mean_abs_error:.8f}, " - f"expected atol={args.atol}, rtol={args.rtol}" - ) - - -if __name__ == "__main__": - main() diff --git a/examples/research_projects/rae_dit/verify_train_resume.py b/examples/research_projects/rae_dit/verify_train_resume.py deleted file mode 100644 index 62fc9cd8c413..000000000000 --- a/examples/research_projects/rae_dit/verify_train_resume.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import argparse -import json -import math -import tempfile -from pathlib import Path -from types import SimpleNamespace - -from accelerate import Accelerator -from accelerate.utils import set_seed -from PIL import Image -from torch.utils.data import DataLoader -from torchvision.datasets import ImageFolder -from train_rae_dit import build_transforms, collate_fn, compute_resume_offsets, should_skip_resumed_batch - - -def parse_args(): - parser = argparse.ArgumentParser(description="Verify Stage-2 RAE DiT mid-epoch resume batch ordering.") - parser.add_argument("--seed", type=int, default=123, help="Seed used for the shuffled dataloader.") - parser.add_argument("--resolution", type=int, default=16, help="Synthetic image resolution.") - parser.add_argument("--num_samples", type=int, default=6, help="Number of unique samples/classes to create.") - parser.add_argument("--train_batch_size", type=int, default=1, help="Microbatch size used by the trace harness.") - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=2, - help="Gradient accumulation steps used to derive the mid-epoch checkpoint position.", - ) - parser.add_argument("--max_train_steps", type=int, default=3, help="Total optimizer steps to trace.") - parser.add_argument( - "--resume_global_step", - type=int, - default=1, - help="Optimizer step at which the synthetic run resumes from a checkpoint.", - ) - return parser.parse_args() - - -def create_unique_class_dataset(dataset_dir: Path, resolution: int, num_samples: int): - for sample_idx in range(num_samples): - class_dir = dataset_dir / f"class_{sample_idx:02d}" - class_dir.mkdir(parents=True, exist_ok=True) - color = ( - (40 * sample_idx) % 256, - (80 * sample_idx) % 256, - (120 * sample_idx) % 256, - ) - image = Image.new("RGB", (resolution, resolution), color=color) - image.save(class_dir / f"sample_{sample_idx}.png") - - -def collect_class_label_trace( - dataset_dir: Path, - *, - seed: int, - resolution: int, - train_batch_size: int, - gradient_accumulation_steps: int, - max_train_steps: int, - resume_global_step: int = 0, -) -> list[int]: - accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) - set_seed(seed) - - transform_args = SimpleNamespace(resolution=resolution, center_crop=True, random_flip=False) - dataset = ImageFolder(dataset_dir, transform=build_transforms(transform_args)) - train_dataloader = DataLoader( - dataset, - shuffle=True, - collate_fn=collate_fn, - batch_size=train_batch_size, - num_workers=0, - pin_memory=True, - drop_last=True, - ) - train_dataloader = accelerator.prepare(train_dataloader) - - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) - first_epoch = 0 - resume_step = 0 - should_resume = resume_global_step > 0 - - if should_resume: - first_epoch, resume_step = compute_resume_offsets( - global_step=resume_global_step, - num_update_steps_per_epoch=num_update_steps_per_epoch, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - - expected_microbatches = (max_train_steps - resume_global_step) * gradient_accumulation_steps - trace = [] - - for epoch in range(first_epoch, num_train_epochs): - if hasattr(train_dataloader, "set_epoch"): - train_dataloader.set_epoch(epoch) - - for step, batch in enumerate(train_dataloader): - if should_skip_resumed_batch( - should_resume=should_resume, - epoch=epoch, - first_epoch=first_epoch, - step=step, - resume_step=resume_step, - ): - continue - - trace.extend(batch["class_labels"].tolist()) - if len(trace) >= expected_microbatches: - return trace - - raise AssertionError(f"Expected to record {expected_microbatches} microbatches, but only collected {len(trace)}.") - - -def main(): - args = parse_args() - - if args.resume_global_step >= args.max_train_steps: - raise ValueError( - f"`resume_global_step` ({args.resume_global_step}) must be < `max_train_steps` ({args.max_train_steps})." - ) - microbatches_per_epoch = args.num_samples // args.train_batch_size - required_microbatches = args.max_train_steps * args.gradient_accumulation_steps - if microbatches_per_epoch < required_microbatches: - raise ValueError( - "The verifier keeps the proof inside a single epoch. Increase `--num_samples` or decrease " - "`--train_batch_size`, `--gradient_accumulation_steps`, or `--max_train_steps`." - ) - - with tempfile.TemporaryDirectory() as tmpdir: - dataset_dir = Path(tmpdir) / "trace-dataset" - create_unique_class_dataset(dataset_dir, resolution=args.resolution, num_samples=args.num_samples) - - baseline_trace = collect_class_label_trace( - dataset_dir, - seed=args.seed, - resolution=args.resolution, - train_batch_size=args.train_batch_size, - gradient_accumulation_steps=args.gradient_accumulation_steps, - max_train_steps=args.max_train_steps, - ) - resumed_trace = collect_class_label_trace( - dataset_dir, - seed=args.seed, - resolution=args.resolution, - train_batch_size=args.train_batch_size, - gradient_accumulation_steps=args.gradient_accumulation_steps, - max_train_steps=args.max_train_steps, - resume_global_step=args.resume_global_step, - ) - - consumed_microbatches = args.resume_global_step * args.gradient_accumulation_steps - expected_resumed_trace = baseline_trace[consumed_microbatches:] - - print(f"baseline_trace={json.dumps(baseline_trace)}") - print(f"consumed_trace={json.dumps(baseline_trace[:consumed_microbatches])}") - print(f"resumed_trace={json.dumps(resumed_trace)}") - - if resumed_trace != expected_resumed_trace: - raise AssertionError( - "Resumed batch order does not match the uninterrupted run tail. " - f"Expected {expected_resumed_trace}, got {resumed_trace}." - ) - - print("resume batch order verified") - - -if __name__ == "__main__": - main() From 92455c1c795a22dcd49ffb34b214f94c6e486d81 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 12:42:16 +0530 Subject: [PATCH 11/12] Add RAE DiT training validation sampling --- examples/research_projects/rae_dit/README.md | 17 ++- .../research_projects/rae_dit/test_rae_dit.py | 10 ++ .../rae_dit/train_rae_dit.py | 144 +++++++++++++++++- 3 files changed, 168 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/rae_dit/README.md b/examples/research_projects/rae_dit/README.md index e47c09919fa8..369adc2a0ef0 100644 --- a/examples/research_projects/rae_dit/README.md +++ b/examples/research_projects/rae_dit/README.md @@ -24,6 +24,7 @@ This is a minimal full-finetuning scaffold, not a paper-complete training stack. - trains only the Stage-2 `RAEDiT2DModel` - uses `FlowMatchEulerDiscreteScheduler` with the same shifted-sigma schedule shape already used elsewhere in `diffusers` - consumes ImageFolder class ids as `class_labels` +- can generate validation samples through `RAEDiTPipeline` during training - saves the trained transformer under `output_dir/transformer` - saves the scheduler config under `output_dir/scheduler` - writes `id2label.json` from the ImageFolder class mapping @@ -31,7 +32,6 @@ This is a minimal full-finetuning scaffold, not a paper-complete training stack. It intentionally does not yet include: - a latent-caching path -- validation image generation inside the script - autoguidance or the broader upstream transport stack - exact upstream distributed training/runtime features @@ -69,6 +69,18 @@ accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ --allow_tf32 ``` +To emit validation samples during training, add: + +```bash + --validation_steps 1000 \ + --validation_class_label 207 \ + --num_validation_images 4 \ + --validation_num_inference_steps 25 \ + --validation_guidance_scale 1.0 +``` + +Validation images are written to `output_dir/validation/step-/`. + If you already have a converted or partially trained Stage-2 checkpoint, resume from it with: ```bash @@ -86,4 +98,5 @@ accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ - The script derives a default flow shift from the latent dimensionality as `sqrt(latent_dim / time_shift_base)`, matching the upstream Stage-2 heuristic at a high level. - The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiT2DModel` operates on 2D latent feature maps. -- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents and validation sampling through the pipeline. +- Validation sampling uses a fresh scheduler cloned from the training config so sampling does not mutate the in-flight training scheduler state. +- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents and other training conveniences. diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py index aa22299026cb..3b2e03af4a72 100644 --- a/examples/research_projects/rae_dit/test_rae_dit.py +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -188,6 +188,11 @@ def test_train_rae_dit_smoke(self): --decoder_num_attention_heads 4 --mlp_ratio 2.0 --num_train_timesteps 10 + --validation_steps 1 + --validation_class_label 0 + --num_validation_images 1 + --validation_num_inference_steps 2 + --seed 0 """.split() run_command(self._launch_args + test_args) @@ -195,6 +200,11 @@ def test_train_rae_dit_smoke(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "id2label.json"))) + validation_image_path = os.path.join(tmpdir, "validation", "step-1", "image-0.png") + self.assertTrue(os.path.isfile(validation_image_path)) + validation_image = Image.open(validation_image_path) + self.assertEqual(validation_image.size, (16, 16)) + self.assertEqual(validation_image.mode, "RGB") def test_resume_batch_order_matches_uninterrupted_tail(self): seed = 123 diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index c42b731c76c3..018235752867 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -22,6 +22,7 @@ import shutil from pathlib import Path +import numpy as np import torch from accelerate import Accelerator from accelerate.logging import get_logger @@ -31,7 +32,7 @@ from torchvision.datasets import ImageFolder from tqdm.auto import tqdm -from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiT2DModel +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiT2DModel, RAEDiTPipeline from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 from diffusers.utils import check_min_version @@ -220,6 +221,36 @@ def parse_args(): default=1.29, help="Mode weighting scale used when weighting_scheme=mode.", ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help="Run validation sampling every N optimizer steps. Disabled when omitted.", + ) + parser.add_argument( + "--validation_class_label", + type=int, + default=None, + help="Class id to sample during validation. Disabled when omitted.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=1, + help="Number of validation images to generate each time validation runs.", + ) + parser.add_argument( + "--validation_num_inference_steps", + type=int, + default=25, + help="Number of denoising steps to use for validation sampling.", + ) + parser.add_argument( + "--validation_guidance_scale", + type=float, + default=1.0, + help="Classifier-free guidance scale used during validation sampling.", + ) return parser.parse_args() @@ -322,8 +353,102 @@ def maybe_prune_checkpoints(output_dir: str, checkpoints_total_limit: int | None shutil.rmtree(os.path.join(output_dir, checkpoint)) +def validate_args(args): + validation_enabled = args.validation_class_label is not None + + if validation_enabled and args.validation_steps is None: + raise ValueError("`--validation_steps` must be provided when `--validation_class_label` is set.") + if args.validation_steps is not None and args.validation_steps < 1: + raise ValueError(f"`--validation_steps` must be >= 1, but got {args.validation_steps}.") + if args.validation_class_label is not None and args.validation_class_label < 0: + raise ValueError( + f"`--validation_class_label` must be >= 0, but got {args.validation_class_label}." + ) + if args.num_validation_images < 1: + raise ValueError(f"`--num_validation_images` must be >= 1, but got {args.num_validation_images}.") + if args.validation_num_inference_steps < 1: + raise ValueError( + f"`--validation_num_inference_steps` must be >= 1, but got {args.validation_num_inference_steps}." + ) + if args.validation_guidance_scale < 1.0: + raise ValueError( + f"`--validation_guidance_scale` must be >= 1.0, but got {args.validation_guidance_scale}." + ) + + +def log_validation( + transformer, + autoencoder: AutoencoderRAE, + scheduler: FlowMatchEulerDiscreteScheduler, + args, + accelerator: Accelerator, + step: int, + class_names: list[str], +): + if not accelerator.is_main_process: + return + + transformer_model = unwrap_model(accelerator, transformer) + was_training = transformer_model.training + transformer_model.eval() + + validation_scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler.config) + pipeline = RAEDiTPipeline( + transformer=transformer_model, + vae=autoencoder, + scheduler=validation_scheduler, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + generator = None + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed + step) + + label_name = class_names[args.validation_class_label] + logger.info( + "Running validation... Generating %s image(s) for class %s (%s).", + args.num_validation_images, + args.validation_class_label, + label_name, + ) + images = pipeline( + class_labels=[args.validation_class_label], + guidance_scale=args.validation_guidance_scale, + num_images_per_prompt=args.num_validation_images, + num_inference_steps=args.validation_num_inference_steps, + generator=generator, + output_type="pil", + ).images + + validation_dir = os.path.join(args.output_dir, "validation", f"step-{step}") + os.makedirs(validation_dir, exist_ok=True) + for image_idx, image in enumerate(images): + image.save(os.path.join(validation_dir, f"image-{image_idx}.png")) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + formatted_images = np.stack([np.asarray(image) for image in images]) + tracker.writer.add_images(f"validation/{label_name}", formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{image_idx}: {label_name}") + for image_idx, image in enumerate(images) + ] + } + ) + + if was_training: + transformer_model.train() + + def main(): args = parse_args() + validate_args(args) logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -366,6 +491,11 @@ def main(): raise ValueError( f"`--num_classes` ({num_classes}) must be >= the number of dataset classes ({inferred_num_classes})." ) + if args.validation_class_label is not None and args.validation_class_label >= inferred_num_classes: + raise ValueError( + f"`--validation_class_label` ({args.validation_class_label}) must be < the number of dataset classes " + f"({inferred_num_classes})." + ) train_dataloader = DataLoader( dataset, @@ -639,6 +769,18 @@ def load_model_hook(models, input_dir): noise_scheduler.save_pretrained(os.path.join(save_path, "scheduler")) logger.info(f"Saved state to {save_path}") + if args.validation_steps is not None and global_step % args.validation_steps == 0: + log_validation( + transformer=transformer, + autoencoder=autoencoder, + scheduler=noise_scheduler, + args=args, + accelerator=accelerator, + step=global_step, + class_names=dataset.classes, + ) + accelerator.wait_for_everyone() + if global_step >= args.max_train_steps: break From 794d350cf166c0508ec1a0a046d3b3539fa12b23 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Mon, 9 Mar 2026 18:29:11 +0530 Subject: [PATCH 12/12] Align RAE DiT with diffusers patterns --- docs/source/en/api/pipelines/rae_dit.md | 4 +- examples/research_projects/rae_dit/README.md | 11 - .../research_projects/rae_dit/test_rae_dit.py | 10 + .../rae_dit/train_rae_dit.py | 27 ++- scripts/convert_rae_stage2_to_diffusers.py | 79 +++++- src/diffusers/__init__.py | 2 + src/diffusers/models/modeling_utils.py | 5 +- .../transformers/transformer_rae_dit.py | 229 +++++++++--------- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/rae_dit/__init__.py | 6 +- .../pipelines/rae_dit/pipeline_output.py | 36 +++ .../pipelines/rae_dit/pipeline_rae_dit.py | 28 ++- src/diffusers/utils/dummy_pt_objects.py | 7 + tests/models/test_modeling_utils.py | 29 --- .../test_models_rae_dit_transformer2d.py | 143 ++++++----- tests/others/test_rae_dit_conversion.py | 85 ++++++- .../rae_dit/test_pipeline_rae_dit.py | 6 +- 17 files changed, 465 insertions(+), 246 deletions(-) create mode 100644 src/diffusers/pipelines/rae_dit/pipeline_output.py delete mode 100644 tests/models/test_modeling_utils.py diff --git a/docs/source/en/api/pipelines/rae_dit.md b/docs/source/en/api/pipelines/rae_dit.md index 613da874928b..b2666cb39899 100644 --- a/docs/source/en/api/pipelines/rae_dit.md +++ b/docs/source/en/api/pipelines/rae_dit.md @@ -54,6 +54,6 @@ image = pipe(class_labels=[class_id], num_inference_steps=25).images[0] - all - __call__ -## ImagePipelineOutput +## RAEDiTPipelineOutput -[[autodoc]] pipelines.ImagePipelineOutput +[[autodoc]] RAEDiTPipelineOutput diff --git a/examples/research_projects/rae_dit/README.md b/examples/research_projects/rae_dit/README.md index 369adc2a0ef0..e8f0ac3e3b04 100644 --- a/examples/research_projects/rae_dit/README.md +++ b/examples/research_projects/rae_dit/README.md @@ -4,17 +4,6 @@ This folder contains the minimal Stage-2 follow-up for the RAE integration: trai It is intentionally placed under `examples/research_projects/rae_dit/` rather than the top-level `examples/` trainers because this is still an experimental follow-up to the new RAE support. -## What this mirrors - -The scaffold is deliberately composed from existing `diffusers` patterns instead of introducing a new training style: - -- `examples/research_projects/autoencoder_rae/train_autoencoder_rae.py` - for ImageFolder loading, RAE-specific preprocessing, and the experimental research-project placement. -- `examples/dreambooth/train_dreambooth_flux.py` - for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks. -- `examples/flux-control/train_control_flux.py` - for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers. - ## Current scope This is a minimal full-finetuning scaffold, not a paper-complete training stack. It currently does the following: diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py index 3b2e03af4a72..431c001e5f18 100644 --- a/examples/research_projects/rae_dit/test_rae_dit.py +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -37,6 +37,7 @@ build_transforms, collate_fn, compute_resume_offsets, + load_autoencoder_rae, maybe_load_resumed_scheduler, should_skip_resumed_batch, ) @@ -161,6 +162,15 @@ def _create_dataset(self, tmpdir, resolution=16): image.save(os.path.join(class_dir, f"sample_{image_idx}.png")) return dataset_dir + def test_load_autoencoder_rae_loads_local_checkpoint(self): + with tempfile.TemporaryDirectory() as tmpdir: + rae_dir = self._create_tiny_rae(tmpdir) + model = load_autoencoder_rae(rae_dir) + + self.assertIsInstance(model, AutoencoderRAE) + self.assertEqual(model.config.image_size, 16) + self.assertEqual(model.config.patch_size, 4) + def test_train_rae_dit_smoke(self): with tempfile.TemporaryDirectory() as tmpdir: rae_dir = self._create_tiny_rae(tmpdir) diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py index 018235752867..a6c980c66165 100644 --- a/examples/research_projects/rae_dit/train_rae_dit.py +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -33,9 +33,11 @@ from tqdm.auto import tqdm from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiT2DModel, RAEDiTPipeline +from diffusers.models.model_loading_utils import load_state_dict from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 -from diffusers.utils import check_min_version +from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, check_min_version +from diffusers.utils.hub_utils import _get_model_file from diffusers.utils.torch_utils import is_compiled_module @@ -44,6 +46,27 @@ logger = get_logger(__name__) +def load_autoencoder_rae(model_name_or_path: str) -> AutoencoderRAE: + config = AutoencoderRAE.load_config(model_name_or_path) + model = AutoencoderRAE.from_config(config) + + try: + model_file = _get_model_file(model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME) + except EnvironmentError: + model_file = _get_model_file(model_name_or_path, weights_name=WEIGHTS_NAME) + + state_dict = load_state_dict(model_file) + load_result = model.load_state_dict(state_dict, strict=False, assign=True) + + unexpected_keys = set(load_result.unexpected_keys) - {"decoder.decoder_pos_embed"} + if len(load_result.missing_keys) > 0 or len(unexpected_keys) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for AutoencoderRAE: " + f"missing_keys={load_result.missing_keys}, unexpected_keys={sorted(unexpected_keys)}" + ) + return model + + def parse_args(): parser = argparse.ArgumentParser(description="Minimal stage-2 trainer for RAEDiT2DModel.") parser.add_argument( @@ -476,7 +499,7 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) accelerator.wait_for_everyone() - autoencoder = AutoencoderRAE.from_pretrained(args.pretrained_rae_model_name_or_path) + autoencoder = load_autoencoder_rae(args.pretrained_rae_model_name_or_path) autoencoder.requires_grad_(False) autoencoder.eval() diff --git a/scripts/convert_rae_stage2_to_diffusers.py b/scripts/convert_rae_stage2_to_diffusers.py index c6927e5eedf3..012a39dc4092 100644 --- a/scripts/convert_rae_stage2_to_diffusers.py +++ b/scripts/convert_rae_stage2_to_diffusers.py @@ -9,7 +9,10 @@ from huggingface_hub import HfApi, hf_hub_download from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline +from diffusers.models.model_loading_utils import load_state_dict from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel +from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME +from diffusers.utils.hub_utils import _get_model_file DEFAULT_NUM_TRAIN_TIMESTEPS = 1000 @@ -207,6 +210,35 @@ def build_scheduler_config(config: dict[str, Any]) -> tuple[FlowMatchEulerDiscre return scheduler, metadata +def _swap_projection_halves(tensor: torch.Tensor) -> torch.Tensor: + return torch.cat(tensor.chunk(2, dim=0)[::-1], dim=0) + + +def translate_transformer_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: + translated = {} + + for key, value in state_dict.items(): + if key == "pos_embed": + continue + + if ".mlp.w12." in key: + new_key = key.replace(".mlp.w12.", ".mlp.net.0.proj.") + if isinstance(value, torch.Tensor): + value = _swap_projection_halves(value) + elif ".mlp.w3." in key: + new_key = key.replace(".mlp.w3.", ".mlp.net.2.") + elif ".mlp.fc1." in key: + new_key = key.replace(".mlp.fc1.", ".mlp.net.0.proj.") + elif ".mlp.fc2." in key: + new_key = key.replace(".mlp.fc2.", ".mlp.net.2.") + else: + new_key = key + + translated[new_key] = value + + return translated + + def convert_transformer_state_dict( transformer_config: dict[str, Any], checkpoint_path: Path, @@ -219,11 +251,11 @@ def convert_transformer_state_dict( ) -> dict[str, Any]: raw_checkpoint = load_checkpoint(checkpoint_path) state_dict = unwrap_state_dict(raw_checkpoint, checkpoint_key=checkpoint_key, prefer_ema=prefer_ema) + state_dict = translate_transformer_state_dict(state_dict) - with torch.device("meta"): - model = RAEDiT2DModel(**transformer_config) + model = RAEDiT2DModel(**transformer_config) - load_result = model.load_state_dict(state_dict, strict=False, assign=True) + load_result = model.load_state_dict(state_dict, strict=False) missing_keys = set(load_result.missing_keys) unexpected_keys = set(load_result.unexpected_keys) @@ -248,7 +280,15 @@ def convert_transformer_state_dict( ) output_dir.mkdir(parents=True, exist_ok=True) - model.save_pretrained(output_dir, safe_serialization=safe_serialization) + model.save_config(output_dir) + weights_path = output_dir / (SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME) + state_dict_to_save = model.state_dict() + if safe_serialization: + import safetensors.torch + + safetensors.torch.save_file(state_dict_to_save, weights_path, metadata={"format": "pt"}) + else: + torch.save(state_dict_to_save, weights_path) if verify_load: reloaded = RAEDiT2DModel.from_pretrained(output_dir, low_cpu_mem_usage=False) @@ -269,6 +309,35 @@ def write_metadata(output_path: Path, metadata: dict[str, Any]) -> None: json.dump(metadata, handle, indent=2) +def load_autoencoder_rae(model_name_or_path: str, cache_dir: str | None = None) -> AutoencoderRAE: + config = AutoencoderRAE.load_config(model_name_or_path, cache_dir=cache_dir) + model = AutoencoderRAE.from_config(config) + + try: + model_file = _get_model_file( + model_name_or_path, + weights_name=SAFETENSORS_WEIGHTS_NAME, + cache_dir=cache_dir, + ) + except EnvironmentError: + model_file = _get_model_file( + model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + ) + + state_dict = load_state_dict(model_file) + load_result = model.load_state_dict(state_dict, strict=False, assign=True) + + unexpected_keys = set(load_result.unexpected_keys) - {"decoder.decoder_pos_embed"} + if len(load_result.missing_keys) > 0 or len(unexpected_keys) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for AutoencoderRAE: " + f"missing_keys={load_result.missing_keys}, unexpected_keys={sorted(unexpected_keys)}" + ) + return model + + def resolve_input_path(accessor: RepoAccessor, path: str) -> Path: candidates = [path] if path.startswith("models/"): @@ -419,7 +488,7 @@ def convert(args: argparse.Namespace) -> None: scheduler.save_pretrained(scheduler_output_dir) if args.vae_model_name_or_path is not None: - vae = AutoencoderRAE.from_pretrained(args.vae_model_name_or_path) + vae = load_autoencoder_rae(args.vae_model_name_or_path, cache_dir=args.cache_dir) transformer = RAEDiT2DModel.from_pretrained(transformer_output_dir, low_cpu_mem_usage=False) guidance_transformer = None if "guidance_transformer" in metadata: diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 63cc9e34d6f3..4120ad22764c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -335,6 +335,7 @@ "LDMSuperResolutionPipeline", "PNDMPipeline", "RAEDiTPipeline", + "RAEDiTPipelineOutput", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", @@ -1110,6 +1111,7 @@ LDMSuperResolutionPipeline, PNDMPipeline, RAEDiTPipeline, + RAEDiTPipelineOutput, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a8b5e7e1783c..0901840679e3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -215,10 +215,7 @@ def no_init_weights(): """ def _skip_init(*args, **kwargs): - # Preserve the `torch.nn.init.*` return contract so third-party model - # constructors that chain on the returned tensor still work under - # `no_init_weights()`. - return args[0] if len(args) > 0 else None + pass for name, init_func in TORCH_INIT_FUNCTIONS.items(): setattr(torch.nn.init, name, _skip_init) diff --git a/src/diffusers/models/transformers/transformer_rae_dit.py b/src/diffusers/models/transformers/transformer_rae_dit.py index ac2341c7def8..324179cfca92 100644 --- a/src/diffusers/models/transformers/transformer_rae_dit.py +++ b/src/diffusers/models/transformers/transformer_rae_dit.py @@ -7,13 +7,14 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ..embeddings import PatchEmbed, get_2d_sincos_pos_embed +from ..attention import FeedForward +from ..embeddings import PatchEmbed, apply_rotary_emb, get_2d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm -def _repeat_to_length(hidden_states: torch.Tensor, target_length: int) -> torch.Tensor: +def _expand_conditioning_tokens(hidden_states: torch.Tensor, target_length: int) -> torch.Tensor: if hidden_states.shape[1] == target_length: return hidden_states @@ -44,20 +45,6 @@ def _repeat_to_length(hidden_states: torch.Tensor, target_length: int) -> torch. ) -def _ddt_modulate(hidden_states: torch.Tensor, shift: torch.Tensor | None, scale: torch.Tensor) -> torch.Tensor: - if shift is None: - shift = torch.zeros_like(scale) - - shift = _repeat_to_length(shift, hidden_states.shape[1]) - scale = _repeat_to_length(scale, hidden_states.shape[1]) - return hidden_states * (1 + scale) + shift - - -def _ddt_gate(hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - gate = _repeat_to_length(gate, hidden_states.shape[1]) - return hidden_states * gate - - def _to_pair(value: int | tuple[int, int] | list[int], name: str) -> tuple[int, int]: if isinstance(value, int): return value, value @@ -68,39 +55,17 @@ def _to_pair(value: int | tuple[int, int] | list[int], name: str) -> tuple[int, return int(value[0]), int(value[1]) -def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = hidden_states.view(*hidden_states.shape[:-1], -1, 2) - first, second = hidden_states.unbind(dim=-1) - hidden_states = torch.stack((-second, first), dim=-1) - return hidden_states.flatten(-2) - - -class _ApproximateGELUMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int): - super().__init__() - self.fc1 = nn.Linear(hidden_size, intermediate_size) - self.act = nn.GELU(approximate="tanh") - self.fc2 = nn.Linear(intermediate_size, hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class SwiGLUFFN(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int): - super().__init__() - self.w12 = nn.Linear(hidden_size, 2 * intermediate_size) - self.w3 = nn.Linear(intermediate_size, hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.w12(hidden_states) - hidden_states_1, hidden_states_2 = hidden_states.chunk(2, dim=-1) - hidden_states = F.silu(hidden_states_1) * hidden_states_2 - hidden_states = self.w3(hidden_states) - return hidden_states +def _swap_swiglu_projection_halves(feedforward: FeedForward) -> None: + projection = feedforward.net[0].proj + projection.weight.data = torch.cat( + projection.weight.data.chunk(2, dim=0)[::-1], + dim=0, + ) + if projection.bias is not None: + projection.bias.data = torch.cat( + projection.bias.data.chunk(2, dim=0)[::-1], + dim=0, + ) class GaussianFourierEmbedding(nn.Module): @@ -150,7 +115,7 @@ def forward( return self.embedding_table(class_labels) -class VisionRotaryEmbeddingFast(nn.Module): +class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, pt_seq_len: int, ft_seq_len: int | None = None, theta: float = 10000.0): super().__init__() @@ -183,10 +148,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs_cos = freqs_cos.repeat_interleave(repeat, dim=0) freqs_sin = freqs_sin.repeat_interleave(repeat, dim=0) - return hidden_states * freqs_cos + _rotate_half(hidden_states) * freqs_sin + return apply_rotary_emb(hidden_states, (freqs_cos, freqs_sin), sequence_dim=2) -class NormAttention(nn.Module): +class RAEDiTAttention(nn.Module): def __init__( self, dim: int, @@ -210,11 +175,13 @@ def __init__( self.k_norm = norm_cls(self.head_dim, **norm_kwargs) if qk_norm else nn.Identity() self.proj = nn.Linear(dim, dim) - def forward(self, hidden_states: torch.Tensor, rope: VisionRotaryEmbeddingFast | None = None) -> torch.Tensor: - batch_size, sequence_length, channels = hidden_states.shape - qkv = self.qkv(hidden_states) - qkv = qkv.view(batch_size, sequence_length, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - query, key, value = qkv.unbind(0) + def forward(self, hidden_states: torch.Tensor, rope: VisionRotaryEmbedding | None = None) -> torch.Tensor: + batch_size, _, channels = hidden_states.shape + query, key, value = self.qkv(hidden_states).chunk(3, dim=-1) + + query = query.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + key = key.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + value = value.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) query = self.q_norm(query) key = self.k_norm(key) @@ -227,7 +194,7 @@ def forward(self, hidden_states: torch.Tensor, rope: VisionRotaryEmbeddingFast | key = key.to(dtype=value.dtype) hidden_states = F.scaled_dot_product_attention(query, key, value) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, channels) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, channels) hidden_states = self.proj(hidden_states) return hidden_states @@ -244,6 +211,7 @@ def __init__( wo_shift: bool = False, ): super().__init__() + self.use_swiglu = use_swiglu if use_rmsnorm: self.norm1 = RMSNorm(hidden_size, eps=1e-6) @@ -252,7 +220,7 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = NormAttention( + self.attn = RAEDiTAttention( hidden_size, num_heads=num_heads, qkv_bias=True, @@ -261,10 +229,13 @@ def __init__( ) mlp_hidden_dim = int(hidden_size * mlp_ratio) - if use_swiglu: - self.mlp = SwiGLUFFN(hidden_size, int(2 * mlp_hidden_dim / 3)) - else: - self.mlp = _ApproximateGELUMLP(hidden_size, mlp_hidden_dim) + mlp_inner_dim = int(2 * mlp_hidden_dim / 3) if use_swiglu else mlp_hidden_dim + self.mlp = FeedForward( + hidden_size, + inner_dim=mlp_inner_dim, + activation_fn="swiglu" if use_swiglu else "gelu-approximate", + bias=True, + ) self.wo_shift = wo_shift if wo_shift: @@ -276,7 +247,7 @@ def forward( self, hidden_states: torch.Tensor, conditioning: torch.Tensor, - feat_rope: VisionRotaryEmbeddingFast | None = None, + feat_rope: VisionRotaryEmbedding | None = None, ) -> torch.Tensor: if conditioning.ndim < hidden_states.ndim: conditioning = conditioning.unsqueeze(1) @@ -290,13 +261,25 @@ def forward( 6, dim=-1 ) - hidden_states = hidden_states + _ddt_gate( - self.attn(_ddt_modulate(self.norm1(hidden_states), shift_msa, scale_msa), rope=feat_rope), gate_msa - ) - hidden_states = hidden_states + _ddt_gate( - self.mlp(_ddt_modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)), - gate_mlp, - ) + if shift_msa is None: + shift_msa = torch.zeros_like(scale_msa) + if shift_mlp is None: + shift_mlp = torch.zeros_like(scale_mlp) + + if shift_msa.shape[1] != hidden_states.shape[1]: + shift_msa = _expand_conditioning_tokens(shift_msa, hidden_states.shape[1]) + scale_msa = _expand_conditioning_tokens(scale_msa, hidden_states.shape[1]) + gate_msa = _expand_conditioning_tokens(gate_msa, hidden_states.shape[1]) + if shift_mlp.shape[1] != hidden_states.shape[1]: + shift_mlp = _expand_conditioning_tokens(shift_mlp, hidden_states.shape[1]) + scale_mlp = _expand_conditioning_tokens(scale_mlp, hidden_states.shape[1]) + gate_mlp = _expand_conditioning_tokens(gate_mlp, hidden_states.shape[1]) + + norm_hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states + self.attn(norm_hidden_states * (1 + scale_msa) + shift_msa, rope=feat_rope) * gate_msa + + norm_hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(norm_hidden_states * (1 + scale_mlp) + shift_mlp) * gate_mlp return hidden_states @@ -317,7 +300,11 @@ def forward(self, hidden_states: torch.Tensor, conditioning: torch.Tensor) -> to conditioning = conditioning.unsqueeze(1) shift, scale = self.adaLN_modulation(conditioning).chunk(2, dim=-1) - hidden_states = _ddt_modulate(self.norm_final(hidden_states), shift, scale) + if shift.shape[1] != hidden_states.shape[1]: + shift = _expand_conditioning_tokens(shift, hidden_states.shape[1]) + scale = _expand_conditioning_tokens(scale, hidden_states.shape[1]) + + hidden_states = self.norm_final(hidden_states) * (1 + scale) + shift hidden_states = self.linear(hidden_states) return hidden_states @@ -333,6 +320,7 @@ class RAEDiT2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "final_layer"] @register_to_config def __init__( @@ -410,10 +398,12 @@ def __init__( num_patches = self.s_embedder.height * self.s_embedder.width if use_pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, encoder_hidden_size), requires_grad=False) + grid_size = int(num_patches**0.5) + pos_embed = get_2d_sincos_pos_embed(encoder_hidden_size, grid_size, output_type="pt") + self.register_buffer("pos_embed", pos_embed.unsqueeze(0).float(), persistent=False) self.x_pos_embed = None else: - self.register_parameter("pos_embed", None) + self.register_buffer("pos_embed", None, persistent=False) self.x_pos_embed = None if use_rope: @@ -421,8 +411,8 @@ def __init__( decoder_rope_dim = decoder_hidden_size // decoder_num_attention_heads // 2 encoder_side = int(sqrt(num_patches)) decoder_side = int(sqrt(self.x_embedder.height * self.x_embedder.width)) - self.enc_feat_rope = VisionRotaryEmbeddingFast(encoder_rope_dim, pt_seq_len=encoder_side) - self.dec_feat_rope = VisionRotaryEmbeddingFast(decoder_rope_dim, pt_seq_len=decoder_side) + self.enc_feat_rope = VisionRotaryEmbedding(encoder_rope_dim, pt_seq_len=encoder_side) + self.dec_feat_rope = VisionRotaryEmbedding(decoder_rope_dim, pt_seq_len=decoder_side) else: self.enc_feat_rope = None self.dec_feat_rope = None @@ -463,13 +453,9 @@ def _basic_init(module): nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - if self.use_pos_embed: - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt" - ) - self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0)) - for block in self.blocks: + if block.use_swiglu: + _swap_swiglu_projection_halves(block.mlp) nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) @@ -495,21 +481,6 @@ def unpatchify(self, hidden_states: torch.Tensor) -> torch.Tensor: ) return hidden_states - def _run_block( - self, - block: RAEDiTBlock, - hidden_states: torch.Tensor, - conditioning: torch.Tensor, - feat_rope: VisionRotaryEmbeddingFast | None, - ) -> torch.Tensor: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def custom_forward(hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: - return block(hidden_states, conditioning, feat_rope=feat_rope) - - return self._gradient_checkpointing_func(custom_forward, hidden_states, conditioning) - return block(hidden_states, conditioning, feat_rope=feat_rope) - def forward( self, hidden_states: torch.Tensor, @@ -533,31 +504,67 @@ def forward( if conditioning_hidden_states is None: conditioning_hidden_states = self.s_embedder(hidden_states) if self.use_pos_embed: - conditioning_hidden_states = conditioning_hidden_states + self.pos_embed + conditioning_hidden_states = conditioning_hidden_states + self.pos_embed.to( + device=conditioning_hidden_states.device, dtype=conditioning_hidden_states.dtype + ) for block_idx in range(self.num_encoder_blocks): - conditioning_hidden_states = self._run_block( - self.blocks[block_idx], - conditioning_hidden_states, - conditioning, - self.enc_feat_rope, - ) + block = self.blocks[block_idx] + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def custom_forward( + hidden_states: torch.Tensor, + conditioning: torch.Tensor, + block: RAEDiTBlock = block, + feat_rope: VisionRotaryEmbedding | None = self.enc_feat_rope, + ) -> torch.Tensor: + return block(hidden_states, conditioning, feat_rope=feat_rope) + + conditioning_hidden_states = self._gradient_checkpointing_func( + custom_forward, conditioning_hidden_states, conditioning + ) + else: + conditioning_hidden_states = block( + conditioning_hidden_states, + conditioning, + feat_rope=self.enc_feat_rope, + ) conditioning_hidden_states = F.silu(timestep_emb.unsqueeze(1) + conditioning_hidden_states) + projector_dtype = conditioning_hidden_states.dtype + projector_param = next(self.s_projector.parameters(), None) + if projector_param is not None: + projector_dtype = projector_param.dtype + + conditioning_hidden_states = conditioning_hidden_states.to(device=hidden_states.device, dtype=projector_dtype) conditioning_hidden_states = self.s_projector(conditioning_hidden_states) hidden_states = self.x_embedder(hidden_states) if self.use_pos_embed and self.x_pos_embed is not None: - hidden_states = hidden_states + self.x_pos_embed + hidden_states = hidden_states + self.x_pos_embed.to(device=hidden_states.device, dtype=hidden_states.dtype) for block_idx in range(self.num_encoder_blocks, self.num_blocks): - hidden_states = self._run_block( - self.blocks[block_idx], - hidden_states, - conditioning_hidden_states, - self.dec_feat_rope, - ) + block = self.blocks[block_idx] + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def custom_forward( + hidden_states: torch.Tensor, + conditioning_hidden_states: torch.Tensor, + block: RAEDiTBlock = block, + feat_rope: VisionRotaryEmbedding | None = self.dec_feat_rope, + ) -> torch.Tensor: + return block(hidden_states, conditioning_hidden_states, feat_rope=feat_rope) + + hidden_states = self._gradient_checkpointing_func( + custom_forward, hidden_states, conditioning_hidden_states + ) + else: + hidden_states = block( + hidden_states, + conditioning_hidden_states, + feat_rope=self.dec_feat_rope, + ) hidden_states = self.final_layer(hidden_states, conditioning_hidden_states) hidden_states = self.unpatchify(hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7a280b9b6a04..85e11552bda6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -147,7 +147,7 @@ "FluxKontextInpaintPipeline", ] _import_structure["prx"] = ["PRXPipeline"] - _import_structure["rae_dit"] = ["RAEDiTPipeline"] + _import_structure["rae_dit"] = ["RAEDiTPipeline", "RAEDiTPipelineOutput"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -784,7 +784,7 @@ QwenImageLayeredPipeline, QwenImagePipeline, ) - from .rae_dit import RAEDiTPipeline + from .rae_dit import RAEDiTPipeline, RAEDiTPipelineOutput from .sana import ( SanaControlNetPipeline, SanaPipeline, diff --git a/src/diffusers/pipelines/rae_dit/__init__.py b/src/diffusers/pipelines/rae_dit/__init__.py index adb404b485b1..9716a8197a59 100644 --- a/src/diffusers/pipelines/rae_dit/__init__.py +++ b/src/diffusers/pipelines/rae_dit/__init__.py @@ -3,9 +3,13 @@ from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule -_import_structure = {"pipeline_rae_dit": ["RAEDiTPipeline"]} +_import_structure = { + "pipeline_output": ["RAEDiTPipelineOutput"], + "pipeline_rae_dit": ["RAEDiTPipeline"], +} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_output import RAEDiTPipelineOutput from .pipeline_rae_dit import RAEDiTPipeline else: import sys diff --git a/src/diffusers/pipelines/rae_dit/pipeline_output.py b/src/diffusers/pipelines/rae_dit/pipeline_output.py new file mode 100644 index 000000000000..63296155061e --- /dev/null +++ b/src/diffusers/pipelines/rae_dit/pipeline_output.py @@ -0,0 +1,36 @@ +# Copyright 2026 HuggingFace Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class RAEDiTPipelineOutput(BaseOutput): + """ + Output class for RAE DiT image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`) + Denoised images as PIL images, a NumPy array of shape `(batch_size, height, width, num_channels)`, or a + PyTorch tensor of shape `(batch_size, num_channels, height, width)`. Torch tensors may also represent + latent outputs when `output_type="latent"`. + """ + + images: list[PIL.Image.Image] | np.ndarray | torch.Tensor diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py index fc11f2dad424..15f733878c60 100644 --- a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -2,11 +2,13 @@ import torch +from ...image_processor import VaeImageProcessor from ...models import AutoencoderRAE from ...models.transformers.transformer_rae_dit import RAEDiT2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import RAEDiTPipelineOutput class RAEDiTPipeline(DiffusionPipeline): @@ -52,6 +54,7 @@ def __init__( self.labels[label.strip()] = int(key) self.labels = dict(sorted(self.labels.items())) + self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_resize=False, do_normalize=False) self._guidance_scale = 1.0 @property @@ -87,7 +90,7 @@ def _prepare_class_labels( return class_labels - def _prepare_latents( + def prepare_latents( self, batch_size: int, latent_channels: int, @@ -99,6 +102,12 @@ def _prepare_latents( ) -> torch.Tensor: shape = (batch_size, latent_channels, latent_size, latent_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested a batch size of " + f"{batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: return randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -139,7 +148,7 @@ def __call__( num_inference_steps: int = 50, output_type: str = "pil", return_dict: bool = True, - ) -> ImagePipelineOutput | tuple: + ) -> RAEDiTPipelineOutput | tuple: r""" The call function to the pipeline for generation. @@ -163,7 +172,7 @@ def __call__( output_type (`str`, *optional*, defaults to `"pil"`): Output format. Choose from `"pil"`, `"np"`, `"pt"`, or `"latent"`. return_dict (`bool`, *optional*, defaults to `True`): - Whether to return an [`ImagePipelineOutput`] instead of a tuple. + Whether to return an [`RAEDiTPipelineOutput`] instead of a tuple. """ if num_images_per_prompt < 1: @@ -194,7 +203,7 @@ def __call__( latent_size = self.transformer.config.sample_size latent_channels = self.transformer.config.in_channels - latents = self._prepare_latents( + latents = self.prepare_latents( batch_size=batch_size, latent_channels=latent_channels, latent_size=latent_size, @@ -244,16 +253,11 @@ def __call__( output = latents else: images = self.vae.decode(latents).sample.clamp(0, 1) - if output_type == "pt": - output = images - else: - output = images.cpu().permute(0, 2, 3, 1).float().numpy() - if output_type == "pil": - output = self.numpy_to_pil(output) + output = self.image_processor.postprocess(images, output_type=output_type) self.maybe_free_model_hooks() if not return_dict: return (output,) - return ImagePipelineOutput(images=output) + return RAEDiTPipelineOutput(images=output) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ff89538605fc..019a42f112ea 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2443,6 +2443,13 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RAEDiTPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class RePaintPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/test_modeling_utils.py b/tests/models/test_modeling_utils.py deleted file mode 100644 index 57e69fddcb39..000000000000 --- a/tests/models/test_modeling_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# coding=utf-8 -# Copyright 2026 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from diffusers.models.modeling_utils import no_init_weights - - -def test_no_init_weights_preserves_torch_init_return_contract(): - tensor = torch.empty(2, 3) - - with no_init_weights(): - truncated = torch.nn.init.trunc_normal_(tensor) - zeroed = torch.nn.init.zeros_(tensor) - - assert truncated is tensor - assert zeroed is tensor diff --git a/tests/models/transformers/test_models_rae_dit_transformer2d.py b/tests/models/transformers/test_models_rae_dit_transformer2d.py index aebc09fbd7b5..b5d8970820ad 100644 --- a/tests/models/transformers/test_models_rae_dit_transformer2d.py +++ b/tests/models/transformers/test_models_rae_dit_transformer2d.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - +import pytest import torch from diffusers import RAEDiT2DModel -from diffusers.models.transformers.transformer_rae_dit import _repeat_to_length +from diffusers.models.transformers.transformer_rae_dit import _expand_conditioning_tokens +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin, TrainingTesterMixin enable_full_determinism() @@ -40,34 +40,18 @@ def _initialize_non_zero_stage2_head(model: RAEDiT2DModel): model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) -class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase): +class RAEDiT2DTesterConfig(BaseModelTesterConfig): model_class = RAEDiT2DModel main_input_name = "hidden_states" + input_shape = (8, 4, 4) + output_shape = (8, 4, 4) @property - def dummy_input(self): - batch_size = 2 - in_channels = 8 - sample_size = 4 - scheduler_num_train_steps = 1000 - num_class_labels = 10 - - hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device) - timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device) - class_labels = torch.randint(0, num_class_labels, size=(batch_size,)).to(torch_device) - - return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_labels} - - @property - def input_shape(self): - return (8, 4, 4) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - @property - def output_shape(self): - return (8, 4, 4) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self): + return { "sample_size": 4, "patch_size": 1, "in_channels": 8, @@ -84,37 +68,73 @@ def prepare_init_args_and_inputs_for_common(self): "wo_shift": False, "use_pos_embed": True, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - def test_output(self): - super().test_output( - expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape + def get_dummy_inputs(self): + batch_size = 2 + in_channels = 8 + sample_size = 4 + scheduler_num_train_steps = 1000 + num_class_labels = 10 + + hidden_states = randn_tensor( + (batch_size, in_channels, sample_size, sample_size), generator=self.generator, device=torch_device + ) + timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,), generator=self.generator).to( + torch_device + ) + class_labels = torch.randint(0, num_class_labels, size=(batch_size,), generator=self.generator).to( + torch_device ) + return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_labels} + + +class TestRAEDiT2DModel(RAEDiT2DTesterConfig, ModelTesterMixin): + def test_swiglu_feedforward_matches_previous_chunk_order(self): + model = self.model_class(**self.get_init_dict()).to(torch_device).eval() + block = model.blocks[0] + + hidden_states = randn_tensor((2, 4, model.encoder_hidden_size), generator=self.generator, device=torch_device) + projection = block.mlp.net[0].proj + output_projection = block.mlp.net[2] + + unswapped_weight = torch.cat(projection.weight.data.chunk(2, dim=0)[::-1], dim=0) + unswapped_bias = None + if projection.bias is not None: + unswapped_bias = torch.cat(projection.bias.data.chunk(2, dim=0)[::-1], dim=0) + + projected = torch.nn.functional.linear(hidden_states, unswapped_weight, unswapped_bias) + first_half, second_half = projected.chunk(2, dim=-1) + expected = torch.nn.functional.linear( + torch.nn.functional.silu(first_half) * second_half, + output_projection.weight, + output_projection.bias, + ) + + actual = block.mlp(hidden_states) + assert torch.allclose(actual, expected, atol=1e-6, rtol=1e-5) + def test_output_with_precomputed_conditioning_hidden_states(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device).eval() _initialize_non_zero_stage2_head(model) batch_size = inputs_dict[self.main_input_name].shape[0] num_patches = (init_dict["sample_size"] // init_dict["patch_size"]) ** 2 - conditioning_hidden_states = floats_tensor((batch_size, num_patches, init_dict["hidden_size"][0])).to( - torch_device + conditioning_hidden_states = randn_tensor( + (batch_size, num_patches, init_dict["hidden_size"][0]), generator=self.generator, device=torch_device ) with torch.no_grad(): output = model(**inputs_dict, conditioning_hidden_states=conditioning_hidden_states).sample - self.assertEqual(output.shape, inputs_dict[self.main_input_name].shape) + assert output.shape == inputs_dict[self.main_input_name].shape def test_precomputed_conditioning_matches_internal_encoder_path(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device).eval() _initialize_non_zero_stage2_head(model) hidden_states = inputs_dict["hidden_states"] @@ -147,12 +167,12 @@ def test_precomputed_conditioning_matches_internal_encoder_path(self): conditioning_hidden_states=conditioning_hidden_states, ).sample - self.assertTrue(torch.allclose(output_internal, output_precomputed, atol=1e-5, rtol=1e-4)) + assert torch.allclose(output_internal, output_precomputed, atol=1e-5, rtol=1e-4) - def test_repeat_to_length_preserves_2d_layout(self): + def test_expand_conditioning_tokens_preserves_2d_layout(self): hidden_states = torch.tensor([[[1.0], [2.0], [3.0], [4.0]]]) - repeated = _repeat_to_length(hidden_states, target_length=16) + repeated = _expand_conditioning_tokens(hidden_states, target_length=16) expected = torch.tensor( [ @@ -176,31 +196,26 @@ def test_repeat_to_length_preserves_2d_layout(self): ] ] ) - self.assertTrue(torch.equal(repeated, expected)) + assert torch.equal(repeated, expected) - def test_repeat_to_length_broadcasts_global_conditioning(self): + def test_expand_conditioning_tokens_broadcasts_global_conditioning(self): hidden_states = torch.tensor([[[1.0, 2.0]]]) - repeated = _repeat_to_length(hidden_states, target_length=4) + repeated = _expand_conditioning_tokens(hidden_states, target_length=4) expected = torch.tensor([[[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]]]) - self.assertTrue(torch.equal(repeated, expected)) + assert torch.equal(repeated, expected) - def test_repeat_to_length_rejects_incompatible_multi_token_layouts(self): + def test_expand_conditioning_tokens_rejects_incompatible_multi_token_layouts(self): hidden_states = torch.randn(1, 2, 4) - with self.assertRaises(ValueError): - _repeat_to_length(hidden_states, target_length=8) + with pytest.raises(ValueError): + _expand_conditioning_tokens(hidden_states, target_length=8) - def test_gradient_checkpointing_is_applied(self): - expected_set = {"RAEDiT2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - def test_effective_gradient_checkpointing(self): - super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) +class TestRAEDiT2DTraining(RAEDiT2DTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + super().test_gradient_checkpointing_is_applied(expected_set={"RAEDiT2DModel"}) - @unittest.skip( - "RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative." - ) - def test_layerwise_casting_inference(self): - pass + def test_gradient_checkpointing_equivalence(self): + super().test_gradient_checkpointing_equivalence(loss_tolerance=1e-4) diff --git a/tests/others/test_rae_dit_conversion.py b/tests/others/test_rae_dit_conversion.py index 29ca175d4455..046263cd521b 100644 --- a/tests/others/test_rae_dit_conversion.py +++ b/tests/others/test_rae_dit_conversion.py @@ -13,9 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile + import torch -from scripts.convert_rae_stage2_to_diffusers import unwrap_state_dict +from diffusers import AutoencoderRAE +from scripts.convert_rae_stage2_to_diffusers import ( + load_autoencoder_rae, + translate_transformer_state_dict, + unwrap_state_dict, +) def test_unwrap_state_dict_strips_supported_prefixes(): @@ -24,3 +31,79 @@ def test_unwrap_state_dict_strips_supported_prefixes(): assert unwrap_state_dict({"model.module.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} assert unwrap_state_dict({"model.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} assert unwrap_state_dict({"module.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} + + +def test_translate_transformer_state_dict_maps_feedforward_keys(): + weight = torch.arange(12, dtype=torch.float32).reshape(4, 3) + bias = torch.arange(4, dtype=torch.float32) + out_weight = torch.arange(6, dtype=torch.float32).reshape(2, 3) + out_bias = torch.arange(2, dtype=torch.float32) + + translated = translate_transformer_state_dict( + { + "blocks.0.mlp.w12.weight": weight, + "blocks.0.mlp.w12.bias": bias, + "blocks.0.mlp.w3.weight": out_weight, + "blocks.0.mlp.w3.bias": out_bias, + } + ) + + assert "blocks.0.mlp.net.0.proj.weight" in translated + assert "blocks.0.mlp.net.0.proj.bias" in translated + assert "blocks.0.mlp.net.2.weight" in translated + assert "blocks.0.mlp.net.2.bias" in translated + assert torch.equal( + translated["blocks.0.mlp.net.0.proj.weight"], + torch.cat(weight.chunk(2, dim=0)[::-1], dim=0), + ) + assert torch.equal( + translated["blocks.0.mlp.net.0.proj.bias"], + torch.cat(bias.chunk(2, dim=0)[::-1], dim=0), + ) + assert torch.equal(translated["blocks.0.mlp.net.2.weight"], out_weight) + assert torch.equal(translated["blocks.0.mlp.net.2.bias"], out_bias) + + +def test_translate_transformer_state_dict_maps_gelu_keys(): + fc1_weight = torch.arange(6, dtype=torch.float32).reshape(2, 3) + fc2_weight = torch.arange(6, dtype=torch.float32).reshape(3, 2) + + translated = translate_transformer_state_dict( + { + "blocks.0.mlp.fc1.weight": fc1_weight, + "blocks.0.mlp.fc2.weight": fc2_weight, + } + ) + + assert torch.equal(translated["blocks.0.mlp.net.0.proj.weight"], fc1_weight) + assert torch.equal(translated["blocks.0.mlp.net.2.weight"], fc2_weight) + + +def test_load_autoencoder_rae_loads_local_checkpoint_without_from_pretrained(): + model = AutoencoderRAE( + encoder_type="mae", + encoder_hidden_size=64, + encoder_patch_size=4, + encoder_num_hidden_layers=1, + encoder_input_size=16, + patch_size=4, + image_size=16, + num_channels=3, + decoder_hidden_size=64, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=4, + decoder_intermediate_size=128, + encoder_norm_mean=[0.5, 0.5, 0.5], + encoder_norm_std=[0.5, 0.5, 0.5], + noise_tau=0.0, + reshape_to_2d=True, + scaling_factor=1.0, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_pretrained(tmpdir, safe_serialization=False) + loaded = load_autoencoder_rae(tmpdir) + + assert isinstance(loaded, AutoencoderRAE) + assert loaded.config.image_size == 16 + assert loaded.config.patch_size == 4 diff --git a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py index 6e16b0051dfe..e4e8ccc1fb81 100644 --- a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py +++ b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py @@ -22,7 +22,7 @@ import torch.nn.functional as F import diffusers.models.autoencoders.autoencoder_rae as _rae_module -from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline, RAEDiTPipelineOutput from diffusers.models.autoencoders.autoencoder_rae import _ENCODER_FORWARD_FNS, _build_encoder from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel @@ -180,7 +180,9 @@ def test_inference(self): pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") pipe.set_progress_bar_config(disable=None) - image = pipe(**self.get_dummy_inputs("cpu")).images + output = pipe(**self.get_dummy_inputs("cpu")) + self.assertIsInstance(output, RAEDiTPipelineOutput) + image = output.images image_slice = image[0, -2:, -2:, -1] self.assertEqual(image.shape, (1, 4, 4, 3))