Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-
import argparse
import jax
from jax.experimental import mesh_utils
from transformer_engine.common import recipe as te_recipe
from transformer_engine.jax.quantize import ScalingMode
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap

# Global flag to track if distributed has been initialized
Expand Down Expand Up @@ -183,6 +185,36 @@ def _create_mesh(args):
return mesh


def get_scaling_mode_from_recipe_name(name: str) -> ScalingMode:
"""Get ScalingMode from a recipe name string."""
match name:
case "DelayedScaling":
return ScalingMode.DELAYED_TENSOR_SCALING
case "Float8CurrentScaling":
return ScalingMode.CURRENT_TENSOR_SCALING
case "MXFP8BlockScaling":
return ScalingMode.MXFP8_1D_SCALING
case "NVFP4BlockScaling":
return ScalingMode.NVFP4_1D_SCALING
case _:
raise ValueError(f"Invalid recipe name, got {name}")


def get_quantization_recipe_from_name_string(name: str):
"""Query recipe from a given name string"""
match name:
case "DelayedScaling":
return te_recipe.DelayedScaling()
case "MXFP8BlockScaling":
return te_recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return te_recipe.Float8CurrentScaling()
case "NVFP4BlockScaling":
return te_recipe.NVFP4BlockScaling()
case _:
raise ValueError(f"Invalid quantization_recipe, got {name}")


def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
"""Create common argument parser for all collective GEMM tests."""
parser = argparse.ArgumentParser(description=description)
Expand Down Expand Up @@ -229,13 +261,19 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
help="Type of collective operation",
)
parser.add_argument(
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
"--quantize-recipe", type=str, default="DelayedScaling", help="Quantization recipe to use"
)
parser.add_argument(
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
)
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Enable FP8 quantization",
)

return parser
79 changes: 57 additions & 22 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,43 @@ else
echo "NVLINK support detected"
fi

# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
# Define individual test cases to run (file::class::method)
# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all
# the time.
TEST_CASES=(
# test_gemm.py cases
"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_all_gather_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_bf16_reduce_scatter_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_all_gather_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_delayed_scaling_fp8_reduce_scatter_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_all_gather_with_dp"
"test_gemm.py::TestCollectiveGemmWithDP::test_te_current_scaling_fp8_reduce_scatter_with_dp"
# TODO(Phuong): Enable when supported
# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_all_gather_with_dp"
# "test_gemm.py::TestCollectiveGemmWithDP::test_te_mxfp8_reduce_scatter_with_dp"
# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_all_gather_with_dp"
# "test_gemm.py::TestCollectiveGemmWithDP::test_te_nvfp4_reduce_scatter_with_dp"

# test_dense_grad.py cases
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_all_gather"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_bf16_reduce_scatter"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_all_gather"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_delayed_scaling_fp8_reduce_scatter"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_all_gather"
"test_dense_grad.py::TestCollectiveDenseGradient::test_te_current_scaling_fp8_reduce_scatter"
# TODO(Phuong): Enable when supported
# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_all_gather"
# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_mxfp8_reduce_scatter"
# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_all_gather"
# "test_dense_grad.py::TestCollectiveDenseGradient::test_te_nvfp4_reduce_scatter"

# test_layernorm_mlp_grad.py cases
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_bf16_layernorm_mlp_grad"
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_delayed_scaling_fp8_layernorm_mlp_grad"
"test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_current_scaling_fp8_layernorm_mlp_grad"
# TODO(Phuong): Enable when supported
# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_mxfp8_layernorm_mlp_grad"
# "test_layernorm_mlp_grad.py::TestCollectiveLayerNormMLPGradient::test_te_nvfp4_layernorm_mlp_grad"
)

echo
Expand Down Expand Up @@ -57,32 +89,35 @@ cleanup() {
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM

# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
echo "=== Starting test file: $TEST_FILE ..."
echo "=== Starting test: $TEST_CASE ..."

# Clear PIDs array for this test file
# Extract just the test method name for log/xml file naming
TEST_NAME=$(echo "$TEST_CASE" | awk -F'::' '{print $NF}')

# Clear PIDs array for this test case
PIDS=()

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
LOG_FILE="${TEST_NAME}_gpu_${i}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
Expand All @@ -93,22 +128,22 @@ for TEST_FILE in "${TEST_FILES[@]}"; do
# Wait for all processes to finish
wait

# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE PASSED"
elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "... $TEST_FILE INVALID"
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
rm ${TEST_NAME}_gpu_*.log
done

wait
Expand Down
123 changes: 117 additions & 6 deletions examples/jax/collective_gemm/test_dense_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
get_quantization_recipe_from_name_string,
get_scaling_mode_from_recipe_name,
)

from transformer_engine.jax.dense import dense

from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.quantize import (
autocast,
is_scaling_mode_supported,
QuantizerFactory,
noop_quantizer_set,
)
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOp,
CollectiveOpSet,
Expand Down Expand Up @@ -56,7 +63,9 @@ def _get_operand_sharding(mesh, collective_op):
return x_sharding, weight_sharding, bias_sharding


