Skip to content

[NNX] NNX migration prep (8/N): NNX native lora grpo#3824

Open
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-native-lora-grpo
Open

[NNX] NNX migration prep (8/N): NNX native lora grpo#3824
ecnal-cienet wants to merge 3 commits into
mainfrom
feat/nnx-native-lora-grpo

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 6, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. ❌ Linen↔NNX checkpoint comparator (sibling branch on PR4.5).
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. 🔄 [This PR] NNX-native LoRA + GRPO. NNX-native serving / decode-checkpoint LoRA via apply_lora_on_base_params_nnx / unapply_lora_from_base_params_nnx / get_lora_abstract_state_nnx (the maxengine pure_nnx + LoRA carve-out from PR7 is cleared); NNX-native GRPO trainer via grpo_loss_fn_nnx + compute_log_probs_nnx + NNX setup_train_loop/train_step/eval_step paths. Stacks on PR7.
  9. ❌ NNX-aware QK-Clip + remaining checkpoint utilities.
    9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

This PR implements NNX-native LoRA serving and NNX-native GRPO by adding NNX-shape walkers and step helpers alongside the existing Linen ones, then dispatching on config.pure_nnx. Every NNX modification is gated by if config.pure_nnx:, preserving the Linen path byte-for-byte. The diff spans +551 / −84 across 5 source files, plus 2 new test files (515 lines).

Part 1: NNX-shape LoRA Walkers

New helpers in src/maxtext/utils/lora_utils.py operating on nnx.State pure trees (no {"params": ...} outer wrap):

  • apply_lora_on_base_params_nnx mutates base_params in place: W += B @ A * scale at target attention paths
  • unapply_lora_from_base_params_nnx is the symmetric inverse
  • get_lora_abstract_state_nnx walks the abstract state.model substate and emits a parallel tree with lora_a.kernel/lora_b.kernel leaves at target attention paths and None elsewhere
  • _nnx_param_subtree drops the outer TrainStateNNX wrapping

The base model stays pristine; "apply" merges the delta into the kernel, "unapply" reverses. No nnx.LoRA wrapper, no model surgery. The on-disk format (HuggingFace PEFT-style lora_a.kernel / lora_b.kernel) round-trips between Linen and NNX consumers unchanged.

Part 2: LoRA Dispatch in setup_initial_lora_state and load_adapter

Both top-level entry points in lora_utils.py branch on config.pure_nnx:

  • NNX init builds the abstract base via model_creation_utils.create_nnx_abstract_model + TrainStateNNX(model, optimizer)
  • Linen branch is the original init_initial_state + get_lora_abstract_state path, untouched

Part 3: MaxEngine LoRA Carve-out Cleared

src/maxtext/inference/maxengine/maxengine.py:

  • load_single_adapter no longer raises NotImplementedError on pure_nnx
  • apply_adapter / unapply_adapter branch on config.pure_nnx to call the _nnx siblings

Part 4: GRPO Loss and Step Helpers

src/maxtext/experimental/rl/grpo_trainer.py:

  • grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train). Signature matches Linen grpo_loss_fn so callers dispatch on the same shape. dropout_rng and params are unused on NNX; reference_model is a frozen nnx.Module and the reference forward is wrapped in stop_gradient. Returns (loss, LossAux), same dataclass as Linen.
  • _train_step_nnx: nnx.merge(graphdef, state) to reconstruct TrainStateNNX, value_and_grad over policy params, state.apply_gradients(grads), return nnx.state(new_state, nnx.Not(nnx.Intermediate)).
  • _eval_step_nnx: same merge + loss-fn call, no state update.
  • train_step / eval_step early-dispatch on config.pure_nnx; Linen branches verbatim.

Part 5: GRPO setup_train_loop on NNX

grpo_trainer.py::setup_train_loop:

  • Builds training and inference models via mt.from_config(rngs=create_nnx_rngs(...))
  • Initializes state via create_nnx_abstract_model + TrainStateNNX(model, optimizer, reference_model=...)
  • Reference uses the same init seed as policy and is never updated by apply_gradients (sibling field on TrainStateNNX, not embedded in params)
  • The WARNING: GRPO RL trainer does not yet support pure_nnx natively log is removed

Part 6: GRPO train_loop NNX Branches

grpo_trainer.py::train_loop — three Linen-coupled spots branched on pure_nnx:

  • Initial reference seeding is skipped on NNX (already set up by init_state_fn)
  • metric_logger.write_setup_info_to_tensorboard receives a flat nnx.Param state on NNX
  • Checkpoint save passes the whole TrainStateNNX on NNX; the Linen _split_grpo_state(state)[0] strip is bypassed

The reshard call routes to pathways_reshard_nnx when pure_nnx. New helpers in grpo_utils.py:

  • compute_log_probs_nnx: NNX model is called directly; intermediates pulled via nnx.state(model, nnx.Intermediate).to_pure_dict()
  • pathways_reshard_nnx: splits state.model to a flat nnx.Param state, reshards onto the inference mesh, calls inference_engine.update_params(...)

