Skip to content

NNX Migration: all commits#3075

Draft
xibinliu wants to merge 13 commits intomainfrom
xibin/nnx_all
Draft

NNX Migration: all commits#3075
xibinliu wants to merge 13 commits intomainfrom
xibin/nnx_all

Conversation

@xibinliu
Copy link
Collaborator

@xibinliu xibinliu commented Feb 3, 2026

Description

Contains all the commits to make the NNX migrated MaxText work.

Tests

Current focus is on:

  • pre-training with nnx, with these flags: "pure_nnx_decoder=True enable_nnx=True pure_nnx=True fused_qkv=False"

Example:

python3  -m MaxText.train src/MaxText/configs/base.yml run_name="xibin-nnx" model_name="gpt3-52k"  dataset_type=synthetic steps=10 enable_checkpointing=False pure_nnx_decoder=True enable_nnx=True pure_nnx=True fused_qkv=False

More features and models will be tested and compared with Linen.
Test track doc

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@xibinliu xibinliu force-pushed the xibin/nnx_all branch 3 times, most recently from b7ffb2c to 00b198c Compare February 5, 2026 16:15
@xibinliu xibinliu force-pushed the xibin/nnx_all branch 2 times, most recently from 1b54aa5 to ce2571a Compare February 5, 2026 17:14
hsuan-lun-chiang and others added 12 commits February 6, 2026 17:58
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
A TrainState for NNX, which includes model and optimizer
Unit tests include checkpoint tests:
- restore a saved state
- convert linen TrainState to NNX TrainState
- Parameter only restore (no opt_state)
Also added unit tests.
Refactored the model_creation_utils to provide common def
create_nnx_abstract_model() func

b/src/maxtext/utils/model_creation_utils.py
1. A new func get_abstract_state_nnx() is added to maxtext_utils.py
The it will be called during training to create NNX training state.

Same as the linen version, it handles shard_optimizer_over_data,
optimizer_memory_host_offload, and parameter_memory_host_offload

Unit tests are added to this NNX func.

2. Add nnx train_state handling in train_utils.py

DPO handling will be supported (or removed) later in train_utils.py
Also added unit tests for NNX model.
NNX: loss fn and train_step, eval_step, train_loop
  - Convert nnx.State to pure dict for checkpoint saving
  - Restore pure dict back to nnx.State after loading
…store

Add a bidirectional Linen <-> NNX checkpoint converter tool that handles:
  - Auto-detection of checkpoint format
  - Conversion of params structure (double nesting vs flat)
  - Stacking/unstacking per-layer parameters
  - Value wrapper handling for NNX format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants