Skip to content

Commit 94324cd

Browse files
Charles LiCharles
authored andcommitted
Support gradient_accumulation
1 parent e3d4720 commit 94324cd

2 files changed

Lines changed: 129 additions & 2 deletions

File tree

src/maxtext/trainers/pre_train/nnx_train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from maxtext.optimizers import optimizers
102102
from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding
103103
from maxtext.utils.globals import EPS
104+
from maxtext.utils.gradient_accumulation import nnx_gradient_accumulation_loss_and_grad
104105
from maxtext.utils.rampup_batch import create_rampup_manager
105106

106107
_diag_modules = _cloud_diag()
@@ -287,8 +288,11 @@ def train_step(
287288
# Compute loss and gradients w.r.t. model parameters.
288289
# nnx.value_and_grad differentiates only through nnx.Param variables,
289290
# keeping non-differentiable state (RNGs, cache, etc.) frozen.
290-
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
291-
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
291+
if config.gradient_accumulation_steps > 1:
292+
loss, aux, raw_grads = nnx_gradient_accumulation_loss_and_grad(loss_fn, model, config, data, dropout_rng)
293+
else:
294+
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
295+
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
292296

293297
# Cast gradients to configured dtype before clipping / accumulation
294298
raw_grads = jax.tree.map(
@@ -612,6 +616,7 @@ def train_loop(config, recorder, state=None):
612616
if config.compiled_trainstep_file == "":
613617
compiled = p_train_step.lower(model_state, opt_state, shaped_batch, example_rng).compile()
614618
compiled_stats = compiled.memory_analysis()
619+
max_logging.info(f"print_compiled_memory_stats:")
615620
max_utils.print_compiled_memory_stats(compiled_stats)
616621

617622
# ---- Profiler / logger ----------------------------------------------------
@@ -625,6 +630,7 @@ def train_loop(config, recorder, state=None):
625630
_job_completed_gracefully = False
626631
try:
627632
last_step_completion = datetime.datetime.now()
633+
max_logging.info(f"Entering train loop from start_step={start_step}")
628634

629635
for step in np.arange(start_step, config.steps):
630636
prof.maybe_activate_profiler(step, opt_state)

src/maxtext/utils/gradient_accumulation.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,127 @@ def reshape_to_microbatch_accumulations(batch_arr):
137137
return loss, aux, raw_grads
138138

139139

140+
# ---------------------------------------------------------------------------
141+
# Gradient accumulation helper for NNX
142+
# ---------------------------------------------------------------------------
143+
144+
145+
def nnx_gradient_accumulation_loss_and_grad(_loss_fn, model, config, data, dropout_rng):
146+
"""
147+
Calculates gradients using gradient accumulation.
148+
149+
This function computes the gradient of `_loss_fn` over multiple microbatches
150+
and accumulates them before returning a single, averaged gradient. It uses
151+
`jax.lax.scan` for efficient accumulation on device.
152+
153+
It also supports a `shard_optimizer_over_data` mode (e.g., ZeRO-1) where
154+
parameters are cast to bf16 and sharded *before* the accumulation loop
155+
to perform the all-gather in lower precision.
156+
157+
Args:
158+
_loss_fn: The loss function to differentiate. Its signature is expected
159+
to be: `(model, config, data, dropout_rng, is_train=True)`.
160+
config: Model and training configuration object. Must contain
161+
`gradient_accumulation_steps` and `shard_optimizer_over_data`.
162+
model: The model module.
163+
data: A PyTree of batched data. The leading dimension is assumed
164+
to be the total batch size (microbatch_size * num_accumulations).
165+
dropout_rng: JAX PRNGKey for dropout.
166+
extra_dpo_args: A tuple of extra arguments to pass to the loss function.
167+
168+
Returns:
169+
A tuple containing:
170+
- total_loss (Array): The mean loss, averaged over all microbatches.
171+
- final_aux (PyTree): Auxiliary outputs, summed across microbatches.
172+
- raw_grads (PyTree): The accumulated and averaged gradients.
173+
"""
174+
175+
# For more efficient DP/ZeRO-1 + GA
176+
# if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
177+
# ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
178+
# grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
179+
# else:
180+
# ga_params_shardings = grad_shardings = params_shardings
181+
182+
graphdef, params, rest = nnx.split(model, nnx.Param, ...)
183+
184+
# When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints
185+
# so that all-gather is done once in the lower precision before the gradient accumulation loop
186+
if config.shard_optimizer_over_data:
187+
188+
def convert_to_bf16(param):
189+
if param.dtype == jnp.float32:
190+
return param.astype(jnp.bfloat16)
191+
return param
192+
193+
ga_params = jax.tree.map(convert_to_bf16, params)
194+
else:
195+
ga_params = params
196+
197+
# ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings)
198+
grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True)
199+
200+
def accumulate_gradient(acc_grad_and_loss, data):
201+
ga_params = acc_grad_and_loss["ga_params"]
202+
# Reconstruct the model using the fixed parameters (ga_params)
203+
# and the advancing non-parameter state (RNGs) from the carry.
204+
205+
# as ga_params will change during train_step, always create a local_model
206+
local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"])
207+
(_, aux), cur_batch_gradient = grad_func(local_model, config, data, dropout_rng, is_train=True)
208+
_, _, next_rest_state = nnx.split(local_model, nnx.Param, ...)
209+
210+
acc_grad_and_loss["rest_state"] = next_rest_state
211+
acc_grad_and_loss["loss"] += aux["total_loss"]
212+
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
213+
acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"]
214+
acc_grad_and_loss["grad"] = jax.tree.map(lambda x, y: x + y, cur_batch_gradient, acc_grad_and_loss["grad"])
215+
acc_grad_and_loss["total_weights"] += aux["total_weights"]
216+
return acc_grad_and_loss, aux
217+
218+
def reshape_to_microbatch_accumulations(batch_arr):
219+
"""Reshape [B, ...] → [num_microbatches, B//num_microbatches, ...]."""
220+
num_microbatches = config.gradient_accumulation_steps
221+
microbatch_shape = (num_microbatches, batch_arr.shape[0] // num_microbatches) + batch_arr.shape[1:]
222+
return jnp.reshape(batch_arr, microbatch_shape)
223+
224+
# def reshape_to_microbatch_accumulations(batch_arr):
225+
# """Reshape global batch to microbatches, assuming batch axis is leading."""
226+
# num_microbatches = config.gradient_accumulation_steps
227+
# microbatch_shape = (batch_arr.shape[0] // num_microbatches, num_microbatches) + batch_arr.shape[1:]
228+
# reshaped_batch_arr = jnp.reshape(batch_arr, microbatch_shape)
229+
# return jnp.swapaxes(reshaped_batch_arr, 0, 1)
230+
231+
data = jax.tree.map(reshape_to_microbatch_accumulations, data)
232+
init_grad = jax.tree.map(jnp.zeros_like, ga_params)
233+
# init_grad = jax.tree.map(_maybe_shard_with_name, init_grad, grad_shardings)
234+
init_grad_and_loss = {
235+
"loss": 0.0,
236+
"grad": init_grad,
237+
"total_weights": 0,
238+
"moe_lb_loss": 0.0,
239+
"mtp_loss": 0.0,
240+
"ga_params": ga_params,
241+
}
242+
init_grad_and_loss["rest_state"] = rest
243+
244+
grad_and_loss, aux = jax.lax.scan(
245+
accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps
246+
)
247+
loss = (
248+
grad_and_loss["loss"] / grad_and_loss["total_weights"]
249+
+ grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps
250+
+ grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps
251+
)
252+
raw_grads = grad_and_loss["grad"]
253+
raw_grads = jax.tree.map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads)
254+
aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr
255+
256+
nnx.update(model, grad_and_loss["rest_state"])
257+
258+
return loss, aux, raw_grads
259+
260+
140261
# GA helper functions
141262
def update_sharding_for_reduced(sharding: NamedSharding) -> NamedSharding:
142263
"""

0 commit comments

Comments
 (0)