diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 080104ee89..d9b686b182 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -446,12 +446,16 @@ def _load_params_nnx(self, params, rng): rest_dict = rest_state.to_pure_dict() def _overlay(dst, src): - if isinstance(dst, dict): + if isinstance(dst, dict) and isinstance(src, dict): for k, v in dst.items(): if k in src: dst[k] = _overlay(v, src[k]) return dst - return src if not isinstance(src, dict) else dst + # On structural mismatch keep dst (PREFILL); swapping a leaf for a subtree + # (or the other way) would corrupt the model. Both-leaves is the overlay case. + if isinstance(dst, dict) or isinstance(src, dict): + return dst + return src rest_dict = _overlay(rest_dict, loaded_rest_dict) nnx.replace_by_pure_dict(rest_state, rest_dict) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index de48b60830..6dcdeea39a 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -982,7 +982,9 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. if isinstance(shared_embedding, nnx.Module): - embedding_table = shared_embedding.embedding.value + # Use [...] not the deprecated .value: .value records the read in NNX's mutation + # tracking, which leaks a tracer out of vocab_tiling_nnx_loss's custom_vjp. + embedding_table = shared_embedding.embedding[...] else: embedding_table = shared_embedding.variables["params"]["embedding"] if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 2af0d560da..ac908c0f96 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -360,7 +360,6 @@ def __init__( else: decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) - self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) @@ -567,10 +566,6 @@ def __call__( mutable=mutable_collections, ) # pytype: disable=wrong-keyword-args - # Materialize hidden state when vocab tiling is enabled - if self.config.num_vocab_tiling > 1: - self.hidden_states = hidden_state - # If we are initializing the model AND MTP is enabled, we must create # dummy target tensors. This allows Flax to trace the MTPBlock and create # all its necessary parameters, without requiring the main training pipeline diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index e9d83f4db3..841f1cf257 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1041,17 +1041,16 @@ def _build_value_target(v): # Free memory used by initial sharded_state before restore, to make room for the incoming checkpoint arrays. def _free_device_memory(node): - val = node if isinstance(node, nnx.Variable) and not isinstance(node, nnx.RngState): inner = node.get_value() if hasattr(node, "get_value") else node[...] - # Same QTensor caveat as `_build_value_target`: AQT serve-mode `qrhs.frozen` - # wraps a QTensor whose `__getitem__` fails on `LogicallyPartitioned`. - # We only need to free a single jax.Array leaf — for composite values - # there's nothing to free at this level, so skip. - val = inner if hasattr(inner, "shape") else None - - if isinstance(val, jax.Array) and not val.is_deleted(): - val.delete() + # AQT serve-mode `qrhs.frozen` wraps a QTensor (composite pytree) rather + # than a single jax.Array. Walking via tree_leaves frees the qvalue/scale + # arrays too; the single-leaf case is a 1-element tree. + for leaf in jax.tree_util.tree_leaves(inner): + if isinstance(leaf, jax.Array) and not leaf.is_deleted(): + leaf.delete() + elif isinstance(node, jax.Array) and not node.is_deleted(): + node.delete() return node diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index 2610555941..b17e318570 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -30,6 +30,29 @@ from maxtext.utils import max_utils +# Submodule names whose params are used by logits_from_hidden_states_for_vocab_tiling: +# the final norm, the LM-head dense, and the embedding table when logits are tied. +# vocab_tiling_nnx_loss splits these out as the only params the loss differentiates. +_OUTPUT_HEAD_PATH_KEYS = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense") + + +def _is_output_head_param_path(path, _value): + """Filter for nnx.split: True when the param path belongs to the output head.""" + + # JAX path entries differ by key type: DictKey uses .key, GetAttrKey uses .name + # in newer Flax and .attr in older. Check all three so the filter survives + # version upgrades. + def _name(k): + for attr in ("key", "attr", "name"): + v = getattr(k, attr, None) + if v is not None: + return str(v) + return str(k) + + keys = [_name(k) for k in path] + return any(k in keys for k in _OUTPUT_HEAD_PATH_KEYS) + + def vocab_tiling_linen_loss( hidden_states, data, @@ -253,12 +276,12 @@ def _bwd_scan_body(grad_params_acc, chunk_data): def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): """Computes cross-entropy loss with vocab tiling for NNX models. - NNX equivalent of ``vocab_tiling_linen_loss``. Scans the vocab dimension - and calls ``model.logits_from_hidden_states_for_vocab_tiling`` per chunk. The NNX model - carries its own parameters, so no explicit gather is needed. - - Uses default autograd; a custom_vjp for backward memory savings can be - added later if needed. + NNX equivalent of `vocab_tiling_linen_loss`. A `custom_vjp` runs the loss in + vocab chunks via `jax.lax.scan` so the backward only holds one chunk's logits + at a time, matching the Linen path's memory profile. `nnx.split` separates the + output-head params (which the loss differentiates) from everything else; the + rest of the model is passed through but not differentiated, so the scan's + residuals stay small. Args: model: NNX model exposing ``logits_from_hidden_states_for_vocab_tiling``. @@ -320,42 +343,137 @@ def _reshape(inputs, out_shape, out_sharding): labels = _maybe_shard_with_name(labels, label_spec) segmentation = _maybe_shard_with_name(segmentation, label_spec) - batch_size, seq_len, emb_dim = hidden_states.shape - vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + # head_params is what the loss differentiates; other_params (transformer layers) and + # rest (rngs) are passed through the custom_vjp but not differentiated. They go through + # as primals rather than closure captures: capturing them leaks tracers across the + # custom_vjp + lax.scan boundary, which fails for tied embeddings. + graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...) - reshaped_hidden_states = _reshape( - hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec - ) - reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) - reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) - - # Rebuild the model per chunk inside the scan: the output head pulls an rng stream, and - # mutating the outer model's rng inside scan's sub-trace raises TraceContextError. - # nnx.merge(..., copy=True) makes fresh Variables local to each iteration. - graphdef, model_state = nnx.split(model) - - def _scan_body(accumulators, chunk_data): - loss_accumulator, z_loss_accumulator = accumulators - hidden_chunk, label_chunk, segmentation_chunk = chunk_data - hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) - label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) - segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) - - chunk_model = nnx.merge(graphdef, model_state, copy=True) - chunk_logits = chunk_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode) - chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec) - one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) - chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( - chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk): + local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True) + chunk_logits = local_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode) + return _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + + @jax.custom_vjp + def chunked_cross_entropy_loss(chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation): + (total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd( + chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation ) + return total_loss, total_z_loss + + def _chunked_cross_entropy_loss_fwd( + chunk_head_params, chunk_other_params, chunk_rest, hidden_states, labels, segmentation + ): + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _fwd_scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) - masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) - masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + chunk_logits = _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) - return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None - initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype)) - (total_loss, total_z_loss), _ = jax.lax.scan( - _scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + # Always accumulate in fp32 — `cross_entropy_with_logits` returns fp32 regardless of + # logits dtype, and a bf16 carry would mismatch the body output type under lax.scan. + initial_acc = (jnp.zeros((), dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _fwd_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + residuals = ( + chunk_head_params, + chunk_other_params, + chunk_rest, + reshaped_hidden_states, + reshaped_labels, + reshaped_segmentation, + batch_size, + seq_len, + emb_dim, + ) + return (total_loss, total_z_loss), residuals + + def _chunked_cross_entropy_loss_bwd(residuals, cotangents): + # z_loss is folded into the xent loss inside cross_entropy_with_logits. + loss_cotangent, _ = cotangents + + ( + chunk_head_params, + chunk_other_params, + chunk_rest, + reshaped_hidden_states, + reshaped_labels, + reshaped_segmentation, + batch_size, + seq_len, + emb_dim, + ) = residuals + + def _single_chunk_loss_fn(input_head_params, input_hidden_chunk, input_label_chunk, input_segmentation_chunk): + chunk_logits = _logits_for_chunk(input_head_params, chunk_other_params, chunk_rest, input_hidden_chunk) + one_hot_label_chunk = jax.nn.one_hot(input_label_chunk, config.vocab_size) + xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier) + return jnp.sum(xent * (input_segmentation_chunk != 0)) + + def _bwd_scan_body(grad_head_acc, chunk_data): + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + # pylint: disable=unnecessary-lambda-assignment + loss_fn_for_vjp = lambda p, h: _single_chunk_loss_fn(p, h, label_chunk, segmentation_chunk) + _, vjp_fn = jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk) + (grad_head_update, grad_hidden_chunk) = vjp_fn(1.0) + grad_hidden_chunk = _maybe_shard_with_name(grad_hidden_chunk, chunked_hidden_spec) + + grad_head_acc = jax.tree_util.tree_map(lambda acc, update: acc + update, grad_head_acc, grad_head_update) + return grad_head_acc, grad_hidden_chunk + + initial_grad_head = jax.tree_util.tree_map(jnp.zeros_like, chunk_head_params) + + grad_head, grad_reshaped_hidden_states = jax.lax.scan( + _bwd_scan_body, initial_grad_head, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec) + grad_head = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_head) + grad_head = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), chunk_head_params, grad_head) + grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) + + # Return explicit zeros for other_params and rest, not None. With None, JAX builds + # the zero cotangents with the wrong layer-axis order for scanned params, and the + # AOT trace fails the cotangent shape check. + grad_other = jax.tree_util.tree_map(jnp.zeros_like, chunk_other_params) + grad_rest = jax.tree_util.tree_map(jnp.zeros_like, chunk_rest) + return ( + grad_head, + grad_other, + grad_rest, + grad_reshaped_hidden_states.astype(reshaped_hidden_states.dtype), + None, + None, + ) + + chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd) + + total_loss, total_z_loss = chunked_cross_entropy_loss( + head_params, other_params, rest, hidden_states, labels, segmentation ) return total_loss, total_z_loss diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 510ce95f5a..899c1227e4 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -22,6 +22,8 @@ import pytest from flax import linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from jax.sharding import Mesh @@ -642,3 +644,332 @@ def test_vocab_tiling_gradient_context_parallelism(self): self.assert_pytrees_all_close( grads_non_tiling, grads_tiling, "Gradients of embedding table do not match for context parallelism." ) + + +class VocabTilingNNXTest(unittest.TestCase): + """Loss + gradient parity for the NNX vocab-tiling `custom_vjp` path. + + Compares two computations against the same NNX model: + - reference: full-vocab `model.logits_from_hidden_states_for_vocab_tiling(...)` then xent over the whole vocab. + - tiled: `vocab_tiling_nnx_loss(...)` which scans over `num_vocab_tiling` chunks + and uses a `custom_vjp` for the backward. + + Both paths share the same params; the test checks that loss values and parameter + gradients match within tolerance, exercising both forward and backward. + """ + + def setUp(self): + self.base_config = [None, get_test_config_path()] + self.rng = jax.random.PRNGKey(1234) + # Global batch must divide fsdp axis (= jax.device_count() by default), so the + # batch sharding constraints inside vocab_tiling_nnx_loss are satisfied. + self.batch_size = jax.device_count() + self.seq_len = 64 + self.rtol = 1e-2 + self.atol = 1e-2 + + def _build_cfg_and_model( + self, + *, + num_vocab_tiling=4, + logits_via_embedding=False, + z_loss_multiplier=1e-4, + ): + """Build a pyconfig + matching NNX `Transformer` for the test.""" + cfg = pyconfig.initialize( + self.base_config, + run_name=f"vt_nnx_n{num_vocab_tiling}_emb{logits_via_embedding}_z{z_loss_multiplier}", + enable_checkpointing=False, + enable_dropout=False, + max_target_length=self.seq_len, + per_device_batch_size=1, + logits_via_embedding=logits_via_embedding, + base_num_decoder_layers=0, + dtype="float32", + matmul_precision="high", + num_vocab_tiling=num_vocab_tiling, + z_loss_multiplier=z_loss_multiplier, + pure_nnx=True, + enable_nnx=True, + pure_nnx_decoder=True, + ) + mesh = maxtext_utils.get_mesh_from_config(cfg) + rngs = maxtext_utils_nnx.create_nnx_rngs(cfg) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs, model_mode=MODEL_MODE_TRAIN) + return cfg, model + + def _make_inputs(self, cfg, *, dtype=jnp.float32, pad_half=False): + """Synthetic hidden_states/labels/segmentation; `pad_half=True` zeros the back half of seg.""" + rng_hidden, rng_targets = jax.random.split(self.rng) + hidden_states = jax.random.normal(rng_hidden, (self.batch_size, self.seq_len, cfg.emb_dim), dtype=dtype) + labels = jax.random.randint(rng_targets, (self.batch_size, self.seq_len), 0, cfg.vocab_size) + if pad_half: + half = self.seq_len // 2 + segmentation = jnp.concatenate( + [ + jnp.ones((self.batch_size, half), dtype=jnp.int32), + jnp.zeros((self.batch_size, self.seq_len - half), dtype=jnp.int32), + ], + axis=1, + ) + else: + segmentation = jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32) + return hidden_states, labels, segmentation + + def _reference_loss_fn(self, cfg, graphdef, rest, hidden_states, labels, segmentation): + """Full-vocab xent loss closure (params, hidden_states) -> scalar loss.""" + + def loss_fn(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + logits = local_model.logits_from_hidden_states_for_vocab_tiling(h, True, "train") + one_hot = jax.nn.one_hot(labels, cfg.vocab_size) + xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot, z_loss=cfg.z_loss_multiplier) + return jnp.sum(xent * (segmentation != 0)) + + return loss_fn + + def _tiled_loss_fn(self, cfg, graphdef, rest, hidden_states, labels, segmentation): + """vocab_tiling_nnx_loss closure (params, hidden_states) -> scalar loss.""" + # hidden_states unused at the closure boundary (it comes via h), but kept in the + # signature so the two closures are callable interchangeably. + del hidden_states + data = {"targets": labels, "targets_segmentation": segmentation} + + def loss_fn(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + total_loss, _ = vocab_tiling_nnx_loss(local_model, h, data, cfg, is_train=True) + return total_loss + + return loss_fn + + def _split_and_axes(self, cfg, model): + """Common boilerplate: split the model and bind the logical axis rules.""" + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + return graphdef, params, rest + + def _assert_pytrees_close(self, ref, tiled, msg, *, rtol=None, atol=None): + rtol = self.rtol if rtol is None else rtol + atol = self.atol if atol is None else atol + leaves_close = jax.tree_util.tree_map(lambda x, y: jnp.allclose(x, y, rtol=rtol, atol=atol), ref, tiled) + if not all(jax.tree_util.tree_leaves(leaves_close)): + raise AssertionError(msg) + + @staticmethod + def _vg(fn, *, argnums=0): + """value_and_grad wrapped in jit. Eager value_and_grad trips an IndivisibleError + on the fsdp reshape inside vocab_tiling_nnx_loss; jit lets XLA reshard cleanly, + which is also how train.py runs it.""" + return jax.jit(jax.value_and_grad(fn, argnums=argnums)) + + @staticmethod + def _g(fn, *, argnums=0): + """grad wrapped in jit — see `_vg`.""" + return jax.jit(jax.grad(fn, argnums=argnums)) + + def _run_parity(self, *, logits_via_embedding): + """Compare full-vocab xent loss/grads against the tiled custom_vjp path.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4, logits_via_embedding=logits_via_embedding) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grads = self._vg(ref_loss_fn)(params, hidden_states) + tile_loss, tile_grads = self._vg(tile_loss_fn)(params, hidden_states) + + assert jnp.allclose( + ref_loss, tile_loss, rtol=self.rtol, atol=self.atol + ), f"Losses differ: ref={ref_loss} tiled={tile_loss}" + self._assert_pytrees_close(ref_grads, tile_grads, "Param gradients differ between full-vocab and tiled paths.") + + # ---------- Original parity tests (params gradient under both embedding modes) ---------- + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_non_tied_embedding(self): + """custom_vjp parity for non-tied embedding (separate logits_dense).""" + self._run_parity(logits_via_embedding=False) + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_tied_embedding(self): + """custom_vjp parity when logits share the input embedding table.""" + self._run_parity(logits_via_embedding=True) + + # ---------- Coverage expansion ---------- + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_total_z_loss_value_parity(self): + """The second tuple element (total_z_loss) must match the full-vocab reference.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + data = {"targets": labels, "targets_segmentation": segmentation} + + def _ref(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + logits = local_model.logits_from_hidden_states_for_vocab_tiling(h, True, "train") + one_hot = jax.nn.one_hot(labels, cfg.vocab_size) + xent_ref, z_ref = max_utils.cross_entropy_with_logits(logits, one_hot, z_loss=cfg.z_loss_multiplier) + return jnp.sum(xent_ref * (segmentation != 0)), jnp.sum(z_ref * (segmentation != 0)) + + def _tile(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + return vocab_tiling_nnx_loss(local_model, h, data, cfg, is_train=True) + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_total_loss, ref_total_z_loss = jax.jit(_ref)(params, hidden_states) + tile_total_loss, tile_total_z_loss = jax.jit(_tile)(params, hidden_states) + + assert jnp.allclose(ref_total_loss, tile_total_loss, rtol=self.rtol, atol=self.atol) + assert jnp.allclose( + ref_total_z_loss, tile_total_z_loss, rtol=self.rtol, atol=self.atol + ), f"total_z_loss differs: ref={ref_total_z_loss} tiled={tile_total_z_loss}" + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_padded_segmentation(self): + """Half-padded segmentation: mask actually changes the loss, and parity holds.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + + # Compare unpadded vs padded loss to confirm the mask is wired through. + hs, labels, full_seg = self._make_inputs(cfg, pad_half=False) + _, _, pad_seg = self._make_inputs(cfg, pad_half=True) + graphdef, params, rest = self._split_and_axes(cfg, model) + + def _tile_loss_only(p, h, seg): + local_model = nnx.merge(graphdef, p, rest, copy=True) + total, _ = vocab_tiling_nnx_loss( + local_model, h, {"targets": labels, "targets_segmentation": seg}, cfg, is_train=True + ) + return total + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + full_loss = jax.jit(_tile_loss_only)(params, hs, full_seg) + pad_loss = jax.jit(_tile_loss_only)(params, hs, pad_seg) + assert float(pad_loss) < float( + full_loss + ), f"Padded loss should be strictly smaller (fewer tokens contribute). full={full_loss} pad={pad_loss}" + + # Now check parity against the full-vocab reference using the padded mask. + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hs, labels, pad_seg) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hs, labels, pad_seg) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grads = self._vg(ref_loss_fn)(params, hs) + tile_loss, tile_grads = self._vg(tile_loss_fn)(params, hs) + assert jnp.allclose(ref_loss, tile_loss, rtol=self.rtol, atol=self.atol) + self._assert_pytrees_close(ref_grads, tile_grads, "Padded-segmentation gradients differ.") + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_grad_over_hidden_states(self): + """Gradient w.r.t. hidden_states (argnums=1) matches the reference: exercises the + custom_vjp's hidden_states cotangent, which the params-only tests don't reach.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_grad_h = self._g(ref_loss_fn, argnums=1)(params, hidden_states) + tile_grad_h = self._g(tile_loss_fn, argnums=1)(params, hidden_states) + + assert ref_grad_h.shape == hidden_states.shape + assert tile_grad_h.shape == hidden_states.shape + assert ref_grad_h.dtype == hidden_states.dtype + assert tile_grad_h.dtype == hidden_states.dtype + assert jnp.allclose(ref_grad_h, tile_grad_h, rtol=self.rtol, atol=self.atol), "grad_hidden_states diverged" + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_bf16_hidden_states(self): + """bf16 hidden_states: loss/grad parity holds and the grad keeps the bf16 dtype.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg, dtype=jnp.bfloat16) + graphdef, params, rest = self._split_and_axes(cfg, model) + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grad_h = self._vg(ref_loss_fn, argnums=1)(params, hidden_states) + tile_loss, tile_grad_h = self._vg(tile_loss_fn, argnums=1)(params, hidden_states) + + # bf16 has ~3 decimal digits — loosen tolerance. + assert jnp.allclose(ref_loss, tile_loss, rtol=5e-2, atol=5e-2) + assert tile_grad_h.dtype == jnp.bfloat16, f"grad cast to primal dtype expected bf16, got {tile_grad_h.dtype}" + assert jnp.allclose(ref_grad_h, tile_grad_h, rtol=5e-2, atol=5e-2) + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_z_loss_zero(self): + """z_loss=0: total_z_loss is exactly zero; loss/grad parity still holds.""" + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4, z_loss_multiplier=0.0) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + data = {"targets": labels, "targets_segmentation": segmentation} + + def _tile_fn(p, h): + local_model = nnx.merge(graphdef, p, rest, copy=True) + return vocab_tiling_nnx_loss(local_model, h, data, cfg, is_train=True) + + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + total_loss, total_z_loss = jax.jit(_tile_fn)(params, hidden_states) + assert float(total_z_loss) == 0.0, f"z_loss=0 but tile path returned {total_z_loss}" + assert float(total_loss) > 0.0 # cross-entropy on random logits should be positive + + ref_loss_fn = self._reference_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + ref_loss, ref_grads = self._vg(ref_loss_fn)(params, hidden_states) + tile_loss, tile_grads = self._vg(tile_loss_fn)(params, hidden_states) + assert jnp.allclose(ref_loss, tile_loss, rtol=self.rtol, atol=self.atol) + self._assert_pytrees_close(ref_grads, tile_grads, "z_loss=0 gradients differ.") + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_other_params_get_zero_grad(self): + """Carve-out invariant: non-head params get zero grad, head params don't. + + logits_from_hidden_states_for_vocab_tiling only uses the output-head params, so + the loss gradient for every other param must be exactly zero. The "at least one + head grad is non-zero" check guards against a bug that just zeros everything. + """ + cfg, model = self._build_cfg_and_model(num_vocab_tiling=4) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + _, tile_grads = self._vg(tile_loss_fn)(params, hidden_states) + + head_keywords = ("token_embedder", "shared_embedding", "decoder_norm", "logits_dense") + head_nonzero_seen = False + for path, leaf in jax.tree_util.tree_leaves_with_path(tile_grads): + path_str = jax.tree_util.keystr(path) + is_head = any(kw in path_str for kw in head_keywords) + if is_head: + if jnp.any(leaf != 0): + head_nonzero_seen = True + else: + assert jnp.all(leaf == 0), f"non-head leaf {path_str} has non-zero grad — carve-out is wrong" + assert head_nonzero_seen, "expected at least one head leaf with non-zero grad; got all zeros" + + @pytest.mark.tpu_only + def test_nnx_vocab_tiling_num_vocab_tiling_variants(self): + """Different num_vocab_tiling values (2, 4, 8) all produce identical loss + grads.""" + losses = [] + grads_list = [] + for n in (2, 4, 8): + cfg, model = self._build_cfg_and_model(num_vocab_tiling=n) + hidden_states, labels, segmentation = self._make_inputs(cfg) + graphdef, params, rest = self._split_and_axes(cfg, model) + tile_loss_fn = self._tiled_loss_fn(cfg, graphdef, rest, hidden_states, labels, segmentation) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + loss, grads = self._vg(tile_loss_fn)(params, hidden_states) + losses.append(loss) + grads_list.append(grads) + + base_loss = losses[0] + base_grads = grads_list[0] + for n, loss, grads in zip((2, 4, 8), losses, grads_list): + assert jnp.allclose( + loss, base_loss, rtol=self.rtol, atol=self.atol + ), f"num_vocab_tiling={n}: loss diverges from n=2 baseline ({loss} vs {base_loss})" + self._assert_pytrees_close(base_grads, grads, f"num_vocab_tiling={n}: grads diverge from n=2 baseline.") diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index a469a6fa70..1975ad1abf 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -1024,10 +1024,6 @@ def test_muon(self, muon_consistent_rms): @pytest.mark.cpu_only def test_vocab_tiling_bf16(self): """test vocab_tiling when weight_dtype=bfloat16""" - cfg = pyconfig.initialize([None, get_test_config_path()]) - if getattr(cfg, "enable_nnx", False): - pytest.skip("Vocab tiling not supported on NNX.") - compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16.pickle" train_compile_main( ( @@ -1123,3 +1119,31 @@ def test_zero1_optimizer_sharding(self): "shard_mode=explicit", ) ) + + @pytest.mark.cpu_only + def test_vocab_tiling_bf16_nnx(self): + """AOT compile vocab tiling on the NNX path (vocab_tiling_nnx_loss + custom_vjp). + + Sets `pure_nnx`/`enable_nnx`/`pure_nnx_decoder` explicitly so the NNX AOT + path is covered regardless of the default values. Once those defaults flip + to True, `test_vocab_tiling_bf16` above will already exercise this same + path via defaults. + """ + compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16_nnx.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=1", + "base_num_decoder_layers=2", + "per_device_batch_size=2", + "max_target_length=1024", + "num_vocab_tiling=4", + "weight_dtype=bfloat16", + "pure_nnx=true", + "enable_nnx=true", + "pure_nnx_decoder=true", + ) + )