Add RAE Diffusion Transformer inference/preliminary training pipelines#13231
Add RAE Diffusion Transformer inference/preliminary training pipelines#13231plugyawn wants to merge 12 commits intohuggingface:mainfrom
Conversation
|
@kashif @sayakpaul would be great if you could review. Please note the no_init_weights() fix (details in the PR body); if you prefer, that could be a separate PR, but considering |
|
Thanks for the PR. To keep the scope manageable, could we break it down into separate PRs? For example,
could be a separate PR. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks!
I left some initial comments, let me know if they make sense.
| - `examples/dreambooth/train_dreambooth_flux.py` | ||
| for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks. | ||
| - `examples/flux-control/train_control_flux.py` | ||
| for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers. |
| # Preserve the `torch.nn.init.*` return contract so third-party model | ||
| # constructors that chain on the returned tensor still work under | ||
| # `no_init_weights()`. | ||
| return args[0] if len(args) > 0 else None |
| super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) | ||
|
|
||
| @unittest.skip( | ||
| "RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative." |
There was a problem hiding this comment.
I don't think this is the case? We can always skip layerwise casting for certain layer or layer groups here:
| model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02) | ||
|
|
||
|
|
||
| class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Test should use the newly added model tester mixins. You can find an example in #13046
| if shift is None: | ||
| shift = torch.zeros_like(scale) |
There was a problem hiding this comment.
This is a small function, which is okay being present in the caller sites inline?
We also probably don't need _repeat_to_length().
| if self.use_pos_embed: | ||
| pos_embed = get_2d_sincos_pos_embed( | ||
| self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt" | ||
| ) | ||
| self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0)) |
There was a problem hiding this comment.
Can we use how #13046 initialized the position embeddings?
There was a problem hiding this comment.
Yeah, that makes sense, will do that.
| ) | ||
| return hidden_states | ||
|
|
||
| def _run_block( |
There was a problem hiding this comment.
We don't need this. Let's instead follow this pattern:
|
|
||
| return class_labels | ||
|
|
||
| def _prepare_latents( |
There was a problem hiding this comment.
It should be called prepare_latents() similar to other pipelines.
| if output_type == "pt": | ||
| output = images | ||
| else: | ||
| output = images.cpu().permute(0, 2, 3, 1).float().numpy() | ||
| if output_type == "pil": | ||
| output = self.numpy_to_pil(output) |
There was a problem hiding this comment.
We should use an image processor instead here. See:
| if not return_dict: | ||
| return (output,) | ||
|
|
||
| return ImagePipelineOutput(images=output) |
There was a problem hiding this comment.
Let's give this pipeline a separate output class: RAEDiTPipelineOutput.
What does this PR do?
This PR adds support for Diffusion Transformers with Representation Autoencoders. As the authors say, "Representation Autoencoders (RAEs) reuse pretrained, frozen representation encoders together with lightweight trained decoders to provide high-fidelity, semantically rich latents for diffusion transformers."
This addresses #13225, and solves (a) Inference and adds a training example for (b), tested on an NVIDIA A100 SXM4 40GB GPU
Reference implementation: byteriper's repository
Implemented:
RAEDiT2DModelcore model andRAEDiTPipeline, checkpoint conversion script for published upstream checkpoints, documentation, and a small training example. Documentations, etc, also written.The inference works fine on visual inspection, and parity with the official inference is high;

max_abs_error=0.00001717,mean_abs_error=0.00000122, for sampling with the same class/noise schedule/init latent noise/sampling.L2R: byteriper's RAE DiT implementation, converted checkpoint in diffusers, using the same published Stage-2 checkpoint, the same class label, the same initial latent noise, and the same 25-step shifted Euler sampling schedule.

Inference is actually faster on
diffusers: on a 40GB A100, the timings are:Note: there is also a change to
no_init_weights( ). Specifically: it makes Diffusers’ skip-weight-init behave more like normal PyTorch. Now, when no_init_weights() is active, the torch.nn.init.* functions stop returning the tensor they were called on (for ref: PyTorch does return). Most models never notice this, but the RAE-DiT implementation does rely on the return value during construction, which can make otherwise valid checkpoints fail to load through the standard from_pretrained() path.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.