From 36029ceb91217087f01a8ade67f4d935e606ab8a Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Mon, 22 Dec 2025 02:20:47 +0000 Subject: [PATCH] add sharding debug feature --- src/MaxText/train_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 16c5d6d70b..cef39a5a3a 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -109,11 +109,6 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr donate_argnums=donate_argnums, ) - # print weights sharding info under debug sharding mode - if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh) - return p_train_step @@ -219,6 +214,11 @@ def setup_train_loop(config, recorder, devices=None): # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + # print weights sharding info under debug sharding mode + if config.debug_sharding: + max_utils.print_non_trivial_mesh_axis(model.mesh) + maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh) + if config.use_dpo: abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) max_logging.log(