Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 125 additions & 2 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from flax import nnx
from flax.training import train_state
import jax
import jax.numpy as jnp
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator
from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
from maxtext.layers import train_state_nnx
from maxtext.utils import exceptions
from maxtext.utils import max_logging
from maxtext.utils import gcs_utils
Expand Down Expand Up @@ -165,6 +167,104 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
process_count: Optional[int] = None


def _default_for_sds(sds):
"""Returns a deterministic value matching `sds` shape/dtype/sharding.

Used to fill NNX-only state (rngs/dropout) that the Linen on-disk layout never
carried. Materializes under jit with the target out_shardings so it works on
multi-host meshes (device_put can't place a global sharding whose devices aren't
all addressable from this process).
"""
if not (hasattr(sds, "dtype") and hasattr(sds, "shape")):
return sds

def _make():
if "key" in str(sds.dtype):
base = jax.random.key(0)
return base if sds.shape == () else jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape)
return jnp.zeros(sds.shape, dtype=sds.dtype)

sharding = getattr(sds, "sharding", None)
if sharding is None:
return _make()
return jax.jit(_make, out_shardings=sharding)()


def _populate_pure_dict_from_partial(abstract_pure, partial_concrete):
"""Fills `abstract_pure` with values from `partial_concrete` (by path), defaulting the rest.

Paths present in `partial_concrete` take the restored value; paths absent from it
(NNX-only state the Linen checkpoint never had) get `_default_for_sds`.
"""
if isinstance(abstract_pure, dict):
return {
k: _populate_pure_dict_from_partial(v, partial_concrete.get(k) if isinstance(partial_concrete, dict) else None)
for k, v in abstract_pure.items()
}
if partial_concrete is not None and not isinstance(partial_concrete, dict):
return partial_concrete
return _default_for_sds(abstract_pure)


def _load_linen_checkpoint_into_nnx(path, abstract_nnx_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3):
"""Restores a Linen-layout checkpoint into an NNX state (pure_nnx resume).

Restores against a Linen-shape abstract, reshapes back via
`from_linen_checkpoint_dict`, then fills NNX-only rngs/dropout with defaults.
"""
max_logging.log(f"Restoring Linen-layout checkpoint into NNX state at {path}")
nnx_abstract_pure = abstract_nnx_state.to_pure_dict()
linen_abstract = train_state_nnx.to_linen_checkpoint_dict(nnx_abstract_pure)
ckptr = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
save_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
)
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
restored = ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True)
restored = ckptr.restore(epath.Path(path), args=restored)
partial_nnx = train_state_nnx.from_linen_checkpoint_dict(restored)
return _populate_pure_dict_from_partial(nnx_abstract_pure, partial_nnx)


def _rebuild_nnx_with_values(abstract_nnx_state, concrete_weights):
"""Fills each Variable in `abstract_nnx_state` with the matching restored array."""
leaves, treedef = jax.tree_util.tree_flatten(abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable))
concrete = jax.tree_util.tree_leaves(concrete_weights)
if len(leaves) != len(concrete):
raise ValueError(f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs {len(concrete)} restored.")
new_leaves = [v.replace(value=a) if isinstance(v, nnx.Variable) else a for v, a in zip(leaves, concrete)]
return jax.tree_util.tree_unflatten(treedef, new_leaves)


def _load_linen_params_into_nnx(path, nnx_params_abstract, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3):
"""Weight-only load of a Linen-layout checkpoint into an NNX params state.

Reuses `to_linen_checkpoint_dict` (wrapping the params under `model`) to build the
`params/params/...` restore target, then rebinds the restored weights into the
NNX params Variables.
"""
max_logging.log(f"Restoring Linen-layout params into NNX state at {path}")
linen_abstract = train_state_nnx.to_linen_checkpoint_dict({"model": nnx_params_abstract.to_pure_dict()})
ckptr = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
save_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
)
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
restored = ckptr.restore(
epath.Path(path),
args=ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True),
)
return _rebuild_nnx_with_values(nnx_params_abstract, restored["params"]["params"])


def _load_full_state_from_path(
path,
abstract_unboxed_pre_state,
Expand Down Expand Up @@ -216,6 +316,12 @@ def combine_sharding(sds, shardings):
else:
raise ocp_v1.errors.InvalidLayoutError(f"Unknown checkpoint layout: {source_checkpoint_layout}")
else:
# pure_nnx saves in the Linen on-disk layout; reshape it back into the NNX state.
if isinstance(abstract_unboxed_pre_state, nnx.State):
return _load_linen_checkpoint_into_nnx(
path, abstract_unboxed_pre_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3
)

# Original v0 logic.
p = epath.Path(path)
handler = ocp.PyTreeCheckpointHandler(
Expand Down Expand Up @@ -640,6 +746,17 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

# pure_nnx saves in the Linen on-disk layout; reshape it back into the NNX state.
# (Emergency managers use their own restore path below.)
if isinstance(abstract_unboxed_pre_state, nnx.State) and not isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
restored_nnx = _load_linen_checkpoint_into_nnx(
checkpoint_path, abstract_unboxed_pre_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3
)
return ({"items": restored_nnx}, None)

# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
Expand Down Expand Up @@ -742,6 +859,12 @@ def load_params_from_path(
assert load_parameters_from_path, "load_parameters_from_path is not defined."
max_logging.log(f"restoring params from {load_parameters_from_path}")

# NNX target: the on-disk checkpoint is in Linen layout; reshape it into the NNX params state.
if isinstance(abstract_unboxed_params, nnx.State):
return _load_linen_params_into_nnx(
load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3
)

# *_concurrent_gb should be set for large models, the default is 96.
max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}")
ckptr = ocp.Checkpointer(
Expand Down Expand Up @@ -794,8 +917,8 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
139 changes: 138 additions & 1 deletion src/maxtext/layers/train_state_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" The NNX Unified TrainState. """
"""The NNX Unified TrainState."""

from typing import Any

from flax import nnx
import jax
import jax.numpy as jnp


class TrainStateNNX(nnx.Module):
Expand Down Expand Up @@ -46,3 +48,138 @@ def apply_gradients(self, grads: Any):
"This usually happens when the state was created for inference only."
)
self.optimizer.update(self.model, grads)


# On-disk checkpoint format.
#
# A pure_nnx run saves in the same on-disk layout as a Linen run, so the two are
# interchangeable. The NNX state pure dict differs from Linen's in three ways, all
# reshaped below at save time:
# 1. top-level keys: {model, optimizer:{step, opt_state}} -> {params:{params:...}, step, opt_state}
# 2. weights: model/... -> params/params/... (Linen `params` collection)
# 3. opt_state: int-keyed dict (empty entries skipped) -> list with None for EmptyState,
# and mu/nu wrapped under the `params` collection
# NNX-only rngs/dropout state is dropped (Linen never had it).

_NNX_RNG_STATE_KEYS = ("rngs", "dropout")


def _cast_step(step, dtype):
"""Casts the step's dtype, handling both concrete arrays and abstract ShapeDtypeStruct.

NNX stores step as uint32 and Linen as int32; this also runs on the abstract
(SDS) state when building a restore target, so it can't assume concrete values.
"""
if isinstance(step, jax.ShapeDtypeStruct):
return jax.ShapeDtypeStruct(step.shape, dtype, sharding=getattr(step, "sharding", None))
return jnp.asarray(step, dtype=dtype)


def _strip_rng_state(tree):
"""Removes the NNX-only 'rngs'/'dropout' subtrees that Linen doesn't carry.

Subtrees that become empty after stripping are dropped too, so the result has
no keys Linen wouldn't also have.2
"""
if not isinstance(tree, dict):
return tree
out = {}
for k, v in tree.items():
if k in _NNX_RNG_STATE_KEYS:
continue
stripped = _strip_rng_state(v)
if isinstance(stripped, dict) and not stripped:
continue
out[k] = stripped
return out


def _wrap_mu_nu_with_params(state):
"""Wraps mu/nu under an inner 'params' key (the Linen collection)."""
if not isinstance(state, dict):
return state
return {k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v for k, v in state.items()}


def _as_chain_index(key):
"""Returns the int index for an int or digit-string key, else None."""
if isinstance(key, int):
return key
if isinstance(key, str) and key.isdigit():
return int(key)
return None


def _opt_state_to_linen(opt_state):
"""Reshapes the NNX optax-chain opt_state to Linen's list-with-None layout.

NNX serializes the chain as an int-keyed dict, skipping empty entries; Linen
uses a list with `None` for each `EmptyState`. A single-element chain is
returned unwrapped to match Linen's un-chained optimizers (e.g. adam_pax).
"""
if not isinstance(opt_state, dict):
return opt_state
indices = [_as_chain_index(k) for k in opt_state.keys()]
if not indices or any(i is None for i in indices):
return _wrap_mu_nu_with_params(opt_state)
chain = [None] * (max(indices) + 1)
for key, idx in zip(opt_state.keys(), indices):
chain[idx] = _wrap_mu_nu_with_params(opt_state[key])
return chain[0] if len(chain) == 1 else chain


def to_linen_checkpoint_dict(nnx_pure_dict):
"""Reshapes a TrainStateNNX pure dict ({model, optimizer}) into the Linen on-disk layout."""
if not isinstance(nnx_pure_dict, dict):
return nnx_pure_dict
result = {}
if "model" in nnx_pure_dict:
result["params"] = {"params": _strip_rng_state(nnx_pure_dict["model"])}
optimizer = nnx_pure_dict.get("optimizer")
if isinstance(optimizer, dict):
if "step" in optimizer:
# NNX stores step as uint32; Linen uses int32.
result["step"] = _cast_step(optimizer["step"], jnp.int32)
if "opt_state" in optimizer:
result["opt_state"] = _opt_state_to_linen(optimizer["opt_state"])
return result


def _strip_mu_nu_params(state):
"""Inverse of `_wrap_mu_nu_with_params`: removes the inner 'params' wrap from mu/nu."""
if not isinstance(state, dict):
return state
return {k: v["params"] if k in ("mu", "nu") and isinstance(v, dict) and "params" in v else v for k, v in state.items()}


def _opt_state_from_linen(opt_state):
"""Inverse of `_opt_state_to_linen`: Linen list-with-None -> NNX int-keyed dict."""
if isinstance(opt_state, list):
return {i: _strip_mu_nu_params(e) for i, e in enumerate(opt_state) if isinstance(e, dict)}
if not isinstance(opt_state, dict):
return opt_state
return {0: _strip_mu_nu_params(opt_state)}


def from_linen_checkpoint_dict(linen_pure_dict):
"""Inverse of `to_linen_checkpoint_dict`: Linen on-disk layout -> NNX layout.

Doesn't restore NNX-only rngs/dropout (absent from Linen); callers fill those.
"""
if not isinstance(linen_pure_dict, dict):
return linen_pure_dict
result = {}
params = linen_pure_dict.get("params")
if isinstance(params, dict) and "params" in params:
result["model"] = params["params"]
elif params is not None:
result["model"] = params
optimizer = {}
if "step" in linen_pure_dict:
# Linen stores step as int32; NNX uses uint32.
optimizer["step"] = _cast_step(linen_pure_dict["step"], jnp.uint32)
if "opt_state" in linen_pure_dict:
optimizer["opt_state"] = _opt_state_from_linen(linen_pure_dict["opt_state"])
if optimizer:
result["optimizer"] = optimizer
return result
27 changes: 27 additions & 0 deletions tests/unit/checkpointing_nnx_load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

"""Unit tests for the NNX branches of load_state_if_possible."""

import os
import tempfile
import unittest
from unittest import mock

from etils import epath
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp
from flax import nnx

from maxtext.common import checkpointing
Expand Down Expand Up @@ -102,5 +106,28 @@ def test_no_paths_returns_none_none(self):
self.assertIsNone(params)


class TestLoadParamsIntoNNX(unittest.TestCase):
"""Weight-only load (load_parameters_path) of a Linen-layout checkpoint into NNX."""

def test_linen_layout_params_restore_into_nnx_state(self):
"""load_params_from_path reshapes an on-disk Linen-layout checkpoint into the NNX params state."""
model = _Model(rngs=nnx.Rngs(0))
_, params_abstract, _ = nnx.split(model, nnx.Param, ...)
weights = {"linear": {"kernel": jnp.arange(2, dtype=jnp.float32).reshape(2, 1), "bias": jnp.array([5.0])}}

with tempfile.TemporaryDirectory() as d: # pylint: disable=consider-using-with
path = os.path.join(d, "ckpt")
# On-disk Linen layout: params/params/<weights> plus an unrelated `step`.
ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).save(
epath.Path(path), {"params": {"params": weights}, "step": jnp.array(3)}, force=True
)
restored = checkpointing.load_params_from_path(path, params_abstract, 8)

self.assertIsInstance(restored, nnx.State)
pure = restored.to_pure_dict()
self.assertTrue(jnp.array_equal(pure["linear"]["kernel"], weights["linear"]["kernel"]))
self.assertTrue(jnp.array_equal(pure["linear"]["bias"], weights["linear"]["bias"]))


if __name__ == "__main__":
unittest.main()
Loading
Loading