From df9c7b0b925b044fd918e108a5a0beb907641b79 Mon Sep 17 00:00:00 2001 From: maxdiffusion authors Date: Fri, 6 Feb 2026 00:10:11 +0000 Subject: [PATCH 1/3] source sync PiperOrigin-RevId: 866156279 --- src/maxdiffusion/__init__.py | 4 ++++ src/maxdiffusion/configuration_utils.py | 6 +++--- src/maxdiffusion/trainers/wan_trainer.py | 26 +++++++++++++----------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index a1a2c2f5..30092857 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -64,6 +64,10 @@ ], } +if is_flax_available(): + from flax import config as flax_config + flax_config.update('flax_always_shard_variable', False) + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 0e8c9968..c432d674 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -376,11 +376,11 @@ def load_config( if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + if subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): # Load from a PyTorch checkpoint config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - elif subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)): - config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) else: raise EnvironmentError(f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}.") else: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 29042728..e432ce0b 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -444,19 +444,21 @@ def loss_fn(params): noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - model_pred = model( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=encoder_hidden_states, - deterministic=False, - rngs=nnx.Rngs(dropout_rng), - ) + with jax.named_scope('forward_pass'): + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + deterministic=False, + rngs=nnx.Rngs(dropout_rng), + ) - training_target = scheduler.training_target(latents, noise, timesteps) - training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) - loss = (training_target - model_pred) ** 2 - loss = loss * training_weight - loss = jnp.mean(loss) + with jax.named_scope('loss'): + training_target = scheduler.training_target(latents, noise, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = (training_target - model_pred) ** 2 + loss = loss * training_weight + loss = jnp.mean(loss) return loss From f51786df41947aa390c42d3b8e885b45e887b675 Mon Sep 17 00:00:00 2001 From: maxdiffusion authors Date: Fri, 6 Feb 2026 02:02:59 +0000 Subject: [PATCH 2/3] source sync PiperOrigin-RevId: 866197055 --- src/maxdiffusion/__init__.py | 1 + src/maxdiffusion/trainers/wan_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 30092857..9e249543 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -66,6 +66,7 @@ if is_flax_available(): from flax import config as flax_config + flax_config.update('flax_always_shard_variable', False) try: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index e432ce0b..e2d3b5ce 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -444,7 +444,7 @@ def loss_fn(params): noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - with jax.named_scope('forward_pass'): + with jax.named_scope("forward_pass"): model_pred = model( hidden_states=noisy_latents, timestep=timesteps, @@ -453,7 +453,7 @@ def loss_fn(params): rngs=nnx.Rngs(dropout_rng), ) - with jax.named_scope('loss'): + with jax.named_scope("loss"): training_target = scheduler.training_target(latents, noise, timesteps) training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) loss = (training_target - model_pred) ** 2 From cd5c94da428c0305089e1bb2bee37ab34bb5fe76 Mon Sep 17 00:00:00 2001 From: maxdiffusion authors Date: Fri, 6 Feb 2026 18:58:45 +0000 Subject: [PATCH 3/3] source sync PiperOrigin-RevId: 866530641 --- src/maxdiffusion/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 9e249543..e9addadc 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -67,7 +67,7 @@ if is_flax_available(): from flax import config as flax_config - flax_config.update('flax_always_shard_variable', False) + flax_config.update("flax_always_shard_variable", False) try: if not is_onnx_available():