diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index ad7618868a..9a92545e11 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -23,10 +23,12 @@ from flax import nnx from flax.training import train_state import jax +import jax.numpy as jnp from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator +from maxtext.layers import train_state_nnx from maxtext.utils import exceptions from maxtext.utils import max_logging from maxtext.utils import gcs_utils @@ -165,6 +167,104 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs): process_count: Optional[int] = None +def _default_for_sds(sds): + """Returns a deterministic value matching `sds` shape/dtype/sharding. + + Used to fill NNX-only state (rngs/dropout) that the Linen on-disk layout never + carried. Materializes under jit with the target out_shardings so it works on + multi-host meshes (device_put can't place a global sharding whose devices aren't + all addressable from this process). + """ + if not (hasattr(sds, "dtype") and hasattr(sds, "shape")): + return sds + + def _make(): + if "key" in str(sds.dtype): + base = jax.random.key(0) + return base if sds.shape == () else jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape) + return jnp.zeros(sds.shape, dtype=sds.dtype) + + sharding = getattr(sds, "sharding", None) + if sharding is None: + return _make() + return jax.jit(_make, out_shardings=sharding)() + + +def _populate_pure_dict_from_partial(abstract_pure, partial_concrete): + """Fills `abstract_pure` with values from `partial_concrete` (by path), defaulting the rest. + + Paths present in `partial_concrete` take the restored value; paths absent from it + (NNX-only state the Linen checkpoint never had) get `_default_for_sds`. + """ + if isinstance(abstract_pure, dict): + return { + k: _populate_pure_dict_from_partial(v, partial_concrete.get(k) if isinstance(partial_concrete, dict) else None) + for k, v in abstract_pure.items() + } + if partial_concrete is not None and not isinstance(partial_concrete, dict): + return partial_concrete + return _default_for_sds(abstract_pure) + + +def _load_linen_checkpoint_into_nnx(path, abstract_nnx_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3): + """Restores a Linen-layout checkpoint into an NNX state (pure_nnx resume). + + Restores against a Linen-shape abstract, reshapes back via + `from_linen_checkpoint_dict`, then fills NNX-only rngs/dropout with defaults. + """ + max_logging.log(f"Restoring Linen-layout checkpoint into NNX state at {path}") + nnx_abstract_pure = abstract_nnx_state.to_pure_dict() + linen_abstract = train_state_nnx.to_linen_checkpoint_dict(nnx_abstract_pure) + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract) + restored = ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True) + restored = ckptr.restore(epath.Path(path), args=restored) + partial_nnx = train_state_nnx.from_linen_checkpoint_dict(restored) + return _populate_pure_dict_from_partial(nnx_abstract_pure, partial_nnx) + + +def _rebuild_nnx_with_values(abstract_nnx_state, concrete_weights): + """Fills each Variable in `abstract_nnx_state` with the matching restored array.""" + leaves, treedef = jax.tree_util.tree_flatten(abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + concrete = jax.tree_util.tree_leaves(concrete_weights) + if len(leaves) != len(concrete): + raise ValueError(f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs {len(concrete)} restored.") + new_leaves = [v.replace(value=a) if isinstance(v, nnx.Variable) else a for v, a in zip(leaves, concrete)] + return jax.tree_util.tree_unflatten(treedef, new_leaves) + + +def _load_linen_params_into_nnx(path, nnx_params_abstract, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3): + """Weight-only load of a Linen-layout checkpoint into an NNX params state. + + Reuses `to_linen_checkpoint_dict` (wrapping the params under `model`) to build the + `params/params/...` restore target, then rebinds the restored weights into the + NNX params Variables. + """ + max_logging.log(f"Restoring Linen-layout params into NNX state at {path}") + linen_abstract = train_state_nnx.to_linen_checkpoint_dict({"model": nnx_params_abstract.to_pure_dict()}) + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=checkpoint_storage_concurrent_gb, + save_concurrent_gb=checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + ) + restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract) + restored = ckptr.restore( + epath.Path(path), + args=ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True), + ) + return _rebuild_nnx_with_values(nnx_params_abstract, restored["params"]["params"]) + + def _load_full_state_from_path( path, abstract_unboxed_pre_state, @@ -216,6 +316,12 @@ def combine_sharding(sds, shardings): else: raise ocp_v1.errors.InvalidLayoutError(f"Unknown checkpoint layout: {source_checkpoint_layout}") else: + # pure_nnx saves in the Linen on-disk layout; reshape it back into the NNX state. + if isinstance(abstract_unboxed_pre_state, nnx.State): + return _load_linen_checkpoint_into_nnx( + path, abstract_unboxed_pre_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3 + ) + # Original v0 logic. p = epath.Path(path) handler = ocp.PyTreeCheckpointHandler( @@ -640,6 +746,17 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + # pure_nnx saves in the Linen on-disk layout; reshape it back into the NNX state. + # (Emergency managers use their own restore path below.) + if isinstance(abstract_unboxed_pre_state, nnx.State) and not isinstance( + checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) + ): + checkpoint_path = str(checkpoint_manager.directory / str(step) / "items") + restored_nnx = _load_linen_checkpoint_into_nnx( + checkpoint_path, abstract_unboxed_pre_state, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3 + ) + return ({"items": restored_nnx}, None) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX restore_target = abstract_unboxed_pre_state if isinstance(abstract_unboxed_pre_state, nnx.State): @@ -742,6 +859,12 @@ def load_params_from_path( assert load_parameters_from_path, "load_parameters_from_path is not defined." max_logging.log(f"restoring params from {load_parameters_from_path}") + # NNX target: the on-disk checkpoint is in Linen layout; reshape it into the NNX params state. + if isinstance(abstract_unboxed_params, nnx.State): + return _load_linen_params_into_nnx( + load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt, use_zarr3 + ) + # *_concurrent_gb should be set for large models, the default is 96. max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") ckptr = ocp.Checkpointer( @@ -794,8 +917,8 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step actual_step = int(state.step) - 1 if config.pure_nnx: - # Convert nnx.State to dict. - state = state.to_pure_dict() + # Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable. + state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict()) # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..1e69588b63 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any from flax import nnx +import jax +import jax.numpy as jnp class TrainStateNNX(nnx.Module): @@ -46,3 +48,138 @@ def apply_gradients(self, grads: Any): "This usually happens when the state was created for inference only." ) self.optimizer.update(self.model, grads) + + +# On-disk checkpoint format. +# +# A pure_nnx run saves in the same on-disk layout as a Linen run, so the two are +# interchangeable. The NNX state pure dict differs from Linen's in three ways, all +# reshaped below at save time: +# 1. top-level keys: {model, optimizer:{step, opt_state}} -> {params:{params:...}, step, opt_state} +# 2. weights: model/... -> params/params/... (Linen `params` collection) +# 3. opt_state: int-keyed dict (empty entries skipped) -> list with None for EmptyState, +# and mu/nu wrapped under the `params` collection +# NNX-only rngs/dropout state is dropped (Linen never had it). + +_NNX_RNG_STATE_KEYS = ("rngs", "dropout") + + +def _cast_step(step, dtype): + """Casts the step's dtype, handling both concrete arrays and abstract ShapeDtypeStruct. + + NNX stores step as uint32 and Linen as int32; this also runs on the abstract + (SDS) state when building a restore target, so it can't assume concrete values. + """ + if isinstance(step, jax.ShapeDtypeStruct): + return jax.ShapeDtypeStruct(step.shape, dtype, sharding=getattr(step, "sharding", None)) + return jnp.asarray(step, dtype=dtype) + + +def _strip_rng_state(tree): + """Removes the NNX-only 'rngs'/'dropout' subtrees that Linen doesn't carry. + + Subtrees that become empty after stripping are dropped too, so the result has + no keys Linen wouldn't also have.2 + """ + if not isinstance(tree, dict): + return tree + out = {} + for k, v in tree.items(): + if k in _NNX_RNG_STATE_KEYS: + continue + stripped = _strip_rng_state(v) + if isinstance(stripped, dict) and not stripped: + continue + out[k] = stripped + return out + + +def _wrap_mu_nu_with_params(state): + """Wraps mu/nu under an inner 'params' key (the Linen collection).""" + if not isinstance(state, dict): + return state + return {k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v for k, v in state.items()} + + +def _as_chain_index(key): + """Returns the int index for an int or digit-string key, else None.""" + if isinstance(key, int): + return key + if isinstance(key, str) and key.isdigit(): + return int(key) + return None + + +def _opt_state_to_linen(opt_state): + """Reshapes the NNX optax-chain opt_state to Linen's list-with-None layout. + + NNX serializes the chain as an int-keyed dict, skipping empty entries; Linen + uses a list with `None` for each `EmptyState`. A single-element chain is + returned unwrapped to match Linen's un-chained optimizers (e.g. adam_pax). + """ + if not isinstance(opt_state, dict): + return opt_state + indices = [_as_chain_index(k) for k in opt_state.keys()] + if not indices or any(i is None for i in indices): + return _wrap_mu_nu_with_params(opt_state) + chain = [None] * (max(indices) + 1) + for key, idx in zip(opt_state.keys(), indices): + chain[idx] = _wrap_mu_nu_with_params(opt_state[key]) + return chain[0] if len(chain) == 1 else chain + + +def to_linen_checkpoint_dict(nnx_pure_dict): + """Reshapes a TrainStateNNX pure dict ({model, optimizer}) into the Linen on-disk layout.""" + if not isinstance(nnx_pure_dict, dict): + return nnx_pure_dict + result = {} + if "model" in nnx_pure_dict: + result["params"] = {"params": _strip_rng_state(nnx_pure_dict["model"])} + optimizer = nnx_pure_dict.get("optimizer") + if isinstance(optimizer, dict): + if "step" in optimizer: + # NNX stores step as uint32; Linen uses int32. + result["step"] = _cast_step(optimizer["step"], jnp.int32) + if "opt_state" in optimizer: + result["opt_state"] = _opt_state_to_linen(optimizer["opt_state"]) + return result + + +def _strip_mu_nu_params(state): + """Inverse of `_wrap_mu_nu_with_params`: removes the inner 'params' wrap from mu/nu.""" + if not isinstance(state, dict): + return state + return {k: v["params"] if k in ("mu", "nu") and isinstance(v, dict) and "params" in v else v for k, v in state.items()} + + +def _opt_state_from_linen(opt_state): + """Inverse of `_opt_state_to_linen`: Linen list-with-None -> NNX int-keyed dict.""" + if isinstance(opt_state, list): + return {i: _strip_mu_nu_params(e) for i, e in enumerate(opt_state) if isinstance(e, dict)} + if not isinstance(opt_state, dict): + return opt_state + return {0: _strip_mu_nu_params(opt_state)} + + +def from_linen_checkpoint_dict(linen_pure_dict): + """Inverse of `to_linen_checkpoint_dict`: Linen on-disk layout -> NNX layout. + + Doesn't restore NNX-only rngs/dropout (absent from Linen); callers fill those. + """ + if not isinstance(linen_pure_dict, dict): + return linen_pure_dict + result = {} + params = linen_pure_dict.get("params") + if isinstance(params, dict) and "params" in params: + result["model"] = params["params"] + elif params is not None: + result["model"] = params + optimizer = {} + if "step" in linen_pure_dict: + # Linen stores step as int32; NNX uses uint32. + optimizer["step"] = _cast_step(linen_pure_dict["step"], jnp.uint32) + if "opt_state" in linen_pure_dict: + optimizer["opt_state"] = _opt_state_from_linen(linen_pure_dict["opt_state"]) + if optimizer: + result["optimizer"] = optimizer + return result diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py index 622f19323a..936dd70ea4 100644 --- a/tests/unit/checkpointing_nnx_load_test.py +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -14,12 +14,16 @@ """Unit tests for the NNX branches of load_state_if_possible.""" +import os +import tempfile import unittest from unittest import mock +from etils import epath import jax import jax.numpy as jnp import optax +import orbax.checkpoint as ocp from flax import nnx from maxtext.common import checkpointing @@ -102,5 +106,28 @@ def test_no_paths_returns_none_none(self): self.assertIsNone(params) +class TestLoadParamsIntoNNX(unittest.TestCase): + """Weight-only load (load_parameters_path) of a Linen-layout checkpoint into NNX.""" + + def test_linen_layout_params_restore_into_nnx_state(self): + """load_params_from_path reshapes an on-disk Linen-layout checkpoint into the NNX params state.""" + model = _Model(rngs=nnx.Rngs(0)) + _, params_abstract, _ = nnx.split(model, nnx.Param, ...) + weights = {"linear": {"kernel": jnp.arange(2, dtype=jnp.float32).reshape(2, 1), "bias": jnp.array([5.0])}} + + with tempfile.TemporaryDirectory() as d: # pylint: disable=consider-using-with + path = os.path.join(d, "ckpt") + # On-disk Linen layout: params/params/ plus an unrelated `step`. + ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).save( + epath.Path(path), {"params": {"params": weights}, "step": jnp.array(3)}, force=True + ) + restored = checkpointing.load_params_from_path(path, params_abstract, 8) + + self.assertIsInstance(restored, nnx.State) + pure = restored.to_pure_dict() + self.assertTrue(jnp.array_equal(pure["linear"]["kernel"], weights["linear"]["kernel"])) + self.assertTrue(jnp.array_equal(pure["linear"]["bias"], weights["linear"]["bias"])) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py index 0f7dc22d68..fed96eb635 100644 --- a/tests/unit/train_state_nnx_checkpoint_test.py +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -385,20 +385,23 @@ def test_nnx_and_linen_agree_on_actual_step(self): self._invoke_maybe_save(linen_state, pure_nnx=False)["step"], ) - def test_nnx_state_is_converted_to_pure_dict_before_save(self): - """For pure_nnx=True, maybe_save_checkpoint must pass a plain dict to save_checkpoint, not an nnx.State.""" + def test_nnx_state_is_saved_in_linen_layout(self): + """For pure_nnx=True, maybe_save_checkpoint reshapes the NNX state to the Linen on-disk layout.""" state = self._build_nnx_state(self.N_STEPS) self.assertIsInstance(state, nnx.State) # precondition: NNX train_step returns an nnx.State captured = self._invoke_maybe_save(state, pure_nnx=True) - # save_checkpoint should have received a plain Python dict (the result of - # nnx.State.to_pure_dict()), not the original nnx.State. + # save_checkpoint should receive a plain dict in Linen layout, not the nnx.State. self.assertIsInstance(captured["state"], dict) self.assertNotIsInstance(captured["state"], nnx.State) - # Sanity: the converted dict still mirrors the TrainStateNNX structure. - self.assertIn("model", captured["state"]) - self.assertIn("optimizer", captured["state"]) + # Linen layout: {params: {params: ...}, step, opt_state}; not the NNX {model, optimizer}. + self.assertIn("params", captured["state"]) + self.assertIn("step", captured["state"]) + self.assertIn("opt_state", captured["state"]) + self.assertNotIn("model", captured["state"]) + self.assertNotIn("optimizer", captured["state"]) + self.assertIn("params", captured["state"]["params"]) def test_linen_state_is_passed_through_unchanged(self): """For pure_nnx=False, maybe_save_checkpoint must pass the original TrainState object through.""" @@ -408,5 +411,54 @@ def test_linen_state_is_passed_through_unchanged(self): self.assertIs(captured["state"], state) +class TestLinenCheckpointFormatConverters(unittest.TestCase): + """to_linen_checkpoint_dict / from_linen_checkpoint_dict (NNX <-> Linen on-disk layout).""" + + def _nnx_pure(self): + # A 3-element optax chain (e.g. adamw): index 1 is an EmptyState (absent in the int-keyed dict). + return { + "model": { + "decoder": {"norm": {"scale": jnp.ones((3,))}}, + "dropout": {"rngs": {"default": {"key": jnp.ones((2,), dtype=jnp.uint32)}}}, # NNX-only + }, + "optimizer": { + "step": jnp.asarray(7, dtype=jnp.uint32), + "opt_state": { + 0: {"count": jnp.asarray(7), "mu": {"decoder": jnp.ones((3,))}, "nu": {"decoder": jnp.ones((3,))}}, + 2: {"count": jnp.asarray(7)}, + }, + }, + } + + def test_to_linen_layout(self): + linen = train_state_nnx.to_linen_checkpoint_dict(self._nnx_pure()) + self.assertEqual(set(linen.keys()), {"params", "step", "opt_state"}) + self.assertIn("params", linen["params"]) # params/params/ collection wrap + self.assertNotIn("dropout", linen["params"]["params"]) # NNX-only rngs/dropout stripped + self.assertEqual(linen["step"].dtype, jnp.int32) # Linen step is int32 + # opt_state is a list with None for the EmptyState slot, mu/nu wrapped under params. + self.assertIsInstance(linen["opt_state"], list) + self.assertEqual(len(linen["opt_state"]), 3) + self.assertIsNone(linen["opt_state"][1]) + self.assertIn("params", linen["opt_state"][0]["mu"]) + + def test_round_trip_preserves_values(self): + nnx_pure = self._nnx_pure() + back = train_state_nnx.from_linen_checkpoint_dict(train_state_nnx.to_linen_checkpoint_dict(nnx_pure)) + self.assertEqual(set(back.keys()), {"model", "optimizer"}) + self.assertEqual(back["optimizer"]["step"].dtype, jnp.uint32) # NNX step back to uint32 + self.assertEqual(set(back["optimizer"]["opt_state"].keys()), {0, 2}) # int-keyed dict, EmptyState dropped + self.assertNotIn("params", back["optimizer"]["opt_state"][0]["mu"]) # mu/nu unwrapped + self.assertTrue( + jnp.array_equal(nnx_pure["model"]["decoder"]["norm"]["scale"], back["model"]["decoder"]["norm"]["scale"]) + ) + + def test_cast_step_handles_shapedtypestruct(self): + sds = jax.ShapeDtypeStruct((), jnp.uint32) + out = train_state_nnx._cast_step(sds, jnp.int32) # pylint: disable=protected-access + self.assertIsInstance(out, jax.ShapeDtypeStruct) + self.assertEqual(out.dtype, jnp.int32) + + if __name__ == "__main__": unittest.main()