Skip to content

Commit 3fc5161

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX migration preparation: pure_nnx flag and init_state_fn
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models co-exist. - init_state_fn: a function to initialize the model state for the training. It will be set to different function for NNX and Linen.
1 parent 69c077d commit 3fc5161

File tree

21 files changed

+538
-128
lines changed

21 files changed

+538
-128
lines changed

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"""
3636

3737
import argparse
38+
import functools
3839
import gc
3940
import os
4041
import sys
@@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
8788
mesh = Mesh(devices_array, cfg.mesh_axes)
8889

8990
quant = quantizations.configure_quantization(cfg)
90-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
91+
if cfg.pure_nnx:
92+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
93+
else:
94+
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
9195
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
9296
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
9397

@@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
98102
cfg.checkpoint_period,
99103
)
100104

101-
state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
105+
if cfg.pure_nnx:
106+
# NNX has a different function to init the training state.
107+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
108+
else:
109+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
110+
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
102111
max_logging.log("start")
103112
max_utils.print_mem_stats("After params initialized")
104113

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,7 @@ subslice_shape: ""
10881088
# NNX
10891089
enable_nnx: True
10901090
pure_nnx_decoder: True
1091+
pure_nnx: True
10911092

10921093
################################## Qwen3-Next Specific Configs ##################################
10931094
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ class HardwareAndMesh(BaseModel):
786786
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
787787
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
788788
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
789+
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
789790

790791

791792
class LayoutAndSharding(BaseModel):

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -546,23 +546,43 @@ def setup_train_loop(
546546
max_logging.log("Training mesh used for the workload")
547547
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
548548
training_devices = jax.devices()[num_inference_devices:]
549-
model = mt.from_config(config, devices=training_devices)
549+
if config.pure_nnx:
550+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
551+
else:
552+
model = mt.from_config(config, devices=training_devices)
550553
mesh = model.mesh
551554
max_logging.log("Inference mesh used for the workload")
552555
inference_devices = jax.devices()[:num_inference_devices]
553-
inference_model = mt.from_config(config_inference, devices=inference_devices)
556+
if config_inference.pure_nnx:
557+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
558+
else:
559+
inference_model = mt.from_config(config_inference, devices=inference_devices)
554560
inference_mesh = inference_model.mesh
555-
init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh)
561+
init_rng = jax.random.PRNGKey(config.init_weights_seed)
562+
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
563+
if config.pure_nnx:
564+
# NNX has a different function to init the training state.
565+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
566+
else:
567+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
568+
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)
556569

557570
with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
558571
data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh)
559572
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
560-
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
573+
data_iterator, config, mesh, checkpoint_manager, init_state_fn
561574
)
562575

563576
# create inference_state_mesh_shardings from inference_mesh
577+
if config_inference.pure_nnx:
578+
# NNX has a different function to init the training state.
579+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
580+
else:
581+
init_inference_state_fn = functools.partial(
582+
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
583+
)
564584
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
565-
inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False
585+
config_inference, inference_mesh, init_inference_state_fn, is_training=False
566586
)[2]
567587
if not config.using_pipeline_parallelism:
568588
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
@@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None):
697717
data_buffer = []
698718
data_buffer_lock = threading.Lock()
699719

700-
start_step = get_first_step(state) # this is the start_step for training
720+
start_step = get_first_step(model, state) # this is the start_step for training
701721
prof = profiler.Profiler(config, offset_step=start_step)
702722
inference_prof = profiler.Profiler(config_inference, offset_step=start_step)
703723
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None):
113113

114114
# Model and Optimizer definition
115115
quant = quantizations.configure_quantization(config)
116-
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
116+
if config.pure_nnx:
117+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
118+
else:
119+
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
117120
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))
118121

119122
self.abstract_params = None
@@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
229232
rng1, rng2, rng3 = jax.random.split(rng, 3)
230233
if params:
231234
print("Resharding given params")
235+
if self.config.pure_nnx:
236+
# NNX has a different function to init the training state.
237+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
238+
else:
239+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
232240
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
233-
self.model, None, self.config, rng, self._mesh, False
241+
self.config, self._mesh, init_state_fn, False
234242
)
235243
# reshard given params based on shardings from config in MaxEngine
236244
params = jax.device_put(params, state_mesh_shardings.params)
237245
state = maxtext_utils.init_decode_state(None, params)
238246
state = max_utils.unbox_logicallypartioned(state)
239247
else:
240-
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(
241-
self.model, self.config, rng1, self._mesh, None
242-
)
248+
if self.config.pure_nnx:
249+
# NNX has a different function to init the training state.
250+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
251+
else:
252+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
253+
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
243254
# pylint: disable=isinstance-second-argument-not-valid-type
244255
self.abstract_params = jax.tree_util.tree_map(
245256
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from typing import Sequence
2828

2929
from absl import app
30+
from flax import nnx
3031
from flax.linen import partitioning as nn_partitioning
3132
import jax
3233
from jax.experimental.serialize_executable import serialize
@@ -36,6 +37,7 @@
3637
from maxtext.configs import pyconfig
3738
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
3839
from maxtext.layers import quantizations
40+
from maxtext.layers import train_state_nnx
3941
from maxtext.models import models
4042
from maxtext.optimizers import optimizers
4143
from maxtext.trainers.diloco import diloco
@@ -44,6 +46,8 @@
4446
from maxtext.utils import max_utils
4547
from maxtext.utils import maxtext_utils
4648
from maxtext.utils import sharding
49+
from maxtext.utils import maxtext_utils_nnx
50+
from maxtext.utils import model_creation_utils
4751

4852
# pylint: disable=too-many-positional-arguments
4953

@@ -93,7 +97,10 @@ def get_shaped_inputs(topology_mesh, config):
9397
"""Get shaped abstractions of inputs to train_step: state, batch and rng"""
9498
# Construct the model and optimizer to get shaped versions of the state
9599
quant = quantizations.configure_quantization(config)
96-
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
100+
if config.pure_nnx:
101+
_create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh)
102+
else:
103+
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
97104
# The learning_rate_schedule is baked into the compiled object.
98105
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
99106
# pass in model for muon
@@ -103,18 +110,39 @@ def get_shaped_inputs(topology_mesh, config):
103110
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)
104111
shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype)
105112

106-
# Shaped state
107-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
108-
model, tx, config, example_rng, topology_mesh
109-
)
113+
if config.pure_nnx:
114+
115+
def create_train_state_fn():
116+
nnx_model = _create_model_partial()
117+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
118+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
119+
120+
init_state_fn = create_train_state_fn
121+
else:
122+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng)
110123

111-
# unsharded logical annotations
112-
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh)
124+
# Shaped state
125+
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True)
126+
127+
if config.pure_nnx:
128+
# NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings.
129+
logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings)
130+
# For NNX, get_functional_train_with_signature expects the graphdef (static structure),
131+
# not the raw model — mirroring how the training loop does nnx.split(train_state).
132+
with nn_partitioning.axis_rules(config.logical_axis_rules):
133+
graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh)
134+
model = graphdef
135+
else:
136+
# unsharded logical annotations
137+
logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)
113138

114139
# Shaped batch
115140
shaped_batch = maxtext_utils.get_shaped_batch(config)
116141

117-
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
142+
if config.pure_nnx:
143+
shaped_train_args = (abstract_state, shaped_batch)
144+
else:
145+
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
118146
shaped_train_kwargs = {}
119147
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model
120148

@@ -277,12 +305,20 @@ def main(argv: Sequence[str]) -> None:
277305
# print weights sharding info under debug sharding mode
278306
if config.debug_sharding:
279307
max_utils.print_non_trivial_mesh_axis(topology_mesh)
280-
maxtext_utils.print_shardings_params(
281-
shaped_train_args[0].params,
282-
state_mesh_shardings.params,
283-
topology_mesh,
284-
logical_annotations.params,
285-
)
308+
if config.pure_nnx:
309+
maxtext_utils.print_shardings_params(
310+
shaped_train_args[0],
311+
state_mesh_shardings,
312+
topology_mesh,
313+
logical_annotations,
314+
)
315+
else:
316+
maxtext_utils.print_shardings_params(
317+
shaped_train_args[0].params,
318+
state_mesh_shardings.params,
319+
topology_mesh,
320+
logical_annotations.params,
321+
)
286322

287323
# Compile
288324
print("Jitting and compiling train step...", flush=True)

src/maxtext/utils/generate_param_only_checkpoint.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16.
2323
"""
2424

25+
import functools
2526
import os.path
2627
from typing import Sequence
2728

@@ -42,8 +43,6 @@
4243
from maxtext.utils import max_utils
4344
from maxtext.utils import maxtext_utils
4445

45-
Transformer = models.transformer_as_linen
46-
4746

4847
def _possibly_unroll_params(config, training_state, training_state_annotations, mesh):
4948
"""Unroll scanned input layers when force_unroll is set."""
@@ -93,12 +92,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
9392
"""Read training checkpoint at path defined by load_full_state_path."""
9493
# Model and Optimizer definition
9594
quant = quantizations.configure_quantization(config)
96-
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
95+
if config.pure_nnx:
96+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
97+
else:
98+
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
9799
rng = random.PRNGKey(0)
98100
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
99101
tx = optimizers.get_optimizer(config, learning_rate_schedule)
102+
if config.pure_nnx:
103+
# NNX has a different function to init the training state.
104+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
105+
else:
106+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng)
100107
state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state(
101-
model, None, tx, config, rng, mesh, checkpoint_manager
108+
None, config, mesh, checkpoint_manager, init_state_fn
102109
)
103110
num_params = max_utils.calculate_num_params_from_pytree(state.params)
104111
max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion")
@@ -109,7 +116,10 @@ def _generate_lora_decode_checkpoints(config, mesh):
109116
"""Read lora checkpoints checkpoint at path defined by load_full_state_path."""
110117
# Model and Optimizer definition
111118
quant = quantizations.configure_quantization(config)
112-
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
119+
if config.pure_nnx:
120+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
121+
else:
122+
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
113123
rng = random.PRNGKey(0)
114124
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
115125
tx = optimizers.get_optimizer(config, learning_rate_schedule)

src/maxtext/utils/layerwise_quantization.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
3131
"""
3232

33+
import functools
3334
import os
3435
from typing import Any, Sequence
3536

@@ -174,12 +175,19 @@ def __init__(self, config: Any, rng: PRNGKeyType):
174175

175176
# Model and quantization config
176177
self.quant = quantizations.configure_quantization(config)
177-
model = models.transformer_as_linen(
178-
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
179-
)
180-
self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(
181-
model, None, self.config, self.rng, self._mesh, False
182-
)
178+
if self.config.pure_nnx:
179+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
180+
else:
181+
model = models.transformer_as_linen(
182+
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
183+
)
184+
if self.config.pure_nnx:
185+
# NNX has a different function to init the training state.
186+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
187+
else:
188+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng)
189+
190+
self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False)
183191

184192
def load_and_quantize(self) -> None:
185193
"""

src/maxtext/utils/lora_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
""" Common LoRA utils needed to support LoRA adapters."""
1616

17+
from functools import partial
1718
import json
1819

1920
import jax
@@ -166,7 +167,12 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp
166167

167168
if lora_adapter_path:
168169
max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}")
169-
unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, rng, mesh, True)
170+
if config.pure_nnx:
171+
# NNX has a different function to init the training state.
172+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
173+
else:
174+
init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng)
175+
unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True)
170176

171177
lora_config_path = lora_adapter_path + "adapter_config.json"
172178

0 commit comments

Comments
 (0)