Part 7: Carve-outs (NotImplementedError Sites)

Feature Tracked In
GRPO + gradient_accumulation_steps > 1 Follow-up
GRPO + scan_layers=False Follow-up (needs an NNX-aware unscan helper)

Tests

New unit tests (tests/unit/lora_utils_nnx_test.py, 10 tests):

  • 5 on get_lora_abstract_state_nnx: q/k/v/o shape derivation, target-vs-non-target masking, sharding propagation, leaf type validation, error paths
  • 3 on apply_lora_on_base_params_nnx: apply→unapply identity, target-only mutation, numerical parity vs Linen apply_lora_on_base_params on the same random inputs
  • 2 Linen regression smoke tests on apply_lora_on_base_params and unapply_lora_from_base_params (no existing unit test for these helpers in the tree)

New unit tests (tests/unit/grpo_nnx_test.py, 8 tests):

  • 5 on grpo_loss_fn_nnx: LossAux shape parity, signature compatibility, identical-policy/reference → zero KL, grpo_beta=0aux.avg_kl=None, finite policy grads
  • 1 on compute_log_probs_nnx: shape [B, S] → [B, S-1]
  • 2 Linen regression smoke tests on grpo_loss_fn and compute_log_probs (the existing Linen integration test is TPU-only and currently @pytest.mark.skip)

Modified test: tests/unit/maxengine_test.py swaps test_lora_raises_for_nnx (asserted NotImplementedError) for test_lora_load_single_adapter_reaches_loader_on_nnx (asserts FileNotFoundError from the loader).

Existing Linen tests: untouched and still pass; pure_nnx=False stays default.

Test results: 198 passed, 1 skipped (pre-existing CPU-only skip) across the broader NNX regression sweep — maxengine_test, dpo_nnx_test, train_nnx_test, lora_utils_nnx_test, grpo_nnx_test, train_state_nnx_test, train_utils_nnx_test, gradient_accumulation_nnx_test, linen_nnx_converter_test, compare_linen_nnx_checkpoint_test.

Linting: bash lint.sh — pyink + pylint 10.00/10.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet changed the title Feat/nnx native lora grpo [NNX] NNX migration prep (8/N): Feat/nnx native lora grpo May 6, 2026
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (8/N): Feat/nnx native lora grpo [NNX] NNX migration prep (8/N): native lora grpo May 6, 2026
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (8/N): native lora grpo [NNX] NNX migration prep (8/N): NNX native lora grpo May 6, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented May 6, 2026

Codecov Report

❌ Patch coverage is 64.00000% with 45 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/lora_utils.py 61.20% 38 Missing and 7 partials ⚠️

📢 Thoughts on this report? Let us know!

…e.py)

