Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411
Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411Perseus14 wants to merge 1 commit into
Conversation
7406dd0 to
0c91d18
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces significant memory, compilation, and bandwidth optimizations for the Wan VAE on TPUs. By implementing temporal chunking, enforcing static shapes, and moving postprocessing operations directly into the JAX execution graph, the PR achieves better scaling efficiency and execution robustness.
🔍 General Feedback
- Optimization: The transition to on-TPU postprocessing (quantizing to
uint8on-device) is an excellent optimization that significantly reduces host-device bandwidth. - Architectural Clarity: Standardizing the spatial sharding rules and using temporal chunking with static padding/trimming is a very high-quality improvement for JAX/XLA performance.
- Breaking Change: Note that the pipeline's output has changed from floating-point values (typically
[0, 1]) touint8values ([0, 255]), and it now retains the batch dimension. This is well-handled in the included utility updates but may affect external users. - Test Integrity: The dramatic increase in tolerance for the KV cache test (
atol=180) is concerning and should be reviewed to ensure it's not masking a functional regression.
| enc = self.quant_conv(out) | ||
| mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] | ||
| enc = jnp.concatenate([mu, logvar], axis=-1) | ||
| feat_cache.init_cache() |
There was a problem hiding this comment.
is feat_cache.init_cache() needed here? feat_cache.init_cache() is always called in the beginning of the function, therefore it seems like here it should not be needed
There was a problem hiding this comment.
@eltsai Could you confirm? I tested by removing it and it works, video is generated with no quality difference.
There was a problem hiding this comment.
It seems like feat_cache.init_cache() is unnecessary here as well
…ic shapes, and JIT support - Address VAE review comments: remove redundant init_cache and fix ceiling division in iter_ - Fix KVCache test
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
Overview
This pull request introduces significant memory, compilation, and bandwidth optimizations for the Wan (T2V, I2V, Vace and Animate) VAE encoding and decoding execution graphs on TPUs. By enforcing static shapes and moving postprocessing operations on-device, this PR significantly improves scaling efficiency and execution robustness.
Key Changes & Optimizations
1. VAE Temporal Chunking & Static Compilation
vae_decode_chunkandvae_encode_chunk) in the base configuration files.jax.lax.scanloops and precise output trimming post-scan. This guarantees uniform static shapes across iterations, allowing seamless static JIT/XLA compilation without dynamic shape fallbacks or recompilation overhead.2. On-TPU Output Quantization
[0, 1], clipping, and scaling touint8) directly into the JAX/TPU execution graph across all Wan pipelines (WanPipeline,WanPipelineI2V, andWanAnimatePipeline).uint8instead of floating-point cuts device-to-host memory interconnect transfer bandwidth by up to 4x.torch.Tensorconversions and PyTorch processor overhead during generation output assembly.3. Distributed Sharding Robustness
P("redundant", None, None, "vae_spatial", None). Explicitly marking the batch axis as redundant prevents unintended cross-mesh synchronization across replicas under XLA SPMD partitioning.attention_flax.pyto support diverse hardware topologies gracefully.4. Configurable Activation & Weight Precision
pyconfig.pyfor user-defined numerical precision (vae_dtype,vae_weights_dtype, andscheduler_dtype), defaulting VAE runtime memory down to 16-bit precision to save 50% High Bandwidth Memory (HBM) capacity.Performance
Note
Test Configuration: 720p | 81 frames | 40 steps
Hardware: TPU 7x-8
JAX Version: v0.10.0
Conclusion: No visual change from baseline across all tested variants.