@@ -137,6 +137,127 @@ def reshape_to_microbatch_accumulations(batch_arr):
137137 return loss , aux , raw_grads
138138
139139
140+ # ---------------------------------------------------------------------------
141+ # Gradient accumulation helper for NNX
142+ # ---------------------------------------------------------------------------
143+
144+
145+ def nnx_gradient_accumulation_loss_and_grad (_loss_fn , model , config , data , dropout_rng ):
146+ """
147+ Calculates gradients using gradient accumulation.
148+
149+ This function computes the gradient of `_loss_fn` over multiple microbatches
150+ and accumulates them before returning a single, averaged gradient. It uses
151+ `jax.lax.scan` for efficient accumulation on device.
152+
153+ It also supports a `shard_optimizer_over_data` mode (e.g., ZeRO-1) where
154+ parameters are cast to bf16 and sharded *before* the accumulation loop
155+ to perform the all-gather in lower precision.
156+
157+ Args:
158+ _loss_fn: The loss function to differentiate. Its signature is expected
159+ to be: `(model, config, data, dropout_rng, is_train=True)`.
160+ config: Model and training configuration object. Must contain
161+ `gradient_accumulation_steps` and `shard_optimizer_over_data`.
162+ model: The model module.
163+ data: A PyTree of batched data. The leading dimension is assumed
164+ to be the total batch size (microbatch_size * num_accumulations).
165+ dropout_rng: JAX PRNGKey for dropout.
166+ extra_dpo_args: A tuple of extra arguments to pass to the loss function.
167+
168+ Returns:
169+ A tuple containing:
170+ - total_loss (Array): The mean loss, averaged over all microbatches.
171+ - final_aux (PyTree): Auxiliary outputs, summed across microbatches.
172+ - raw_grads (PyTree): The accumulated and averaged gradients.
173+ """
174+
175+ # For more efficient DP/ZeRO-1 + GA
176+ # if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
177+ # ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
178+ # grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
179+ # else:
180+ # ga_params_shardings = grad_shardings = params_shardings
181+
182+ graphdef , params , rest = nnx .split (model , nnx .Param , ...)
183+
184+ # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints
185+ # so that all-gather is done once in the lower precision before the gradient accumulation loop
186+ if config .shard_optimizer_over_data :
187+
188+ def convert_to_bf16 (param ):
189+ if param .dtype == jnp .float32 :
190+ return param .astype (jnp .bfloat16 )
191+ return param
192+
193+ ga_params = jax .tree .map (convert_to_bf16 , params )
194+ else :
195+ ga_params = params
196+
197+ # ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings)
198+ grad_func = nnx .value_and_grad (_loss_fn , argnums = 0 , has_aux = True )
199+
200+ def accumulate_gradient (acc_grad_and_loss , data ):
201+ ga_params = acc_grad_and_loss ["ga_params" ]
202+ # Reconstruct the model using the fixed parameters (ga_params)
203+ # and the advancing non-parameter state (RNGs) from the carry.
204+
205+ # as ga_params will change during train_step, always create a local_model
206+ local_model = nnx .merge (graphdef , ga_params , acc_grad_and_loss ["rest_state" ])
207+ (_ , aux ), cur_batch_gradient = grad_func (local_model , config , data , dropout_rng , is_train = True )
208+ _ , _ , next_rest_state = nnx .split (local_model , nnx .Param , ...)
209+
210+ acc_grad_and_loss ["rest_state" ] = next_rest_state
211+ acc_grad_and_loss ["loss" ] += aux ["total_loss" ]
212+ acc_grad_and_loss ["moe_lb_loss" ] += aux ["moe_lb_loss" ]
213+ acc_grad_and_loss ["mtp_loss" ] += aux ["mtp_loss" ]
214+ acc_grad_and_loss ["grad" ] = jax .tree .map (lambda x , y : x + y , cur_batch_gradient , acc_grad_and_loss ["grad" ])
215+ acc_grad_and_loss ["total_weights" ] += aux ["total_weights" ]
216+ return acc_grad_and_loss , aux
217+
218+ def reshape_to_microbatch_accumulations (batch_arr ):
219+ """Reshape [B, ...] → [num_microbatches, B//num_microbatches, ...]."""
220+ num_microbatches = config .gradient_accumulation_steps
221+ microbatch_shape = (num_microbatches , batch_arr .shape [0 ] // num_microbatches ) + batch_arr .shape [1 :]
222+ return jnp .reshape (batch_arr , microbatch_shape )
223+
224+ # def reshape_to_microbatch_accumulations(batch_arr):
225+ # """Reshape global batch to microbatches, assuming batch axis is leading."""
226+ # num_microbatches = config.gradient_accumulation_steps
227+ # microbatch_shape = (batch_arr.shape[0] // num_microbatches, num_microbatches) + batch_arr.shape[1:]
228+ # reshaped_batch_arr = jnp.reshape(batch_arr, microbatch_shape)
229+ # return jnp.swapaxes(reshaped_batch_arr, 0, 1)
230+
231+ data = jax .tree .map (reshape_to_microbatch_accumulations , data )
232+ init_grad = jax .tree .map (jnp .zeros_like , ga_params )
233+ # init_grad = jax.tree.map(_maybe_shard_with_name, init_grad, grad_shardings)
234+ init_grad_and_loss = {
235+ "loss" : 0.0 ,
236+ "grad" : init_grad ,
237+ "total_weights" : 0 ,
238+ "moe_lb_loss" : 0.0 ,
239+ "mtp_loss" : 0.0 ,
240+ "ga_params" : ga_params ,
241+ }
242+ init_grad_and_loss ["rest_state" ] = rest
243+
244+ grad_and_loss , aux = jax .lax .scan (
245+ accumulate_gradient , init_grad_and_loss , data , length = config .gradient_accumulation_steps
246+ )
247+ loss = (
248+ grad_and_loss ["loss" ] / grad_and_loss ["total_weights" ]
249+ + grad_and_loss ["moe_lb_loss" ] / config .gradient_accumulation_steps
250+ + grad_and_loss ["mtp_loss" ] / config .gradient_accumulation_steps
251+ )
252+ raw_grads = grad_and_loss ["grad" ]
253+ raw_grads = jax .tree .map (lambda arr : arr / grad_and_loss ["total_weights" ], raw_grads )
254+ aux = jax .tree .map (lambda x : jnp .sum (x , axis = 0 ), aux ) # pytype: disable=module-attr
255+
256+ nnx .update (model , grad_and_loss ["rest_state" ])
257+
258+ return loss , aux , raw_grads
259+
260+
140261# GA helper functions
141262def update_sharding_for_reduced (sharding : NamedSharding ) -> NamedSharding :
142263 """
0 commit comments