def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
def _mean_dense(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set
):
output = dense(
x,
weight,
Expand All @@ -66,13 +75,16 @@ def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collectiv
kernel_axes=weight_axes,
output_axes=output_axes,
collective_op_set=collective_op_set,
quantizer_set=quantizer_set,
)
return jnp.mean(output.astype(jnp.float32))


def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
def _value_and_grad_dense(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set
):
return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set, quantizer_set
)


Expand All @@ -98,11 +110,16 @@ def run_dense_grad_tests(args, mesh=None):
)
collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)

use_fp8 = getattr(args, "use_fp8", False)
recipe = get_quantization_recipe_from_name_string(args.quantize_recipe) if use_fp8 else None
with mesh, autocast(
enabled=False,
recipe=None,
enabled=use_fp8,
recipe=recipe,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Build quantizer_set inside autocast so create_set() reads the global recipe
# for correct fwd/bwd dtypes.
quantizer_set = QuantizerFactory.create_set() if use_fp8 else noop_quantizer_set
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
Expand All @@ -123,6 +140,7 @@ def run_dense_grad_tests(args, mesh=None):
weight_axes,
output_axes,
noop_collective_op_set,
quantizer_set,
)
output, sharded_grads = _value_and_grad_dense(
x_sharded,
Expand All @@ -132,6 +150,7 @@ def run_dense_grad_tests(args, mesh=None):
weight_axes,
output_axes,
collective_op_set,
quantizer_set,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
Expand Down Expand Up @@ -187,6 +206,98 @@ def test_te_bf16_reduce_scatter(self):
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)

def test_te_delayed_scaling_fp8_all_gather(self):
"""Test Collective Dense Gradient with FP8 DelayedScaling + AllGather"""
self.args.quantize_recipe = "DelayedScaling"
is_supported, reason = is_scaling_mode_supported(
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
)
if not is_supported:
self.skipTest(reason)
self.args.use_fp8 = True
self.args.collective_type = "all_gather"
run_dense_grad_tests(self.args, self.mesh)

def test_te_delayed_scaling_fp8_reduce_scatter(self):
"""Test Collective Dense Gradient with FP8 DelayedScaling + ReduceScatter"""
self.args.quantize_recipe = "DelayedScaling"
is_supported, reason = is_scaling_mode_supported(
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
)
if not is_supported:
self.skipTest(reason)
self.args.use_fp8 = True
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)

def test_te_current_scaling_fp8_all_gather(self):
"""Test Collective Dense Gradient with FP8 Float8CurrentScaling + AllGather"""
self.args.quantize_recipe = "Float8CurrentScaling"
is_supported, reason = is_scaling_mode_supported(
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
)
if not is_supported:
self.skipTest(reason)
self.args.use_fp8 = True
self.args.collective_type = "all_gather"
run_dense_grad_tests(self.args, self.mesh)

def test_te_current_scaling_fp8_reduce_scatter(self):
"""Test Collective Dense Gradient with FP8 Float8CurrentScaling + ReduceScatter"""
self.args.quantize_recipe = "Float8CurrentScaling"
is_supported, reason = is_scaling_mode_supported(
get_scaling_mode_from_recipe_name(self.args.quantize_recipe)
)
if not is_supported:
self.skipTest(reason)
self.args.use_fp8 = True
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)

# TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported
# def test_te_mxfp8_all_gather(self):
# """Test Collective Dense Gradient with MXFP8BlockScaling + AllGather"""
# self.args.quantize_recipe = "MXFP8BlockScaling"
# is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
# if not is_supported:
# self.skipTest(reason)
# self.args.use_fp8 = True
# self.args.collective_type = "all_gather"
# run_dense_grad_tests(self.args, self.mesh)

# TODO: Enable when MXFP8BlockScaling + Collective GEMM is supported
# def test_te_mxfp8_reduce_scatter(self):
# """Test Collective Dense Gradient with MXFP8BlockScaling + ReduceScatter"""
# self.args.quantize_recipe = "MXFP8BlockScaling"
# is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
# if not is_supported:
# self.skipTest(reason)
# self.args.use_fp8 = True
# self.args.collective_type = "reduce_scatter"
# run_dense_grad_tests(self.args, self.mesh)

# TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported
# def test_te_nvfp4_all_gather(self):
# """Test Collective Dense Gradient with NVFP4BlockScaling + AllGather"""
# self.args.quantize_recipe = "NVFP4BlockScaling"
# is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
# if not is_supported:
# self.skipTest(reason)
# self.args.use_fp8 = True
# self.args.collective_type = "all_gather"
# run_dense_grad_tests(self.args, self.mesh)

# TODO: Enable when NVFP4BlockScaling + Collective GEMM is supported
# def test_te_nvfp4_reduce_scatter(self):
# """Test Collective Dense Gradient with NVFP4BlockScaling + ReduceScatter"""
# self.args.quantize_recipe = "NVFP4BlockScaling"
# is_supported, reason = is_scaling_mode_supported(get_scaling_mode_from_recipe_name(self.args.quantize_recipe))
# if not is_supported:
# self.skipTest(reason)
# self.args.use_fp8 = True
# self.args.collective_type = "reduce_scatter"
# run_dense_grad_tests(self.args, self.mesh)


if __name__ == "__main__":
import sys
Expand Down
Loading
Loading