Skip to content

Add RAE Diffusion Transformer inference/preliminary training pipelines#13231

Open
plugyawn wants to merge 12 commits intohuggingface:mainfrom
plugyawn:rae-dit-training
Open

Add RAE Diffusion Transformer inference/preliminary training pipelines#13231
plugyawn wants to merge 12 commits intohuggingface:mainfrom
plugyawn:rae-dit-training

Conversation

@plugyawn
Copy link

@plugyawn plugyawn commented Mar 9, 2026

What does this PR do?

This PR adds support for Diffusion Transformers with Representation Autoencoders. As the authors say, "Representation Autoencoders (RAEs) reuse pretrained, frozen representation encoders together with lightweight trained decoders to provide high-fidelity, semantically rich latents for diffusion transformers."

This addresses #13225, and solves (a) Inference and adds a training example for (b), tested on an NVIDIA A100 SXM4 40GB GPU

Reference implementation: byteriper's repository

Implemented: RAEDiT2DModel core model and RAEDiTPipeline, checkpoint conversion script for published upstream checkpoints, documentation, and a small training example. Documentations, etc, also written.

The inference works fine on visual inspection, and parity with the official inference is high; max_abs_error=0.00001717, mean_abs_error=0.00000122, for sampling with the same class/noise schedule/init latent noise/sampling.
Image

L2R: byteriper's RAE DiT implementation, converted checkpoint in diffusers, using the same published Stage-2 checkpoint, the same class label, the same initial latent noise, and the same 25-step shifted Euler sampling schedule.
image

Inference is actually faster on diffusers: on a 40GB A100, the timings are:

Precision CFG Steps Diffusers sec/img Upstream sec/img Diffusers img/s Delta
bf16 1.0 25 0.817 0.913 1.225 +11.8%
bf16 4.0 25 0.852 0.931 1.174 +9.3%
bf16 1.0 50 1.568 1.761 0.638 +12.3%
bf16 4.0 50 1.649 1.853 0.606 +12.4%

Note: there is also a change to no_init_weights( ). Specifically: it makes Diffusers’ skip-weight-init behave more like normal PyTorch. Now, when no_init_weights() is active, the torch.nn.init.* functions stop returning the tensor they were called on (for ref: PyTorch does return). Most models never notice this, but the RAE-DiT implementation does rely on the return value during construction, which can make otherwise valid checkpoints fail to load through the standard from_pretrained() path.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@plugyawn plugyawn changed the title Add Stage-2 RAE DiT support with pipeline, conversion, and training tooling RAE DiT inference, checkpoint conversion, and preliminary training tooling Mar 9, 2026
@plugyawn plugyawn changed the title RAE DiT inference, checkpoint conversion, and preliminary training tooling Add RAE Diffusion Transformer inference/preliminary training pipelines Mar 9, 2026
@plugyawn plugyawn marked this pull request as draft March 9, 2026 05:46
@plugyawn plugyawn marked this pull request as ready for review March 9, 2026 05:51
@plugyawn
Copy link
Author

plugyawn commented Mar 9, 2026

@kashif @sayakpaul would be great if you could review. Please note the no_init_weights() fix (details in the PR body); if you prefer, that could be a separate PR, but considering diffusers is supposed to be an extension to torch, I guess it makes sense?

@sayakpaul
Copy link
Member

Thanks for the PR. To keep the scope manageable, could we break it down into separate PRs?

For example,

there is also a change to no_init_weights( ). Specifically: it makes Diffusers’ skip-weight-init behave more like normal PyTorch. Now, when no_init_weights() is active, the torch.nn.init.* functions stop returning the tensor they were called on (for ref: PyTorch does return). Most models never notice this, but the RAE-DiT implementation does rely on the return value during construction, which can make otherwise valid checkpoints fail to load through the standard from_pretrained() path.

could be a separate PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I left some initial comments, let me know if they make sense.

Comment on lines +13 to +16
- `examples/dreambooth/train_dreambooth_flux.py`
for the flow-matching training loop structure, checkpoint resume flow, and `accelerate.save_state(...)` hooks.
- `examples/flux-control/train_control_flux.py`
for the transformer-only save layout and SD3-style flow-matching timestep weighting helpers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't belong here.

Comment on lines +218 to +221
# Preserve the `torch.nn.init.*` return contract so third-party model
# constructors that chain on the returned tensor still work under
# `no_init_weights()`.
return args[0] if len(args) > 0 else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide an example?

super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)

@unittest.skip(
"RAEDiT initializes the output head to zeros, so cosine-based layerwise casting checks are uninformative."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the case? We can always skip layerwise casting for certain layer or layer groups here:

_skip_layerwise_casting_patterns = None

model.final_layer.linear.bias.data.normal_(mean=0.0, std=0.02)


class RAEDiT2DModelTests(ModelTesterMixin, unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test should use the newly added model tester mixins. You can find an example in #13046

Comment on lines +48 to +49
if shift is None:
shift = torch.zeros_like(scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a small function, which is okay being present in the caller sites inline?

We also probably don't need _repeat_to_length().

Comment on lines +466 to +470
if self.use_pos_embed:
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1], int(sqrt(self.pos_embed.shape[1])), output_type="pt"
)
self.pos_embed.data.copy_(pos_embed.float().unsqueeze(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use how #13046 initialized the position embeddings?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense, will do that.

)
return hidden_states

def _run_block(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this. Let's instead follow this pattern:

for index_block, block in enumerate(self.transformer_blocks):


return class_labels

def _prepare_latents(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be called prepare_latents() similar to other pipelines.

Comment on lines +247 to +252
if output_type == "pt":
output = images
else:
output = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
output = self.numpy_to_pil(output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use an image processor instead here. See:

image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (output,)

return ImagePipelineOutput(images=output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give this pipeline a separate output class: RAEDiTPipelineOutput.

@sayakpaul sayakpaul requested review from dg845 and kashif March 9, 2026 11:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants