From f9c125356147bb74512a4ebb5ebefa38e8c0d1ed Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Fri, 15 May 2026 13:47:47 +0000 Subject: [PATCH] Fix VAE spatial sharding dynamic calculation bug in Wan pipeline. Previously, setting `vae_spatial: -1` in the config (intended to trigger dynamic calculation of the VAE spatial sharding axis size) was ineffective because `pyconfig.py` prematurely overrode any `-1` or missing `vae_spatial` value to `1`. Furthermore, the dynamic calculation formula in `wan_pipeline.py` (`vae_spatial = (2 * total_devices) // dp_size`) was not robust. On single-device runs (where `total_devices=1` and `dp_size=1`) or configurations with odd data parallel (DP) sizes, it would calculate a `vae_spatial` value (e.g., 2) that does not divide `total_devices`, failing the mesh validation assertion. --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 7 ++----- src/maxdiffusion/pyconfig.py | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa29..500e4208 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -672,11 +672,8 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): vae_spatial = getattr(config, "vae_spatial", -1) total_devices = math.prod(devices_array.shape) - if vae_spatial <= 0: - dp_size = mesh.shape.get("data", 1) - if dp_size == -1 or dp_size == 0: - dp_size = 1 - vae_spatial = (2 * total_devices) // dp_size + if vae_spatial == -1: + vae_spatial = total_devices assert ( total_devices % vae_spatial == 0 diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 9c6c9125..19f3ec30 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -281,8 +281,8 @@ def user_init(raw_keys): raw_keys["global_batch_size_to_train_on"], ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) - if raw_keys.get("vae_spatial", -1) == -1: - raw_keys["vae_spatial"] = 1 + if "vae_spatial" not in raw_keys: + raw_keys["vae_spatial"] = -1 def get_num_slices(raw_keys):