diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 2965896d07..d0fad8c48b 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -21,10 +21,24 @@ def dtype_tols(dtype, rtol=None, atol=None): return {"rtol": 1e-3, "atol": 1e-6} elif dtype in [jnp.bfloat16, "bfloat16"]: return {"rtol": 1e-2, "atol": 1e-5} + elif dtype in [jnp.float8_e4m3fn, "float8_e4m3fn", jnp.float8_e5m2, "float8_e5m2"]: + # FP8 quantization introduces ~1% error; match C++ getTolerances for fp8 types + return {"rtol": 1e-2, "atol": 1e-2} else: return {"rtol": 1e-5, "atol": 1e-8} +def get_tolerance_dtype(quantizer_set): + """Return the dtype used to select numerical tolerances based on the active quantizer. + + Reads q_dtype from quantizer_set.x; falls back to bfloat16 when no quantizer is + active (NO_SCALING / noop path, where quantizer_set.x is None). + """ + if quantizer_set.x is not None: + return quantizer_set.x.q_dtype + return jnp.bfloat16 + + def assert_allclose( actual, desired, @@ -77,6 +91,8 @@ def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e- import argparse import jax from jax.experimental import mesh_utils +from transformer_engine.common import recipe as te_recipe +from transformer_engine.jax.quantize import ScalingMode from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap # Global flag to track if distributed has been initialized @@ -183,6 +199,36 @@ def _create_mesh(args): return mesh +def get_scaling_mode_from_recipe_name(name: str) -> ScalingMode: + """Get ScalingMode from a recipe name string.""" + match name: + case "DelayedScaling": + return ScalingMode.DELAYED_TENSOR_SCALING + case "Float8CurrentScaling": + return ScalingMode.CURRENT_TENSOR_SCALING + case "MXFP8BlockScaling": + return ScalingMode.MXFP8_1D_SCALING + case "NVFP4BlockScaling": + return ScalingMode.NVFP4_1D_SCALING + case _: + raise ValueError(f"Invalid recipe name, got {name}") + + +def get_quantization_recipe_from_name_string(name: str): + """Query recipe from a given name string""" + match name: + case "DelayedScaling": + return te_recipe.DelayedScaling() + case "MXFP8BlockScaling": + return te_recipe.MXFP8BlockScaling() + case "Float8CurrentScaling": + return te_recipe.Float8CurrentScaling() + case "NVFP4BlockScaling": + return te_recipe.NVFP4BlockScaling() + case _: + raise ValueError(f"Invalid quantization_recipe, got {name}") + + def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): """Create common argument parser for all collective GEMM tests.""" parser = argparse.ArgumentParser(description=description) @@ -229,7 +275,16 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para help="Type of collective operation", ) parser.add_argument( - "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + "--quantize-recipe", + type=str, + default=None, + choices=[ + "DelayedScaling", + "Float8CurrentScaling", + "MXFP8BlockScaling", + "NVFP4BlockScaling", + ], + help="Quantization recipe to use. Omit for BF16 (no quantization).", ) parser.add_argument( "--enable-data-parallel", action="store_true", help="Enable data parallelism" diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 388c878376..a098515af9 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -23,11 +23,36 @@ else echo "NVLINK support detected" fi -# Define the test files to run -TEST_FILES=( -"test_gemm.py" -"test_dense_grad.py" -"test_layernorm_mlp_grad.py" +# Define individual test cases to run (file::class::method) +# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all +# the time. +TEST_CASES=( +# test_gemm.py cases +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp" +"test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp" +# # "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp" +# +# # test_dense_grad.py cases +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather" +"test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather" +# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter" +# +# # test_layernorm_mlp_grad.py cases +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad" +"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad" +# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad" ) echo @@ -57,24 +82,27 @@ cleanup() { # Set up signal handlers to cleanup on exit trap cleanup EXIT INT TERM -# Run each test file across all GPUs -for TEST_FILE in "${TEST_FILES[@]}"; do +# Run each test case across all GPUs +for TEST_CASE in "${TEST_CASES[@]}"; do echo - echo "=== Starting test file: $TEST_FILE ..." + echo "=== Starting test: $TEST_CASE ..." + + # Extract just the test method name for log/xml file naming + TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}') - # Clear PIDs array for this test file + # Clear PIDs array for this test case PIDS=() for i in $(seq 0 $(($NUM_GPUS - 1))); do # Define output file for logs - LOG_FILE="${TEST_FILE}_gpu_${i}.log" + LOG_FILE="${TEST_NAME}_gpu_${i}.log" if [ $i -eq 0 ]; then # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i 2>&1 | tee "$LOG_FILE" & PID=$! @@ -82,7 +110,7 @@ for TEST_FILE in "${TEST_FILES[@]}"; do else # For other processes: redirect to log files only pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ --num-processes=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & PID=$! @@ -93,22 +121,22 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # Wait for all processes to finish wait - # Check and print the log content from process 0 (now has log file thanks to tee) - if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE SKIPPED" - elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE FAILED" + # Check and print the log content from process 0 + if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE PASSED" + elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then + echo "... $TEST_CASE PASSED" else - echo "... $TEST_FILE INVALID" + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 fi # Remove the log files after processing them wait - rm ${TEST_FILE}_gpu_*.log + rm ${TEST_NAME}_gpu_*.log done wait diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 94c7dc5b66..b6a5422470 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -13,6 +13,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -20,11 +21,18 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.dense import dense -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOp, CollectiveOpSet, @@ -56,7 +64,9 @@ def _get_operand_sharding(mesh, collective_op): return x_sharding, weight_sharding, bias_sharding -def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _mean_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): output = dense( x, weight, @@ -66,13 +76,16 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv kernel_axes=weight_axes, output_axes=output_axes, collective_op_set=collective_op_set, + quantizer_set=quantizer_set, ) return jnp.mean(output.astype(jnp.float32)) -def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): +def _value_and_grad_dense( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set +): return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( - x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set ) @@ -98,11 +111,18 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + use_quantization = args.quantize_recipe is not None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_quantization, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -123,6 +143,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, noop_collective_op_set, + quantizer_set, ) output, sharded_grads = _value_and_grad_dense( x_sharded, @@ -132,6 +153,7 @@ def run_dense_grad_tests(args, mesh=None): weight_axes, output_axes, collective_op_set, + quantizer_set, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -148,9 +170,10 @@ def run_dense_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveDenseGradient(unittest.TestCase): @@ -187,6 +210,94 @@ def test_te_bf16_reduce_scatter(self): self.args.collective_type = "reduce_scatter" run_dense_grad_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter(self): + """Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_mxfp8_all_gather(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_mxfp8_reduce_scatter(self): + """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + # def test_te_nvfp4_all_gather(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "all_gather" + # run_dense_grad_tests(self.args, self.mesh) + + # def test_te_nvfp4_reduce_scatter(self): + # """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "reduce_scatter" + # run_dense_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index ea119713e3..8f0e9a44cf 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -22,6 +22,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -29,10 +30,17 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp from transformer_engine.jax.sharding import MeshResource @@ -72,13 +80,14 @@ def _get_dp_and_tp_sizes(args): @partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) -def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): +def _jitted_cgemm(x, weight, bias, quantizer_set, contracting_dims, collective_op, output_sharding): output = tex.gemm( x, weight, bias=bias, contracting_dims=contracting_dims, collective_op=collective_op, + quantizer_set=quantizer_set, ) if output_sharding is not None: output = jax.lax.with_sharding_constraint(output, output_sharding) @@ -107,11 +116,22 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) + use_quantization = args.quantize_recipe is not None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) + + # autocast sets the global recipe (fwd/bwd dtypes) AND the global MeshResource + # (via global_shard_guard) required for collective GEMM sharding axis resolution. with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_quantization, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() can read the global recipe + # for correct fwd/bwd dtypes. autocast does not inject quantizers into raw + # tex.gemm() calls, so we must pass quantizer_set explicitly. + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set print(f"Device mesh: {mesh}") x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( @@ -125,6 +145,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=CollectiveOp.NONE, output_sharding=output_sharding, @@ -133,6 +154,7 @@ def run_gemm_tests(args, mesh=None): x_sharded, weight_sharded, bias_sharded, + quantizer_set, contracting_dims=((2,), (0,)), collective_op=collective_op, output_sharding=output_sharding, @@ -150,7 +172,9 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - assert_allclose(gathered_ref_output, gathered_output) + assert_allclose( + gathered_ref_output, gathered_output, dtype=get_tolerance_dtype(quantizer_set) + ) class TestCollectiveGemmWithDP(unittest.TestCase): @@ -186,6 +210,96 @@ def test_te_bf16_reduce_scatter_with_dp(self): self.args.collective_type = "reduce_scatter" run_gemm_tests(self.args, self.mesh) + def test_te_delayed_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + AllGather""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 DelayedScaling + ReduceScatter""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_all_gather_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + AllGather""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with FP8 Float8CurrentScaling + ReduceScatter""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + def test_te_mxfp8_all_gather_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + AllGather""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_mxfp8_reduce_scatter_with_dp(self): + """Test Collective GEMM with MXFP8BlockScaling + ReduceScatter""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + # def test_te_nvfp4_all_gather_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + AllGather""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "all_gather" + # run_gemm_tests(self.args, self.mesh) + + # def test_te_nvfp4_reduce_scatter_with_dp(self): + # """Test Collective GEMM with NVFP4BlockScaling + ReduceScatter""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # self.args.collective_type = "reduce_scatter" + # run_gemm_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 84cb011da1..f242840ba0 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -13,6 +13,7 @@ from common import ( assert_allclose, + get_tolerance_dtype, _initialize_distributed, _get_dp_and_tp_sizes, _create_mesh, @@ -20,11 +21,18 @@ TPSP_AXIS, PARAMS_KEY, cgemm_parser, + get_quantization_recipe_from_name_string, + get_scaling_mode_from_recipe_name, ) from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.quantize import ( + autocast, + is_scaling_mode_supported, + QuantizerFactory, + noop_quantizer_set, +) from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOpSet, CollectiveOp, @@ -68,6 +76,7 @@ def _mean_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): output = layernorm_mlp( x, @@ -82,6 +91,7 @@ def _mean_layernorm_mlp( kernel_2_axes=weight_2_axes, activation_type=("gelu",), collective_op_sets=collective_op_sets, + quantizer_sets=quantizer_sets, ) return jnp.mean(output) @@ -98,6 +108,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ): return jax.jit( jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) @@ -113,6 +124,7 @@ def _value_and_grad_layernorm_mlp( weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) @@ -149,11 +161,19 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + use_quantization = args.quantize_recipe is not None + recipe = ( + get_quantization_recipe_from_name_string(args.quantize_recipe) if use_quantization else None + ) with mesh, autocast( - enabled=False, - recipe=None, + enabled=use_quantization, + recipe=recipe, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): + # Build quantizer_set inside autocast so create_set() reads the global recipe + # for correct fwd/bwd dtypes. One set per dense layer (GEMM1=AG, GEMM2=RS). + quantizer_set = QuantizerFactory.create_set() if use_quantization else noop_quantizer_set + quantizer_sets = (quantizer_set, quantizer_set) # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) @@ -181,6 +201,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, noop_collective_op_sets, + quantizer_sets, ) output, sharded_grads = _value_and_grad_layernorm_mlp( x_sharded, @@ -194,6 +215,7 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): weight_1_axes, weight_2_axes, collective_op_sets, + quantizer_sets, ) jax.block_until_ready(ref_output) jax.block_until_ready(output) @@ -210,9 +232,10 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): jax.block_until_ready(gathered_ref_grads) if args.enable_result_check and args.process_id == 0: - assert_allclose(ref_output, output, dtype=jnp.bfloat16) + tol_dtype = get_tolerance_dtype(quantizer_set) + assert_allclose(ref_output, output, dtype=tol_dtype) for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): - assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + assert_allclose(ref_grad, gathered_grad, dtype=tol_dtype) class TestCollectiveLayerNormMLPGradient(unittest.TestCase): @@ -240,9 +263,49 @@ def tearDown(self): os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) def test_te_bf16_layernorm_mlp_grad(self): - """Test Collective Dense Gradient with AllGather""" + """Test Collective LayerNorm MLP Gradient with BF16""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + def test_te_delayed_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 DelayedScaling""" + self.args.quantize_recipe = "DelayedScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + def test_te_current_scaling_fp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with FP8 Float8CurrentScaling""" + self.args.quantize_recipe = "Float8CurrentScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + run_layernorm_mlp_grad_tests(self.args, self.mesh) + def test_te_mxfp8_layernorm_mlp_grad(self): + """Test Collective LayerNorm MLP Gradient with MXFP8BlockScaling""" + self.args.quantize_recipe = "MXFP8BlockScaling" + is_supported, reason = is_scaling_mode_supported( + get_scaling_mode_from_recipe_name(self.args.quantize_recipe) + ) + if not is_supported: + self.skipTest(reason) + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + # def test_te_nvfp4_layernorm_mlp_grad(self): + # """Test Collective LayerNorm MLP Gradient with NVFP4BlockScaling""" + # self.args.quantize_recipe = "NVFP4BlockScaling" + # is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe)) + # if not is_supported: + # self.skipTest(reason) + # run_layernorm_mlp_grad_tests(self.args, self.mesh) + if __name__ == "__main__": import sys diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 70557f29c7..8df1d9995f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -396,6 +396,48 @@ def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): ) +def _reorder_tpsp_leading(tensor, original_shape): + """Reorder tensor so the tpsp axis is leading: reshape (dp, n, tpsp, m, ...), transpose (2, 0, 1, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + +def _reorder_dp_leading(tensor, original_shape): + """Reorder tensor so the dp axis is leading: reshape (tpsp, dp, n, m, ...), transpose (1, 2, 0, 3, ...).""" + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = tensor.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + return reordered.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -610,44 +652,56 @@ def impl( lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis - ) + if not collective_op.is_none and not is_outer: + # MXFP8 + Collective AG/RS: both sides of flatten_axis must be multiples of 128. + # No padding is needed in this case + lhs_first, lhs_last = math.prod(lhs.shape[:lhs_flatten_axis]), math.prod( + lhs.shape[lhs_flatten_axis:] + ) + assert lhs_first % 128 == 0 and lhs_last % 128 == 0, ( + "MXFP8 + Collective AG requires LHS dimensions before and after the flatten" + f" axis to be multiples of 128. Got lhs.shape={lhs.shape}," + f" lhs_flatten_axis={lhs_flatten_axis}" + ) + # The scale needs to be in good shape for reordering + assert lhs_scale_inv.shape[sequence_dim] % tpsp_axis_size() == 0, ( + "MXFP8 + Collective AG/RS requires LHS scale inv sequence dimension to be" + f" multiples of tpsp_axis_size. Got lhs_scale_inv.shape={lhs_scale_inv.shape}," + f" tpsp_axis_size={tpsp_axis_size()}, sequence_dim={sequence_dim}" + ) + else: + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + lhs_transposed, + lhs_flatten_axis, + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) # Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel if scaling_mode.is_mxfp8_scaling: lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) + # Determine if we need to reorder the tensor so that the input/output are in the correct layout for the collective operation + need_reorder = not transpose_batch_sequence and not is_outer and not collective_op.is_none + # Alter lhs blocks so that CGEMM RS outputs correctly + if need_reorder and collective_op.is_reduce_scatter and lhs.shape[0] != 1: + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + lhs = _reorder_tpsp_leading(lhs, lhs.shape) + if ( - collective_op.is_reduce_scatter - and not transpose_batch_sequence - and not is_outer - and not lhs.shape[0] == 1 + need_reorder + and (collective_op.is_reduce_scatter or collective_op.is_all_gather) + and lhs_scale_inv.shape[0] != 1 + and scaling_mode.is_1d_block_scaling() ): assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = lhs.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = lhs.reshape( - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - tpsp_axis_size(), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) - lhs = reordered.reshape(original_shape) + lhs_scale_inv = _reorder_tpsp_leading(lhs_scale_inv, lhs_scale_inv.shape) (output, _) = GemmPrimitive.inner_primitive.bind( lhs, @@ -667,31 +721,9 @@ def impl( collective_op=collective_op, ) # Alter output blocks for CGEMM AG - if ( - collective_op.is_all_gather - and not transpose_batch_sequence - and not is_outer - and not output.shape[0] == 1 - ): + if need_reorder and collective_op.is_all_gather and output.shape[0] != 1: assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" - original_shape = output.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) - reshaped = output.reshape( - tpsp_axis_size(), - dp_or_fsdp_axis_size(), - int(original_shape[0] / dp_or_fsdp_axis_size()), - int(original_shape[1] / tpsp_axis_size()), - *original_shape[2:], - ) - reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) - output = reordered.reshape(original_shape) + output = _reorder_dp_leading(output, output.shape) return (output,) @@ -775,6 +807,7 @@ def _parse_operand_output_specs( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -911,11 +944,22 @@ def _parse_operand_output_specs( # Bias sharding is based on GEMM output before any scatter bias_specs = rhs_non_cspecs if arg_infos[4].size > 0 else (None,) # bias is operand index 4 + # Scale shardings are based on the scaling_mode and collective_op + lhs_scale_specs = rhs_scale_specs = (None,) + if scaling_mode.is_1d_block_scaling(): + rhs_scale_specs = rhs_specs + if collective_op.is_all_gather: + lhs_scale_specs = tuple( + None if i == sequence_dim else s for i, s in enumerate(lhs_specs) + ) + else: + lhs_scale_specs = lhs_specs + if not collective_op.is_none: assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" return ( - (lhs_specs, rhs_specs, bias_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_specs), out_specs, reduce_spec, sequence_dim, @@ -937,7 +981,6 @@ def infer_sharding_from_operands( ): del ( out_dtype, - scaling_mode, use_split_accumulator, result_infos, is_outer, @@ -945,7 +988,11 @@ def infer_sharding_from_operands( ) (_, out_specs, *_) = GemmPrimitive._parse_operand_output_specs( - arg_infos, contracting_dims, transpose_batch_sequence, collective_op + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + scaling_mode, ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -968,7 +1015,7 @@ def partition( del result_infos, is_outer, sequence_dim ( - (lhs_specs, rhs_specs, bias_input_specs), + (lhs_specs, lhs_scale_specs, rhs_specs, rhs_scale_specs, bias_input_specs), out_specs, reduce_spec, inferred_sequence_dim, @@ -977,17 +1024,21 @@ def partition( contracting_dims, transpose_batch_sequence, collective_op, + scaling_mode, ) # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + lhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*lhs_scale_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + rhs_scale_sharding = NamedSharding(mesh, PartitionSpec(*rhs_scale_specs)) + arg_shardings = ( lhs_sharding, - lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + lhs_scale_sharding, rhs_sharding, - rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_scale_sharding, ) # Bias @@ -1198,6 +1249,12 @@ def _te_gemm( rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv + if not collective_op.is_none: + assert not scaling_mode.is_nvfp4_scaling, ( + f"Collective GEMM is not yet supported with {scaling_mode} quantization. " + "Only DELAYED_TENSOR_SCALING and CURRENT_TENSOR_SCALING are supported." + ) + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype if bias is None: bias = jnp.empty(0, dtype=out_dtype) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 737dd65622..2acefa2d30 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -58,6 +58,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( std::vector scale_shape = {1}; auto is_nvfp4 = is_nvfp4_scaling(scaling_mode); auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) { // Block scaling also needs to be collapsed to match 2D data scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), @@ -202,6 +203,7 @@ Error_Type GemmV2FFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); size_t workspace_size = static_cast(workspace->element_count()) - 256; + if (is_nvfp4_scaling(config.scaling_mode)) { auto lhs_scale_size = product(lhs_scale_inv.dimensions()); auto rhs_scale_size = product(rhs_scale_inv.dimensions()); diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c5256aef5c..e584b5a452 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -927,6 +927,7 @@ def apply_padding_to_scale_inv( unpadded_scale_shape = scaling_mode.get_scale_shape( data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) + assert scale_inv.shape == unpadded_scale_shape, ( f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " f"{scale_inv.shape}."