Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,32 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold)
else:
grads = raw_grads

# fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
# Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
# amax_history and corrupts future steps. Replace NaN OWG entries with the current state
# values (skip the amax update for that step) instead of letting NaN flow through.
# Also restore OWG values after apply_gradients to bypass optimizer corruption
# (Adam should not update fp8 scale/amax_history).
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
if fp8_stats is not None:
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
fp8_stats = jax.tree_util.tree_map(
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
fp8_stats,
current_fp8,
)
else:
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
grads = dict(grads)
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
# Zero out any remaining NaN in float gradients to prevent param corruption
grads = jax.tree_util.tree_map(
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
grads,
)

if config.optimizer_memory_host_offload:
state = state.replace(
opt_state=jax.device_put(
Expand Down Expand Up @@ -414,6 +440,12 @@ def move(path, value):
else:
new_state = state.apply_gradients(grads=grads)

# fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
if fp8_stats is not None:
new_params = dict(new_state.params)
new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
new_state = new_state.replace(params=new_params)

# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
Expand Down
29 changes: 28 additions & 1 deletion src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import pickle
import os
from typing import Sequence

from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
Expand All @@ -27,6 +28,7 @@

from jax.experimental import mesh_utils
from jax.experimental.serialize_executable import deserialize_and_load
from jax.sharding import AxisType, Mesh

import jax
import jax.numpy as jnp
Expand All @@ -36,7 +38,8 @@
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager

from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
from maxtext.configs import pyconfig
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode
from maxtext.configs import types
from maxtext.inference.page_manager import PageState
from maxtext.common import checkpointing
Expand Down Expand Up @@ -1650,3 +1653,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging
all_host_upload=False, # Only upload from lead host (Host 0)
)


def get_mesh_from_config(
config: pyconfig.HyperParameters,
devices: Sequence[jax.Device] | None = None,
) -> Mesh:
"""
Geh mesh from the configuration.

Args:
config: the configuration
devices: the devices

Returns:
the device mesh
"""
devices_array = create_device_mesh(config, devices)

if config.shard_mode == ShardMode.EXPLICIT:
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
else:
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))

return Mesh(devices_array, config.mesh_axes, axis_types=axis_types)
9 changes: 6 additions & 3 deletions src/maxtext/utils/maxtext_utils_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@

from maxtext.utils import max_logging
from maxtext.configs import pyconfig
from maxtext.common.common_types import MODEL_MODE_TRAIN


def create_nnx_rngs(
config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None
config: pyconfig.HyperParameters, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None
) -> nnx.Rngs:
"""
Create NNX Rngs

Args:
config: the configuration
is_training: if the Rngs are for training
model_mode: the model mode. See maxtext.common.common_types for valid values.
rng_key: the Rng key

Returns:
Expand All @@ -41,7 +42,9 @@ def create_nnx_rngs(
if rng_key is None:
rng_key = jax.random.PRNGKey(config.init_weights_seed)

if is_training:
if model_mode == MODEL_MODE_TRAIN:
# Use fold_in to derive independent keys for each stream from a single seed.
# aqt is needed for quantization-aware training.
return nnx.Rngs(
params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2)
)
Expand Down
43 changes: 18 additions & 25 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,18 @@
from collections.abc import Sequence
from functools import partial
from typing import overload

from etils import epath
from flax import nnx
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.sharding import AxisType, Mesh
from jax.sharding import Mesh
from maxtext.configs import pyconfig
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx
from orbax import checkpoint as ocp

try:
Expand Down Expand Up @@ -154,6 +152,7 @@ def from_config(
mesh: Mesh | None = None,
*,
model_mode: str = MODEL_MODE_TRAIN,
rngs: None = None,
) -> nn.Module:
...

Expand Down Expand Up @@ -194,15 +193,7 @@ def from_config(
model = from_config(config)
"""
if mesh is None:
devices_array = maxtext_utils.create_device_mesh(config, devices)

if config.shard_mode == ShardMode.EXPLICIT:
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
else:
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))

mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types)

mesh = maxtext_utils.get_mesh_from_config(config, devices)
model = create_model(config, mesh, model_mode=model_mode, rngs=rngs)

# Return only the model
Expand Down Expand Up @@ -245,9 +236,7 @@ def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key
"""

def _create_model(rng_key=None):
if rng_key is None:
rng_key = jax.random.PRNGKey(config.init_weights_seed)
rngs = nnx.Rngs(params=rng_key, dropout=1)
rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key)
return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode)

_create_model_partial = partial(_create_model, rng_key=rng_key)
Expand All @@ -262,14 +251,7 @@ def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAI
"""Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""

def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None):
if rng_key is None:
rng_key = jax.random.PRNGKey(config.init_weights_seed)

if model_mode == MODEL_MODE_TRAIN:
rngs = nnx.Rngs(params=rng_key, dropout=1)
else:
rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference

rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key)
return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode)

_create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key)
Expand All @@ -282,6 +264,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN,
if mesh is None:
mesh = abstract_model.mesh

# Note for pure_nnx:
# Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and
# we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen
# LogicallyPartitioned structure.
# In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned
# structure in the abstract state and we can get the sharded state with the following code:
# graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh)
# abstract_model = nnx.merge(graphdef, state)
# model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh)
# sharded_state = nnx.state(model)

# JIT a function that creates the model state with proper sharding from the start.
# By providing out_shardings, we instruct JAX to produce sharded output directly,
# avoiding a large intermediate allocation on a single device.
Expand Down
Loading
Loading