PR5 audited maxengine.py and routed the inference path to the Linen
implementation regardless of pure_nnx, with a comment block explaining
that "the flag affects training, not inference serving." That kept the
Linen serving path unchanged but meant pure_nnx=True users silently got
the Linen engine. This change replaces the route with a real NNX flow:
when config.pure_nnx=True, the engine builds an NNX Transformer, splits
out (params, cache, rest) with nnx.split, and at every JIT body merges
the model concretely with nnx.merge to run the forward pass. Linen is
preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:`
and pure_nnx=False is still the default.

maxengine.py (__init__):
- Build two abstract NNX Transformers on the NNX path: self.model with
  model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar
  with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on,
  decode_state shape). Both are needed because NNX cache vars inherit
  CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode,
  and bulk_insert searches for the substring "cache_batch" in the
  AR-mode logical-axes tuple. nnx.eval_shape is called directly inside
  nn_partitioning.axis_rules rather than through create_nnx_abstract_model
  to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only
  axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh).
- Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT
  bodies can pass (params, cache, rest) separately to nnx.merge. The
  rest slot (RNG vars etc.) is materialized concretely in load_params.

maxengine.py (cache adapter + _nnx_run_model):
- bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the
  cache via tree_map_with_path and switch on path[-1].key (the cache
  variable name like "cached_prefill_key"). Linen mutable cache is a
  plain nested dict. NNX Cache state would expose a ".value" accessor
  at that position. Bridge via nnx.State.to_pure_dict() (after the
  model run) and nnx.replace_by_pure_dict (before nnx.merge), so the
  cache plumbing helpers see the same shape on both paths.
- Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True)
  -> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True
  avoids reusing Variable objects across traces (TraceContextError),
  mirroring train.py's diff_wrapper workaround.
- Add _nnx_cache_state_template / _nnx_init_cache_dict helpers
  parametrised by mode so prefill (batch 1) and decode_state (batch N)
  pull from the right abstract model.

maxengine.py (load_params):
- New _load_params_nnx: accepts user-provided NNX-shape params or loads
  via from_pretrained. For user-provided params, materializes a concrete
  model once via _create_model_fn() to capture a real rest state for
  nnx.merge (wasteful but simple; the from_pretrained branch avoids
  this). Refreshes self.graphdef from the concrete model so subsequent
  merges line up exactly.
- Builds self.abstract_params, populates self.prefill_kv_cache_annotations
  and self.kv_cache_annotations (using model_ar for the latter so
  bulk_insert's substring lookup hits), wraps both into NamedSharding.
- pure_nnx + quantization, pure_nnx + LoRA, pure_nnx +
  stack_prefill_result_cache=True, pure_nnx + prefill_multisampling,
  and pure_nnx + prefill_concat raise NotImplementedError for now;
  the Linen path is the workaround. AOT compilation
  (aot_compile / _compile_generate_and_get_layouts) is not gated and
  may work as-is; not exercised by tests yet.

maxengine.py (init_decode_state, _prefill_jit, _generate_jit):
- _init_decode_state_nnx zero-initializes a pure-dict cache from
  model_ar (so the leading batch dim matches generate's input shape)
  and builds kv_cache_annotations_named per leaf by reading
  nnx.Cache.metadata. Tries "out_sharding", "sharding", and
  "sharding_names" because Flax 0.12.6 renamed these.
- _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch
  that calls _nnx_run_model in place of self.model.apply with
  mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict
  cache directly (no params|{"cache":...} dict-merge — params is an
  nnx.State, not a dict).

maxtext_utils.py:
- New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx
  that mirror the Linen helpers' return shape (per-leaf PartitionSpec
  tree). Both delegate to _nnx_cache_partition_specs which extracts
  nnx.Cache state via nnx.split, calls
  get_nnx_named_sharding_with_scan_axis inside
  nn_partitioning.axis_rules so logical axes ("layers", "cache_batch",
  "norm", ...) resolve to physical mesh axes, and converts the result
  to a pure-dict tree.

tests/unit/maxengine_test.py:
- New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and
  per-layer cache shape checks), test_basic_decode_nnx (4-step generate
  with next_pos advancement check), test_quantize_raises_for_nnx,
  test_lora_raises_for_nnx.
- New test_linen_nnx_parity_prefill: bridges Linen-init params into
  the NNX engine via linen_nnx_converter (convert_linen_to_nnx ->
  _strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the
  NNX engine's prefill matches Linen on the same weights — logits
  within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses
  bf16 compute) and exact greedy first-token argmax.
- Existing Linen tests untouched.

Test summary: 9 passed, 1 skipped (test_chunked_prefill is a
pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink
all green.
…acked prefill cache)

PR7 (NNX-native MaxEngine inference) made the core prefill/generate/insert
path work under pure_nnx=True but left three serving features raising
NotImplementedError on the NNX path. This promotes all three to NNX-native.
Linen is preserved byte-for-byte: the original model.apply(..., mutable=["cache"])
calls are unchanged, just moved into else: branches, and every NNX edit is
gated `if config.pure_nnx:`.

maxengine.py:
- _prefill_multisampling_jit: drops the NotImplementedError; adds a pure_nnx
  branch that runs prefill through _nnx_run_model (MODEL_MODE_PREFILL, batch=1)
  with a fresh _nnx_init_cache_dict. The loop that draws num_samples first
  tokens from the shared logits is unchanged.
- prefill_concat: same swap; the packed positions and segment ids thread
  through _nnx_run_model unchanged.
- stack_prefill_result_cache=True: now supported for both scan_layers values.
  scan_layers=True already stacks the per-layer KV cache on axis 0 (the Linen
  post-stack shape), so _maybe_stack/_maybe_unstack_prefill_result_cache are
  no-ops and prefill_kv_cache_shardings stays the full tree. scan_layers=False
  keeps unstacked per-layer subtrees under cache["decoder"]["layers"][i] (int
  keys), so _maybe_stack stacks them into one subtree with a leading layer axis,
  _maybe_unstack splits it back into the int-keyed per-layer dict that
  bulk_insert/_insert_jit walk, and _load_params_nnx prepends a layer axis to
  each prefix-sharding spec (the NNX analog of the Linen P(None, *spec) +
  ["decoder"]["layers_0"] reshape).

tests/integration/maxengine_test.py:
- New _build_linen_params helper and a shared _stack_prefill_roundtrip helper.
- test_prefill_multisampling_nnx, test_prefill_concat_nnx: NNX vs Linen
  result-shape parity, finite logits + cache.
- test_stack_prefill_result_cache_nnx (scan_layers=True) and
  test_stack_prefill_result_cache_scan_layers_false_nnx (scan_layers=False):
  prefill -> insert -> generate round-trip, layer-stacked leaves, finite
  logits, next_pos advances.

Remaining NNX MaxEngine carve-outs are quantization (PR9) and LoRA (PR8),
which are other PRs' scope.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-native-lora-grpo branch from a07256c to a17e792 Compare May 26, 2026 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant