diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e0b7af4898b2..5a47e710220d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -388,6 +388,8 @@ title: PriorTransformer - local: api/models/qwenimage_transformer2d title: QwenImageTransformer2DModel + - local: api/models/rae_dit_transformer2d + title: RAEDiT2DModel - local: api/models/sana_transformer2d title: SanaTransformer2DModel - local: api/models/sana_video_transformer3d @@ -604,6 +606,8 @@ title: PRX - local: api/pipelines/qwenimage title: QwenImage + - local: api/pipelines/rae_dit + title: RAE DiT - local: api/pipelines/sana title: Sana - local: api/pipelines/sana_sprint diff --git a/docs/source/en/api/models/rae_dit_transformer2d.md b/docs/source/en/api/models/rae_dit_transformer2d.md new file mode 100644 index 000000000000..df171db8ef77 --- /dev/null +++ b/docs/source/en/api/models/rae_dit_transformer2d.md @@ -0,0 +1,32 @@ + + +# RAEDiT2DModel + +The `RAEDiT2DModel` is the Stage-2 latent diffusion transformer introduced in +[Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690). + +Unlike DiT models that operate on VAE latents, this transformer denoises the latent space learned by +[`AutoencoderRAE`](./autoencoder_rae). It is designed to be used with [`FlowMatchEulerDiscreteScheduler`] and +decoded back to RGB with [`AutoencoderRAE`]. + +## Loading a pretrained transformer + +```python +from diffusers import RAEDiT2DModel + +transformer = RAEDiT2DModel.from_pretrained("path/to/converted-stage2-transformer") +``` + +## RAEDiT2DModel + +[[autodoc]] RAEDiT2DModel diff --git a/docs/source/en/api/pipelines/rae_dit.md b/docs/source/en/api/pipelines/rae_dit.md new file mode 100644 index 000000000000..b2666cb39899 --- /dev/null +++ b/docs/source/en/api/pipelines/rae_dit.md @@ -0,0 +1,59 @@ + + +# RAE DiT + +[Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) introduces a +two-stage recipe: first train a representation autoencoder (RAE), then train a diffusion transformer on the resulting +latent space. + +[`RAEDiTPipeline`] implements the Stage-2 class-conditional generator in Diffusers. It combines: + +- [`RAEDiT2DModel`] for latent denoising +- [`FlowMatchEulerDiscreteScheduler`] for the denoising trajectory +- [`AutoencoderRAE`] for decoding latent samples to RGB images + +> [!TIP] +> [`RAEDiTPipeline`] expects a Stage-2 checkpoint converted to Diffusers format together with a compatible +> [`AutoencoderRAE`] checkpoint. + +## Loading a converted pipeline + +```python +import torch +from diffusers import RAEDiTPipeline + +pipe = RAEDiTPipeline.from_pretrained( + "path/to/converted-rae-dit-imagenet256", + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe(class_labels=[207], num_inference_steps=25).images[0] +image.save("golden_retriever.png") +``` + +If the converted pipeline includes an `id2label` mapping, you can also look up class ids by name: + +```python +class_id = pipe.get_label_ids("golden retriever")[0] +image = pipe(class_labels=[class_id], num_inference_steps=25).images[0] +``` + +## RAEDiTPipeline + +[[autodoc]] RAEDiTPipeline + - all + - __call__ + +## RAEDiTPipelineOutput + +[[autodoc]] RAEDiTPipelineOutput diff --git a/examples/research_projects/rae_dit/README.md b/examples/research_projects/rae_dit/README.md new file mode 100644 index 000000000000..e8f0ac3e3b04 --- /dev/null +++ b/examples/research_projects/rae_dit/README.md @@ -0,0 +1,91 @@ +# Training RAEDiT Stage 2 + +This folder contains the minimal Stage-2 follow-up for the RAE integration: training `RAEDiT2DModel` on top of a frozen `AutoencoderRAE`. + +It is intentionally placed under `examples/research_projects/rae_dit/` rather than the top-level `examples/` trainers because this is still an experimental follow-up to the new RAE support. + +## Current scope + +This is a minimal full-finetuning scaffold, not a paper-complete training stack. It currently does the following: + +- loads a frozen pretrained `AutoencoderRAE` +- encodes RGB images to normalized Stage-1 latents on the fly +- trains only the Stage-2 `RAEDiT2DModel` +- uses `FlowMatchEulerDiscreteScheduler` with the same shifted-sigma schedule shape already used elsewhere in `diffusers` +- consumes ImageFolder class ids as `class_labels` +- can generate validation samples through `RAEDiTPipeline` during training +- saves the trained transformer under `output_dir/transformer` +- saves the scheduler config under `output_dir/scheduler` +- writes `id2label.json` from the ImageFolder class mapping + +It intentionally does not yet include: + +- a latent-caching path +- autoguidance or the broader upstream transport stack +- exact upstream distributed training/runtime features + +## Dataset format + +The script expects an `ImageFolder`-compatible dataset: + +```text +train_data_dir/ + n01440764/ + img_0001.jpeg + n01443537/ + img_0002.jpeg +``` + +The folder names define the class labels used during Stage-2 training. + +## Quickstart + +```bash +accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ + --pretrained_rae_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/rae-dit \ + --resolution 256 \ + --train_batch_size 8 \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing \ + --learning_rate 1e-4 \ + --lr_scheduler cosine \ + --lr_warmup_steps 1000 \ + --max_train_steps 200000 \ + --mixed_precision bf16 \ + --report_to wandb \ + --allow_tf32 +``` + +To emit validation samples during training, add: + +```bash + --validation_steps 1000 \ + --validation_class_label 207 \ + --num_validation_images 4 \ + --validation_num_inference_steps 25 \ + --validation_guidance_scale 1.0 +``` + +Validation images are written to `output_dir/validation/step-/`. + +If you already have a converted or partially trained Stage-2 checkpoint, resume from it with: + +```bash +accelerate launch examples/research_projects/rae_dit/train_rae_dit.py \ + --pretrained_rae_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \ + --pretrained_transformer_model_name_or_path /path/to/previous/transformer \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/rae-dit-finetune \ + --resolution 256 \ + --train_batch_size 8 \ + --max_train_steps 50000 +``` + +## Notes + +- The script derives a default flow shift from the latent dimensionality as `sqrt(latent_dim / time_shift_base)`, matching the upstream Stage-2 heuristic at a high level. +- The trainer assumes the selected `AutoencoderRAE` uses `reshape_to_2d=True`, because `RAEDiT2DModel` operates on 2D latent feature maps. +- Validation sampling uses a fresh scheduler cloned from the training config so sampling does not mutate the in-flight training scheduler state. +- This example is meant to land first as a training scaffold that matches the new Stage-2 model and export layout. A later follow-up can add cached latents and other training conveniences. diff --git a/examples/research_projects/rae_dit/test_rae_dit.py b/examples/research_projects/rae_dit/test_rae_dit.py new file mode 100644 index 000000000000..431c001e5f18 --- /dev/null +++ b/examples/research_projects/rae_dit/test_rae_dit.py @@ -0,0 +1,269 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +import sys +import tempfile +from pathlib import Path +from types import SimpleNamespace + +from accelerate import Accelerator +from accelerate.utils import set_seed +from PIL import Image +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder + +from diffusers import AutoencoderRAE +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 +from train_rae_dit import ( # noqa: E402 + build_transforms, + collate_fn, + compute_resume_offsets, + load_autoencoder_rae, + maybe_load_resumed_scheduler, + should_skip_resumed_batch, +) + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +def _create_unique_class_dataset(dataset_dir: Path, resolution: int, num_samples: int): + for sample_idx in range(num_samples): + class_dir = dataset_dir / f"class_{sample_idx:02d}" + class_dir.mkdir(parents=True, exist_ok=True) + color = ((40 * sample_idx) % 256, (80 * sample_idx) % 256, (120 * sample_idx) % 256) + image = Image.new("RGB", (resolution, resolution), color=color) + image.save(class_dir / f"sample_{sample_idx}.png") + + +def _collect_class_label_trace( + dataset_dir: Path, + *, + seed: int, + resolution: int, + train_batch_size: int, + gradient_accumulation_steps: int, + max_train_steps: int, + resume_global_step: int = 0, +) -> list[int]: + accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + set_seed(seed) + + transform_args = SimpleNamespace(resolution=resolution, center_crop=True, random_flip=False) + dataset = ImageFolder(dataset_dir, transform=build_transforms(transform_args)) + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=train_batch_size, + num_workers=0, + pin_memory=True, + drop_last=True, + ) + train_dataloader = accelerator.prepare(train_dataloader) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + first_epoch = 0 + resume_step = 0 + should_resume = resume_global_step > 0 + + if should_resume: + first_epoch, resume_step = compute_resume_offsets( + global_step=resume_global_step, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + + expected_microbatches = (max_train_steps - resume_global_step) * gradient_accumulation_steps + trace = [] + + for epoch in range(first_epoch, num_train_epochs): + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + for step, batch in enumerate(train_dataloader): + if should_skip_resumed_batch( + should_resume=should_resume, + epoch=epoch, + first_epoch=first_epoch, + step=step, + resume_step=resume_step, + ): + continue + + trace.extend(batch["class_labels"].tolist()) + if len(trace) >= expected_microbatches: + return trace + + raise AssertionError(f"Expected to record {expected_microbatches} microbatches, but only collected {len(trace)}.") + + +class RAEDiT(ExamplesTestsAccelerate): + def _create_tiny_rae(self, tmpdir): + model = AutoencoderRAE( + encoder_type="mae", + encoder_hidden_size=64, + encoder_patch_size=4, + encoder_num_hidden_layers=1, + encoder_input_size=16, + patch_size=4, + image_size=16, + num_channels=3, + decoder_hidden_size=64, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=4, + decoder_intermediate_size=128, + encoder_norm_mean=[0.5, 0.5, 0.5], + encoder_norm_std=[0.5, 0.5, 0.5], + noise_tau=0.0, + reshape_to_2d=True, + scaling_factor=1.0, + ) + output_dir = os.path.join(tmpdir, "tiny-rae") + model.save_pretrained(output_dir, safe_serialization=False) + return output_dir + + def _create_dataset(self, tmpdir, resolution=16): + dataset_dir = os.path.join(tmpdir, "dataset") + for class_idx in range(2): + class_dir = os.path.join(dataset_dir, f"class_{class_idx:02d}") + os.makedirs(class_dir, exist_ok=True) + for image_idx in range(2): + color = ( + (50 * class_idx + 20 * image_idx) % 256, + (80 * class_idx + 30 * image_idx) % 256, + (110 * class_idx + 40 * image_idx) % 256, + ) + image = Image.new("RGB", (resolution, resolution), color=color) + image.save(os.path.join(class_dir, f"sample_{image_idx}.png")) + return dataset_dir + + def test_load_autoencoder_rae_loads_local_checkpoint(self): + with tempfile.TemporaryDirectory() as tmpdir: + rae_dir = self._create_tiny_rae(tmpdir) + model = load_autoencoder_rae(rae_dir) + + self.assertIsInstance(model, AutoencoderRAE) + self.assertEqual(model.config.image_size, 16) + self.assertEqual(model.config.patch_size, 4) + + def test_train_rae_dit_smoke(self): + with tempfile.TemporaryDirectory() as tmpdir: + rae_dir = self._create_tiny_rae(tmpdir) + dataset_dir = self._create_dataset(tmpdir) + + test_args = f""" + examples/research_projects/rae_dit/train_rae_dit.py + --pretrained_rae_model_name_or_path {rae_dir} + --train_data_dir {dataset_dir} + --output_dir {tmpdir} + --resolution 16 + --center_crop + --train_batch_size 1 + --dataloader_num_workers 0 + --max_train_steps 2 + --gradient_accumulation_steps 1 + --learning_rate 1e-3 + --lr_scheduler constant + --lr_warmup_steps 0 + --encoder_hidden_size 32 + --decoder_hidden_size 32 + --encoder_num_layers 1 + --decoder_num_layers 1 + --encoder_num_attention_heads 4 + --decoder_num_attention_heads 4 + --mlp_ratio 2.0 + --num_train_timesteps 10 + --validation_steps 1 + --validation_class_label 0 + --num_validation_images 1 + --validation_num_inference_steps 2 + --seed 0 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "id2label.json"))) + validation_image_path = os.path.join(tmpdir, "validation", "step-1", "image-0.png") + self.assertTrue(os.path.isfile(validation_image_path)) + validation_image = Image.open(validation_image_path) + self.assertEqual(validation_image.size, (16, 16)) + self.assertEqual(validation_image.mode, "RGB") + + def test_resume_batch_order_matches_uninterrupted_tail(self): + seed = 123 + resolution = 16 + num_samples = 6 + train_batch_size = 1 + gradient_accumulation_steps = 2 + max_train_steps = 3 + resume_global_step = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) / "trace-dataset" + _create_unique_class_dataset(dataset_dir, resolution=resolution, num_samples=num_samples) + + baseline_trace = _collect_class_label_trace( + dataset_dir, + seed=seed, + resolution=resolution, + train_batch_size=train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + max_train_steps=max_train_steps, + ) + resumed_trace = _collect_class_label_trace( + dataset_dir, + seed=seed, + resolution=resolution, + train_batch_size=train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + max_train_steps=max_train_steps, + resume_global_step=resume_global_step, + ) + + consumed_microbatches = resume_global_step * gradient_accumulation_steps + self.assertEqual(resumed_trace, baseline_trace[consumed_microbatches:]) + + def test_maybe_load_resumed_scheduler_prefers_checkpoint_config(self): + args = SimpleNamespace(num_train_timesteps=999, flow_shift=2.5) + + with tempfile.TemporaryDirectory() as tmpdir: + scheduler_dir = os.path.join(tmpdir, "scheduler") + FlowMatchEulerDiscreteScheduler(num_train_timesteps=10, shift=7.0).save_pretrained(scheduler_dir) + + restored = maybe_load_resumed_scheduler( + args=args, + checkpoint_path=tmpdir, + noise_scheduler=FlowMatchEulerDiscreteScheduler(num_train_timesteps=999, shift=2.5), + ) + + self.assertEqual(restored.config.num_train_timesteps, 10) + self.assertEqual(restored.config.shift, 7.0) + self.assertEqual(args.num_train_timesteps, 10) + self.assertEqual(args.flow_shift, 7.0) diff --git a/examples/research_projects/rae_dit/train_rae_dit.py b/examples/research_projects/rae_dit/train_rae_dit.py new file mode 100644 index 000000000000..a6c980c66165 --- /dev/null +++ b/examples/research_projects/rae_dit/train_rae_dit.py @@ -0,0 +1,823 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import math +import os +import shutil +from pathlib import Path + +import numpy as np +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import ImageFolder +from tqdm.auto import tqdm + +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiT2DModel, RAEDiTPipeline +from diffusers.models.model_loading_utils import load_state_dict +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, check_min_version +from diffusers.utils.hub_utils import _get_model_file +from diffusers.utils.torch_utils import is_compiled_module + + +check_min_version("0.38.0.dev0") + +logger = get_logger(__name__) + + +def load_autoencoder_rae(model_name_or_path: str) -> AutoencoderRAE: + config = AutoencoderRAE.load_config(model_name_or_path) + model = AutoencoderRAE.from_config(config) + + try: + model_file = _get_model_file(model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME) + except EnvironmentError: + model_file = _get_model_file(model_name_or_path, weights_name=WEIGHTS_NAME) + + state_dict = load_state_dict(model_file) + load_result = model.load_state_dict(state_dict, strict=False, assign=True) + + unexpected_keys = set(load_result.unexpected_keys) - {"decoder.decoder_pos_embed"} + if len(load_result.missing_keys) > 0 or len(unexpected_keys) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for AutoencoderRAE: " + f"missing_keys={load_result.missing_keys}, unexpected_keys={sorted(unexpected_keys)}" + ) + return model + + +def parse_args(): + parser = argparse.ArgumentParser(description="Minimal stage-2 trainer for RAEDiT2DModel.") + parser.add_argument( + "--train_data_dir", + type=str, + required=True, + help="Path to an ImageFolder-style dataset root. Class folder names define label ids.", + ) + parser.add_argument("--output_dir", type=str, default="rae-dit", help="Directory to save checkpoints/model.") + parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.") + parser.add_argument("--report_to", type=str, default="tensorboard", help="Tracker to use with Accelerate.") + parser.add_argument("--seed", type=int, default=None, help="Seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=None, + help="Training image resolution. Defaults to the loaded RAE image size.", + ) + parser.add_argument("--center_crop", action="store_true", help="Use center crop instead of random crop.") + parser.add_argument("--random_flip", action="store_true", help="Apply random horizontal flips during training.") + parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size per device.") + parser.add_argument("--dataloader_num_workers", type=int, default=4, help="Number of dataloader workers.") + parser.add_argument("--num_train_epochs", type=int, default=10, help="Training epochs if max steps is not set.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total training steps. Overrides num_train_epochs when provided.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of update steps to accumulate before optimizer step.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Enable gradient checkpointing on the transformer.", + ) + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Gradient clipping norm.") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by world size, accumulation steps, and batch size.", + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="AdamW beta1.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="AdamW beta2.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="AdamW weight decay.") + parser.add_argument("--adam_epsilon", type=float, default=1e-8, help="AdamW epsilon.") + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help='Scheduler type. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"].', + ) + parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Scheduler warmup steps.") + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help="Override Accelerate mixed precision mode.", + ) + parser.add_argument("--allow_tf32", action="store_true", help="Enable TF32 matmul on Ampere+ GPUs.") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=1000, + help="Save Accelerate checkpoints every N optimizer steps.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help="Maximum number of checkpoint folders to keep.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help='Checkpoint path or "latest" to resume from the latest checkpoint in output_dir.', + ) + parser.add_argument( + "--pretrained_rae_model_name_or_path", + type=str, + required=True, + help="Path or Hub id for the pretrained stage-1 AutoencoderRAE.", + ) + parser.add_argument( + "--pretrained_transformer_model_name_or_path", + type=str, + default=None, + help="Optional path or Hub id for a pretrained RAEDiT transformer checkpoint.", + ) + parser.add_argument("--patch_size", type=int, default=1, help="Latent patch size for the Stage-2 transformer.") + parser.add_argument("--encoder_hidden_size", type=int, default=1152, help="Encoder token width.") + parser.add_argument("--decoder_hidden_size", type=int, default=2048, help="Decoder token width.") + parser.add_argument("--encoder_num_layers", type=int, default=28, help="Number of encoder blocks.") + parser.add_argument("--decoder_num_layers", type=int, default=2, help="Number of decoder blocks.") + parser.add_argument("--encoder_num_attention_heads", type=int, default=16, help="Encoder attention heads.") + parser.add_argument("--decoder_num_attention_heads", type=int, default=16, help="Decoder attention heads.") + parser.add_argument("--mlp_ratio", type=float, default=4.0, help="MLP expansion ratio.") + parser.add_argument( + "--class_dropout_prob", + type=float, + default=0.1, + help="Class dropout probability for classifier-free guidance readiness.", + ) + parser.add_argument( + "--num_classes", + type=int, + default=None, + help="Number of class labels. Defaults to the number of ImageFolder classes.", + ) + parser.add_argument("--use_qknorm", action="store_true", help="Enable QK norm in attention.") + parser.add_argument("--use_swiglu", action=argparse.BooleanOptionalAction, default=True, help="Use SwiGLU MLPs.") + parser.add_argument( + "--use_rope", action=argparse.BooleanOptionalAction, default=True, help="Use rotary embeddings." + ) + parser.add_argument( + "--use_rmsnorm", + action=argparse.BooleanOptionalAction, + default=True, + help="Use RMSNorm instead of LayerNorm.", + ) + parser.add_argument("--wo_shift", action="store_true", help="Disable AdaLN shift modulation.") + parser.add_argument( + "--use_pos_embed", + action=argparse.BooleanOptionalAction, + default=True, + help="Use fixed sin-cos positional embeddings on the encoder stream.", + ) + parser.add_argument( + "--num_train_timesteps", + type=int, + default=1000, + help="Number of flow-matching training timesteps.", + ) + parser.add_argument( + "--flow_shift", + type=float, + default=None, + help="Explicit flow-matching shift. If omitted, it is derived from the latent size.", + ) + parser.add_argument( + "--time_shift_base", + type=float, + default=4096.0, + help="Base latent dimensionality used to derive the default flow shift.", + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help="Weighting scheme for flow-matching timestep sampling and loss weighting.", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="Mean used when the logit-normal weighting scheme is selected.", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="Std used when the logit-normal weighting scheme is selected.", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Mode weighting scale used when weighting_scheme=mode.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=None, + help="Run validation sampling every N optimizer steps. Disabled when omitted.", + ) + parser.add_argument( + "--validation_class_label", + type=int, + default=None, + help="Class id to sample during validation. Disabled when omitted.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=1, + help="Number of validation images to generate each time validation runs.", + ) + parser.add_argument( + "--validation_num_inference_steps", + type=int, + default=25, + help="Number of denoising steps to use for validation sampling.", + ) + parser.add_argument( + "--validation_guidance_scale", + type=float, + default=1.0, + help="Classifier-free guidance scale used during validation sampling.", + ) + return parser.parse_args() + + +def build_transforms(args): + image_transforms = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + if args.random_flip: + image_transforms.append(transforms.RandomHorizontalFlip()) + image_transforms.append(transforms.ToTensor()) + return transforms.Compose(image_transforms) + + +def collate_fn(examples): + pixel_values = torch.stack([example[0] for example in examples]).float() + class_labels = torch.tensor([example[1] for example in examples], dtype=torch.long) + return {"pixel_values": pixel_values, "class_labels": class_labels} + + +def compute_resume_offsets( + global_step: int, num_update_steps_per_epoch: int, gradient_accumulation_steps: int +) -> tuple[int, int]: + first_epoch = global_step // num_update_steps_per_epoch + resume_global_step = global_step * gradient_accumulation_steps + resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps) + return first_epoch, resume_step + + +def should_skip_resumed_batch(should_resume: bool, epoch: int, first_epoch: int, step: int, resume_step: int) -> bool: + return should_resume and epoch == first_epoch and step < resume_step + + +def maybe_load_resumed_scheduler( + args, checkpoint_path: str | None, noise_scheduler: FlowMatchEulerDiscreteScheduler +) -> FlowMatchEulerDiscreteScheduler: + if checkpoint_path is None: + return noise_scheduler + + scheduler_path = os.path.join(checkpoint_path, "scheduler") + if not os.path.isdir(scheduler_path): + return noise_scheduler + + restored_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_path) + args.num_train_timesteps = int(restored_scheduler.config.num_train_timesteps) + args.flow_shift = float(restored_scheduler.config.shift) + return restored_scheduler + + +def get_latent_spec(autoencoder: AutoencoderRAE) -> tuple[int, int]: + if not autoencoder.config.reshape_to_2d: + raise ValueError("Stage-2 RAE DiT training expects `AutoencoderRAE.reshape_to_2d=True`.") + + latent_channels = int(autoencoder.config.encoder_hidden_size) + latent_size = int(autoencoder.config.encoder_input_size // autoencoder.config.encoder_patch_size) + return latent_channels, latent_size + + +def resolve_flow_shift(args, latent_channels: int, latent_size: int) -> float: + if args.flow_shift is not None: + return float(args.flow_shift) + + latent_dim = latent_channels * latent_size * latent_size + return math.sqrt(latent_dim / float(args.time_shift_base)) + + +def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None): + if device is None: + device = timesteps.device + + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device=device) + timesteps = timesteps.to(device=device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def unwrap_model(accelerator, model): + model = accelerator.unwrap_model(model) + return model._orig_mod if is_compiled_module(model) else model + + +def maybe_prune_checkpoints(output_dir: str, checkpoints_total_limit: int | None): + if checkpoints_total_limit is None: + return + + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + if len(checkpoints) < checkpoints_total_limit: + return + + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + for checkpoint in checkpoints[:num_to_remove]: + shutil.rmtree(os.path.join(output_dir, checkpoint)) + + +def validate_args(args): + validation_enabled = args.validation_class_label is not None + + if validation_enabled and args.validation_steps is None: + raise ValueError("`--validation_steps` must be provided when `--validation_class_label` is set.") + if args.validation_steps is not None and args.validation_steps < 1: + raise ValueError(f"`--validation_steps` must be >= 1, but got {args.validation_steps}.") + if args.validation_class_label is not None and args.validation_class_label < 0: + raise ValueError( + f"`--validation_class_label` must be >= 0, but got {args.validation_class_label}." + ) + if args.num_validation_images < 1: + raise ValueError(f"`--num_validation_images` must be >= 1, but got {args.num_validation_images}.") + if args.validation_num_inference_steps < 1: + raise ValueError( + f"`--validation_num_inference_steps` must be >= 1, but got {args.validation_num_inference_steps}." + ) + if args.validation_guidance_scale < 1.0: + raise ValueError( + f"`--validation_guidance_scale` must be >= 1.0, but got {args.validation_guidance_scale}." + ) + + +def log_validation( + transformer, + autoencoder: AutoencoderRAE, + scheduler: FlowMatchEulerDiscreteScheduler, + args, + accelerator: Accelerator, + step: int, + class_names: list[str], +): + if not accelerator.is_main_process: + return + + transformer_model = unwrap_model(accelerator, transformer) + was_training = transformer_model.training + transformer_model.eval() + + validation_scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler.config) + pipeline = RAEDiTPipeline( + transformer=transformer_model, + vae=autoencoder, + scheduler=validation_scheduler, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + generator = None + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed + step) + + label_name = class_names[args.validation_class_label] + logger.info( + "Running validation... Generating %s image(s) for class %s (%s).", + args.num_validation_images, + args.validation_class_label, + label_name, + ) + images = pipeline( + class_labels=[args.validation_class_label], + guidance_scale=args.validation_guidance_scale, + num_images_per_prompt=args.num_validation_images, + num_inference_steps=args.validation_num_inference_steps, + generator=generator, + output_type="pil", + ).images + + validation_dir = os.path.join(args.output_dir, "validation", f"step-{step}") + os.makedirs(validation_dir, exist_ok=True) + for image_idx, image in enumerate(images): + image.save(os.path.join(validation_dir, f"image-{image_idx}.png")) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + formatted_images = np.stack([np.asarray(image) for image in images]) + tracker.writer.add_images(f"validation/{label_name}", formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{image_idx}: {label_name}") + for image_idx, image in enumerate(images) + ] + } + ) + + if was_training: + transformer_model.train() + + +def main(): + args = parse_args() + validate_args(args) + + logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + project_config=accelerator_project_config, + log_with=args.report_to, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if args.seed is not None: + set_seed(args.seed) + + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + autoencoder = load_autoencoder_rae(args.pretrained_rae_model_name_or_path) + autoencoder.requires_grad_(False) + autoencoder.eval() + + latent_channels, latent_size = get_latent_spec(autoencoder) + if args.resolution is None: + args.resolution = int(autoencoder.config.image_size) + + dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args)) + inferred_num_classes = len(dataset.classes) + num_classes = inferred_num_classes if args.num_classes is None else int(args.num_classes) + if num_classes < inferred_num_classes: + raise ValueError( + f"`--num_classes` ({num_classes}) must be >= the number of dataset classes ({inferred_num_classes})." + ) + if args.validation_class_label is not None and args.validation_class_label >= inferred_num_classes: + raise ValueError( + f"`--validation_class_label` ({args.validation_class_label}) must be < the number of dataset classes " + f"({inferred_num_classes})." + ) + + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=True, + drop_last=True, + ) + + if args.pretrained_transformer_model_name_or_path is not None: + transformer = RAEDiT2DModel.from_pretrained(args.pretrained_transformer_model_name_or_path) + if transformer.config.in_channels != latent_channels or transformer.config.sample_size != latent_size: + raise ValueError( + "Loaded transformer latent shape does not match the selected AutoencoderRAE. " + f"Expected channels={latent_channels}, size={latent_size}; got " + f"channels={transformer.config.in_channels}, size={transformer.config.sample_size}." + ) + if transformer.config.num_classes < num_classes: + raise ValueError( + f"Loaded transformer supports {transformer.config.num_classes} classes but dataset requires {num_classes}." + ) + else: + transformer = RAEDiT2DModel( + sample_size=latent_size, + patch_size=args.patch_size, + in_channels=latent_channels, + hidden_size=(args.encoder_hidden_size, args.decoder_hidden_size), + depth=(args.encoder_num_layers, args.decoder_num_layers), + num_heads=(args.encoder_num_attention_heads, args.decoder_num_attention_heads), + mlp_ratio=args.mlp_ratio, + class_dropout_prob=args.class_dropout_prob, + num_classes=num_classes, + use_qknorm=args.use_qknorm, + use_swiglu=args.use_swiglu, + use_rope=args.use_rope, + use_rmsnorm=args.use_rmsnorm, + wo_shift=args.wo_shift, + use_pos_embed=args.use_pos_embed, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + flow_shift = resolve_flow_shift(args, latent_channels=latent_channels, latent_size=latent_size) + noise_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=args.num_train_timesteps, + shift=flow_shift, + ) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + optimizer = torch.optim.AdamW( + transformer.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + def save_model_hook(models, weights, output_dir): + if not accelerator.is_main_process: + return + + for model in models: + if isinstance(unwrap_model(accelerator, model), RAEDiT2DModel): + unwrap_model(accelerator, model).save_pretrained(os.path.join(output_dir, "transformer")) + else: + raise ValueError(f"Unexpected model type during save: {type(model)}") + + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + model = models.pop() + target_model = unwrap_model(accelerator, model) + if not isinstance(target_model, RAEDiT2DModel): + raise ValueError(f"Unexpected model type during load: {type(model)}") + + load_model = RAEDiT2DModel.from_pretrained(input_dir, subfolder="transformer") + target_model.register_to_config(**load_model.config) + target_model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + autoencoder.to(accelerator.device, dtype=weight_dtype) + + if overrode_max_train_steps: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + resume_step = 0 + resume_path = None + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + resume_path = args.resume_from_checkpoint + if not os.path.isdir(resume_path): + resume_path = os.path.join(args.output_dir, os.path.basename(resume_path)) + else: + checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + resume_path = os.path.join(args.output_dir, checkpoints[-1]) if checkpoints else None + + if resume_path is None or not os.path.isdir(resume_path): + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + else: + noise_scheduler = maybe_load_resumed_scheduler(args, resume_path, noise_scheduler) + flow_shift = float(noise_scheduler.config.shift) + accelerator.print(f"Resuming from checkpoint {resume_path}") + accelerator.load_state(resume_path) + global_step = int(os.path.basename(resume_path).split("-")[1]) + initial_global_step = global_step + first_epoch, resume_step = compute_resume_offsets( + global_step=global_step, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + + if accelerator.is_main_process: + accelerator.init_trackers( + "train_rae_dit", + config={ + **vars(args), + "latent_channels": latent_channels, + "latent_size": latent_size, + "flow_shift": flow_shift, + "inferred_num_classes": inferred_num_classes, + }, + ) + with open(os.path.join(args.output_dir, "id2label.json"), "w", encoding="utf-8") as f: + json.dump(dict(enumerate(dataset.classes)), f, indent=2, sort_keys=True) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + logger.info("***** Running stage-2 RAE DiT training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num classes = {inferred_num_classes}") + logger.info(f" RAE latent shape = ({latent_channels}, {latent_size}, {latent_size})") + logger.info(f" Flow shift = {flow_shift:.4f}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size = {total_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + for step, batch in enumerate(train_dataloader): + if should_skip_resumed_batch( + should_resume=args.resume_from_checkpoint is not None, + epoch=epoch, + first_epoch=first_epoch, + step=step, + resume_step=resume_step, + ): + continue + + with accelerator.accumulate(transformer): + pixel_values = batch["pixel_values"].to( + device=accelerator.device, dtype=weight_dtype, non_blocking=True + ) + class_labels = batch["class_labels"].to(device=accelerator.device, non_blocking=True) + + with torch.no_grad(): + latents = autoencoder.encode(pixel_values).latent + + noise = torch.randn_like(latents) + batch_size = latents.shape[0] + + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=batch_size, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + device=latents.device, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps.to(device=latents.device)[indices] + + sigmas = get_sigmas( + noise_scheduler, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + device=latents.device, + ) + noisy_latents = noise_scheduler.scale_noise(latents, timesteps, noise) + + model_pred = transformer( + hidden_states=noisy_latents, + timestep=timesteps / noise_scheduler.config.num_train_timesteps, + class_labels=class_labels, + ).sample + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + target = noise - latents + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + dim=1, + ) + loss = loss.mean() + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + maybe_prune_checkpoints(args.output_dir, args.checkpoints_total_limit) + accelerator.wait_for_everyone() + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + if accelerator.is_main_process: + noise_scheduler.save_pretrained(os.path.join(save_path, "scheduler")) + logger.info(f"Saved state to {save_path}") + + if args.validation_steps is not None and global_step % args.validation_steps == 0: + log_validation( + transformer=transformer, + autoencoder=autoencoder, + scheduler=noise_scheduler, + args=args, + accelerator=accelerator, + step=global_step, + class_names=dataset.classes, + ) + accelerator.wait_for_everyone() + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped_transformer = unwrap_model(accelerator, transformer) + unwrapped_transformer.save_pretrained(os.path.join(args.output_dir, "transformer")) + noise_scheduler.save_pretrained(os.path.join(args.output_dir, "scheduler")) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/convert_rae_stage2_to_diffusers.py b/scripts/convert_rae_stage2_to_diffusers.py new file mode 100644 index 000000000000..012a39dc4092 --- /dev/null +++ b/scripts/convert_rae_stage2_to_diffusers.py @@ -0,0 +1,635 @@ +import argparse +import json +import math +from pathlib import Path +from typing import Any + +import torch +import yaml +from huggingface_hub import HfApi, hf_hub_download + +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline +from diffusers.models.model_loading_utils import load_state_dict +from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel +from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME +from diffusers.utils.hub_utils import _get_model_file + + +DEFAULT_NUM_TRAIN_TIMESTEPS = 1000 +DEFAULT_SHIFT_BASE = 4096 +DEFAULT_TRANSFORMER_SUBFOLDER = "transformer" +DEFAULT_GUIDANCE_SUBFOLDER = "guidance_transformer" +DEFAULT_SCHEDULER_SUBFOLDER = "scheduler" + + +class RepoAccessor: + def __init__(self, repo_or_path: str, cache_dir: str | None = None): + self.repo_or_path = repo_or_path + self.cache_dir = cache_dir + self.local_root: Path | None = None + self.repo_id: str | None = None + self.repo_files: set[str] | None = None + + root = Path(repo_or_path) + if root.exists() and root.is_dir(): + self.local_root = root + else: + self.repo_id = repo_or_path + self.repo_files = set(HfApi().list_repo_files(repo_or_path)) + + def exists(self, relative_path: str) -> bool: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + return (self.local_root / relative_path).is_file() + return relative_path in self.repo_files + + def fetch(self, relative_path: str) -> Path: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + path = self.local_root / relative_path + if not path.is_file(): + raise FileNotFoundError(f"File not found: {path}") + return path + + downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir) + return Path(downloaded) + + +def read_yaml(accessor: RepoAccessor, relative_path: str) -> dict[str, Any]: + with resolve_input_path(accessor, relative_path).open() as handle: + config = yaml.safe_load(handle) + + if not isinstance(config, dict): + raise ValueError(f"Expected YAML object at `{relative_path}` to decode to a dictionary.") + + return config + + +def _get_nested(mapping: dict[str, Any], path: str) -> Any: + value: Any = mapping + for part in path.split("."): + if not isinstance(value, dict) or part not in value: + raise KeyError(f"Missing `{path}` in checkpoint/config object.") + value = value[part] + return value + + +def _resolve_section(config: dict[str, Any], *keys: str) -> dict[str, Any]: + for key in keys: + section = config.get(key) + if isinstance(section, dict): + return section + raise KeyError(f"Could not find any of {keys} in config.") + + +def _normalize_pair(value: int | list[int] | tuple[int, ...], field_name: str) -> tuple[int, int]: + if isinstance(value, (list, tuple)): + if len(value) != 2: + raise ValueError(f"`{field_name}` must have length 2 when provided as a sequence, but got {value}.") + return int(value[0]), int(value[1]) + + scalar = int(value) + return scalar, scalar + + +def _normalize_patch_size(value: int | list[int] | tuple[int, ...]) -> int | tuple[int, int]: + stage1_patch_size, stage2_patch_size = _normalize_pair(value, "patch_size") + if stage1_patch_size == stage2_patch_size: + return stage1_patch_size + return stage1_patch_size, stage2_patch_size + + +def _maybe_strip_common_prefix(state_dict: dict[str, Any], prefix: str) -> dict[str, Any]: + if len(state_dict) > 0 and all(key.startswith(prefix) for key in state_dict): + return {key[len(prefix) :]: value for key, value in state_dict.items()} + return state_dict + + +def unwrap_state_dict( + maybe_wrapped: dict[str, Any], + checkpoint_key: str | None = None, + prefer_ema: bool = True, +) -> dict[str, Any]: + if checkpoint_key: + state_dict = _get_nested(maybe_wrapped, checkpoint_key) + else: + state_dict = maybe_wrapped + if isinstance(state_dict, dict): + candidate_keys = ["ema", "model", "state_dict"] if prefer_ema else ["model", "ema", "state_dict"] + for key in candidate_keys: + if key in state_dict and isinstance(state_dict[key], dict): + state_dict = state_dict[key] + break + + if not isinstance(state_dict, dict): + raise ValueError("Resolved checkpoint payload is not a dictionary state dict.") + + state_dict = dict(state_dict) + for prefix in ("model.module.", "model.", "module."): + state_dict = _maybe_strip_common_prefix(state_dict, prefix) + return state_dict + + +def load_checkpoint(checkpoint_path: Path) -> dict[str, Any]: + suffix = checkpoint_path.suffix.lower() + if suffix == ".safetensors": + import safetensors.torch + + return safetensors.torch.load_file(checkpoint_path) + + return torch.load(checkpoint_path, map_location="cpu") + + +def build_transformer_config(stage2_params: dict[str, Any], misc: dict[str, Any]) -> dict[str, Any]: + hidden_size = _normalize_pair(stage2_params["hidden_size"], "hidden_size") + depth = _normalize_pair(stage2_params["depth"], "depth") + num_heads = _normalize_pair(stage2_params["num_heads"], "num_heads") + patch_size = _normalize_patch_size(stage2_params.get("patch_size", 1)) + + input_size = int(stage2_params["input_size"]) + in_channels = int(stage2_params["in_channels"]) + + latent_size = misc.get("latent_size") + if latent_size is not None: + if len(latent_size) != 3: + raise ValueError(f"`misc.latent_size` should have length 3, but got {latent_size}.") + latent_channels, latent_height, latent_width = [int(dim) for dim in latent_size] + if latent_channels != in_channels: + raise ValueError( + f"`misc.latent_size[0]` ({latent_channels}) does not match `stage_2.params.in_channels` ({in_channels})." + ) + if latent_height != input_size or latent_width != input_size: + raise ValueError( + f"`misc.latent_size[1:]` ({latent_height}, {latent_width}) does not match `stage_2.params.input_size` ({input_size})." + ) + + return { + "sample_size": input_size, + "patch_size": patch_size, + "in_channels": in_channels, + "hidden_size": hidden_size, + "depth": depth, + "num_heads": num_heads, + "mlp_ratio": float(stage2_params.get("mlp_ratio", 4.0)), + "class_dropout_prob": float(stage2_params.get("class_dropout_prob", 0.1)), + "num_classes": int(stage2_params.get("num_classes", misc.get("num_classes", 1000))), + "use_qknorm": bool(stage2_params.get("use_qknorm", False)), + "use_swiglu": bool(stage2_params.get("use_swiglu", True)), + "use_rope": bool(stage2_params.get("use_rope", True)), + "use_rmsnorm": bool(stage2_params.get("use_rmsnorm", True)), + "wo_shift": bool(stage2_params.get("wo_shift", False)), + "use_pos_embed": bool(stage2_params.get("use_pos_embed", True)), + } + + +def build_scheduler_config(config: dict[str, Any]) -> tuple[FlowMatchEulerDiscreteScheduler, dict[str, Any]]: + transport = _resolve_section(config, "transport") + misc = _resolve_section(config, "misc") + + transport_params = transport.get("params", {}) + latent_size = misc.get("latent_size", None) + if latent_size is None: + raise KeyError("Config must define `misc.latent_size` for scheduler conversion.") + + shift_dim = int(misc.get("time_dist_shift_dim", math.prod(int(dim) for dim in latent_size))) + shift_base = int(misc.get("time_dist_shift_base", DEFAULT_SHIFT_BASE)) + shift = math.sqrt(shift_dim / shift_base) + + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=int(transport_params.get("num_train_timesteps", DEFAULT_NUM_TRAIN_TIMESTEPS)), + shift=shift, + stochastic_sampling=False, + ) + metadata = { + "num_train_timesteps": scheduler.config.num_train_timesteps, + "shift": scheduler.config.shift, + "path_type": transport_params.get("path_type", "Linear"), + "prediction": transport_params.get("prediction", "velocity"), + "time_dist_type": transport_params.get("time_dist_type", "uniform"), + } + return scheduler, metadata + + +def _swap_projection_halves(tensor: torch.Tensor) -> torch.Tensor: + return torch.cat(tensor.chunk(2, dim=0)[::-1], dim=0) + + +def translate_transformer_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: + translated = {} + + for key, value in state_dict.items(): + if key == "pos_embed": + continue + + if ".mlp.w12." in key: + new_key = key.replace(".mlp.w12.", ".mlp.net.0.proj.") + if isinstance(value, torch.Tensor): + value = _swap_projection_halves(value) + elif ".mlp.w3." in key: + new_key = key.replace(".mlp.w3.", ".mlp.net.2.") + elif ".mlp.fc1." in key: + new_key = key.replace(".mlp.fc1.", ".mlp.net.0.proj.") + elif ".mlp.fc2." in key: + new_key = key.replace(".mlp.fc2.", ".mlp.net.2.") + else: + new_key = key + + translated[new_key] = value + + return translated + + +def convert_transformer_state_dict( + transformer_config: dict[str, Any], + checkpoint_path: Path, + checkpoint_key: str | None, + prefer_ema: bool, + output_dir: Path, + safe_serialization: bool, + verify_load: bool, + component_name: str, +) -> dict[str, Any]: + raw_checkpoint = load_checkpoint(checkpoint_path) + state_dict = unwrap_state_dict(raw_checkpoint, checkpoint_key=checkpoint_key, prefer_ema=prefer_ema) + state_dict = translate_transformer_state_dict(state_dict) + + model = RAEDiT2DModel(**transformer_config) + + load_result = model.load_state_dict(state_dict, strict=False) + missing_keys = set(load_result.missing_keys) + unexpected_keys = set(load_result.unexpected_keys) + + allowed_missing = { + "pos_embed", + "enc_feat_rope.freqs_cos", + "enc_feat_rope.freqs_sin", + "dec_feat_rope.freqs_cos", + "dec_feat_rope.freqs_sin", + } + missing_keys -= allowed_missing + + if unexpected_keys: + raise RuntimeError( + f"Unexpected keys while converting {component_name}: {sorted(unexpected_keys)[:20]}" + + (" ..." if len(unexpected_keys) > 20 else "") + ) + if missing_keys: + raise RuntimeError( + f"Missing keys while converting {component_name}: {sorted(missing_keys)[:20]}" + + (" ..." if len(missing_keys) > 20 else "") + ) + + output_dir.mkdir(parents=True, exist_ok=True) + model.save_config(output_dir) + weights_path = output_dir / (SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME) + state_dict_to_save = model.state_dict() + if safe_serialization: + import safetensors.torch + + safetensors.torch.save_file(state_dict_to_save, weights_path, metadata={"format": "pt"}) + else: + torch.save(state_dict_to_save, weights_path) + + if verify_load: + reloaded = RAEDiT2DModel.from_pretrained(output_dir, low_cpu_mem_usage=False) + if not isinstance(reloaded, RAEDiT2DModel): + raise RuntimeError(f"Verification failed for {component_name}: reloaded object is not RAEDiT2DModel.") + + return { + "checkpoint_path": str(checkpoint_path), + "checkpoint_key": checkpoint_key, + "prefer_ema": prefer_ema, + "config": transformer_config, + "num_parameters": sum(t.numel() for t in state_dict.values() if isinstance(t, torch.Tensor)), + } + + +def write_metadata(output_path: Path, metadata: dict[str, Any]) -> None: + with (output_path / "conversion_metadata.json").open("w") as handle: + json.dump(metadata, handle, indent=2) + + +def load_autoencoder_rae(model_name_or_path: str, cache_dir: str | None = None) -> AutoencoderRAE: + config = AutoencoderRAE.load_config(model_name_or_path, cache_dir=cache_dir) + model = AutoencoderRAE.from_config(config) + + try: + model_file = _get_model_file( + model_name_or_path, + weights_name=SAFETENSORS_WEIGHTS_NAME, + cache_dir=cache_dir, + ) + except EnvironmentError: + model_file = _get_model_file( + model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + ) + + state_dict = load_state_dict(model_file) + load_result = model.load_state_dict(state_dict, strict=False, assign=True) + + unexpected_keys = set(load_result.unexpected_keys) - {"decoder.decoder_pos_embed"} + if len(load_result.missing_keys) > 0 or len(unexpected_keys) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for AutoencoderRAE: " + f"missing_keys={load_result.missing_keys}, unexpected_keys={sorted(unexpected_keys)}" + ) + return model + + +def resolve_input_path(accessor: RepoAccessor, path: str) -> Path: + candidates = [path] + if path.startswith("models/"): + candidates.append(path[len("models/") :]) + + for candidate in candidates: + local_path = Path(candidate) + if local_path.is_file(): + return local_path + + try: + return accessor.fetch(candidate) + except FileNotFoundError: + continue + + raise FileNotFoundError(f"Could not resolve `{path}` from `{accessor.repo_or_path}`.") + + +def resolve_checkpoint_path( + accessor: RepoAccessor, + configured_path: str | None, + override_path: str | None, + description: str, +) -> Path | None: + path = override_path or configured_path + if path is None: + return None + try: + return resolve_input_path(accessor, path) + except FileNotFoundError as error: + raise FileNotFoundError(f"{description} not found: {path}") from error + + +def convert(args: argparse.Namespace) -> None: + weights_accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir) + config_accessor = ( + RepoAccessor(args.config_repo_or_path, cache_dir=args.cache_dir) + if args.config_repo_or_path + else weights_accessor + ) + config = read_yaml(config_accessor, args.config_path) + + stage2 = _resolve_section(config, "stage_2", "stage2") + stage2_params = stage2.get("params", {}) + misc = _resolve_section(config, "misc") + guidance = _resolve_section(config, "guidance") + + transformer_config = build_transformer_config(stage2_params, misc) + checkpoint_path = resolve_checkpoint_path( + weights_accessor, + configured_path=stage2.get("ckpt"), + override_path=args.checkpoint_path, + description="Stage-2 checkpoint", + ) + if checkpoint_path is None: + raise ValueError( + "Could not resolve a Stage-2 checkpoint. Pass `--checkpoint_path` or provide `stage_2.ckpt` in config." + ) + + scheduler, scheduler_metadata = build_scheduler_config(config) + sampler = _resolve_section(config, "sampler") + + guidance_config = guidance.get("guidance_model") + guidance_checkpoint_path = None + guidance_transformer_config = None + if guidance_config is not None and not args.skip_guidance_model: + guidance_transformer_config = build_transformer_config(guidance_config["params"], misc) + guidance_checkpoint_path = resolve_checkpoint_path( + weights_accessor, + configured_path=guidance_config.get("ckpt"), + override_path=args.guidance_checkpoint_path, + description="Guidance checkpoint", + ) + + metadata = { + "source": { + "weights_repo_or_path": args.repo_or_path, + "config_repo_or_path": args.config_repo_or_path, + "config_path": args.config_path, + "vae_model_name_or_path": args.vae_model_name_or_path, + }, + "scheduler": scheduler_metadata, + "sampler": { + "mode": sampler.get("mode", "ODE"), + "params": sampler.get("params", {}), + }, + "guidance": { + "method": guidance.get("method", "cfg"), + "scale": float(guidance.get("scale", 1.0)), + "t_min": float(guidance.get("t-min", guidance.get("t_min", 0.0))), + "t_max": float(guidance.get("t-max", guidance.get("t_max", 1.0))), + }, + "misc": misc, + } + + print(f"Using config: {args.config_path}") + print(f"Using Stage-2 checkpoint: {checkpoint_path}") + print(f"Derived scheduler shift: {scheduler.config.shift:.6f}") + if ( + metadata["sampler"]["mode"] != "ODE" + or metadata["sampler"]["params"].get("sampling_method", "euler") != "euler" + ): + print( + "Warning: upstream sampler is not the public ODE/Euler path. The saved scheduler still uses " + "FlowMatchEulerDiscreteScheduler for diffusers V1 compatibility." + ) + if guidance_checkpoint_path is not None: + print(f"Using guidance checkpoint: {guidance_checkpoint_path}") + elif guidance_config is not None: + print("Guidance model found in config but not converting it (missing checkpoint or `--skip_guidance_model`).") + + if args.dry_run: + print(json.dumps(metadata, indent=2)) + print(json.dumps({"transformer_config": transformer_config}, indent=2)) + if guidance_transformer_config is not None: + print(json.dumps({"guidance_transformer_config": guidance_transformer_config}, indent=2)) + return + + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + transformer_output_dir = output_path / args.transformer_subfolder + metadata["transformer"] = convert_transformer_state_dict( + transformer_config=transformer_config, + checkpoint_path=checkpoint_path, + checkpoint_key=args.checkpoint_key, + prefer_ema=not args.disable_ema, + output_dir=transformer_output_dir, + safe_serialization=args.safe_serialization, + verify_load=args.verify_load, + component_name="transformer", + ) + + if guidance_checkpoint_path is not None and guidance_transformer_config is not None: + guidance_output_dir = output_path / args.guidance_subfolder + metadata["guidance_transformer"] = convert_transformer_state_dict( + transformer_config=guidance_transformer_config, + checkpoint_path=guidance_checkpoint_path, + checkpoint_key=args.guidance_checkpoint_key, + prefer_ema=not args.disable_ema, + output_dir=guidance_output_dir, + safe_serialization=args.safe_serialization, + verify_load=args.verify_load, + component_name="guidance_transformer", + ) + + scheduler_output_dir = output_path / args.scheduler_subfolder + scheduler.save_pretrained(scheduler_output_dir) + + if args.vae_model_name_or_path is not None: + vae = load_autoencoder_rae(args.vae_model_name_or_path, cache_dir=args.cache_dir) + transformer = RAEDiT2DModel.from_pretrained(transformer_output_dir, low_cpu_mem_usage=False) + guidance_transformer = None + if "guidance_transformer" in metadata: + guidance_transformer = RAEDiT2DModel.from_pretrained(guidance_output_dir, low_cpu_mem_usage=False) + scheduler_for_pipe = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_output_dir) + + id2label = None + if args.id2label_json_path is not None: + with Path(args.id2label_json_path).expanduser().open("r", encoding="utf-8") as handle: + id2label = json.load(handle) + + pipe = RAEDiTPipeline( + transformer=transformer, + guidance_transformer=guidance_transformer, + vae=vae, + scheduler=scheduler_for_pipe, + id2label=id2label, + ) + pipe.save_pretrained(output_path, safe_serialization=args.safe_serialization) + metadata["pipeline"] = {"saved": True, "id2label_json_path": args.id2label_json_path} + + write_metadata(output_path, metadata) + + print(f"Saved transformer to: {transformer_output_dir}") + print(f"Saved scheduler to: {scheduler_output_dir}") + if "guidance_transformer" in metadata: + print(f"Saved guidance model to: {output_path / args.guidance_subfolder}") + if "pipeline" in metadata: + print(f"Saved pipeline to: {output_path}") + print(f"Saved metadata to: {output_path / 'conversion_metadata.json'}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert upstream RAE Stage-2 checkpoints to diffusers format") + parser.add_argument( + "--repo_or_path", + type=str, + required=True, + help="Hub repo id or local directory containing the upstream Stage-2 weights.", + ) + parser.add_argument( + "--config_repo_or_path", + type=str, + default=None, + help="Optional separate hub repo id or local directory containing the upstream YAML configs. Defaults to `repo_or_path`.", + ) + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Relative path to the upstream Stage-2 YAML config inside `config_repo_or_path`, or a direct local file path.", + ) + parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted components.") + parser.add_argument( + "--vae_model_name_or_path", + type=str, + default=None, + help="Optional diffusers AutoencoderRAE checkpoint to bundle into a full RAEDiTPipeline export.", + ) + parser.add_argument( + "--id2label_json_path", + type=str, + default=None, + help="Optional JSON mapping of class ids to label strings for the saved pipeline.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Optional Stage-2 checkpoint override. Interpreted relative to `repo_or_path` unless it is already local.", + ) + parser.add_argument( + "--guidance_checkpoint_path", + type=str, + default=None, + help="Optional autoguidance checkpoint override.", + ) + parser.add_argument( + "--checkpoint_key", + type=str, + default=None, + help="Optional dotted key path inside the Stage-2 checkpoint payload. By default the converter auto-prefers `ema` then `model`.", + ) + parser.add_argument( + "--guidance_checkpoint_key", + type=str, + default=None, + help="Optional dotted key path inside the guidance checkpoint payload.", + ) + parser.add_argument( + "--transformer_subfolder", + type=str, + default=DEFAULT_TRANSFORMER_SUBFOLDER, + help="Subfolder name used for the converted primary transformer.", + ) + parser.add_argument( + "--guidance_subfolder", + type=str, + default=DEFAULT_GUIDANCE_SUBFOLDER, + help="Subfolder name used for the converted autoguidance transformer.", + ) + parser.add_argument( + "--scheduler_subfolder", + type=str, + default=DEFAULT_SCHEDULER_SUBFOLDER, + help="Subfolder name used for the saved scheduler config.", + ) + parser.add_argument("--cache_dir", type=str, default=None, help="Optional Hugging Face Hub cache directory.") + parser.add_argument( + "--disable_ema", + action="store_true", + help="Do not prefer `ema` when the checkpoint stores both `ema` and `model` weights.", + ) + parser.add_argument( + "--skip_guidance_model", + action="store_true", + help="Do not convert `guidance.guidance_model` even if it is present in the config.", + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + help="Save converted transformer weights as safetensors.", + ) + parser.add_argument( + "--verify_load", + action="store_true", + help="Reload each saved transformer with `from_pretrained(low_cpu_mem_usage=False)` after conversion.", + ) + parser.add_argument( + "--dry_run", + action="store_true", + help="Only resolve paths and print the derived config/metadata without saving anything.", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert(args) + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d6d557a4c224..4120ad22764c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -261,6 +261,7 @@ "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", + "RAEDiT2DModel", "SanaControlNetModel", "SanaTransformer2DModel", "SanaVideoTransformer3DModel", @@ -333,6 +334,8 @@ "LDMPipeline", "LDMSuperResolutionPipeline", "PNDMPipeline", + "RAEDiTPipeline", + "RAEDiTPipelineOutput", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", @@ -1036,6 +1039,7 @@ QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, + RAEDiT2DModel, SanaControlNetModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, @@ -1106,6 +1110,8 @@ LDMPipeline, LDMSuperResolutionPipeline, PNDMPipeline, + RAEDiTPipeline, + RAEDiTPipelineOutput, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e4bc95fdf884..c33e1d8222aa 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -118,6 +118,7 @@ _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] + _import_structure["transformers.transformer_rae_dit"] = ["RAEDiT2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] @@ -238,6 +239,7 @@ PriorTransformer, PRXTransformer2DModel, QwenImageTransformer2DModel, + RAEDiT2DModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..5a174d5e1b07 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -44,6 +44,7 @@ from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel + from .transformer_rae_dit import RAEDiT2DModel from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_rae_dit.py b/src/diffusers/models/transformers/transformer_rae_dit.py new file mode 100644 index 000000000000..324179cfca92 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_rae_dit.py @@ -0,0 +1,575 @@ +from __future__ import annotations + +from math import pi, sqrt + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention import FeedForward +from ..embeddings import PatchEmbed, apply_rotary_emb, get_2d_sincos_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + + +def _expand_conditioning_tokens(hidden_states: torch.Tensor, target_length: int) -> torch.Tensor: + if hidden_states.shape[1] == target_length: + return hidden_states + + if target_length % hidden_states.shape[1] != 0: + raise ValueError( + f"Cannot repeat sequence of length {hidden_states.shape[1]} to match target length {target_length}." + ) + + if hidden_states.shape[1] == 1: + return hidden_states.expand(-1, target_length, -1) + + source_side = int(sqrt(hidden_states.shape[1])) + target_side = int(sqrt(target_length)) + if ( + source_side * source_side == hidden_states.shape[1] + and target_side * target_side == target_length + and target_side % source_side == 0 + ): + scale = target_side // source_side + batch_size, _, channels = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, source_side, source_side, channels) + hidden_states = hidden_states.repeat_interleave(scale, dim=1).repeat_interleave(scale, dim=2) + return hidden_states.reshape(batch_size, target_length, channels) + + raise ValueError( + "Cannot expand conditioning tokens without preserving their 2D layout: " + f"source length {hidden_states.shape[1]} is incompatible with target length {target_length}." + ) + + +def _to_pair(value: int | tuple[int, int] | list[int], name: str) -> tuple[int, int]: + if isinstance(value, int): + return value, value + + if len(value) != 2: + raise ValueError(f"`{name}` must be an int or a pair, but got {value}.") + + return int(value[0]), int(value[1]) + + +def _swap_swiglu_projection_halves(feedforward: FeedForward) -> None: + projection = feedforward.net[0].proj + projection.weight.data = torch.cat( + projection.weight.data.chunk(2, dim=0)[::-1], + dim=0, + ) + if projection.bias is not None: + projection.bias.data = torch.cat( + projection.bias.data.chunk(2, dim=0)[::-1], + dim=0, + ) + + +class GaussianFourierEmbedding(nn.Module): + def __init__(self, hidden_size: int, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.embedding_size = embedding_size + self.scale = scale + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.mlp = nn.Sequential( + nn.Linear(embedding_size * 2, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, timestep: torch.Tensor) -> torch.Tensor: + timestep = timestep.to(self.W.dtype) + hidden_states = timestep[:, None] * self.W[None, :] * 2 * pi + hidden_states = torch.cat([torch.sin(hidden_states), torch.cos(hidden_states)], dim=-1) + hidden_states = self.mlp(hidden_states) + return hidden_states + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float): + super().__init__() + self.num_classes = num_classes + self.dropout_prob = dropout_prob + self.embedding_table = nn.Embedding(num_classes + int(dropout_prob > 0), hidden_size) + + def token_drop( + self, class_labels: torch.LongTensor, force_drop_ids: torch.Tensor | None = None + ) -> torch.LongTensor: + if force_drop_ids is None: + drop_ids = torch.rand(class_labels.shape[0], device=class_labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + return torch.where(drop_ids, self.num_classes, class_labels) + + def forward( + self, + class_labels: torch.LongTensor, + train: bool, + force_drop_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + if (train and self.dropout_prob > 0) or (force_drop_ids is not None): + class_labels = self.token_drop(class_labels, force_drop_ids=force_drop_ids) + return self.embedding_table(class_labels) + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, pt_seq_len: int, ft_seq_len: int | None = None, theta: float = 10000.0): + super().__init__() + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim)) + positions = torch.arange(ft_seq_len, dtype=torch.float32) / ft_seq_len * pt_seq_len + freqs = torch.einsum("n,d->nd", positions, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = torch.cat( + [ + freqs[:, None, :].expand(ft_seq_len, ft_seq_len, -1), + freqs[None, :, :].expand(ft_seq_len, ft_seq_len, -1), + ], + dim=-1, + ) + + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1])) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1])) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + _, _, sequence_length, _ = hidden_states.shape + base_sequence_length = self.freqs_cos.shape[0] + repeat = sequence_length // base_sequence_length + + freqs_cos = self.freqs_cos + freqs_sin = self.freqs_sin + if repeat != 1: + freqs_cos = freqs_cos.repeat_interleave(repeat, dim=0) + freqs_sin = freqs_sin.repeat_interleave(repeat, dim=0) + + return apply_rotary_emb(hidden_states, (freqs_cos, freqs_sin), sequence_dim=2) + + +class RAEDiTAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qk_norm: bool = False, + use_rmsnorm: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}") + + self.num_heads = num_heads + self.head_dim = dim // num_heads + + norm_cls = RMSNorm if use_rmsnorm else nn.LayerNorm + norm_kwargs = {"eps": 1e-6} if use_rmsnorm else {"elementwise_affine": True, "eps": 1e-6} + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_cls(self.head_dim, **norm_kwargs) if qk_norm else nn.Identity() + self.k_norm = norm_cls(self.head_dim, **norm_kwargs) if qk_norm else nn.Identity() + self.proj = nn.Linear(dim, dim) + + def forward(self, hidden_states: torch.Tensor, rope: VisionRotaryEmbedding | None = None) -> torch.Tensor: + batch_size, _, channels = hidden_states.shape + query, key, value = self.qkv(hidden_states).chunk(3, dim=-1) + + query = query.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + key = key.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + value = value.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + + query = self.q_norm(query) + key = self.k_norm(key) + + if rope is not None: + query = rope(query) + key = rope(key) + + query = query.to(dtype=value.dtype) + key = key.to(dtype=value.dtype) + + hidden_states = F.scaled_dot_product_attention(query, key, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, channels) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class RAEDiTBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_qknorm: bool = False, + use_swiglu: bool = True, + use_rmsnorm: bool = True, + wo_shift: bool = False, + ): + super().__init__() + self.use_swiglu = use_swiglu + + if use_rmsnorm: + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + else: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = RAEDiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=use_qknorm, + use_rmsnorm=use_rmsnorm, + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + mlp_inner_dim = int(2 * mlp_hidden_dim / 3) if use_swiglu else mlp_hidden_dim + self.mlp = FeedForward( + hidden_size, + inner_dim=mlp_inner_dim, + activation_fn="swiglu" if use_swiglu else "gelu-approximate", + bias=True, + ) + + self.wo_shift = wo_shift + if wo_shift: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 4 * hidden_size, bias=True)) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward( + self, + hidden_states: torch.Tensor, + conditioning: torch.Tensor, + feat_rope: VisionRotaryEmbedding | None = None, + ) -> torch.Tensor: + if conditioning.ndim < hidden_states.ndim: + conditioning = conditioning.unsqueeze(1) + + if self.wo_shift: + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(conditioning).chunk(4, dim=-1) + shift_msa = None + shift_mlp = None + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(conditioning).chunk( + 6, dim=-1 + ) + + if shift_msa is None: + shift_msa = torch.zeros_like(scale_msa) + if shift_mlp is None: + shift_mlp = torch.zeros_like(scale_mlp) + + if shift_msa.shape[1] != hidden_states.shape[1]: + shift_msa = _expand_conditioning_tokens(shift_msa, hidden_states.shape[1]) + scale_msa = _expand_conditioning_tokens(scale_msa, hidden_states.shape[1]) + gate_msa = _expand_conditioning_tokens(gate_msa, hidden_states.shape[1]) + if shift_mlp.shape[1] != hidden_states.shape[1]: + shift_mlp = _expand_conditioning_tokens(shift_mlp, hidden_states.shape[1]) + scale_mlp = _expand_conditioning_tokens(scale_mlp, hidden_states.shape[1]) + gate_mlp = _expand_conditioning_tokens(gate_mlp, hidden_states.shape[1]) + + norm_hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states + self.attn(norm_hidden_states * (1 + scale_msa) + shift_msa, rope=feat_rope) * gate_msa + + norm_hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(norm_hidden_states * (1 + scale_mlp) + shift_mlp) * gate_mlp + return hidden_states + + +class RAEDiTFinalLayer(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, use_rmsnorm: bool = True): + super().__init__() + + if use_rmsnorm: + self.norm_final = RMSNorm(hidden_size, eps=1e-6) + else: + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + if conditioning.ndim < hidden_states.ndim: + conditioning = conditioning.unsqueeze(1) + + shift, scale = self.adaLN_modulation(conditioning).chunk(2, dim=-1) + if shift.shape[1] != hidden_states.shape[1]: + shift = _expand_conditioning_tokens(shift, hidden_states.shape[1]) + scale = _expand_conditioning_tokens(scale, hidden_states.shape[1]) + + hidden_states = self.norm_final(hidden_states) * (1 + scale) + shift + hidden_states = self.linear(hidden_states) + return hidden_states + + +class RAEDiT2DModel(ModelMixin, ConfigMixin): + r""" + Stage-2 latent diffusion transformer used by the RAE paper. + + The architecture mirrors the upstream two-stream `DiTwDDTHead` design: + an encoder path first builds conditioning tokens from the latent input, + then a decoder path denoises the latent tokens conditioned on those + encoded tokens. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "final_layer"] + + @register_to_config + def __init__( + self, + sample_size: int = 16, + patch_size: int | tuple[int, int] | list[int] = 1, + in_channels: int = 768, + hidden_size: int | tuple[int, int] | list[int] = (1152, 2048), + depth: int | tuple[int, int] | list[int] = (28, 2), + num_heads: int | tuple[int, int] | list[int] = (16, 16), + mlp_ratio: float = 4.0, + class_dropout_prob: float = 0.1, + num_classes: int = 1000, + use_qknorm: bool = False, + use_swiglu: bool = True, + use_rope: bool = True, + use_rmsnorm: bool = True, + wo_shift: bool = False, + use_pos_embed: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = in_channels + self.gradient_checkpointing = False + + encoder_hidden_size, decoder_hidden_size = _to_pair(hidden_size, "hidden_size") + encoder_num_layers, decoder_num_layers = _to_pair(depth, "depth") + encoder_num_attention_heads, decoder_num_attention_heads = _to_pair(num_heads, "num_heads") + + self.encoder_hidden_size = encoder_hidden_size + self.decoder_hidden_size = decoder_hidden_size + self.num_encoder_blocks = encoder_num_layers + self.num_decoder_blocks = decoder_num_layers + self.num_blocks = encoder_num_layers + decoder_num_layers + self.use_rope = use_rope + self.use_pos_embed = use_pos_embed + + self.s_patch_size, self.x_patch_size = _to_pair(patch_size, "patch_size") + + self.s_channel_per_token = in_channels * self.s_patch_size * self.s_patch_size + self.x_channel_per_token = in_channels * self.x_patch_size * self.x_patch_size + + self.s_embedder = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=self.s_patch_size, + in_channels=in_channels, + embed_dim=encoder_hidden_size, + bias=True, + pos_embed_type=None, + ) + self.x_embedder = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=self.x_patch_size, + in_channels=in_channels, + embed_dim=decoder_hidden_size, + bias=True, + pos_embed_type=None, + ) + + self.s_projector = ( + nn.Linear(encoder_hidden_size, decoder_hidden_size) + if encoder_hidden_size != decoder_hidden_size + else nn.Identity() + ) + self.t_embedder = GaussianFourierEmbedding(encoder_hidden_size) + self.y_embedder = LabelEmbedder(num_classes, encoder_hidden_size, class_dropout_prob) + self.final_layer = RAEDiTFinalLayer( + decoder_hidden_size, + out_channels=self.x_channel_per_token, + use_rmsnorm=use_rmsnorm, + ) + + num_patches = self.s_embedder.height * self.s_embedder.width + if use_pos_embed: + grid_size = int(num_patches**0.5) + pos_embed = get_2d_sincos_pos_embed(encoder_hidden_size, grid_size, output_type="pt") + self.register_buffer("pos_embed", pos_embed.unsqueeze(0).float(), persistent=False) + self.x_pos_embed = None + else: + self.register_buffer("pos_embed", None, persistent=False) + self.x_pos_embed = None + + if use_rope: + encoder_rope_dim = encoder_hidden_size // encoder_num_attention_heads // 2 + decoder_rope_dim = decoder_hidden_size // decoder_num_attention_heads // 2 + encoder_side = int(sqrt(num_patches)) + decoder_side = int(sqrt(self.x_embedder.height * self.x_embedder.width)) + self.enc_feat_rope = VisionRotaryEmbedding(encoder_rope_dim, pt_seq_len=encoder_side) + self.dec_feat_rope = VisionRotaryEmbedding(decoder_rope_dim, pt_seq_len=decoder_side) + else: + self.enc_feat_rope = None + self.dec_feat_rope = None + + self.blocks = nn.ModuleList( + [ + RAEDiTBlock( + hidden_size=encoder_hidden_size if index < encoder_num_layers else decoder_hidden_size, + num_heads=encoder_num_attention_heads + if index < encoder_num_layers + else decoder_num_attention_heads, + mlp_ratio=mlp_ratio, + use_qknorm=use_qknorm, + use_swiglu=use_swiglu, + use_rmsnorm=use_rmsnorm, + wo_shift=wo_shift, + ) + for index in range(self.num_blocks) + ] + ) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + nn.init.xavier_uniform_(self.x_embedder.proj.weight.view(self.x_embedder.proj.weight.shape[0], -1)) + nn.init.constant_(self.x_embedder.proj.bias, 0) + nn.init.xavier_uniform_(self.s_embedder.proj.weight.view(self.s_embedder.proj.weight.shape[0], -1)) + nn.init.constant_(self.s_embedder.proj.bias, 0) + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + for block in self.blocks: + if block.use_swiglu: + _swap_swiglu_projection_halves(block.mlp) + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, hidden_states: torch.Tensor) -> torch.Tensor: + channels = self.in_channels + patch_size = self.x_embedder.patch_size + if isinstance(patch_size, tuple): + patch_size = patch_size[0] + + height = width = int(sqrt(hidden_states.shape[1])) + if height * width != hidden_states.shape[1]: + raise ValueError("Sequence length must form a square grid for unpatchify.") + + hidden_states = hidden_states.reshape(hidden_states.shape[0], height, width, patch_size, patch_size, channels) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + hidden_states = hidden_states.reshape( + hidden_states.shape[0], channels, height * patch_size, width * patch_size + ) + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor | None = None, + class_labels: torch.LongTensor | None = None, + conditioning_hidden_states: torch.Tensor | None = None, + return_dict: bool = True, + ) -> Transformer2DModelOutput | tuple[torch.Tensor]: + if timestep is None: + raise ValueError("`timestep` must be provided.") + if class_labels is None: + raise ValueError("`class_labels` must be provided.") + + timestep = timestep.reshape(-1).to(hidden_states.device) + class_labels = class_labels.reshape(-1).to(hidden_states.device) + + timestep_emb = self.t_embedder(timestep) + class_emb = self.y_embedder(class_labels, self.training) + conditioning = F.silu(timestep_emb + class_emb) + + if conditioning_hidden_states is None: + conditioning_hidden_states = self.s_embedder(hidden_states) + if self.use_pos_embed: + conditioning_hidden_states = conditioning_hidden_states + self.pos_embed.to( + device=conditioning_hidden_states.device, dtype=conditioning_hidden_states.dtype + ) + + for block_idx in range(self.num_encoder_blocks): + block = self.blocks[block_idx] + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def custom_forward( + hidden_states: torch.Tensor, + conditioning: torch.Tensor, + block: RAEDiTBlock = block, + feat_rope: VisionRotaryEmbedding | None = self.enc_feat_rope, + ) -> torch.Tensor: + return block(hidden_states, conditioning, feat_rope=feat_rope) + + conditioning_hidden_states = self._gradient_checkpointing_func( + custom_forward, conditioning_hidden_states, conditioning + ) + else: + conditioning_hidden_states = block( + conditioning_hidden_states, + conditioning, + feat_rope=self.enc_feat_rope, + ) + + conditioning_hidden_states = F.silu(timestep_emb.unsqueeze(1) + conditioning_hidden_states) + + projector_dtype = conditioning_hidden_states.dtype + projector_param = next(self.s_projector.parameters(), None) + if projector_param is not None: + projector_dtype = projector_param.dtype + + conditioning_hidden_states = conditioning_hidden_states.to(device=hidden_states.device, dtype=projector_dtype) + conditioning_hidden_states = self.s_projector(conditioning_hidden_states) + + hidden_states = self.x_embedder(hidden_states) + if self.use_pos_embed and self.x_pos_embed is not None: + hidden_states = hidden_states + self.x_pos_embed.to(device=hidden_states.device, dtype=hidden_states.dtype) + + for block_idx in range(self.num_encoder_blocks, self.num_blocks): + block = self.blocks[block_idx] + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def custom_forward( + hidden_states: torch.Tensor, + conditioning_hidden_states: torch.Tensor, + block: RAEDiTBlock = block, + feat_rope: VisionRotaryEmbedding | None = self.dec_feat_rope, + ) -> torch.Tensor: + return block(hidden_states, conditioning_hidden_states, feat_rope=feat_rope) + + hidden_states = self._gradient_checkpointing_func( + custom_forward, hidden_states, conditioning_hidden_states + ) + else: + hidden_states = block( + hidden_states, + conditioning_hidden_states, + feat_rope=self.dec_feat_rope, + ) + + hidden_states = self.final_layer(hidden_states, conditioning_hidden_states) + hidden_states = self.unpatchify(hidden_states) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8007035338b0..85e11552bda6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -147,6 +147,7 @@ "FluxKontextInpaintPipeline", ] _import_structure["prx"] = ["PRXPipeline"] + _import_structure["rae_dit"] = ["RAEDiTPipeline", "RAEDiTPipelineOutput"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -783,6 +784,7 @@ QwenImageLayeredPipeline, QwenImagePipeline, ) + from .rae_dit import RAEDiTPipeline, RAEDiTPipelineOutput from .sana import ( SanaControlNetPipeline, SanaPipeline, diff --git a/src/diffusers/pipelines/rae_dit/__init__.py b/src/diffusers/pipelines/rae_dit/__init__.py new file mode 100644 index 000000000000..9716a8197a59 --- /dev/null +++ b/src/diffusers/pipelines/rae_dit/__init__.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = { + "pipeline_output": ["RAEDiTPipelineOutput"], + "pipeline_rae_dit": ["RAEDiTPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_output import RAEDiTPipelineOutput + from .pipeline_rae_dit import RAEDiTPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/src/diffusers/pipelines/rae_dit/pipeline_output.py b/src/diffusers/pipelines/rae_dit/pipeline_output.py new file mode 100644 index 000000000000..63296155061e --- /dev/null +++ b/src/diffusers/pipelines/rae_dit/pipeline_output.py @@ -0,0 +1,36 @@ +# Copyright 2026 HuggingFace Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class RAEDiTPipelineOutput(BaseOutput): + """ + Output class for RAE DiT image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`) + Denoised images as PIL images, a NumPy array of shape `(batch_size, height, width, num_channels)`, or a + PyTorch tensor of shape `(batch_size, num_channels, height, width)`. Torch tensors may also represent + latent outputs when `output_type="latent"`. + """ + + images: list[PIL.Image.Image] | np.ndarray | torch.Tensor diff --git a/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py new file mode 100644 index 000000000000..15f733878c60 --- /dev/null +++ b/src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import torch + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderRAE +from ...models.transformers.transformer_rae_dit import RAEDiT2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import RAEDiTPipelineOutput + + +class RAEDiTPipeline(DiffusionPipeline): + r""" + Pipeline for class-conditioned image generation in RAE latent space. + + Parameters: + transformer ([`RAEDiT2DModel`]): + Class-conditioned latent transformer used for Stage-2 denoising in RAE latent space. + vae ([`AutoencoderRAE`]): + Representation autoencoder used to decode latent samples back to RGB images. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + Flow-matching scheduler used to integrate the latent denoising trajectory. + """ + + _optional_components = ["guidance_transformer"] + model_cpu_offload_seq = "transformer->vae" + + def __init__( + self, + transformer: RAEDiT2DModel, + vae: AutoencoderRAE, + scheduler: FlowMatchEulerDiscreteScheduler, + guidance_transformer: RAEDiT2DModel | None = None, + id2label: dict[int, str] | None = None, + ): + super().__init__() + self.register_modules( + transformer=transformer, + guidance_transformer=guidance_transformer, + vae=vae, + scheduler=scheduler, + ) + serialized_id2label = None + if id2label is not None: + serialized_id2label = {str(key): value for key, value in id2label.items()} + self.register_to_config(id2label=serialized_id2label) + + self.labels = {} + if self.config.id2label is not None: + for key, value in self.config.id2label.items(): + for label in value.split(","): + self.labels[label.strip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_resize=False, do_normalize=False) + self._guidance_scale = 1.0 + + @property + def do_classifier_free_guidance(self) -> bool: + return self._guidance_scale > 1.0 + + def get_label_ids(self, label: str | list[str]) -> list[int]: + r""" + Map ImageNet-style label strings to class ids. + """ + + if not isinstance(label, list): + label = [label] + + for label_name in label: + if label_name not in self.labels: + raise ValueError( + f"{label_name} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[label_name] for label_name in label] + + def _prepare_class_labels( + self, + class_labels: int | list[int] | torch.Tensor, + num_images_per_prompt: int, + device: torch.device, + ) -> torch.LongTensor: + class_labels = torch.as_tensor(class_labels, device=device, dtype=torch.long).reshape(-1) + + if num_images_per_prompt > 1: + class_labels = class_labels.repeat_interleave(num_images_per_prompt) + + return class_labels + + def prepare_latents( + self, + batch_size: int, + latent_channels: int, + latent_size: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None, + latents: torch.Tensor | None, + ) -> torch.Tensor: + shape = (batch_size, latent_channels, latent_size, latent_size) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested a batch size of " + f"{batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + latents = latents.to(device=device, dtype=dtype) + if latents.shape != shape: + raise ValueError(f"Expected `latents` to have shape {shape}, but got {tuple(latents.shape)}.") + + return latents + + def _prepare_timesteps( + self, timestep: torch.Tensor | float, batch_size: int, sample: torch.Tensor + ) -> torch.Tensor: + if not torch.is_tensor(timestep): + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) + elif timestep.ndim == 0: + timestep = timestep[None].to(sample.device) + else: + timestep = timestep.to(sample.device) + + return timestep.expand(batch_size) + + @torch.no_grad() + def __call__( + self, + class_labels: int | list[int] | torch.Tensor, + guidance_scale: float = 1.0, + guidance_start: float = 0.0, + guidance_end: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + num_inference_steps: int = 50, + output_type: str = "pil", + return_dict: bool = True, + ) -> RAEDiTPipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + class_labels (`int`, `list[int]`, or `torch.Tensor`): + The class ids for the images to generate. + guidance_scale (`float`, *optional*, defaults to `1.0`): + Classifier-free guidance scale. Guidance is enabled when `guidance_scale > 1`. + guidance_start (`float`, *optional*, defaults to `0.0`): + Lower bound of the normalized timestep interval in which classifier-free guidance is active. + guidance_end (`float`, *optional*, defaults to `1.0`): + Upper bound of the normalized timestep interval in which classifier-free guidance is active. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + Number of images to generate per class label. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Random generator used for latent sampling. + latents (`torch.Tensor`, *optional*): + Pre-generated latent noise tensor of shape `(batch, channels, height, width)`. + num_inference_steps (`int`, *optional*, defaults to `50`): + Number of denoising steps. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format. Choose from `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`RAEDiTPipelineOutput`] instead of a tuple. + """ + + if num_images_per_prompt < 1: + raise ValueError(f"`num_images_per_prompt` must be >= 1, but got {num_images_per_prompt}.") + if guidance_scale < 1.0: + raise ValueError(f"`guidance_scale` must be >= 1.0, but got {guidance_scale}.") + if not 0.0 <= guidance_start <= guidance_end <= 1.0: + raise ValueError( + f"`guidance_start` and `guidance_end` must satisfy 0 <= guidance_start <= guidance_end <= 1, but got " + f"{guidance_start} and {guidance_end}." + ) + if output_type not in {"latent", "np", "pil", "pt"}: + raise ValueError(f"Unsupported `output_type`: {output_type}.") + if guidance_scale > 1.0 and self.transformer.config.class_dropout_prob <= 0: + raise ValueError( + "Classifier-free guidance requires `transformer.config.class_dropout_prob > 0` so a null class token exists." + ) + + self._guidance_scale = guidance_scale + + device = self._execution_device + dtype = self.transformer.dtype + + class_labels = self._prepare_class_labels( + class_labels, num_images_per_prompt=num_images_per_prompt, device=device + ) + batch_size = class_labels.shape[0] + + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size, + latent_channels=latent_channels, + latent_size=latent_size, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + self._num_timesteps = len(self.scheduler.timesteps) + + for timestep in self.progress_bar(self.scheduler.timesteps): + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents, latents], dim=0) + null_class_labels = torch.full( + (batch_size,), + self.transformer.config.num_classes, + device=device, + dtype=class_labels.dtype, + ) + class_labels_input = torch.cat([class_labels, null_class_labels], dim=0) + else: + latent_model_input = latents + class_labels_input = class_labels + + timestep_input = self._prepare_timesteps(timestep, latent_model_input.shape[0], latent_model_input) + timestep_input = timestep_input / self.scheduler.config.num_train_timesteps + model_output = self.transformer( + latent_model_input, + timestep=timestep_input, + class_labels=class_labels_input, + ).sample + + if self.do_classifier_free_guidance: + cond_model_output, uncond_model_output = model_output.chunk(2, dim=0) + guided_model_output = uncond_model_output + guidance_scale * (cond_model_output - uncond_model_output) + guidance_mask = (timestep_input[:batch_size] >= guidance_start) & ( + timestep_input[:batch_size] <= guidance_end + ) + guidance_mask = guidance_mask.view(-1, *([1] * (cond_model_output.ndim - 1))) + model_output = torch.where(guidance_mask, guided_model_output, cond_model_output) + + latents = self.scheduler.step(model_output, timestep, latents).prev_sample + + if output_type == "latent": + output = latents + else: + images = self.vae.decode(latents).sample.clamp(0, 1) + output = self.image_processor.postprocess(images, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return RAEDiTPipelineOutput(images=output) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3425cc8d2b61..019a42f112ea 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1541,6 +1541,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RAEDiT2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SanaControlNetModel(metaclass=DummyObject): _backends = ["torch"] @@ -2413,6 +2428,28 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class RAEDiTPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class RAEDiTPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class RePaintPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/transformers/test_models_rae_dit_transformer2d.py b/tests/models/transformers/test_models_rae_dit_transformer2d.py new file mode 100644 index 000000000000..b5d8970820ad --- /dev/null +++ b/tests/models/transformers/test_models_rae_dit_transformer2d.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import RAEDiT2DModel +from diffusers.models.transformers.transformer_rae_dit import _expand_conditioning_tokens +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin, TrainingTesterMixin + + +enable_full_determinism() + + +def _initialize_non_zero_stage2_head(model: RAEDiT2DModel): + torch.manual_seed(0) + + for block in model.blocks: + block.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + block.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + + model.final_layer.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) + + +class RAEDiT2DTesterConfig(BaseModelTesterConfig): + model_class = RAEDiT2DModel + main_input_name = "hidden_states" + input_shape = (8, 4, 4) + output_shape = (8, 4, 4) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self): + return { + "sample_size": 4, + "patch_size": 1, + "in_channels": 8, + "hidden_size": (32, 64), + "depth": (2, 1), + "num_heads": (4, 4), + "mlp_ratio": 2.0, + "class_dropout_prob": 0.0, + "num_classes": 10, + "use_qknorm": True, + "use_swiglu": True, + "use_rope": True, + "use_rmsnorm": True, + "wo_shift": False, + "use_pos_embed": True, + } + + def get_dummy_inputs(self): + batch_size = 2 + in_channels = 8 + sample_size = 4 + scheduler_num_train_steps = 1000 + num_class_labels = 10 + + hidden_states = randn_tensor( + (batch_size, in_channels, sample_size, sample_size), generator=self.generator, device=torch_device + ) + timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,), generator=self.generator).to( + torch_device + ) + class_labels = torch.randint(0, num_class_labels, size=(batch_size,), generator=self.generator).to( + torch_device + ) + + return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_labels} + + +class TestRAEDiT2DModel(RAEDiT2DTesterConfig, ModelTesterMixin): + def test_swiglu_feedforward_matches_previous_chunk_order(self): + model = self.model_class(**self.get_init_dict()).to(torch_device).eval() + block = model.blocks[0] + + hidden_states = randn_tensor((2, 4, model.encoder_hidden_size), generator=self.generator, device=torch_device) + projection = block.mlp.net[0].proj + output_projection = block.mlp.net[2] + + unswapped_weight = torch.cat(projection.weight.data.chunk(2, dim=0)[::-1], dim=0) + unswapped_bias = None + if projection.bias is not None: + unswapped_bias = torch.cat(projection.bias.data.chunk(2, dim=0)[::-1], dim=0) + + projected = torch.nn.functional.linear(hidden_states, unswapped_weight, unswapped_bias) + first_half, second_half = projected.chunk(2, dim=-1) + expected = torch.nn.functional.linear( + torch.nn.functional.silu(first_half) * second_half, + output_projection.weight, + output_projection.bias, + ) + + actual = block.mlp(hidden_states) + assert torch.allclose(actual, expected, atol=1e-6, rtol=1e-5) + + def test_output_with_precomputed_conditioning_hidden_states(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device).eval() + _initialize_non_zero_stage2_head(model) + + batch_size = inputs_dict[self.main_input_name].shape[0] + num_patches = (init_dict["sample_size"] // init_dict["patch_size"]) ** 2 + conditioning_hidden_states = randn_tensor( + (batch_size, num_patches, init_dict["hidden_size"][0]), generator=self.generator, device=torch_device + ) + + with torch.no_grad(): + output = model(**inputs_dict, conditioning_hidden_states=conditioning_hidden_states).sample + + assert output.shape == inputs_dict[self.main_input_name].shape + + def test_precomputed_conditioning_matches_internal_encoder_path(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device).eval() + _initialize_non_zero_stage2_head(model) + + hidden_states = inputs_dict["hidden_states"] + timesteps = inputs_dict["timestep"] + class_labels = inputs_dict["class_labels"] + + with torch.no_grad(): + timestep_emb = model.t_embedder(timesteps.reshape(-1).to(torch_device)) + class_emb = model.y_embedder(class_labels.reshape(-1).to(torch_device), train=False) + conditioning = torch.nn.functional.silu(timestep_emb + class_emb) + + conditioning_hidden_states = model.s_embedder(hidden_states) + if model.use_pos_embed: + conditioning_hidden_states = conditioning_hidden_states + model.pos_embed + + for block_idx in range(model.num_encoder_blocks): + conditioning_hidden_states = model.blocks[block_idx]( + conditioning_hidden_states, + conditioning, + feat_rope=model.enc_feat_rope, + ) + + conditioning_hidden_states = torch.nn.functional.silu( + timestep_emb.unsqueeze(1) + conditioning_hidden_states + ) + + output_internal = model(**inputs_dict).sample + output_precomputed = model( + **inputs_dict, + conditioning_hidden_states=conditioning_hidden_states, + ).sample + + assert torch.allclose(output_internal, output_precomputed, atol=1e-5, rtol=1e-4) + + def test_expand_conditioning_tokens_preserves_2d_layout(self): + hidden_states = torch.tensor([[[1.0], [2.0], [3.0], [4.0]]]) + + repeated = _expand_conditioning_tokens(hidden_states, target_length=16) + + expected = torch.tensor( + [ + [ + [1.0], + [1.0], + [2.0], + [2.0], + [1.0], + [1.0], + [2.0], + [2.0], + [3.0], + [3.0], + [4.0], + [4.0], + [3.0], + [3.0], + [4.0], + [4.0], + ] + ] + ) + assert torch.equal(repeated, expected) + + def test_expand_conditioning_tokens_broadcasts_global_conditioning(self): + hidden_states = torch.tensor([[[1.0, 2.0]]]) + + repeated = _expand_conditioning_tokens(hidden_states, target_length=4) + + expected = torch.tensor([[[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]]]) + assert torch.equal(repeated, expected) + + def test_expand_conditioning_tokens_rejects_incompatible_multi_token_layouts(self): + hidden_states = torch.randn(1, 2, 4) + + with pytest.raises(ValueError): + _expand_conditioning_tokens(hidden_states, target_length=8) + + +class TestRAEDiT2DTraining(RAEDiT2DTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + super().test_gradient_checkpointing_is_applied(expected_set={"RAEDiT2DModel"}) + + def test_gradient_checkpointing_equivalence(self): + super().test_gradient_checkpointing_equivalence(loss_tolerance=1e-4) diff --git a/tests/others/test_rae_dit_conversion.py b/tests/others/test_rae_dit_conversion.py new file mode 100644 index 000000000000..046263cd521b --- /dev/null +++ b/tests/others/test_rae_dit_conversion.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile + +import torch + +from diffusers import AutoencoderRAE +from scripts.convert_rae_stage2_to_diffusers import ( + load_autoencoder_rae, + translate_transformer_state_dict, + unwrap_state_dict, +) + + +def test_unwrap_state_dict_strips_supported_prefixes(): + tensor = torch.randn(1) + + assert unwrap_state_dict({"model.module.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} + assert unwrap_state_dict({"model.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} + assert unwrap_state_dict({"module.blocks.0.weight": tensor}) == {"blocks.0.weight": tensor} + + +def test_translate_transformer_state_dict_maps_feedforward_keys(): + weight = torch.arange(12, dtype=torch.float32).reshape(4, 3) + bias = torch.arange(4, dtype=torch.float32) + out_weight = torch.arange(6, dtype=torch.float32).reshape(2, 3) + out_bias = torch.arange(2, dtype=torch.float32) + + translated = translate_transformer_state_dict( + { + "blocks.0.mlp.w12.weight": weight, + "blocks.0.mlp.w12.bias": bias, + "blocks.0.mlp.w3.weight": out_weight, + "blocks.0.mlp.w3.bias": out_bias, + } + ) + + assert "blocks.0.mlp.net.0.proj.weight" in translated + assert "blocks.0.mlp.net.0.proj.bias" in translated + assert "blocks.0.mlp.net.2.weight" in translated + assert "blocks.0.mlp.net.2.bias" in translated + assert torch.equal( + translated["blocks.0.mlp.net.0.proj.weight"], + torch.cat(weight.chunk(2, dim=0)[::-1], dim=0), + ) + assert torch.equal( + translated["blocks.0.mlp.net.0.proj.bias"], + torch.cat(bias.chunk(2, dim=0)[::-1], dim=0), + ) + assert torch.equal(translated["blocks.0.mlp.net.2.weight"], out_weight) + assert torch.equal(translated["blocks.0.mlp.net.2.bias"], out_bias) + + +def test_translate_transformer_state_dict_maps_gelu_keys(): + fc1_weight = torch.arange(6, dtype=torch.float32).reshape(2, 3) + fc2_weight = torch.arange(6, dtype=torch.float32).reshape(3, 2) + + translated = translate_transformer_state_dict( + { + "blocks.0.mlp.fc1.weight": fc1_weight, + "blocks.0.mlp.fc2.weight": fc2_weight, + } + ) + + assert torch.equal(translated["blocks.0.mlp.net.0.proj.weight"], fc1_weight) + assert torch.equal(translated["blocks.0.mlp.net.2.weight"], fc2_weight) + + +def test_load_autoencoder_rae_loads_local_checkpoint_without_from_pretrained(): + model = AutoencoderRAE( + encoder_type="mae", + encoder_hidden_size=64, + encoder_patch_size=4, + encoder_num_hidden_layers=1, + encoder_input_size=16, + patch_size=4, + image_size=16, + num_channels=3, + decoder_hidden_size=64, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=4, + decoder_intermediate_size=128, + encoder_norm_mean=[0.5, 0.5, 0.5], + encoder_norm_std=[0.5, 0.5, 0.5], + noise_tau=0.0, + reshape_to_2d=True, + scaling_factor=1.0, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_pretrained(tmpdir, safe_serialization=False) + loaded = load_autoencoder_rae(tmpdir) + + assert isinstance(loaded, AutoencoderRAE) + assert loaded.config.image_size == 16 + assert loaded.config.patch_size == 4 diff --git a/tests/pipelines/rae_dit/__init__.py b/tests/pipelines/rae_dit/__init__.py new file mode 100644 index 000000000000..0acff376f3a3 --- /dev/null +++ b/tests/pipelines/rae_dit/__init__.py @@ -0,0 +1 @@ +# Copyright 2026 HuggingFace Inc. diff --git a/tests/pipelines/rae_dit/test_pipeline_rae_dit.py b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py new file mode 100644 index 000000000000..e4e8ccc1fb81 --- /dev/null +++ b/tests/pipelines/rae_dit/test_pipeline_rae_dit.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +import torch.nn.functional as F + +import diffusers.models.autoencoders.autoencoder_rae as _rae_module +from diffusers import AutoencoderRAE, FlowMatchEulerDiscreteScheduler, RAEDiTPipeline, RAEDiTPipelineOutput +from diffusers.models.autoencoders.autoencoder_rae import _ENCODER_FORWARD_FNS, _build_encoder +from diffusers.models.transformers.transformer_rae_dit import RAEDiT2DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import ( + CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS, + CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +def _initialize_non_zero_stage2_head(model: RAEDiT2DModel): + torch.manual_seed(0) + + for block in model.blocks: + block.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + block.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + + model.final_layer.adaLN_modulation[-1].weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.adaLN_modulation[-1].bias.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.weight.data.normal_(mean=0.0, std=0.02) + model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) + + +class _TinyTestEncoderModule(torch.nn.Module): + def __init__(self, hidden_size: int = 8, patch_size: int = 4, **kwargs): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + + def forward(self, images: torch.Tensor) -> torch.Tensor: + pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size) + tokens = pooled.flatten(2).transpose(1, 2).contiguous() + return tokens.repeat(1, 1, self.hidden_size) + + +def _tiny_test_encoder_forward(model, images): + return model(images) + + +def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers): + return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size) + + +_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward +_original_build_encoder = _build_encoder + + +def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers): + if encoder_type == "tiny_test": + return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers) + return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers) + + +_rae_module._build_encoder = _patched_build_encoder + + +class RAEDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = RAEDiTPipeline + params = CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS + batch_params = CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS + test_attention_slicing = False + test_xformers_attention = False + + @classmethod + def tearDownClass(cls): + _rae_module._build_encoder = _original_build_encoder + _ENCODER_FORWARD_FNS.pop("tiny_test", None) + super().tearDownClass() + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = RAEDiT2DModel( + sample_size=2, + patch_size=1, + in_channels=8, + hidden_size=(16, 16), + depth=(1, 1), + num_heads=(2, 2), + mlp_ratio=2.0, + class_dropout_prob=0.1, + num_classes=4, + use_qknorm=False, + use_swiglu=True, + use_rope=True, + use_rmsnorm=True, + use_pos_embed=True, + ) + _initialize_non_zero_stage2_head(transformer) + + vae = AutoencoderRAE( + encoder_type="tiny_test", + encoder_hidden_size=8, + encoder_patch_size=4, + encoder_num_hidden_layers=1, + decoder_hidden_size=16, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=2, + decoder_intermediate_size=32, + patch_size=2, + encoder_input_size=8, + image_size=4, + num_channels=3, + encoder_norm_mean=[0.5, 0.5, 0.5], + encoder_norm_std=[0.5, 0.5, 0.5], + noise_tau=0.0, + reshape_to_2d=True, + scaling_factor=1.0, + ) + scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0) + + return { + "transformer": transformer.eval(), + "guidance_transformer": None, + "vae": vae.eval(), + "scheduler": scheduler, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + return { + "class_labels": [1], + "generator": generator, + "guidance_scale": 1.0, + "num_inference_steps": 2, + "output_type": "np", + } + + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(torch_device))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + output_loaded = pipe_loaded(**self.get_dummy_inputs(torch_device))[0] + max_diff = np.abs(output_loaded - output).max() + self.assertLess(max_diff, expected_max_difference) + + def test_inference(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs("cpu")) + self.assertIsInstance(output, RAEDiTPipelineOutput) + image = output.images + image_slice = image[0, -2:, -2:, -1] + + self.assertEqual(image.shape, (1, 4, 4, 3)) + expected_slice = np.array([0.78739226, 0.79371649, 0.56565261, 0.78660309]) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-4) + + def test_inference_classifier_free_guidance(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs("cpu") + inputs.update({"guidance_scale": 2.0}) + + image = pipe(**inputs).images + self.assertEqual(image.shape, (1, 4, 4, 3)) + self.assertTrue(np.isfinite(image).all()) + + no_guidance = pipe(**self.get_dummy_inputs("cpu")).images + self.assertGreater(np.abs(image - no_guidance).max(), 1e-6) + + def test_guidance_interval_can_disable_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + base = pipe(**self.get_dummy_inputs("cpu")).images + + inputs = self.get_dummy_inputs("cpu") + inputs.pop("guidance_scale") + cfg_disabled = pipe( + **inputs, + guidance_scale=2.0, + guidance_start=0.25, + guidance_end=0.75, + ).images + + self.assertLessEqual(np.abs(base - cfg_disabled).max(), 1e-5) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-4) + + def test_latent_output(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs("cpu") + inputs.pop("output_type") + latents = pipe(**inputs, output_type="latent").images + self.assertEqual(latents.shape, (1, 8, 2, 2)) + self.assertTrue(torch.isfinite(latents).all().item()) + + def test_get_label_ids(self): + pipe = self.pipeline_class( + **self.get_dummy_components(), + id2label={ + 0: "zero", + 1: "one, first", + }, + ) + self.assertEqual(pipe.get_label_ids("first"), [1]) + self.assertEqual(pipe.get_label_ids(["zero", "one"]), [0, 1]) + + def test_save_load_preserves_label_ids(self): + pipe = self.pipeline_class( + **self.get_dummy_components(), + id2label={ + 0: "zero", + 1: "one, first", + }, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + + self.assertEqual(pipe_loaded.config.id2label, {"0": "zero", "1": "one, first"}) + self.assertEqual(pipe_loaded.get_label_ids("first"), [1]) + self.assertEqual(pipe_loaded.get_label_ids(["zero", "one"]), [0, 1]) + + def test_save_load_preserves_guidance_transformer(self): + components = self.get_dummy_components() + guidance_transformer = RAEDiT2DModel(**components["transformer"].config) + _initialize_non_zero_stage2_head(guidance_transformer) + components["guidance_transformer"] = guidance_transformer + + pipe = self.pipeline_class(**components) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "guidance_transformer"))) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + + self.assertIsNotNone(pipe_loaded.guidance_transformer) + self.assertIsInstance(pipe_loaded.guidance_transformer, RAEDiT2DModel) + self.assertEqual(pipe_loaded.guidance_transformer.config.sample_size, guidance_transformer.config.sample_size)