From e0905bdcd8f077dd1ec94ad332d5b1f2c0e42925 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 5 Mar 2026 14:20:49 -0800 Subject: [PATCH 01/12] add cgemm + FP8 tests Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 40 ++++++- .../jax/collective_gemm/run_test_cgemm.sh | 79 +++++++++---- .../jax/collective_gemm/test_dense_grad.py | 111 +++++++++++++++++- examples/jax/collective_gemm/test_gemm.py | 111 +++++++++++++++++- .../test_layernorm_mlp_grad.py | 65 +++++++++- transformer_engine/jax/cpp_extensions/gemm.py | 6 + 6 files changed, 375 insertions(+), 37 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 2965896d07..e3904221a4 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -77,6 +77,8 @@ def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e- import argparse import jax from jax.experimental import mesh_utils +from transformer_engine.common import recipe as te_recipe +from transformer_engine.jax.quantize import ScalingMode from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap # Global flag to track if distributed has been initialized @@ -183,6 +185,36 @@ def _create_mesh(args): return mesh +def get_scaling_mode_from_recipe_name(name: str) -> ScalingMode: + """Get ScalingMode from a recipe name string.""" + match name: + case "DelayedScaling": + return ScalingMode.DELAYED_TENSOR_SCALING + case "Float8CurrentScaling": + return ScalingMode.CURRENT_TENSOR_SCALING + case "MXFP8BlockScaling": + return ScalingMode.MXFP8_1D_SCALING + case "NVFP4BlockScaling": + return ScalingMode.NVFP4_1D_SCALING + case _: + raise ValueError(f"Invalid recipe name, got {name}") + + +def get_quantization_recipe_from_name_string(name: str): + """Query recipe from a given name string""" + match name: + case "DelayedScaling": + return te_recipe.DelayedScaling() + case "MXFP8BlockScaling": + return te_recipe.MXFP8BlockScaling() + case "Float8CurrentScaling": + return te_recipe.Float8CurrentScaling() + case "NVFP4BlockScaling": + return te_recipe.NVFP4BlockScaling() + case _: + raise ValueError(f"Invalid quantization_recipe, got {name}") + + def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): """Create common argument parser for all collective GEMM tests.""" parser = argparse.ArgumentParser(description=description) @@ -229,7 +261,7 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para help="Type of collective operation", ) parser.add_argument( - "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + "--quantize-recipe", type=str, default="DelayedScaling", help="Quantization recipe to use" ) parser.add_argument( "--enable-data-parallel", action="store_true", help="Enable data parallelism" @@ -237,5 +269,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="Enable FP8 quantization", + ) return parser diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 388c878376..08fc92d2d8 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -23,11 +23,43 @@ else echo "NVLINK support detected" fi -# Define the test files to run -TEST_FILES=( -"test_gemm.py" -"test_dense_grad.py" -"test_layernorm_mlp_grad.py" +# Define individual test cases to run (file::class::method) +# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all +# the time. +TEST_CASES=( +# test_gemm.py cases +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" +# TODO(Phuong): Enable when supported +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" + +# test_dense_grad.py cases +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +# TODO(Phuong): Enable when supported +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" + +# test_layernorm_mlp_grad.py cases +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +# TODO(Phuong): Enable when supported +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo @@ -57,24 +89,27 @@ cleanup() { # Set up signal handlers to cleanup on exit trap cleanup EXIT INT TERM -# Run each test file across all GPUs -for TEST_FILE in "${TEST_FILES[@]}"; do +# Run each test case across all GPUs +for TEST_CASE in "${TEST_CASES[@]}"; do echo - echo "=== Starting test file: $TEST_FILE ..." + echo "=== Starting test: $TEST_CASE ..." - # Clear PIDs array for this test file + # Extract just the test method name for log/xml file naming + TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}') + + # Clear PIDs array for this test case PIDS=() for i in $(seq 0 $(($NUM_GPUS - 1))); do # Define output file for logs - LOG_FILE="${TEST_FILE}_gpu_${i}.log" + LOG_FILE="${TEST_NAME}_gpu_${i}.log" if [ $i -eq 0 ]; then # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i 2>&1 | tee "$LOG_FILE" & PID=$! @@ -82,7 +117,7 @@ for TEST_FILE in "${TEST_FILES[@]}"; do else # For other processes: redirect to log files only pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & PID=$! @@ -93,22 +128,22 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # Wait for all processes to finish wait - # Check and print the log content from process 0 (now has log file thanks to tee) - if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE SKIPPED" - elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE FAILED" + # Check and print the log content from process 0 + if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE PASSED" + elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE PASSED" else - echo "... $TEST_FILE INVALID" + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 fi # Remove the log files after processing them wait - rm ${TEST_FILE}_gpu_*.log + rm ${TEST_NAME}_gpu_*.log done wait diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 94c7dc5b66..980a9a7df9 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -20,11 +20,18 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.dense import dense -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOp, CollectiveOpSet, @@ -56,7 +63,7 @@ def _get_operand_sharding(mesh, collective_op): return x_sharding, weight_sharding, bias_sharding -def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): output = dense( x, weight, @@ -66,13 +73,14 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv kernel_axes=weight_axes, output_axes=output_axes, collective_op_set=collective_op_set, + quantizer_set=quantizer_set, ) return jnp.mean(output.astype(jnp.float32)) -def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( - x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set ) @@ -98,11 +106,16 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + use_fp8 = getattr(args, "use_fp8", False) + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_fp8, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. + quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -123,6 +136,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, noop_collective_op_set, + quantizer_set, ) output, sharded_grads = _value_and_grad_dense( x_sharded, @@ -132,6 +146,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, collective_op_set, + quantizer_set, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -187,6 +202,90 @@ def test_te_bf16_reduce_scatter(self): self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_all_gather(self): + # """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_reduce_scatter(self): + # """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_all_gather(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_dense_grad_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_reduce_scatter(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_dense_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index ea119713e3..abefafe0a0 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -29,10 +29,17 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp from transformer_engine.jax.sharding import MeshResource @@ -72,13 +79,14 @@ def _get_dp_and_tp_sizes(args): @partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) -def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): +def _jitted_cgemm(x, weight, bias, quantizer_set, contracting_dims, collective_op, output_sharding): output = tex.gemm( x, weight, bias=bias, contracting_dims=contracting_dims, collective_op=collective_op, + quantizer_set=quantizer_set, ) if output_sharding is not None: output = jax.lax.with_sharding_constraint(output, output_sharding) @@ -107,11 +115,20 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) + use_fp8 = getattr(args, "use_fp8", False) + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + + # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource + # (via global_shard_guard) required for collective GEMM sharding axis resolution. with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_fp8, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() can read the global recipe + # for correct fwd/bwd dtypes. autocast does not inject quantizers into raw + # tex.gemm() calls, so we must pass quantizer_set explicitly. + quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set print(f"Device mesh: {mesh}") x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( @@ -125,6 +142,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=CollectiveOp.NONE, output_sharding=output_sharding, @@ -133,6 +151,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=collective_op, output_sharding=output_sharding, @@ -186,6 +205,90 @@ def test_te_bf16_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_all_gather_with_dp(self): + # """Test Collective GEMM with MXFP8BlockScaling + AllGather""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_reduce_scatter_with_dp(self): + # """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_all_gather_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "all_gather" + # run_gemm_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_reduce_scatter_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # self.args.collective_type = "reduce_scatter" + # run_gemm_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 84cb011da1..7a46487325 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -20,11 +20,18 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOpSet, CollectiveOp, @@ -68,6 +75,7 @@ def _mean_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): output = layernorm_mlp( x, @@ -82,6 +90,7 @@ def _mean_layernorm_mlp( kernel_2_axes=weight_2_axes, activation_type=("gelu",), collective_op_sets=collective_op_sets, + quantizer_sets=quantizer_sets, ) return jnp.mean(output) @@ -98,6 +107,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): return jax.jit( jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) @@ -113,6 +123,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) @@ -149,11 +160,17 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + use_fp8 = getattr(args, "use_fp8", False) + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_fp8, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). + quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_sets = (quantizer_set, quantizer_set) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -181,6 +198,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, noop_collective_op_sets, + quantizer_sets, ) output, sharded_grads = _value_and_grad_layernorm_mlp( x_sharded, @@ -194,6 +212,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -240,9 +259,47 @@ def tearDown(self): os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) def test_te_bf16_layernorm_mlp_grad(self): - """Test Collective Dense Gradient with AllGather""" + """Test Collective LayerNorm MLP Gradient with BF16""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True run_layernorm_mlp_grad_tests(self.args, self.mesh) + def test_te_current_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported + # def test_te_mxfp8_layernorm_mlp_grad(self): + # """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" + # self.args.quantize_recipe = "MXFP8BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # run_layernorm_mlp_grad_tests(self.args, self.mesh) + + # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported + # def test_te_nvfp4_layernorm_mlp_grad(self): + # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.use_fp8 = True + # run_layernorm_mlp_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 70557f29c7..6eabe8a691 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1198,6 +1198,12 @@ def _te_gemm( rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv + if not collective_op.is_none and scaling_mode.is_1d_block_scaling(): + raise ValueError( + f"Collective GEMM is not yet supported with {scaling_mode} quantization. " + "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." + ) + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype if bias is None: bias = jnp.empty(0, dtype=out_dtype) From f314655e20abe0bf0f1e5be4dfd35321cd1ea8f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:22:46 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/collective_gemm/test_dense_grad.py | 24 ++++++++++++++----- examples/jax/collective_gemm/test_gemm.py | 16 +++++++++---- .../test_layernorm_mlp_grad.py | 8 +++++-- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 980a9a7df9..dcb51b34f4 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -63,7 +63,9 @@ def _get_operand_sharding(mesh, collective_op): return x_sharding, weight_sharding, bias_sharding -def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): +def _mean_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): output = dense( x, weight, @@ -78,7 +80,9 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv return jnp.mean(output.astype(jnp.float32)) -def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set): +def _value_and_grad_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set ) @@ -205,7 +209,9 @@ def test_te_bf16_reduce_scatter(self): def test_te_delayed_scaling_fp8_all_gather(self): """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -215,7 +221,9 @@ def test_te_delayed_scaling_fp8_all_gather(self): def test_te_delayed_scaling_fp8_reduce_scatter(self): """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -225,7 +233,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self): def test_te_current_scaling_fp8_all_gather(self): """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -235,7 +245,9 @@ def test_te_current_scaling_fp8_all_gather(self): def test_te_current_scaling_fp8_reduce_scatter(self): """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index abefafe0a0..f8c9dd18a8 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -208,7 +208,9 @@ def test_te_bf16_reduce_scatter_with_dp(self): def test_te_delayed_scaling_fp8_all_gather_with_dp(self): """Test Collective GEMM with FP8 DelayedScaling + AllGather""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -218,7 +220,9 @@ def test_te_delayed_scaling_fp8_all_gather_with_dp(self): def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -228,7 +232,9 @@ def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): def test_te_current_scaling_fp8_all_gather_with_dp(self): """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -238,7 +244,9 @@ def test_te_current_scaling_fp8_all_gather_with_dp(self): def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 7a46487325..ebf55ea5df 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -265,7 +265,9 @@ def test_te_bf16_layernorm_mlp_grad(self): def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" self.args.quantize_recipe = "DelayedScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True @@ -274,7 +276,9 @@ def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): def test_te_current_scaling_fp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" self.args.quantize_recipe = "Float8CurrentScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.use_fp8 = True From e4c60bbdd6804fb12637e769b6639c71234c3b1d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:00:58 -0700 Subject: [PATCH 03/12] cgemm+mxfp8 passed for AG Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 2 + .../jax/collective_gemm/run_test_cgemm.sh | 64 +++++++++---------- examples/jax/collective_gemm/test_gemm.py | 19 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 60 ++++++++++++++--- .../jax/csrc/extensions/gemm.cpp | 5 +- transformer_engine/jax/quantize/helper.py | 10 +-- 6 files changed, 105 insertions(+), 55 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index e3904221a4..95452e85bc 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -250,7 +250,9 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" ) parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") + # parser.add_argument("--batch-size", type=int, default=2, help="Batch size for testing") parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") + # parser.add_argument("--seq-len", type=int, default=16384, help="Sequence length for testing") parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") parser.add_argument( diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 08fc92d2d8..b1a8816703 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -28,38 +28,38 @@ fi # the time. TEST_CASES=( # test_gemm.py cases -"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" -"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" -# TODO(Phuong): Enable when supported -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" - -# test_dense_grad.py cases -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" -"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" -# TODO(Phuong): Enable when supported -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" - -# test_layernorm_mlp_grad.py cases -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" -"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" -# TODO(Phuong): Enable when supported -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" +# # TODO(Phuong): Enable when supported +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" +# +# # test_dense_grad.py cases +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +# # TODO(Phuong): Enable when supported +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" +# +# # test_layernorm_mlp_grad.py cases +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +# # TODO(Phuong): Enable when supported +# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index f8c9dd18a8..2a968f7ba4 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -253,16 +253,15 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_all_gather_with_dp(self): - # """Test Collective GEMM with MXFP8BlockScaling + AllGather""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.use_fp8 = True - # self.args.collective_type = "all_gather" - # run_gemm_tests(self.args, self.mesh) + def test_te_mxfp8_all_gather_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported # def test_te_mxfp8_reduce_scatter_with_dp(self): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 6eabe8a691..3a1e4f5e48 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -649,6 +649,34 @@ def impl( reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) lhs = reordered.reshape(original_shape) + if ( + collective_op.is_all_gather + and not transpose_batch_sequence + and not is_outer + and not lhs_scale_inv.shape[0] == 1 + and scaling_mode.is_1d_block_scaling() + ): + + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = lhs_scale_inv.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = lhs_scale_inv.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + lhs_scale_inv = reordered.reshape(original_shape) + (output, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, @@ -775,6 +803,7 @@ def _parse_operand_output_specs( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -911,11 +940,23 @@ def _parse_operand_output_specs( # Bias sharding is based on GEMM output before any scatter bias_specs = rhs_non_cspecs if arg_infos[4].size > 0 else (None,) # bias is operand index 4 + # Scale shardings are based on the scaling_mode and collective_op + lhs_scale_specs = rhs_scale_specs = (None,) + if scaling_mode.is_1d_block_scaling(): + rhs_scale_specs = rhs_specs + if collective_op.is_all_gather: + lhs_scale_specs = tuple(None if i == sequence_dim else s for i, s in enumerate(lhs_specs)) + else: + lhs_scale_specs = lhs_specs + print(lhs_scale_specs) + print(rhs_scale_specs) + + if not collective_op.is_none: assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" return ( - (lhs_specs, rhs_specs, bias_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_specs), out_specs, reduce_spec, sequence_dim, @@ -937,7 +978,6 @@ def infer_sharding_from_operands( ): del ( out_dtype, - scaling_mode, use_split_accumulator, result_infos, is_outer, @@ -945,7 +985,7 @@ def infer_sharding_from_operands( ) (_, out_specs, *_) = GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op + arg_infos, contracting_dims, transpose_batch_sequence, collective_op, scaling_mode, ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -968,7 +1008,7 @@ def partition( del result_infos, is_outer, sequence_dim ( - (lhs_specs, rhs_specs, bias_input_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_input_specs), out_specs, reduce_spec, inferred_sequence_dim, @@ -977,17 +1017,21 @@ def partition( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ) # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + lhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*lhs_scale_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + rhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*rhs_scale_specs)) + arg_shardings = ( lhs_sharding, - lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + lhs_scale_sharding, rhs_sharding, - rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_scale_sharding, ) # Bias @@ -1198,8 +1242,8 @@ def _te_gemm( rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv - if not collective_op.is_none and scaling_mode.is_1d_block_scaling(): - raise ValueError( + if not collective_op.is_none: + assert not scaling_mode.is_nvfp4_scaling, ( f"Collective GEMM is not yet supported with {scaling_mode} quantization. " "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 737dd65622..268307ea83 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -58,6 +58,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( std::vector scale_shape = {1}; auto is_nvfp4 = is_nvfp4_scaling(scaling_mode); auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) { // Block scaling also needs to be collapsed to match 2D data scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), @@ -73,7 +74,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } input.set_with_gemm_swizzled_scales(true); - } else if (is_nvfp4) { // Swizzle for NVFP4 + } + else if (is_nvfp4) { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor @@ -202,6 +204,7 @@ Error_Type GemmV2FFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); size_t workspace_size = static_cast(workspace->element_count()) - 256; + if (is_nvfp4_scaling(config.scaling_mode)) { auto lhs_scale_size = product(lhs_scale_inv.dimensions()); auto rhs_scale_size = product(rhs_scale_inv.dimensions()); diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c5256aef5c..a79c8e6e90 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -927,10 +927,12 @@ def apply_padding_to_scale_inv( unpadded_scale_shape = scaling_mode.get_scale_shape( data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) - assert scale_inv.shape == unpadded_scale_shape, ( - f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " - f"{scale_inv.shape}." - ) + + # TODO + # assert scale_inv.shape == unpadded_scale_shape, ( + # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + # f"{scale_inv.shape}." + # ) # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) From 6b836aff9607695a2d1e4da6bffb66eed9c48328 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:13:00 -0700 Subject: [PATCH 04/12] refactor code Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 100 ++++++++---------- 1 file changed, 45 insertions(+), 55 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3a1e4f5e48..c27a063c25 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -396,6 +396,48 @@ def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): ) +def _reorder_tpsp_leading(tensor, original_shape): + """Reorder tensor so the tpsp axis is leading: reshape (dp, n, tpsp, m, ...), transpose (2, 0, 1, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + +def _reorder_dp_leading(tensor, original_shape): + """Reorder tensor so the dp axis is leading: reshape (tpsp, dp, n, m, ...), transpose (1, 2, 0, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -630,24 +672,7 @@ def impl( and not lhs.shape[0] == 1 ): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = lhs.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = lhs.reshape( - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - tpsp_axis_size(), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) - lhs = reordered.reshape(original_shape) + lhs = _reorder_tpsp_leading(lhs, lhs.shape) if ( collective_op.is_all_gather @@ -656,26 +681,8 @@ def impl( and not lhs_scale_inv.shape[0] == 1 and scaling_mode.is_1d_block_scaling() ): - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = lhs_scale_inv.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = lhs_scale_inv.reshape( - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - tpsp_axis_size(), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) - lhs_scale_inv = reordered.reshape(original_shape) + lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) (output, _) = GemmPrimitive.inner_primitive.bind( lhs, @@ -702,24 +709,7 @@ def impl( and not output.shape[0] == 1 ): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = output.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = output.reshape( - tpsp_axis_size(), - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) - output = reordered.reshape(original_shape) + output = _reorder_dp_leading(output, output.shape) return (output,) From a63db5675eb06e1685c9f4665518c768985281b1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:22:33 -0700 Subject: [PATCH 05/12] mxfp8 + rs passed Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 2 -- .../jax/collective_gemm/run_test_cgemm.sh | 4 ++-- examples/jax/collective_gemm/test_gemm.py | 19 +++++++++---------- transformer_engine/jax/cpp_extensions/gemm.py | 10 ++++++++++ 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 95452e85bc..e3904221a4 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -250,9 +250,7 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" ) parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") - # parser.add_argument("--batch-size", type=int, default=2, help="Batch size for testing") parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") - # parser.add_argument("--seq-len", type=int, default=16384, help="Sequence length for testing") parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") parser.add_argument( diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index b1a8816703..d5de55287b 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -35,8 +35,8 @@ TEST_CASES=( # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" # # TODO(Phuong): Enable when supported -"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" -# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" # diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 2a968f7ba4..fed69ec636 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -263,16 +263,15 @@ def test_te_mxfp8_all_gather_with_dp(self): self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_reduce_scatter_with_dp(self): - # """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.use_fp8 = True - # self.args.collective_type = "reduce_scatter" - # run_gemm_tests(self.args, self.mesh) + def test_te_mxfp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_all_gather_with_dp(self): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c27a063c25..121195a2bb 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -674,6 +674,16 @@ def impl( assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs = _reorder_tpsp_leading(lhs, lhs.shape) + if ( + collective_op.is_reduce_scatter + and not transpose_batch_sequence + and not is_outer + and not lhs_scale_inv.shape[0] == 1 + and scaling_mode.is_1d_block_scaling() + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) + if ( collective_op.is_all_gather and not transpose_batch_sequence From 5c005dcaa25b497f0257e556c6c81b11f53341a3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 08:55:02 -0700 Subject: [PATCH 06/12] simplify the conditions Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 35 ++++--------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 121195a2bb..f9550dc8be 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -664,33 +664,15 @@ def impl( lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) + # Determine if we need to reorder the tensor so that the input/output are in the correct layout for the collective operation + need_reorder = not transpose_batch_sequence and not is_outer and not collective_op.is_none + # Alter lhs blocks so that CGEMM RS outputs correctly - if ( - collective_op.is_reduce_scatter - and not transpose_batch_sequence - and not is_outer - and not lhs.shape[0] == 1 - ): + if need_reorder and collective_op.is_reduce_scatter and lhs.shape[0] != 1: assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs = _reorder_tpsp_leading(lhs, lhs.shape) - if ( - collective_op.is_reduce_scatter - and not transpose_batch_sequence - and not is_outer - and not lhs_scale_inv.shape[0] == 1 - and scaling_mode.is_1d_block_scaling() - ): - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) - - if ( - collective_op.is_all_gather - and not transpose_batch_sequence - and not is_outer - and not lhs_scale_inv.shape[0] == 1 - and scaling_mode.is_1d_block_scaling() - ): + if need_reorder and (collective_op.is_reduce_scatter or collective_op.is_all_gather) and lhs_scale_inv.shape[0] != 1 and scaling_mode.is_1d_block_scaling(): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) @@ -712,12 +694,7 @@ def impl( collective_op=collective_op, ) # Alter output blocks for CGEMM AG - if ( - collective_op.is_all_gather - and not transpose_batch_sequence - and not is_outer - and not output.shape[0] == 1 - ): + if need_reorder and collective_op.is_all_gather and output.shape[0] != 1: assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" output = _reorder_dp_leading(output, output.shape) From 1ba99fa692188a7238436bdf73a4f53a3208537d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 10:30:35 -0700 Subject: [PATCH 07/12] added size check for mxfp8 Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index f9550dc8be..c6dee60a0c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -652,12 +652,26 @@ def impl( lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis - ) + if not collective_op.is_none and not is_outer: + # MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128. + # No padding is needed in this case + lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod(lhs.shape[lhs_flatten_axis:]) + assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( + f"MXFP8 + Collective AG requires LHS dimensions before and after the flatten axis to be multiples of 128. " + f"Got lhs.shape={lhs.shape}, lhs_flatten_axis={lhs_flatten_axis}" + ) + # The scale needs to be in good shape for reordering + assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( + f"MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be multiples of tpsp_axis_size. " + f"Got lhs_scale_inv.shape={lhs_scale_inv.shape}, tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" + ) + else: + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis, + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) # Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel if scaling_mode.is_mxfp8_scaling: @@ -925,9 +939,6 @@ def _parse_operand_output_specs( lhs_scale_specs = tuple(None if i == sequence_dim else s for i, s in enumerate(lhs_specs)) else: lhs_scale_specs = lhs_specs - print(lhs_scale_specs) - print(rhs_scale_specs) - if not collective_op.is_none: assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" From 38fbf562d6d51d36fd9c8df6e6c89fcaf17aa0fc Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:12:03 -0700 Subject: [PATCH 08/12] added tols for assertions Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 14 ++++++++++++++ examples/jax/collective_gemm/run_test_cgemm.sh | 2 +- examples/jax/collective_gemm/test_dense_grad.py | 6 ++++-- examples/jax/collective_gemm/test_gemm.py | 3 ++- .../jax/collective_gemm/test_layernorm_mlp_grad.py | 6 ++++-- 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index e3904221a4..355b5f11b7 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -21,10 +21,24 @@ def dtype_tols(dtype, rtol=None, atol=None): return {"rtol": 1e-3, "atol": 1e-6} elif dtype in [jnp.bfloat16, "bfloat16"]: return {"rtol": 1e-2, "atol": 1e-5} + elif dtype in [jnp.float8_e4m3fn, "float8_e4m3fn", jnp.float8_e5m2, "float8_e5m2"]: + # FP8 quantization introduces ~1% error; match C++ getTolerances for fp8 types + return {"rtol": 1e-2, "atol": 1e-2} else: return {"rtol": 1e-5, "atol": 1e-8} +def get_tolerance_dtype(quantizer_set): + """Return the dtype used to select numerical tolerances based on the active quantizer. + + Reads q_dtype from quantizer_set.x; falls back to bfloat16 when no quantizer is + active (NO_SCALING / noop path, where quantizer_set.x is None). + """ + if quantizer_set.x is not None: + return quantizer_set.x.q_dtype + return jnp.bfloat16 + + def assert_allclose( actual, desired, diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index d5de55287b..04553a0174 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -35,7 +35,7 @@ TEST_CASES=( # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" # "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" # # TODO(Phuong): Enable when supported -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index dcb51b34f4..7c2ec5d607 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -13,6 +13,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -167,9 +168,10 @@ def run_dense_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveDenseGradient(unittest.TestCase): diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index fed69ec636..3df85ab87c 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -22,6 +22,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -169,7 +170,7 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - assert_allclose(gathered_ref_output, gathered_output) + assert_allclose(gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set)) class TestCollectiveGemmWithDP(unittest.TestCase): diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index ebf55ea5df..927f3e99b2 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -13,6 +13,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -229,9 +230,10 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveLayerNormMLPGradient(unittest.TestCase): From 018f4014b2cf546b410d1e236466c24d0c66b9a9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:26:07 -0700 Subject: [PATCH 09/12] update tests with recipes Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 12 +++++----- .../jax/collective_gemm/test_dense_grad.py | 20 +++++++---------- examples/jax/collective_gemm/test_gemm.py | 22 +++++++++---------- .../test_layernorm_mlp_grad.py | 14 +++++------- 4 files changed, 29 insertions(+), 39 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 355b5f11b7..483f3e60af 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -275,7 +275,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para help="Type of collective operation", ) parser.add_argument( - "--quantize-recipe", type=str, default="DelayedScaling", help="Quantization recipe to use" + "--quantize-recipe", + type=str, + default=None, + choices=["DelayedScaling", "Float8CurrentScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"], + help="Quantization recipe to use. Omit for BF16 (no quantization).", ) parser.add_argument( "--enable-data-parallel", action="store_true", help="Enable data parallelism" @@ -283,11 +287,5 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) - parser.add_argument( - "--use-fp8", - action="store_true", - default=False, - help="Enable FP8 quantization", - ) return parser diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 7c2ec5d607..adc97b1790 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -111,16 +111,16 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) - use_fp8 = getattr(args, "use_fp8", False) - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None with mesh, autocast( - enabled=use_fp8, + enabled=use_quantization, recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): # Build quantizer_set inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. - quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -216,7 +216,7 @@ def test_te_delayed_scaling_fp8_all_gather(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_dense_grad_tests(self.args, self.mesh) @@ -228,7 +228,7 @@ def test_te_delayed_scaling_fp8_reduce_scatter(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) @@ -240,7 +240,7 @@ def test_te_current_scaling_fp8_all_gather(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_dense_grad_tests(self.args, self.mesh) @@ -252,7 +252,7 @@ def test_te_current_scaling_fp8_reduce_scatter(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) @@ -263,7 +263,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "all_gather" # run_dense_grad_tests(self.args, self.mesh) @@ -274,7 +273,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "reduce_scatter" # run_dense_grad_tests(self.args, self.mesh) @@ -285,7 +283,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "all_gather" # run_dense_grad_tests(self.args, self.mesh) @@ -296,7 +293,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "reduce_scatter" # run_dense_grad_tests(self.args, self.mesh) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 3df85ab87c..d06fa9e75e 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -116,20 +116,20 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) - use_fp8 = getattr(args, "use_fp8", False) - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource # (via global_shard_guard) required for collective GEMM sharding axis resolution. with mesh, autocast( - enabled=use_fp8, + enabled=use_quantization, recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): # Build quantizer_set inside autocast so create_set() can read the global recipe # for correct fwd/bwd dtypes. autocast does not inject quantizers into raw # tex.gemm() calls, so we must pass quantizer_set explicitly. - quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set print(f"Device mesh: {mesh}") x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( @@ -214,7 +214,7 @@ def test_te_delayed_scaling_fp8_all_gather_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) @@ -226,7 +226,7 @@ def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) @@ -238,7 +238,7 @@ def test_te_current_scaling_fp8_all_gather_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) @@ -250,7 +250,7 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) @@ -260,7 +260,7 @@ def test_te_mxfp8_all_gather_with_dp(self): is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "all_gather" run_gemm_tests(self.args, self.mesh) @@ -270,7 +270,7 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) @@ -281,7 +281,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "all_gather" # run_gemm_tests(self.args, self.mesh) @@ -292,7 +291,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # self.args.collective_type = "reduce_scatter" # run_gemm_tests(self.args, self.mesh) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 927f3e99b2..599c88e3d9 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -161,16 +161,16 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) - use_fp8 = getattr(args, "use_fp8", False) - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None + use_quantization = args.quantize_recipe is not None + recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None with mesh, autocast( - enabled=use_fp8, + enabled=use_quantization, recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): # Build quantizer_set inside autocast so create_set() reads the global recipe # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). - quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set quantizer_sets = (quantizer_set, quantizer_set) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() @@ -272,7 +272,7 @@ def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + run_layernorm_mlp_grad_tests(self.args, self.mesh) def test_te_current_scaling_fp8_layernorm_mlp_grad(self): @@ -283,7 +283,7 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): ) if not is_supported: self.skipTest(reason) - self.args.use_fp8 = True + run_layernorm_mlp_grad_tests(self.args, self.mesh) # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported @@ -293,7 +293,6 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # run_layernorm_mlp_grad_tests(self.args, self.mesh) # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported @@ -303,7 +302,6 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) # if not is_supported: # self.skipTest(reason) - # self.args.use_fp8 = True # run_layernorm_mlp_grad_tests(self.args, self.mesh) From 05ed98d56bd4a9b6fd282fe48a1060aa09876ae8 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:26:25 -0700 Subject: [PATCH 10/12] shape check when padding mxfp8 scales Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/helper.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index a79c8e6e90..e584b5a452 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -928,11 +928,10 @@ def apply_padding_to_scale_inv( data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) - # TODO - # assert scale_inv.shape == unpadded_scale_shape, ( - # f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " - # f"{scale_inv.shape}." - # ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) From 68a63abfd18577a6c2cd115658dc7f69f3b5e37d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 10 Mar 2026 11:38:20 -0700 Subject: [PATCH 11/12] cleanup Signed-off-by: Phuong Nguyen --- .../jax/collective_gemm/run_test_cgemm.sh | 41 ++++++++----------- .../jax/collective_gemm/test_dense_grad.py | 36 ++++++++-------- examples/jax/collective_gemm/test_gemm.py | 2 - .../test_layernorm_mlp_grad.py | 16 ++++---- 4 files changed, 40 insertions(+), 55 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 04553a0174..a098515af9 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -28,38 +28,31 @@ fi # the time. TEST_CASES=( # test_gemm.py cases -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp" -# "test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp" -# # TODO(Phuong): Enable when supported +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" # # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" # # # test_dense_grad.py cases -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" -# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" -# # TODO(Phuong): Enable when supported -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" -# # "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" # # # test_layernorm_mlp_grad.py cases -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" -# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" -# # TODO(Phuong): Enable when supported -# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" -# # "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index adc97b1790..35a4b9457f 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -256,27 +256,24 @@ def test_te_current_scaling_fp8_reduce_scatter(self): self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_all_gather(self): - # """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.collective_type = "all_gather" - # run_dense_grad_tests(self.args, self.mesh) + def test_te_mxfp8_all_gather(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_reduce_scatter(self): - # """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # self.args.collective_type = "reduce_scatter" - # run_dense_grad_tests(self.args, self.mesh) + def test_te_mxfp8_reduce_scatter(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_all_gather(self): # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" # self.args.quantize_recipe = "NVFP4BlockScaling" @@ -286,7 +283,6 @@ def test_te_current_scaling_fp8_reduce_scatter(self): # self.args.collective_type = "all_gather" # run_dense_grad_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_reduce_scatter(self): # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" # self.args.quantize_recipe = "NVFP4BlockScaling" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d06fa9e75e..c969062376 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -274,7 +274,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_all_gather_with_dp(self): # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" # self.args.quantize_recipe = "NVFP4BlockScaling" @@ -284,7 +283,6 @@ def test_te_mxfp8_reduce_scatter_with_dp(self): # self.args.collective_type = "all_gather" # run_gemm_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_reduce_scatter_with_dp(self): # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" # self.args.quantize_recipe = "NVFP4BlockScaling" diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 599c88e3d9..fe245f436b 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -286,16 +286,14 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): run_layernorm_mlp_grad_tests(self.args, self.mesh) - # TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported - # def test_te_mxfp8_layernorm_mlp_grad(self): - # """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" - # self.args.quantize_recipe = "MXFP8BlockScaling" - # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) - # if not is_supported: - # self.skipTest(reason) - # run_layernorm_mlp_grad_tests(self.args, self.mesh) + def test_te_mxfp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + if not is_supported: + self.skipTest(reason) + run_layernorm_mlp_grad_tests(self.args, self.mesh) - # TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported # def test_te_nvfp4_layernorm_mlp_grad(self): # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" # self.args.quantize_recipe = "NVFP4BlockScaling" From b3f43de5fc295755721f1166cf69ebf425f866ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 18:39:54 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/collective_gemm/common.py | 7 +++- .../jax/collective_gemm/test_dense_grad.py | 12 ++++-- examples/jax/collective_gemm/test_gemm.py | 16 ++++++-- .../test_layernorm_mlp_grad.py | 8 +++- transformer_engine/jax/cpp_extensions/gemm.py | 37 ++++++++++++++----- .../jax/csrc/extensions/gemm.cpp | 3 +- 6 files changed, 62 insertions(+), 21 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 483f3e60af..d0fad8c48b 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -278,7 +278,12 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para "--quantize-recipe", type=str, default=None, - choices=["DelayedScaling", "Float8CurrentScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"], + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], help="Quantization recipe to use. Omit for BF16 (no quantization).", ) parser.add_argument( diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 35a4b9457f..b6a5422470 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -112,7 +112,9 @@ def run_dense_grad_tests(args, mesh=None): collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) use_quantization = args.quantize_recipe is not None - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) with mesh, autocast( enabled=use_quantization, recipe=recipe, @@ -259,7 +261,9 @@ def test_te_current_scaling_fp8_reduce_scatter(self): def test_te_mxfp8_all_gather(self): """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.collective_type = "all_gather" @@ -268,7 +272,9 @@ def test_te_mxfp8_all_gather(self): def test_te_mxfp8_reduce_scatter(self): """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) self.args.collective_type = "reduce_scatter" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index c969062376..8f0e9a44cf 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -117,7 +117,9 @@ def run_gemm_tests(args, mesh=None): ) use_quantization = args.quantize_recipe is not None - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource # (via global_shard_guard) required for collective GEMM sharding axis resolution. @@ -170,7 +172,9 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - assert_allclose(gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set)) + assert_allclose( + gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set) + ) class TestCollectiveGemmWithDP(unittest.TestCase): @@ -257,7 +261,9 @@ def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): def test_te_mxfp8_all_gather_with_dp(self): """Test Collective GEMM with MXFP8BlockScaling + AllGather""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) @@ -267,7 +273,9 @@ def test_te_mxfp8_all_gather_with_dp(self): def test_te_mxfp8_reduce_scatter_with_dp(self): """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index fe245f436b..f242840ba0 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -162,7 +162,9 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) use_quantization = args.quantize_recipe is not None - recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) with mesh, autocast( enabled=use_quantization, recipe=recipe, @@ -289,7 +291,9 @@ def test_te_current_scaling_fp8_layernorm_mlp_grad(self): def test_te_mxfp8_layernorm_mlp_grad(self): """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" self.args.quantize_recipe = "MXFP8BlockScaling" - is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) if not is_supported: self.skipTest(reason) run_layernorm_mlp_grad_tests(self.args, self.mesh) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c6dee60a0c..8df1d9995f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -655,19 +655,27 @@ def impl( if not collective_op.is_none and not is_outer: # MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128. # No padding is needed in this case - lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod(lhs.shape[lhs_flatten_axis:]) + lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod( + lhs.shape[lhs_flatten_axis:] + ) assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( - f"MXFP8 + Collective AG requires LHS dimensions before and after the flatten axis to be multiples of 128. " - f"Got lhs.shape={lhs.shape}, lhs_flatten_axis={lhs_flatten_axis}" + "MXFP8 + Collective AG requires LHS dimensions before and after the flatten" + f" axis to be multiples of 128. Got lhs.shape={lhs.shape}," + f" lhs_flatten_axis={lhs_flatten_axis}" ) # The scale needs to be in good shape for reordering assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( - f"MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be multiples of tpsp_axis_size. " - f"Got lhs_scale_inv.shape={lhs_scale_inv.shape}, tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" + "MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be" + f" multiples of tpsp_axis_size. Got lhs_scale_inv.shape={lhs_scale_inv.shape}," + f" tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" ) else: lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis, + lhs_scale_inv, + scaling_mode, + lhs.shape, + lhs_transposed, + lhs_flatten_axis, ) rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis @@ -686,7 +694,12 @@ def impl( assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs = _reorder_tpsp_leading(lhs, lhs.shape) - if need_reorder and (collective_op.is_reduce_scatter or collective_op.is_all_gather) and lhs_scale_inv.shape[0] != 1 and scaling_mode.is_1d_block_scaling(): + if ( + need_reorder + and (collective_op.is_reduce_scatter or collective_op.is_all_gather) + and lhs_scale_inv.shape[0] != 1 + and scaling_mode.is_1d_block_scaling() + ): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) @@ -936,7 +949,9 @@ def _parse_operand_output_specs( if scaling_mode.is_1d_block_scaling(): rhs_scale_specs = rhs_specs if collective_op.is_all_gather: - lhs_scale_specs = tuple(None if i == sequence_dim else s for i, s in enumerate(lhs_specs)) + lhs_scale_specs = tuple( + None if i == sequence_dim else s for i, s in enumerate(lhs_specs) + ) else: lhs_scale_specs = lhs_specs @@ -973,7 +988,11 @@ def infer_sharding_from_operands( ) (_, out_specs, *_) = GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op, scaling_mode, + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + scaling_mode, ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 268307ea83..2acefa2d30 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -74,8 +74,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } input.set_with_gemm_swizzled_scales(true); - } - else if (is_nvfp4) { // Swizzle for NVFP4 + } else if (is_nvfp4) { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor