2727from typing import Sequence
2828
2929from absl import app
30+ from flax import nnx
3031from flax .linen import partitioning as nn_partitioning
3132import jax
3233from jax .experimental .serialize_executable import serialize
3637from maxtext .configs import pyconfig
3738from maxtext .common .common_types import MODEL_MODE_TRAIN , ShardMode
3839from maxtext .layers import quantizations
40+ from maxtext .layers import train_state_nnx
3941from maxtext .models import models
4042from maxtext .optimizers import optimizers
4143from maxtext .trainers .diloco import diloco
4446from maxtext .utils import max_utils
4547from maxtext .utils import maxtext_utils
4648from maxtext .utils import sharding
49+ from maxtext .utils import maxtext_utils_nnx
50+ from maxtext .utils import model_creation_utils
4751
4852# pylint: disable=too-many-positional-arguments
4953
@@ -93,7 +97,10 @@ def get_shaped_inputs(topology_mesh, config):
9397 """Get shaped abstractions of inputs to train_step: state, batch and rng"""
9498 # Construct the model and optimizer to get shaped versions of the state
9599 quant = quantizations .configure_quantization (config )
96- model = Transformer (config , topology_mesh , quant = quant , model_mode = MODEL_MODE_TRAIN )
100+ if config .pure_nnx :
101+ _create_model_partial , model = model_creation_utils .create_nnx_abstract_model (config , topology_mesh )
102+ else :
103+ model = Transformer (config , topology_mesh , quant = quant , model_mode = MODEL_MODE_TRAIN )
97104 # The learning_rate_schedule is baked into the compiled object.
98105 learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (config )
99106 # pass in model for muon
@@ -103,18 +110,39 @@ def get_shaped_inputs(topology_mesh, config):
103110 _ , example_rng = jax .random .split (jax .random .PRNGKey (0 ), 2 )
104111 shaped_rng = jax .ShapeDtypeStruct (example_rng .shape , example_rng .dtype )
105112
106- # Shaped state
107- abstract_state , _ , state_mesh_shardings = maxtext_utils .get_abstract_state (
108- model , tx , config , example_rng , topology_mesh
109- )
113+ if config .pure_nnx :
114+
115+ def create_train_state_fn ():
116+ nnx_model = _create_model_partial ()
117+ optimizer = nnx .Optimizer (nnx_model , tx , wrt = nnx .Param )
118+ return train_state_nnx .TrainStateNNX (nnx_model , optimizer )
119+
120+ init_state_fn = create_train_state_fn
121+ else :
122+ init_state_fn = functools .partial (maxtext_utils .init_initial_state , model , tx , config , True , example_rng )
110123
111- # unsharded logical annotations
112- logical_annotations = maxtext_utils .get_logical_annotations (model , tx , config , example_rng , topology_mesh )
124+ # Shaped state
125+ abstract_state , _ , state_mesh_shardings = maxtext_utils .get_abstract_state (config , topology_mesh , init_state_fn , True )
126+
127+ if config .pure_nnx :
128+ # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings.
129+ logical_annotations = maxtext_utils_nnx .get_partition_spec_nnx (state_mesh_shardings )
130+ # For NNX, get_functional_train_with_signature expects the graphdef (static structure),
131+ # not the raw model — mirroring how the training loop does nnx.split(train_state).
132+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
133+ graphdef , _ = nnx .get_abstract_model (init_state_fn , topology_mesh )
134+ model = graphdef
135+ else :
136+ # unsharded logical annotations
137+ logical_annotations = maxtext_utils .get_logical_annotations (config , topology_mesh , init_state_fn )
113138
114139 # Shaped batch
115140 shaped_batch = maxtext_utils .get_shaped_batch (config )
116141
117- shaped_train_args = (abstract_state , shaped_batch , shaped_rng )
142+ if config .pure_nnx :
143+ shaped_train_args = (abstract_state , shaped_batch )
144+ else :
145+ shaped_train_args = (abstract_state , shaped_batch , shaped_rng )
118146 shaped_train_kwargs = {}
119147 return shaped_train_args , shaped_train_kwargs , state_mesh_shardings , logical_annotations , model
120148
@@ -277,12 +305,20 @@ def main(argv: Sequence[str]) -> None:
277305 # print weights sharding info under debug sharding mode
278306 if config .debug_sharding :
279307 max_utils .print_non_trivial_mesh_axis (topology_mesh )
280- maxtext_utils .print_shardings_params (
281- shaped_train_args [0 ].params ,
282- state_mesh_shardings .params ,
283- topology_mesh ,
284- logical_annotations .params ,
285- )
308+ if config .pure_nnx :
309+ maxtext_utils .print_shardings_params (
310+ shaped_train_args [0 ],
311+ state_mesh_shardings ,
312+ topology_mesh ,
313+ logical_annotations ,
314+ )
315+ else :
316+ maxtext_utils .print_shardings_params (
317+ shaped_train_args [0 ].params ,
318+ state_mesh_shardings .params ,
319+ topology_mesh ,
320+ logical_annotations .params ,
321+ )
286322
287323 # Compile
288324 print ("Jitting and compiling train step..." , flush = True )
0 commit comments