diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index a1a2c2f5..e9addadc 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -64,6 +64,11 @@ ], } +if is_flax_available(): + from flax import config as flax_config + + flax_config.update("flax_always_shard_variable", False) + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 0e8c9968..c432d674 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -376,11 +376,11 @@ def load_config( if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + if subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): # Load from a PyTorch checkpoint config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - elif subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)): - config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) else: raise EnvironmentError(f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}.") else: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 29042728..e2d3b5ce 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -444,19 +444,21 @@ def loss_fn(params): noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - model_pred = model( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=encoder_hidden_states, - deterministic=False, - rngs=nnx.Rngs(dropout_rng), - ) + with jax.named_scope("forward_pass"): + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + deterministic=False, + rngs=nnx.Rngs(dropout_rng), + ) - training_target = scheduler.training_target(latents, noise, timesteps) - training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) - loss = (training_target - model_pred) ** 2 - loss = loss * training_weight - loss = jnp.mean(loss) + with jax.named_scope("loss"): + training_target = scheduler.training_target(latents, noise, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = (training_target - model_pred) ** 2 + loss = loss * training_weight + loss = jnp.mean(loss) return loss