From a699abedbb8ca7cb404734e6cbe8f434ca79ca92 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Tue, 19 May 2026 08:40:32 +0000 Subject: [PATCH] Optimize WAN VAE: add chunking, 2D spatial sharding, padding for static shapes, and JIT support - Address VAE review comments: remove redundant init_cache and fix ceiling division in iter_ - Fix KVCache test --- src/maxdiffusion/configs/base_wan_14b.yml | 12 +- src/maxdiffusion/configs/base_wan_1_3b.yml | 12 + src/maxdiffusion/configs/base_wan_27b.yml | 12 +- src/maxdiffusion/configs/base_wan_animate.yml | 10 + src/maxdiffusion/configs/base_wan_i2v_14b.yml | 12 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 12 +- src/maxdiffusion/generate_wan.py | 8 +- src/maxdiffusion/models/attention_flax.py | 6 +- .../models/wan/autoencoder_kl_wan.py | 210 ++++++++++++------ .../pipelines/wan/wan_pipeline.py | 56 ++--- .../pipelines/wan/wan_pipeline_2_1.py | 2 +- .../pipelines/wan/wan_pipeline_2_2.py | 2 +- .../pipelines/wan/wan_pipeline_animate.py | 14 +- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 6 +- .../pipelines/wan/wan_pipeline_i2v_2p2.py | 7 +- .../pipelines/wan/wan_vace_pipeline_2_1.py | 18 +- src/maxdiffusion/pyconfig.py | 18 +- .../scheduling_unipc_multistep_flax.py | 3 +- src/maxdiffusion/tests/wan_kv_cache_test.py | 18 +- src/maxdiffusion/tests/wan_vae_test.py | 82 +++---- src/maxdiffusion/utils/export_utils.py | 16 +- 21 files changed, 347 insertions(+), 189 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f432928aa..911ccfe33 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -41,10 +41,20 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +vae_weights_dtype: 'float32' +vae_dtype: 'float32' +scheduler_dtype: 'float32' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False -vae_spatial: -1 # default to total_device * 2 // (dp) + +# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory. +vae_decode_chunk: 1 + +# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value. +# Increase to improve encode time at the cost of memory. +vae_encode_chunk: 4 +vae_spatial: -1 # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 0e0552656..56ef45cab 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -41,10 +41,21 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +vae_weights_dtype: 'float32' +vae_dtype: 'float32' +scheduler_dtype: 'float32' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False +# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory. +vae_decode_chunk: 1 + +# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value. +# Increase to improve encode time at the cost of memory. +vae_encode_chunk: 4 +vae_spatial: -1 + # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" # fp32 activations and fp32 weights with HIGHEST will provide the best precision @@ -159,6 +170,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'context'], ] + vae_logical_axis_rules: [ ['activation_batch', 'redundant'], ['activation_length', 'vae_spatial'], diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index bf29fa867..546052ded 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -41,10 +41,20 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +vae_weights_dtype: 'float32' +vae_dtype: 'float32' +scheduler_dtype: 'float32' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False -vae_spatial: -1 # default to total_device * 2 // (dp) + +# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory. +vae_decode_chunk: 1 + +# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value. +# Increase to improve encode time at the cost of memory. +vae_encode_chunk: 4 +vae_spatial: -1 # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 8f95c8558..c18a0eecf 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -41,9 +41,19 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +vae_weights_dtype: 'float32' +vae_dtype: 'float32' +scheduler_dtype: 'float32' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False + +# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory. +vae_decode_chunk: 1 + +# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value. +# Increase to improve encode time at the cost of memory. +vae_encode_chunk: 4 # Number of devices to shard VAE spatial activations across. -1 uses all devices. vae_spatial: -1 diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index ca2d239ab..aba232db9 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -41,10 +41,20 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +vae_weights_dtype: 'float32' +vae_dtype: 'float32' +scheduler_dtype: 'float32' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False -vae_spatial: -1 # default to total_device * 2 // (dp) + +# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory. +vae_decode_chunk: 1 + +# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value. +# Increase to improve encode time at the cost of memory. +vae_encode_chunk: 4 +vae_spatial: -1 # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 90799524c..201a97874 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -41,10 +41,20 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +vae_weights_dtype: 'float32' +vae_dtype: 'float32' +scheduler_dtype: 'float32' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False -vae_spatial: -1 # default to total_device * 2 // (dp) + +# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory. +vae_decode_chunk: 1 + +# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value. +# Increase to improve encode time at the cost of memory. +vae_encode_chunk: 4 +vae_spatial: -1 # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 78a58c2d6..883393310 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -311,11 +311,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): f" Inference: {generation_time:>7.1f}s", ] if trace: + vae_decode_total = trace.get("vae_decode", 0.0) + vae_decode_tpu = trace.get("vae_decode_tpu", 0.0) + vae_decode_post = vae_decode_total - vae_decode_tpu summary.extend([ f" {'─' * 40}", f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s", + f" - VAE Encode: {trace.get('vae_encode', 0.0):>7.1f}s", f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s", - f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s", + f" VAE Decode: {vae_decode_total:>7.1f}s", + f" - TPU Compute: {vae_decode_tpu:>7.1f}s", + f" - Host Formatting: {vae_decode_post:>7.1f}s", ]) summary.append(f"{'=' * 50}") max_logging.log("\n".join(summary)) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 6ff4b4fed..839f1718d 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -325,7 +325,7 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - num_context_shards = mesh.shape["context"] + num_context_shards = mesh.shape["context"] if "context" in mesh.shape else 1 query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) key, _ = _reshape_data_for_flash(key, heads, num_context_shards) value, _ = _reshape_data_for_flash(value, heads, num_context_shards) @@ -491,7 +491,9 @@ def ring_scan_body(carry, _): raise ValueError("ring attention requires context > 1") return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) - devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) + data_dim = mesh.shape["data"] if "data" in mesh.shape else 1 + fsdp_dim = mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1 + devices_in_batch_sharding = data_dim * fsdp_dim # This warning might show up when doing model eval for example, when calculating model flops # and that is expected. if not (query.shape[0] / devices_in_batch_sharding).is_integer(): diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 8948a24ad..3970aa3db 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -121,7 +121,7 @@ def __init__( def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: # Sharding Width (index 3) # Spec: (Batch, Time, Height, Width, Channels) - spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) + spatial_sharding = NamedSharding(self.mesh, P("redundant", None, None, "vae_spatial", None)) x = jax.lax.with_sharding_constraint(x, spatial_sharding) current_padding = list(self._causal_padding) @@ -159,26 +159,30 @@ def __init__( images: bool = True, eps: float = 1e-6, use_bias: bool = False, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, ): broadcastable_dims = (1, 1, 1) if not images else (1, 1) shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.eps = eps self.channel_first = channel_first self.scale = dim**0.5 + self.dtype = dtype # Initialize gamma as parameter - self.gamma = nnx.Param(jnp.ones(shape)) + self.gamma = nnx.Param(jnp.ones(shape, dtype=weights_dtype)) if use_bias: - self.bias = nnx.Param(jnp.zeros(shape)) + self.bias = nnx.Param(jnp.zeros(shape, dtype=weights_dtype)) else: self.bias = 0 def __call__(self, x: jax.Array) -> jax.Array: + x = x.astype(self.dtype) normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True) normalized = x / jnp.maximum(normalized, self.eps) normalized = normalized * self.scale * self.gamma if self.bias: - return normalized + self.bias.value - return normalized + return (normalized + self.bias.value).astype(self.dtype) + return normalized.astype(self.dtype) class WanUpsample(nnx.Module): @@ -254,6 +258,7 @@ def __init__( ): self.dim = dim self.mode = mode + self.dtype = dtype self.time_conv = nnx.data(None) if mode == "upsample2d": @@ -340,6 +345,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): # Input x: (N, D, H, W, C), assume C = self.dim b, t, h, w, c = x.shape assert c == self.dim + x = x.astype(self.dtype) if self.mode == "upsample3d": if feat_cache is not None: @@ -353,7 +359,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], RepSentinel): - cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1) + cache_x = jnp.concatenate([jnp.zeros(cache_x.shape, dtype=cache_x.dtype), cache_x], axis=1) if isinstance(feat_cache[idx], RepSentinel): x = self.time_conv(x) else: @@ -402,7 +408,9 @@ def __init__( self.nonlinearity = get_activation(non_linearity) # layers - self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False) + self.norm1 = WanRMS_norm( + dim=in_dim, rngs=rngs, images=False, channel_first=False, dtype=dtype, weights_dtype=weights_dtype + ) self.conv1 = WanCausalConv3d( rngs=rngs, in_channels=in_dim, @@ -414,7 +422,9 @@ def __init__( weights_dtype=weights_dtype, precision=precision, ) - self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False) + self.norm2 = WanRMS_norm( + dim=out_dim, rngs=rngs, images=False, channel_first=False, dtype=dtype, weights_dtype=weights_dtype + ) self.conv2 = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -488,7 +498,8 @@ def __init__( precision: jax.lax.Precision = None, ): self.dim = dim - self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False) + + self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False, dtype=dtype, weights_dtype=weights_dtype) self.to_qkv = nnx.Conv( in_features=dim, out_features=dim * 3, @@ -510,6 +521,7 @@ def __init__( precision=precision, ) + @jax.named_scope("WanVAEAttentionBlock") def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): identity = x batch_size, time, height, width, channels = x.shape @@ -517,14 +529,15 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): x = x.reshape(batch_size * time, height, width, channels) x = self.norm(x) - qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3) - # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) - qkv = qkv.reshape(batch_size * time, 1, -1, channels * 3) - qkv = jnp.transpose(qkv, (0, 1, 3, 2)) - q, k, v = jnp.split(qkv, 3, axis=-2) - q = jnp.transpose(q, (0, 1, 3, 2)) - k = jnp.transpose(k, (0, 1, 3, 2)) - v = jnp.transpose(v, (0, 1, 3, 2)) + qkv = self.to_qkv(x) # Output: (N*T, H, W, C * 3) + qkv = qkv.reshape(batch_size * time, height * width, 3, channels) + q = qkv[:, :, 0, :] # (B*T, H*W, C) + k = qkv[:, :, 1, :] + v = qkv[:, :, 2, :] + q = q[:, None, :, :] # (B*T, 1, H*W, C) + k = k[:, None, :, :] + v = v[:, None, :, :] + x = jax.nn.dot_product_attention(q, k, v) x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels) @@ -748,7 +761,9 @@ def __init__( ) # output blocks - self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs) + self.norm_out = WanRMS_norm( + out_dim, channel_first=False, images=False, rngs=rngs, dtype=dtype, weights_dtype=weights_dtype + ) self.conv_out = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -896,7 +911,9 @@ def __init__( self.up_blocks = nnx.data(self.up_blocks) # output blocks - self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) + self.norm_out = WanRMS_norm( + dim=out_dim, images=False, rngs=rngs, channel_first=False, dtype=dtype, weights_dtype=weights_dtype + ) self.conv_out = WanCausalConv3d( rngs=rngs, in_channels=out_dim, @@ -1045,10 +1062,23 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, + vae_decode_chunk: int = 1, + vae_encode_chunk: int = 4, ): self.z_dim = z_dim + assert vae_decode_chunk >= 1 or vae_decode_chunk == -1, f"vae_decode_chunk must be >= 1 or -1, got {vae_decode_chunk}" + assert vae_encode_chunk >= 1 or vae_encode_chunk == -1, f"vae_encode_chunk must be >= 1 or -1, got {vae_encode_chunk}" + self.vae_decode_chunk = vae_decode_chunk + self.vae_encode_chunk = vae_encode_chunk self.temperal_downsample = temperal_downsample self.temporal_upsample = temperal_downsample[::-1] + self.temporal_downsample_factor = 2 ** sum(temperal_downsample) + + if self.vae_encode_chunk != -1: + assert ( + self.vae_encode_chunk % self.temporal_downsample_factor == 0 + ), f"vae_encode_chunk ({self.vae_encode_chunk}) must be a multiple of the temporal downsampling factor ({self.temporal_downsample_factor})." + self.latents_mean = latents_mean self.latents_std = latents_std @@ -1111,52 +1141,82 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}" + x = x.astype(self.encoder.conv_in.conv.dtype) t = x.shape[1] - iter_ = 1 + (t - 1) // 4 + CHUNK_SIZE = self.vae_encode_chunk + if CHUNK_SIZE == -1: + CHUNK_SIZE = max(1, t - 1) + assert ( + CHUNK_SIZE % self.temporal_downsample_factor == 0 + ), f"When vae_encode_chunk is -1, the input sequence length - 1 ({CHUNK_SIZE}) must be a multiple of the temporal downsampling factor ({self.temporal_downsample_factor})." + # Number of chunk iterations: 1 for init frame, then ceil((t-1)/CHUNK_SIZE) for the rest + iter_ = 1 + ((t - 1 + CHUNK_SIZE - 1) // CHUNK_SIZE) if t > 1 else 1 enc_feat_map = feat_cache._enc_feat_map - spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) + spatial_sharding = NamedSharding(self.mesh, P("redundant", None, None, "vae_spatial", None)) def finalize(out, enc_feat_map): feat_cache._enc_feat_map = enc_feat_map enc = self.quant_conv(out) - mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] - enc = jnp.concatenate([mu, logvar], axis=-1) - feat_cache.init_cache() return enc # First iteration (i=0): size 1 - chunk_0 = x[:, :1, ...] - out_0, enc_feat_map, _ = self.encoder(chunk_0, feat_cache=enc_feat_map, feat_idx=0) - out_0 = jax.lax.with_sharding_constraint(out_0, spatial_sharding) + with jax.named_scope("AutoencoderKLWan_encode_chunk_0"): + chunk_0 = x[:, :1, ...] + out_0, enc_feat_map, _ = self.encoder(chunk_0, feat_cache=enc_feat_map, feat_idx=0) + out_0 = jax.lax.with_sharding_constraint(out_0, spatial_sharding) if iter_ <= 1: return finalize(out_0, enc_feat_map) - CHUNK_SIZE = 4 # We must adjust enc_feat_map from None/'Rep'/'zeros' for scan shapes. # By running chunk 1 outside the scan, the PyTree shapes will reach their stable state. - chunk_1 = x[:, 1 : (1 + CHUNK_SIZE), ...] - out_1, enc_feat_map, _ = self.encoder(chunk_1, feat_cache=enc_feat_map, feat_idx=0) - out_1 = jax.lax.with_sharding_constraint(out_1, spatial_sharding) + with jax.named_scope("AutoencoderKLWan_encode_chunk_1"): + chunk_1 = x[:, 1 : (1 + CHUNK_SIZE), ...] + out_1, enc_feat_map, _ = self.encoder(chunk_1, feat_cache=enc_feat_map, feat_idx=0) + out_1 = jax.lax.with_sharding_constraint(out_1, spatial_sharding) if iter_ <= 2: out = jnp.concatenate([out_0, out_1], axis=1) out = jax.lax.with_sharding_constraint(out, spatial_sharding) return finalize(out, enc_feat_map) - # Prepare the remaining chunks (each size 4) to be scanned over - # x_rest shape: (B, (iter_-2)*4, H, W, C) - x_rest = x[:, 5:, ...] - # Reshape to (iter_-2, B, 4, H, W, C) for jax.lax.scan - x_scannable = x_rest.reshape(x_rest.shape[0], iter_ - 2, 4, x_rest.shape[2], x_rest.shape[3], x_rest.shape[4]) + # Prepare the remaining chunks to be scanned over + x_rest = x[:, 1 + CHUNK_SIZE :, ...] + T_rest = x_rest.shape[1] + + # Pad T_rest up to the next multiple of CHUNK_SIZE + pad_amount = (-T_rest) % CHUNK_SIZE + if pad_amount > 0: + pad_shape = (x_rest.shape[0], pad_amount, *x_rest.shape[2:]) + x_rest_padded = jnp.concatenate([x_rest, jnp.zeros(pad_shape, dtype=x_rest.dtype)], axis=1) + else: + x_rest_padded = x_rest + T_padded = T_rest + pad_amount + num_scan_iters = T_padded // CHUNK_SIZE + + x_scannable = x_rest_padded.reshape( + x_rest_padded.shape[0], + num_scan_iters, + CHUNK_SIZE, + x_rest_padded.shape[2], + x_rest_padded.shape[3], + x_rest_padded.shape[4], + ) x_scannable = jnp.transpose(x_scannable, (1, 0, 2, 3, 4, 5)) + graphdef, state = nnx.split(self.encoder) + + @jax.named_scope("AutoencoderKLWan_encode_scan_body") def scan_fn(carry, chunk): current_feat_map = carry - out_chunk, next_feat_map, _ = self.encoder(chunk, feat_cache=current_feat_map, feat_idx=0) + local_encoder = nnx.merge(graphdef, state) + out_chunk, next_feat_map, _ = local_encoder(chunk, feat_cache=current_feat_map, feat_idx=0) out_chunk = jax.lax.with_sharding_constraint(out_chunk, spatial_sharding) + next_feat_map = jax.tree_util.tree_map( + lambda x: jax.lax.with_sharding_constraint(x, spatial_sharding) if isinstance(x, jax.Array) else x, next_feat_map + ) return next_feat_map, out_chunk enc_feat_map, out_rest = jax.lax.scan(scan_fn, enc_feat_map, x_scannable) @@ -1164,11 +1224,14 @@ def scan_fn(carry, chunk): out_rest = jnp.transpose(out_rest, (1, 0, 2, 3, 4, 5)) # reshape to (B, (iter_-2)*T', H, W, C) out_rest = out_rest.reshape(out_rest.shape[0], -1, out_rest.shape[3], out_rest.shape[4], out_rest.shape[5]) + # Trim padding from the output + out_rest = out_rest[:, : T_rest // self.temporal_downsample_factor, ...] out = jnp.concatenate([out_0, out_1, out_rest], axis=1) out = jax.lax.with_sharding_constraint(out, spatial_sharding) return finalize(out, enc_feat_map) + @jax.named_scope("AutoencoderKLWan_encode") def encode( self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: @@ -1185,56 +1248,73 @@ def _decode( ) -> Union[FlaxDecoderOutput, jax.Array]: feat_cache.init_cache() iter_ = z.shape[1] + z = z.astype(self.post_quant_conv.conv.dtype) x = self.post_quant_conv(z) dec_feat_map = feat_cache._feat_map # NamedSharding for the Width axis (axis 3) - spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) + spatial_sharding = NamedSharding(self.mesh, P("redundant", None, None, "vae_spatial", None)) # First chunk (i=0) - chunk_in_0 = jax.lax.with_sharding_constraint(x[:, 0:1, ...], spatial_sharding) - out_0, dec_feat_map, _ = self.decoder(chunk_in_0, feat_cache=dec_feat_map, feat_idx=0) - out_0 = jax.lax.with_sharding_constraint(out_0, spatial_sharding) + with jax.named_scope("AutoencoderKLWan_decode_chunk_0"): + chunk_in_0 = jax.lax.with_sharding_constraint(x[:, 0:1, ...], spatial_sharding) + out_0, dec_feat_map, _ = self.decoder(chunk_in_0, feat_cache=dec_feat_map, feat_idx=0) + out_0 = jax.lax.with_sharding_constraint(out_0, spatial_sharding) if iter_ > 1: # Run chunk 1 outside scan to properly form the cache shape - chunk_in_1 = jax.lax.with_sharding_constraint(x[:, 1:2, ...], spatial_sharding) - out_chunk_1, dec_feat_map, _ = self.decoder(chunk_in_1, feat_cache=dec_feat_map, feat_idx=0) - out_chunk_1 = jax.lax.with_sharding_constraint(out_chunk_1, spatial_sharding) - - # Frame re-sync logic for chunk 1 - fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...] - axis = 1 if fm1.shape[0] > 1 else 0 - fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]] - out_1 = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1) + with jax.named_scope("AutoencoderKLWan_decode_chunk_1"): + chunk_in_1 = jax.lax.with_sharding_constraint(x[:, 1:2, ...], spatial_sharding) + out_chunk_1, dec_feat_map, _ = self.decoder(chunk_in_1, feat_cache=dec_feat_map, feat_idx=0) + out_chunk_1 = jax.lax.with_sharding_constraint(out_chunk_1, spatial_sharding) + out_1 = out_chunk_1 out_list = [out_0, out_1] if iter_ > 2: x_rest = x[:, 2:, ...] - # Reshape for scan: (iter_-2, B, 1, H, W, C) - x_scannable = jnp.transpose(x_rest, (1, 0, 2, 3, 4)) - x_scannable = jnp.expand_dims(x_scannable, axis=2) + T_rest = x_rest.shape[1] + K = self.vae_decode_chunk + if K == -1: + K = max(1, T_rest) + + # Pad T_rest up to the next multiple of K so the scan has uniform chunks. + # This avoids data-dependent branches and dynamic shapes under JIT. + pad_amount = (-T_rest) % K # 0 when already divisible + if pad_amount > 0: + pad_shape = (x_rest.shape[0], pad_amount, *x_rest.shape[2:]) + x_rest_padded = jnp.concatenate([x_rest, jnp.zeros(pad_shape, dtype=x_rest.dtype)], axis=1) + else: + x_rest_padded = x_rest + T_padded = T_rest + pad_amount + num_chunks = T_padded // K + + # Reshape to (num_chunks, B, K, H, W, C) for scan + x_scannable = x_rest_padded.reshape( + x_rest_padded.shape[0], num_chunks, K, x_rest_padded.shape[2], x_rest_padded.shape[3], x_rest_padded.shape[4] + ) + x_scannable = jnp.transpose(x_scannable, (1, 0, 2, 3, 4, 5)) + graphdef, state = nnx.split(self.decoder) + + @jax.named_scope("AutoencoderKLWan_decode_scan_body") def scan_fn(carry, chunk_in): current_feat_map = carry + local_decoder = nnx.merge(graphdef, state) chunk_in = jax.lax.with_sharding_constraint(chunk_in, spatial_sharding) - out_chunk, next_feat_map, _ = self.decoder(chunk_in, feat_cache=current_feat_map, feat_idx=0) + out_chunk, next_feat_map, _ = local_decoder(chunk_in, feat_cache=current_feat_map, feat_idx=0) out_chunk = jax.lax.with_sharding_constraint(out_chunk, spatial_sharding) - - # Frame re-sync logic - fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...] - axis = 1 if fm1.shape[0] > 1 else 0 - fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]] - new_chunk = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1) - - return next_feat_map, new_chunk + next_feat_map = jax.tree_util.tree_map( + lambda x: jax.lax.with_sharding_constraint(x, spatial_sharding) if isinstance(x, jax.Array) else x, + next_feat_map, + ) + return next_feat_map, out_chunk dec_feat_map, out_rest = jax.lax.scan(scan_fn, dec_feat_map, x_scannable) - - # out_rest is (iter_-2, B, 4, H, W, C) -> transpose back out_rest = jnp.transpose(out_rest, (1, 0, 2, 3, 4, 5)) out_rest = out_rest.reshape(out_rest.shape[0], -1, out_rest.shape[3], out_rest.shape[4], out_rest.shape[5]) + # Trim padding from the output + out_rest = out_rest[:, : T_rest * self.temporal_downsample_factor, ...] out_list.append(out_rest) out = jnp.concatenate(out_list, axis=1) @@ -1245,12 +1325,12 @@ def scan_fn(carry, chunk_in): feat_cache._feat_map = dec_feat_map out = jnp.clip(out, min=-1.0, max=1.0) - feat_cache.init_cache() if not return_dict: return (out,) return FlaxDecoderOutput(sample=out) + @jax.named_scope("AutoencoderKLWan_decode") def decode( self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxDecoderOutput, jax.Array]: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa293..07831607e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -20,6 +20,7 @@ import math import jax import jax.numpy as jnp +import time from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import flax import flax.linen as nn @@ -316,8 +317,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): subfolder="vae", rngs=rngs, mesh=mesh, - dtype=jnp.float32, - weights_dtype=jnp.float32, + dtype=config.vae_dtype, + weights_dtype=config.vae_weights_dtype, + vae_decode_chunk=config.vae_decode_chunk, + vae_encode_chunk=config.vae_encode_chunk, ) return wan_vae @@ -459,6 +462,7 @@ def load_scheduler(cls, config): config.pretrained_model_name_or_path, subfolder="scheduler", flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p + dtype=config.scheduler_dtype, ) return scheduler, scheduler_state @@ -593,6 +597,7 @@ def prepare_latents_i2v_base( num_frames: int, dtype: jnp.dtype, last_image: Optional[jax.Array] = None, + trace: Optional[dict] = None, ) -> Tuple[jax.Array, jax.Array]: """ Encodes the initial image(s) into latents to be used as conditioning. @@ -630,12 +635,18 @@ def prepare_latents_i2v_base( vae_dtype = getattr(self.vae, "dtype", jnp.float32) video_condition = video_condition.astype(vae_dtype) + t_vae_encode_start = time.perf_counter() with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): data_mesh_size = self.mesh.shape[self.config.mesh_axes[0]] if video_condition.shape[0] % data_mesh_size == 0: sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec) encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() + if hasattr(encoded_output, "block_until_ready"): + encoded_output.block_until_ready() + + if trace is not None: + trace["vae_encode"] = time.perf_counter() - t_vae_encode_start # Normalize latents latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) @@ -648,21 +659,27 @@ def prepare_latents_i2v_base( def _denormalize_latents(self, latents: jax.Array) -> jax.Array: """Denormalizes latents using VAE statistics.""" - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + dtype = self.config.activations_dtype + latents_mean = jnp.array(self.vae.latents_mean, dtype=dtype).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std, dtype=dtype).reshape(1, self.vae.z_dim, 1, 1, 1) latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) return latents - def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: + def _decode_latents_to_video(self, latents: jax.Array, trace: Optional[dict] = None) -> np.ndarray: """Decodes latents to video frames and postprocesses.""" + t_vae_tpu_start = time.perf_counter() with self.vae_mesh, nn_partitioning.axis_rules(self.vae_logical_axis_rules): video = self.vae.decode(latents, self.vae_cache)[0] + video = (video / 2.0) + 0.5 + video = jnp.clip(video, 0.0, 1.0) + video = (video * 255.0).astype(jnp.uint8) + video.block_until_ready() + if trace is not None: + trace["vae_decode_tpu"] = time.perf_counter() - t_vae_tpu_start - video = jnp.transpose(video, (0, 4, 1, 2, 3)) video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - return self.video_processor.postprocess_video(video, output_type="np") + video = np.array(video) + return video @classmethod def _create_common_components(cls, config, vae_only=False, i2v=False): @@ -672,11 +689,8 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): vae_spatial = getattr(config, "vae_spatial", -1) total_devices = math.prod(devices_array.shape) - if vae_spatial <= 0: - dp_size = mesh.shape.get("data", 1) - if dp_size == -1 or dp_size == 0: - dp_size = 1 - vae_spatial = (2 * total_devices) // dp_size + if vae_spatial == -1: + vae_spatial = total_devices assert ( total_devices % vae_spatial == 0 @@ -692,20 +706,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): # logical axis rules for VAE encoding/decoding vae_logical_axis_rules = getattr(config, "vae_logical_axis_rules", None) - if vae_logical_axis_rules is None: - vae_logical_axis_rules = ( - ("activation_batch", "redundant"), - ("activation_length", "vae_spatial"), - ("activation_heads", None), - ("activation_kv_length", None), - ("embed", None), - ("heads", None), - ("norm", None), - ("conv_batch", "redundant"), - ("out_channels", "vae_spatial"), - ("conv_out", "vae_spatial"), - ) - rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 355ba6ae6..e444686e3 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -176,7 +176,7 @@ def __call__( trace["denoise_total"] = time.perf_counter() - t_denoise_start t_decode_start = time.perf_counter() - video = self._decode_latents_to_video(latents) + video = self._decode_latents_to_video(latents, trace=trace) if hasattr(video, "block_until_ready"): video.block_until_ready() trace["vae_decode"] = time.perf_counter() - t_decode_start diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 2c294a124..ae7fbf497 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -207,7 +207,7 @@ def __call__( trace["denoise_total"] = time.perf_counter() - t_denoise_start t_decode_start = time.perf_counter() - video = self._decode_latents_to_video(latents) + video = self._decode_latents_to_video(latents, trace=trace) if hasattr(video, "block_until_ready"): video.block_until_ready() trace["vae_decode"] = time.perf_counter() - t_decode_start diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py index 57ddac7e6..e2c3864f8 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py @@ -702,7 +702,7 @@ def _decode_segment_to_pixels(self, latents_cl: jnp.ndarray) -> jnp.ndarray: """ latents_cf = jnp.transpose(latents_cl, (0, 4, 1, 2, 3)) # (B, z_dim, T, H, W) latents_cf = self._denormalize_latents(latents_cf) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + with self.vae_mesh, nn_partitioning.axis_rules(self.vae_logical_axis_rules): video_cl = self.vae.decode(latents_cf, self.vae_cache)[0] # (B, T, H, W, C) return jnp.transpose(video_cl, (0, 4, 1, 2, 3)) # (B, C, T, H, W) @@ -1041,5 +1041,13 @@ def __call__( return seg_latents # Postprocess to [0, 1] numpy. - video_torch = torch.from_numpy(np.array(video_cf.astype(jnp.float32))).to(torch.bfloat16) - return self.video_processor.postprocess_video(video_torch, output_type="np") + with self.vae_mesh, nn_partitioning.axis_rules(self.vae_logical_axis_rules): + video = (video_cf / 2.0) + 0.5 + video = jnp.clip(video, 0.0, 1.0) + video = jnp.transpose(video, (0, 2, 3, 4, 1)) + video = (video * 255.0).astype(jnp.uint8) + video.block_until_ready() + + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = np.array(video) + return video diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index aa4bbba27..73221fe9b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -99,6 +99,7 @@ def prepare_latents( latents: Optional[jax.Array] = None, last_image: Optional[jax.Array] = None, num_videos_per_prompt: int = 1, + trace: Optional[dict] = None, ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: if hasattr(image, "detach"): image = image.detach().cpu().numpy() @@ -131,7 +132,7 @@ def prepare_latents( latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) else: latents = latents.astype(dtype) - latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) + latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image, trace=trace) mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) if last_image is None: mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) @@ -242,6 +243,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): latents=latents, last_image=last_image_tensor, num_videos_per_prompt=num_videos_per_prompt, + trace=trace, ) latents.block_until_ready() condition.block_until_ready() @@ -303,7 +305,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): return latents, trace t_decode_start = time.perf_counter() - video = self._decode_latents_to_video(latents) + video = self._decode_latents_to_video(latents, trace=trace) if hasattr(video, "block_until_ready"): video.block_until_ready() trace["vae_decode"] = time.perf_counter() - t_decode_start diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 1ba54f2eb..54f3630eb 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -117,6 +117,7 @@ def prepare_latents( latents: Optional[jax.Array] = None, last_image: Optional[jax.Array] = None, num_videos_per_prompt: int = 1, + trace: Optional[dict] = None, ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: if hasattr(image, "detach"): image = image.detach().cpu().numpy() @@ -145,7 +146,7 @@ def prepare_latents( else: latents = latents.astype(dtype) - latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) + latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image, trace=trace) mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) if last_image is None: mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) @@ -267,6 +268,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): rng=latents_rng, latents=latents, last_image=last_image_tensor, + num_videos_per_prompt=num_videos_per_prompt, + trace=trace, ) latents.block_until_ready() condition.block_until_ready() @@ -334,7 +337,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): return latents, trace t_decode_start = time.perf_counter() - video = self._decode_latents_to_video(latents) + video = self._decode_latents_to_video(latents, trace=trace) if hasattr(video, "block_until_ready"): video.block_until_ready() trace["vae_decode"] = time.perf_counter() - t_decode_start diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index b12ae4142..2b73f160d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -655,18 +655,12 @@ def __call__( control_hidden_states_scale=conditioning_scale, ) latents = latents[:, :, num_reference_images:] - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] - - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") + latents = self._denormalize_latents(latents) + latents.block_until_ready() + + video = self._decode_latents_to_video(latents) + if hasattr(video, "block_until_ready"): + video.block_until_ready() return video def prepare_video_latents( diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 9c6c91251..6d664bbf3 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -196,8 +196,18 @@ def calculate_global_batch_sizes(per_device_batch_size): @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" - raw_keys["weights_dtype"] = jax.numpy.dtype(raw_keys["weights_dtype"]) - raw_keys["activations_dtype"] = jax.numpy.dtype(raw_keys["activations_dtype"]) + # Set defaults for dtypes if they weren't explicitly provided + if "vae_dtype" not in raw_keys: + raw_keys["vae_dtype"] = "float32" + if "vae_weights_dtype" not in raw_keys: + raw_keys["vae_weights_dtype"] = "float32" + if "scheduler_dtype" not in raw_keys: + raw_keys["scheduler_dtype"] = "float32" + + # Cast all dtype configs to jax.numpy.dtype + for dtype_key in ["weights_dtype", "activations_dtype", "scheduler_dtype", "vae_dtype", "vae_weights_dtype"]: + if dtype_key in raw_keys: + raw_keys[dtype_key] = jax.numpy.dtype(raw_keys[dtype_key]) if raw_keys["run_name"] == "": raw_keys["run_name"] = os.environ.get("JOBSET_NAME") # using XPK default run_name = raw_keys["run_name"] @@ -281,8 +291,8 @@ def user_init(raw_keys): raw_keys["global_batch_size_to_train_on"], ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) - if raw_keys.get("vae_spatial", -1) == -1: - raw_keys["vae_spatial"] = 1 + if "vae_spatial" not in raw_keys: + raw_keys["vae_spatial"] = -1 def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index ef1832996..bc5929e5d 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -674,7 +674,7 @@ def step( the multistep UniPC. """ - sample = sample.astype(jnp.float32) + sample = sample.astype(self.dtype) if state.timesteps is None: raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") @@ -694,6 +694,7 @@ def step( # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type model_output_for_history = self.convert_model_output(state, model_output, sample) + model_output_for_history = model_output_for_history.astype(self.dtype) # Apply corrector if applicable sample = jax.lax.cond( diff --git a/src/maxdiffusion/tests/wan_kv_cache_test.py b/src/maxdiffusion/tests/wan_kv_cache_test.py index 38589db12..7c880eabf 100644 --- a/src/maxdiffusion/tests/wan_kv_cache_test.py +++ b/src/maxdiffusion/tests/wan_kv_cache_test.py @@ -40,15 +40,10 @@ def setUp(self): [ None, os.path.join(THIS_DIR, "..", "configs", "base_wan_1_3b.yml"), - "pretrained_model_name_or_path=Wan-AI/Wan2.1-T2V-1.3B-Diffusers", "num_inference_steps=2", # Reduced steps for speed - "height=240", # Reduced resolution for speed (divisible by 16) - "width=416", # Reduced resolution for speed (divisible by 16) + "height=256", # Reduced resolution for speed (divisible by 16) + "width=256", # Reduced resolution for speed (divisible by 16) "num_frames=9", # Reduced num_frames for speed - "attention=flash", - "scan_layers=False", - "jit_initializers=False", - "skip_jax_distributed_system=True", ], unittest=True, ) @@ -82,7 +77,7 @@ def mock_transformer_load_config(pretrained_model_name_or_path, return_unused_kw "freq_dim": 256, "image_dim": None, "in_channels": 16, - "num_attention_heads": 12, + "num_attention_heads": 40, "num_layers": 2, "out_channels": 16, "patch_size": [1, 2, 2], @@ -209,8 +204,11 @@ def mock_load_scheduler(config): self.assertEqual(len(video_with_cache), batch_size) self.assertEqual(video_with_cache[0].shape, (num_frames, height, width, 3)) - # Compare outputs - np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=0.7) + # Calculate the average absolute difference across all pixels + mae = np.mean(np.abs(video_no_cache.astype(np.float32) - video_with_cache.astype(np.float32))) + + # Ensure average pixel drift is less than 3 units + self.assertLess(mae, 3.0, f"KV Cache caused an unacceptably high average pixel drift of {mae}") if __name__ == "__main__": diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 0bc13854e..f114c51ab 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -282,13 +282,11 @@ def test_3d_conv(self): ) config = pyconfig.config devices_array = create_device_mesh(config) - # Add vae_spatial axis to mesh for VAE operations - mesh_axes = list(config.mesh_axes) - if "vae_spatial" not in mesh_axes: - mesh_axes.append("vae_spatial") - # Reshape devices to include vae_spatial (size 1 for test) - devices_array = devices_array.reshape(devices_array.shape + (1,)) - mesh = Mesh(devices_array, mesh_axes) + vae_spatial = getattr(config, "vae_spatial", -1) + if vae_spatial == -1: + vae_spatial = devices_array.size + vae_devices_array = devices_array.flatten().reshape(-1, vae_spatial) + mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) batch_size = 1 in_depth, in_height, in_width = 10, 32, 32 @@ -305,7 +303,7 @@ def test_3d_conv(self): dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) # Instantiate the module - with self.mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules): + with mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules): causal_conv_layer = WanCausalConv3d( in_channels=in_channels, out_channels=out_channels, @@ -341,13 +339,11 @@ def test_wan_residual(self): ) config = pyconfig.config devices_array = create_device_mesh(config) - # Add vae_spatial axis to mesh for VAE operations - mesh_axes = list(config.mesh_axes) - if "vae_spatial" not in mesh_axes: - mesh_axes.append("vae_spatial") - # Reshape devices to include vae_spatial (size 1 for test) - devices_array = devices_array.reshape(devices_array.shape + (1,)) - mesh = Mesh(devices_array, mesh_axes) + vae_spatial = getattr(config, "vae_spatial", -1) + if vae_spatial == -1: + vae_spatial = devices_array.size + vae_devices_array = devices_array.flatten().reshape(-1, vae_spatial) + mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) # --- Test Case 1: same in/out dim --- in_dim = out_dim = 96 batch = 1 @@ -399,18 +395,16 @@ def test_wan_midblock(self): ) config = pyconfig.config devices_array = create_device_mesh(config) - # Add vae_spatial axis to mesh for VAE operations - mesh_axes = list(config.mesh_axes) - if "vae_spatial" not in mesh_axes: - mesh_axes.append("vae_spatial") - # Reshape devices to include vae_spatial (size 1 for test) - devices_array = devices_array.reshape(devices_array.shape + (1,)) - mesh = Mesh(devices_array, mesh_axes) + vae_spatial = getattr(config, "vae_spatial", -1) + if vae_spatial == -1: + vae_spatial = devices_array.size + vae_devices_array = devices_array.flatten().reshape(-1, vae_spatial) + mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) batch = 1 t = 1 dim = 384 height = 60 - width = 90 + width = 96 input_shape = (batch, t, height, width, dim) with mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules): wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) @@ -430,13 +424,11 @@ def test_wan_decode(self): ) config = pyconfig.config devices_array = create_device_mesh(config) - # Add vae_spatial axis to mesh for VAE operations - mesh_axes = list(config.mesh_axes) - if "vae_spatial" not in mesh_axes: - mesh_axes.append("vae_spatial") - # Reshape devices to include vae_spatial (size 1 for test) - devices_array = devices_array.reshape(devices_array.shape + (1,)) - mesh = Mesh(devices_array, mesh_axes) + vae_spatial = getattr(config, "vae_spatial", -1) + if vae_spatial == -1: + vae_spatial = devices_array.size + vae_devices_array = devices_array.flatten().reshape(-1, vae_spatial) + mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] @@ -459,7 +451,7 @@ def test_wan_decode(self): t = 13 channels = 16 height = 60 - width = 90 + width = 96 input_shape = (batch, t, height, width, channels) input = jnp.ones(input_shape) @@ -467,7 +459,7 @@ def test_wan_decode(self): latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) input = input / latents_std + latents_mean dummy_output = wan_vae.decode(input, feat_cache=vae_cache) - assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + assert dummy_output.sample.shape == (batch, 49, 480, 768, 3) def test_wan_encode(self): key = jax.random.key(0) @@ -481,13 +473,11 @@ def test_wan_encode(self): ) config = pyconfig.config devices_array = create_device_mesh(config) - # Add vae_spatial axis to mesh for VAE operations - mesh_axes = list(config.mesh_axes) - if "vae_spatial" not in mesh_axes: - mesh_axes.append("vae_spatial") - # Reshape devices to include vae_spatial (size 1 for test) - devices_array = devices_array.reshape(devices_array.shape + (1,)) - mesh = Mesh(devices_array, mesh_axes) + vae_spatial = getattr(config, "vae_spatial", -1) + if vae_spatial == -1: + vae_spatial = devices_array.size + vae_devices_array = devices_array.flatten().reshape(-1, vae_spatial) + mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) dim = 96 z_dim = 16 dim_mult = [1, 2, 4, 4] @@ -533,14 +523,12 @@ def vae_encode(video, wan_vae, vae_cache, key): ) config = pyconfig.config devices_array = create_device_mesh(config) - # Add vae_spatial axis to mesh for VAE operations - mesh_axes = list(config.mesh_axes) - if "vae_spatial" not in mesh_axes: - mesh_axes.append("vae_spatial") - # Reshape devices to include vae_spatial (size 1 for test) - devices_array = devices_array.reshape(devices_array.shape + (1,)) - mesh = Mesh(devices_array, mesh_axes) - with self.mesh, nn_partitioning.axis_rules(self.config.vae_logical_axis_rules): + vae_spatial = getattr(config, "vae_spatial", -1) + if vae_spatial == -1: + vae_spatial = devices_array.size + vae_devices_array = devices_array.flatten().reshape(-1, vae_spatial) + mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) + with mesh, nn_partitioning.axis_rules(self.config.vae_logical_axis_rules): wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index d1ff27b03..279ad1e90 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -141,7 +141,7 @@ def _legacy_export_to_video( output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name if isinstance(video_frames[0], np.ndarray): - video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + video_frames = [(frame * 255).astype(np.uint8) if frame.dtype != np.uint8 else frame for frame in video_frames] elif isinstance(video_frames[0], PIL.Image.Image): video_frames = [np.array(frame) for frame in video_frames] @@ -157,7 +157,7 @@ def _legacy_export_to_video( def export_to_video( - video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], + video_frames: Union[np.ndarray, List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10, quality: float = 5.0, @@ -212,11 +212,15 @@ def export_to_video( if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - if isinstance(video_frames[0], np.ndarray): - video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] - + if isinstance(video_frames, np.ndarray): + if video_frames.dtype != np.uint8: + video_frames = (video_frames * 255).astype(np.uint8) + elif isinstance(video_frames[0], np.ndarray): + video_frames = np.stack(video_frames) + if video_frames.dtype != np.uint8: + video_frames = (video_frames * 255).astype(np.uint8) elif isinstance(video_frames[0], PIL.Image.Image): - video_frames = [np.array(frame) for frame in video_frames] + video_frames = np.stack([np.asarray(frame) for frame in video_frames]) with imageio.get_writer( output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size