From 7d78580971d2232e6d36d3078cbc90bab1c4c9ae Mon Sep 17 00:00:00 2001 From: Branden Vandermoon Date: Sat, 21 Feb 2026 00:42:57 +0000 Subject: [PATCH] PR #3204: Migrate MaxText/train_compile.py to maxtext/trainers/pre_train/train_compile.py Imported from GitHub PR https://github.com/AI-Hypercomputer/maxtext/pull/3204 # Description * Move `MaxText/train_compile.py` to `maxtext/trainers/pre_train/train_compile.py` * Create shim in MaxText/train.py to support old command. Include a deprecation warning * TODO: Add deprecation dates to this and other shims we are adding * Next steps: Update [existing `EXECUTABLE` commands](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/tpu/v4/22b.sh#L58) to call new path for train/train_compile. Leaving this off for now to keep PR smaller. The old commands still work for these NOTE: We will fix the ungrouped import lint errors at the end of restructuring. Otherwise we will just be swapping the order of imports repeatedly # Tests Both new and old commands working as expected. New command: ``` python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml \ compile_topology=v5e-256 \ compile_topology_num_slices=2 \ global_parameter_scale=16 \ per_device_batch_size=4 ``` Old command: ``` python3 -m MaxText.train_compile src/maxtext/configs/base.yml \ compile_topology=v5e-256 \ compile_topology_num_slices=2 \ global_parameter_scale=16 \ per_device_batch_size=4 ``` # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- 9dc270447a04df72b898edf5fc5494eece702897 by Branden Vandermoon : Migrate MaxText/train_compile.py to maxtext/trainers/pre_train/train_compile.py Merging this change closes #3204 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/AI-Hypercomputer/maxtext/pull/3204 from AI-Hypercomputer:bvandermoon-repo-restructure 9dc270447a04df72b898edf5fc5494eece702897 PiperOrigin-RevId: 873125760 --- .../features_and_diagnostics.md | 6 +- src/MaxText/estimator.py | 2 +- src/MaxText/train_compile.py | 307 +---------------- .../trainers/pre_train/train_compile.py | 311 ++++++++++++++++++ tests/end_to_end/tpu/gemma/2b/test_gemma.sh | 2 +- tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh | 2 +- tests/integration/aot_identical_test.py | 2 +- tests/integration/xaot_test.py | 2 +- tests/unit/diloco_test.py | 2 +- tests/unit/sharding_compare_test.py | 2 +- tests/unit/train_compile_test.py | 2 +- tests/utils/sharding_dump.py | 2 +- 12 files changed, 339 insertions(+), 303 deletions(-) create mode 100644 src/maxtext/trainers/pre_train/train_compile.py diff --git a/docs/guides/monitoring_and_debugging/features_and_diagnostics.md b/docs/guides/monitoring_and_debugging/features_and_diagnostics.md index ddf1b570c5..586eb06efa 100644 --- a/docs/guides/monitoring_and_debugging/features_and_diagnostics.md +++ b/docs/guides/monitoring_and_debugging/features_and_diagnostics.md @@ -56,7 +56,7 @@ After installing the dependencies listed above, you are ready to compile ahead o ```sh # Run the below on a single machine, e.g. a CPU -python3 MaxText.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \ +python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \ global_parameter_scale=16 per_device_batch_size=4 ``` @@ -71,7 +71,7 @@ Here is an example that saves then loads the compiled `train_step`, starting wit ```sh # Run the below on a single machine, e.g. a CPU export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true" -python3 -m MaxText.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 \ +python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 \ compile_topology_num_slices=2 \ compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16 \ per_device_batch_size=4 steps=10000 learning_rate=1e-3 @@ -109,7 +109,7 @@ This example illustrates the flags to use for a multihost GPU compilation target ```sh # Run the below on a single A3 machine export XLA_FLAGS="--xla_gpu_enable_async_collectives=true" -python3 -m MaxText.train_compile src/maxtext/configs/base.yml compile_topology=a3 \ +python3 -m maxtext.trainers.pre_train.train_compile src/maxtext/configs/base.yml compile_topology=a3 \ compile_topology_num_slices=4 \ compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16 \ attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3 diff --git a/src/MaxText/estimator.py b/src/MaxText/estimator.py index 36ac174ef9..ac3e1ad002 100644 --- a/src/MaxText/estimator.py +++ b/src/MaxText/estimator.py @@ -40,7 +40,7 @@ import jax from MaxText import pyconfig -from MaxText import train_compile +from maxtext.trainers.pre_train import train_compile def generate_priority_list(config, provided_tensor_names): diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index c39f671dd8..4247f401b1 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,300 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step -Generates shaped versions of state and data without ever constructing them, so its possible -to compile with target hardware (e.g. hundreds/thousands of chips), without using the hardware. -This helpfully detects if your configuration would run into memory problems (OOM) on the target hardware, -before having to use the target hardware - you will see the same OOM error message during this compilation -as you would on the target hardware. -""" +"""Shim for pre-training trainers in `src/maxtext/trainers/pre_train`.""" -import functools -import os -import pickle -from typing import Sequence +import sys +import importlib -from absl import app -from flax.linen import partitioning as nn_partitioning -import jax -from jax.experimental.serialize_executable import serialize -from jax.experimental.topologies import get_topology_desc -from jax.sharding import AxisType, Mesh -from MaxText import accelerator_to_spec_map -from MaxText import optimizers -from MaxText import pyconfig -from MaxText import sharding -from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode -from maxtext.layers import quantizations -from maxtext.models import models -from maxtext.trainers.diloco import diloco -from maxtext.trainers.pre_train import train -from maxtext.utils import gcs_utils -from maxtext.utils import max_utils -from maxtext.utils import maxtext_utils +from absl import logging -# pylint: disable=too-many-positional-arguments +from maxtext.utils import max_logging -Transformer = models.transformer_as_linen - - -def validate_config(config): - """Validates the config is is setup correctly to compile, returning a useful error message if not.""" - assert ( - config.compile_topology != "" - ), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" - assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" - - -def get_topology_mesh(config): - """Get the target hardware devices, and create configured mesh with them""" - target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) - if target_hardware.platform == "gpu": - # Disable sharded autotuning. This is an optimization to distribute - # autotuning across the fleet, but can cause hangs with AoT compilation. - os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" - jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) - topology_devices = jax.devices() - else: - topology_devices = get_topology_desc( - platform=target_hardware.platform, - topology_name=target_hardware.topology_name, - chip_config_name=target_hardware.chip_config_name, - chips_per_host_bounds=target_hardware.chips_per_host_bounds, - num_slices=config.compile_topology_num_slices, - wrap=target_hardware.wrap, - ).devices - if config.shard_mode == ShardMode.EXPLICIT: - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) - topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) - mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto - topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes)) - return topology_mesh - - -def get_shaped_inputs(topology_mesh, config): - """Get shaped abstractions of inputs to train_step: state, batch and rng""" - # Construct the model and optimizer to get shaped versions of the state - quant = quantizations.configure_quantization(config) - model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - # The learning_rate_schedule is baked into the compiled object. - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) - # pass in model for muon - tx = optimizers.get_optimizer(config, learning_rate_schedule, model) - - # Shaped RNG keys - _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) - - # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, example_rng, topology_mesh - ) - - # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) - - # Shaped batch - shaped_batch = maxtext_utils.get_shaped_batch(config) - - shaped_train_args = (abstract_state, shaped_batch, shaped_rng) - shaped_train_kwargs = {} - return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model - - -def jit_and_compile( - func, - func_input_args, - func_input_kwargs, - mesh, - in_shardings, - out_shardings, - static_argnums, - donate_argnums, - config, - logical_axis_rules, -): - """Jit, lower, and compile func.""" - with jax.set_mesh(mesh), logical_axis_rules: - jitted = jax.jit( - func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args) - lowered = jitted.lower(*func_input_args, **func_input_kwargs) - compiled = lowered.compile() - return compiled - - -def save_compiled(compiled, save_name): - """Serialize and save the compiled function.""" - serialized, _, _ = serialize(compiled) - with open(save_name, "wb") as f: - pickle.dump(serialized, f) - - -def is_oom(argv: Sequence[str]) -> bool: - """Function returns a boolean indicating whether OOM happens""" - # Parse and validate configuration - config = pyconfig.initialize(argv) - validate_config(config) - - # Create target mesh - topology_mesh = get_topology_mesh(config) - - # Print system information after building the compile topology to avoid - # prematurely initializing the backend. - max_utils.print_system_information() - - # Get shaped inputs - ( - shaped_train_args, - shaped_train_kwargs, - state_mesh_shardings, - _, - model, - ) = get_shaped_inputs(topology_mesh, config) - - # Get data sharding - data_sharding = sharding.get_input_data_sharding(config, topology_mesh) - - # Get function to compile and shardings - func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( - maxtext_utils.get_functional_train_with_signature( - train.train_step, data_sharding, state_mesh_shardings, model, config - ) - ) +OLD_MODULE_PATH = "MaxText.train_comile" +NEW_MODULE_PATH = "maxtext.trainers.pre_train.train_compile" +if __name__ == "__main__": try: - _ = jit_and_compile( - func_to_compile, - shaped_train_args, - shaped_train_kwargs, - topology_mesh, - in_shard, - out_shard, - static_argnums, - donate_argnums, - config, - nn_partitioning.axis_rules(config.logical_axis_rules), - ) - return False - except Exception as e: - # return true if OOM error happens - # OOM error looks like - # jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ... - # jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ... - message = str(e).lower() - if "resource_exhausted" in message or "hbm" in message: - return True + logging.set_verbosity(logging.INFO) + _new_module = importlib.import_module(NEW_MODULE_PATH) + if hasattr(_new_module, "main"): + max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n") + _new_module.main(sys.argv) + except ImportError as e: + max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n") raise e - - -def main(argv: Sequence[str]) -> None: - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" - ) - print("Starting train_compile.py...", flush=True) - - # Parse and validate configuration - config = pyconfig.initialize(argv) - validate_config(config) - - # Create target mesh - topology_mesh = get_topology_mesh(config) - - # Print system information after building the compile topology to avoid - # prematurely initializing the backend. - max_utils.print_system_information() - - # Get shaped inputs - ( - shaped_train_args, - shaped_train_kwargs, - state_mesh_shardings, - logical_annotations, - model, - ) = get_shaped_inputs(topology_mesh, config) - - # Get data sharding - data_sharding = sharding.get_input_data_sharding(config, topology_mesh) - if config.enable_diloco: - # Build abstract DiLoCo state and shardings for AOT compilation - abstract_state = shaped_train_args[0] - diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( - config, abstract_state, state_mesh_shardings, topology_mesh - ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) - - # Wrap train_step with diloco - train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) - train_step_fn = diloco.build_diloco_train_step(config, train_step_partial) - - # For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng) - func_to_compile = train_step_fn - func_to_compile.__name__ = "train_step" - in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shard = (state_mesh_shardings, None) # State, metrics - static_argnums = () - donate_argnums = 0 - else: - # Get function to compile and shardings - func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( - maxtext_utils.get_functional_train_with_signature( - train.train_step, data_sharding, state_mesh_shardings, model, config - ) - ) - - # print weights sharding info under debug sharding mode - if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(topology_mesh) - maxtext_utils.print_shardings_params( - shaped_train_args[0].params, - state_mesh_shardings.params, - topology_mesh, - logical_annotations.params, - ) - - # Compile - print("Jitting and compiling train step...", flush=True) - compiled = jit_and_compile( - func_to_compile, - shaped_train_args, - shaped_train_kwargs, - topology_mesh, - in_shard, - out_shard, - static_argnums, - donate_argnums, - config, - nn_partitioning.axis_rules(config.logical_axis_rules), - ) - print("Jitting and compilation complete!", flush=True) - - # Serialize and save the compiled object - if config.compiled_trainstep_file != "": - print("Saving compiled object...") - save_compiled(compiled, config.compiled_trainstep_file) - print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") - print("Finished train_compile.py successfully!", flush=True) - print(f"Cost analysis: {compiled.cost_analysis()}") - print(f"Memory analysis: {compiled.memory_analysis()}") - - # Dump HLO if requested - if config.dump_hlo: - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py new file mode 100644 index 0000000000..c39f671dd8 --- /dev/null +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -0,0 +1,311 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step +Generates shaped versions of state and data without ever constructing them, so its possible +to compile with target hardware (e.g. hundreds/thousands of chips), without using the hardware. +This helpfully detects if your configuration would run into memory problems (OOM) on the target hardware, +before having to use the target hardware - you will see the same OOM error message during this compilation +as you would on the target hardware. +""" + +import functools +import os +import pickle +from typing import Sequence + +from absl import app +from flax.linen import partitioning as nn_partitioning +import jax +from jax.experimental.serialize_executable import serialize +from jax.experimental.topologies import get_topology_desc +from jax.sharding import AxisType, Mesh +from MaxText import accelerator_to_spec_map +from MaxText import optimizers +from MaxText import pyconfig +from MaxText import sharding +from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.layers import quantizations +from maxtext.models import models +from maxtext.trainers.diloco import diloco +from maxtext.trainers.pre_train import train +from maxtext.utils import gcs_utils +from maxtext.utils import max_utils +from maxtext.utils import maxtext_utils + +# pylint: disable=too-many-positional-arguments + +Transformer = models.transformer_as_linen + + +def validate_config(config): + """Validates the config is is setup correctly to compile, returning a useful error message if not.""" + assert ( + config.compile_topology != "" + ), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" + assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" + + +def get_topology_mesh(config): + """Get the target hardware devices, and create configured mesh with them""" + target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) + if target_hardware.platform == "gpu": + # Disable sharded autotuning. This is an optimization to distribute + # autotuning across the fleet, but can cause hangs with AoT compilation. + os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" + jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) + topology_devices = jax.devices() + else: + topology_devices = get_topology_desc( + platform=target_hardware.platform, + topology_name=target_hardware.topology_name, + chip_config_name=target_hardware.chip_config_name, + chips_per_host_bounds=target_hardware.chips_per_host_bounds, + num_slices=config.compile_topology_num_slices, + wrap=target_hardware.wrap, + ).devices + if config.shard_mode == ShardMode.EXPLICIT: + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) + mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto + topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes)) + return topology_mesh + + +def get_shaped_inputs(topology_mesh, config): + """Get shaped abstractions of inputs to train_step: state, batch and rng""" + # Construct the model and optimizer to get shaped versions of the state + quant = quantizations.configure_quantization(config) + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + # The learning_rate_schedule is baked into the compiled object. + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + # pass in model for muon + tx = optimizers.get_optimizer(config, learning_rate_schedule, model) + + # Shaped RNG keys + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) + + # Shaped state + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( + model, tx, config, example_rng, topology_mesh + ) + + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) + + # Shaped batch + shaped_batch = maxtext_utils.get_shaped_batch(config) + + shaped_train_args = (abstract_state, shaped_batch, shaped_rng) + shaped_train_kwargs = {} + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model + + +def jit_and_compile( + func, + func_input_args, + func_input_kwargs, + mesh, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + config, + logical_axis_rules, +): + """Jit, lower, and compile func.""" + with jax.set_mesh(mesh), logical_axis_rules: + jitted = jax.jit( + func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + ) + maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args) + lowered = jitted.lower(*func_input_args, **func_input_kwargs) + compiled = lowered.compile() + return compiled + + +def save_compiled(compiled, save_name): + """Serialize and save the compiled function.""" + serialized, _, _ = serialize(compiled) + with open(save_name, "wb") as f: + pickle.dump(serialized, f) + + +def is_oom(argv: Sequence[str]) -> bool: + """Function returns a boolean indicating whether OOM happens""" + # Parse and validate configuration + config = pyconfig.initialize(argv) + validate_config(config) + + # Create target mesh + topology_mesh = get_topology_mesh(config) + + # Print system information after building the compile topology to avoid + # prematurely initializing the backend. + max_utils.print_system_information() + + # Get shaped inputs + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_shardings, + _, + model, + ) = get_shaped_inputs(topology_mesh, config) + + # Get data sharding + data_sharding = sharding.get_input_data_sharding(config, topology_mesh) + + # Get function to compile and shardings + func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( + maxtext_utils.get_functional_train_with_signature( + train.train_step, data_sharding, state_mesh_shardings, model, config + ) + ) + + try: + _ = jit_and_compile( + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + config, + nn_partitioning.axis_rules(config.logical_axis_rules), + ) + return False + except Exception as e: + # return true if OOM error happens + # OOM error looks like + # jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ... + # jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ... + message = str(e).lower() + if "resource_exhausted" in message or "hbm" in message: + return True + raise e + + +def main(argv: Sequence[str]) -> None: + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + print("Starting train_compile.py...", flush=True) + + # Parse and validate configuration + config = pyconfig.initialize(argv) + validate_config(config) + + # Create target mesh + topology_mesh = get_topology_mesh(config) + + # Print system information after building the compile topology to avoid + # prematurely initializing the backend. + max_utils.print_system_information() + + # Get shaped inputs + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_shardings, + logical_annotations, + model, + ) = get_shaped_inputs(topology_mesh, config) + + # Get data sharding + data_sharding = sharding.get_input_data_sharding(config, topology_mesh) + if config.enable_diloco: + # Build abstract DiLoCo state and shardings for AOT compilation + abstract_state = shaped_train_args[0] + diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( + config, abstract_state, state_mesh_shardings, topology_mesh + ) + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + + # Wrap train_step with diloco + train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) + train_step_fn = diloco.build_diloco_train_step(config, train_step_partial) + + # For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng) + func_to_compile = train_step_fn + func_to_compile.__name__ = "train_step" + in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shard = (state_mesh_shardings, None) # State, metrics + static_argnums = () + donate_argnums = 0 + else: + # Get function to compile and shardings + func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( + maxtext_utils.get_functional_train_with_signature( + train.train_step, data_sharding, state_mesh_shardings, model, config + ) + ) + + # print weights sharding info under debug sharding mode + if config.debug_sharding: + max_utils.print_non_trivial_mesh_axis(topology_mesh) + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + logical_annotations.params, + ) + + # Compile + print("Jitting and compiling train step...", flush=True) + compiled = jit_and_compile( + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + config, + nn_partitioning.axis_rules(config.logical_axis_rules), + ) + print("Jitting and compilation complete!", flush=True) + + # Serialize and save the compiled object + if config.compiled_trainstep_file != "": + print("Saving compiled object...") + save_compiled(compiled, config.compiled_trainstep_file) + print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") + print("Finished train_compile.py successfully!", flush=True) + print(f"Cost analysis: {compiled.cost_analysis()}") + print(f"Memory analysis: {compiled.memory_analysis()}") + + # Dump HLO if requested + if config.dump_hlo: + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tests/end_to_end/tpu/gemma/2b/test_gemma.sh b/tests/end_to_end/tpu/gemma/2b/test_gemma.sh index 59ca5ace20..1d3073dbbb 100644 --- a/tests/end_to_end/tpu/gemma/2b/test_gemma.sh +++ b/tests/end_to_end/tpu/gemma/2b/test_gemma.sh @@ -66,4 +66,4 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAX # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B. # To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. -python3 -m MaxText.train_compile "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 +python3 -m maxtext.trainers.pre_train.train_compile "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh index 5264b7878f..bed3528e97 100644 --- a/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -63,4 +63,4 @@ python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. # To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. -python3 -m MaxText.train_compile "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 +python3 -m maxtext.trainers.pre_train.train_compile "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/tests/integration/aot_identical_test.py b/tests/integration/aot_identical_test.py index f30153ff96..ca95593cf3 100644 --- a/tests/integration/aot_identical_test.py +++ b/tests/integration/aot_identical_test.py @@ -27,7 +27,7 @@ import re import jax from tests.utils.test_helpers import get_test_config_path -from MaxText import train_compile +from maxtext.trainers.pre_train import train_compile from maxtext.trainers.pre_train import train diff --git a/tests/integration/xaot_test.py b/tests/integration/xaot_test.py index 1278b599a6..edb5cd039a 100644 --- a/tests/integration/xaot_test.py +++ b/tests/integration/xaot_test.py @@ -25,7 +25,7 @@ import shutil import jax from tests.utils.test_helpers import get_test_config_path -from MaxText import train_compile +from maxtext.trainers.pre_train import train_compile from maxtext.trainers.pre_train import train diff --git a/tests/unit/diloco_test.py b/tests/unit/diloco_test.py index d80020d244..787c87664c 100644 --- a/tests/unit/diloco_test.py +++ b/tests/unit/diloco_test.py @@ -30,7 +30,7 @@ import pytest from MaxText.pyconfig import initialize_pydantic -from MaxText.train_compile import main as train_compile_main +from maxtext.trainers.pre_train.train_compile import main as train_compile_main from maxtext.trainers.diloco import diloco from tests.utils.test_helpers import get_test_config_path diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 434844eb8e..054e9656c6 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -27,7 +27,7 @@ from MaxText.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from MaxText.train_compile import get_shaped_inputs, get_topology_mesh, validate_config +from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config from tests.utils.sharding_dump import TEST_CASES, load_json, named_shardings_to_json, partition_specs_to_json import pytest diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 58d23e9427..609914eed9 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -25,7 +25,7 @@ import pytest -from MaxText.train_compile import main as train_compile_main +from maxtext.trainers.pre_train.train_compile import main as train_compile_main from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_training] diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index 1646a871d7..88cddbf814 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -31,7 +31,7 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_REPO_ROOT from maxtext.models import models -from MaxText.train_compile import get_shaped_inputs, get_topology_mesh, validate_config +from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config Transformer = models.Transformer