Skip to content

Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411

Open
Perseus14 wants to merge 1 commit into
mainfrom
wan_vae_opt
Open

Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411
Perseus14 wants to merge 1 commit into
mainfrom
wan_vae_opt

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 18, 2026

Overview

This pull request introduces significant memory, compilation, and bandwidth optimizations for the Wan (T2V, I2V, Vace and Animate) VAE encoding and decoding execution graphs on TPUs. By enforcing static shapes and moving postprocessing operations on-device, this PR significantly improves scaling efficiency and execution robustness.

Key Changes & Optimizations

1. VAE Temporal Chunking & Static Compilation

  • Introduced configurable chunk size parameters (vae_decode_chunk and vae_encode_chunk) in the base configuration files.
  • Added static temporal zero-padding prior to jax.lax.scan loops and precise output trimming post-scan. This guarantees uniform static shapes across iterations, allowing seamless static JIT/XLA compilation without dynamic shape fallbacks or recompilation overhead.

2. On-TPU Output Quantization

  • Transitioned the final video post-processing stages (range rescaling to [0, 1], clipping, and scaling to uint8) directly into the JAX/TPU execution graph across all Wan pipelines (WanPipeline, WanPipelineI2V, and WanAnimatePipeline).
  • Gathering sharded arrays as uint8 instead of floating-point cuts device-to-host memory interconnect transfer bandwidth by up to 4x.
  • Completely eliminated the need for torch.Tensor conversions and PyTorch processor overhead during generation output assembly.

3. Distributed Sharding Robustness

  • Standardized VAE 1D spatial sharding rules to use P("redundant", None, None, "vae_spatial", None). Explicitly marking the batch axis as redundant prevents unintended cross-mesh synchronization across replicas under XLA SPMD partitioning.
  • Added defensive dimension existence checks when accessing mesh properties in attention_flax.py to support diverse hardware topologies gracefully.
  • Cleaned up Attention QKV splitting logic with direct matrix reshaping, establishing optimal contiguous memory alignment for dot-product attention kernels.

4. Configurable Activation & Weight Precision

  • Integrated robust type casting in pyconfig.py for user-defined numerical precision (vae_dtype, vae_weights_dtype, and scheduler_dtype), defaulting VAE runtime memory down to 16-bit precision to save 50% High Bandwidth Memory (HBM) capacity.

Performance

Note

Test Configuration: 720p | 81 frames | 40 steps
Hardware: TPU 7x-8
JAX Version: v0.10.0

Model Variant BaselIne Generation Time Current Generation Time
WAN2.2 T2V 132.2s 130.6s
WAN2.2 I2V 133.3s 132.3s
WAN2.1 T2V 132.1s 130.6s
WAN2.1 I2V 142.7s 141.6s

Conclusion: No visual change from baseline across all tested variants.

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 18, 2026 10:02
@github-actions
Copy link
Copy Markdown

@Perseus14 Perseus14 self-assigned this May 18, 2026
@Perseus14 Perseus14 requested review from eltsai and ninatu May 18, 2026 10:03
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Comment thread src/maxdiffusion/configs/base_wan_14b.yml Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py
@Perseus14 Perseus14 force-pushed the wan_vae_opt branch 4 times, most recently from 7406dd0 to 0c91d18 Compare May 18, 2026 17:29
@Perseus14 Perseus14 requested review from mbohlool and ninatu May 18, 2026 18:10
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces significant memory, compilation, and bandwidth optimizations for the Wan VAE on TPUs. By implementing temporal chunking, enforcing static shapes, and moving postprocessing operations directly into the JAX execution graph, the PR achieves better scaling efficiency and execution robustness.

🔍 General Feedback

  • Optimization: The transition to on-TPU postprocessing (quantizing to uint8 on-device) is an excellent optimization that significantly reduces host-device bandwidth.
  • Architectural Clarity: Standardizing the spatial sharding rules and using temporal chunking with static padding/trimming is a very high-quality improvement for JAX/XLA performance.
  • Breaking Change: Note that the pipeline's output has changed from floating-point values (typically [0, 1]) to uint8 values ([0, 255]), and it now retains the batch dimension. This is well-handled in the included utility updates but may affect external users.
  • Test Integrity: The dramatic increase in tolerance for the KV cache test (atol=180) is concerning and should be reviewed to ensure it's not masking a functional regression.

Comment thread src/maxdiffusion/tests/wan_kv_cache_test.py Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline.py
enc = self.quant_conv(out)
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
enc = jnp.concatenate([mu, logvar], axis=-1)
feat_cache.init_cache()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is feat_cache.init_cache() needed here? feat_cache.init_cache() is always called in the beginning of the function, therefore it seems like here it should not be needed

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eltsai Could you confirm? I tested by removing it and it works, video is generated with no quality difference.

Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like feat_cache.init_cache() is unnecessary here as well

ninatu
ninatu previously approved these changes May 19, 2026
…ic shapes, and JIT support

- Address VAE review comments: remove redundant init_cache and fix ceiling division in iter_

- Fix KVCache test
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants