diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 56880a428d..ce94eb6e7c 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable(test_operator test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu + test_dequantize_mxfp8_grouped.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu diff --git a/tests/cpp/operator/test_dequantize_mxfp8_grouped.cu b/tests/cpp/operator/test_dequantize_mxfp8_grouped.cu new file mode 100644 index 0000000000..16f192c919 --- /dev/null +++ b/tests/cpp/operator/test_dequantize_mxfp8_grouped.cu @@ -0,0 +1,487 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + +enum ScalingDirection { ROWWISE = 0, COLWISE = 1 }; + +/** + * Compare grouped dequantize output against single-tensor nvte_dequantize + * called in a loop for each tensor. Results must be bitwise identical. + */ +template +void performTest(const ShapeRepresentation shape_rep, const size_t num_tensors, + const std::vector &logical_shape_vec, + const std::vector &first_dims_h, const std::vector &last_dims_h, + const std::vector &offsets_h, const bool rowwise) { + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = logical_shape_vec[0]; + const size_t cols = logical_shape_vec[1]; + + // Compute total elements and per-tensor scale sizes + size_t elts_num = 0; + size_t total_scales = 0; + + std::vector per_tensor_scales_first_dim(num_tensors); + std::vector per_tensor_scales_last_dim(num_tensors); + std::vector per_tensor_scales_offset(num_tensors + 1, 0); + + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + elts_num += M * K; + + size_t unpadded_scales_Y, unpadded_scales_X; + if (rowwise) { + unpadded_scales_Y = M; + unpadded_scales_X = divide_round_up(K, 32); + per_tensor_scales_first_dim[t] = + round_up_to_nearest_multiple(unpadded_scales_Y, scale_tensor_alignment_Y_rowwise); + per_tensor_scales_last_dim[t] = + round_up_to_nearest_multiple(unpadded_scales_X, scale_tensor_alignment_X_rowwise); + } else { + unpadded_scales_Y = divide_round_up(M, 32); + unpadded_scales_X = K; + per_tensor_scales_first_dim[t] = + round_up_to_nearest_multiple(unpadded_scales_Y, scale_tensor_alignment_Y_colwise); + per_tensor_scales_last_dim[t] = + round_up_to_nearest_multiple(unpadded_scales_X, scale_tensor_alignment_X_colwise); + } + + const size_t tensor_scales = per_tensor_scales_first_dim[t] * per_tensor_scales_last_dim[t]; + total_scales += tensor_scales; + per_tensor_scales_offset[t + 1] = total_scales; + } + + // Allocate host data + std::vector in_data_h(elts_num); + std::vector in_scales_h(total_scales); + + // Generate random FP8 data and scales + static std::mt19937 gen(42); + const double minAbs = Numeric_Traits::minNorm; + const double maxAbs = Numeric_Traits::maxNorm; + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + std::uniform_int_distribution int_dis(0, 255); + + for (size_t i = 0; i < elts_num; ++i) { + const bool is_negative = (dis_sign(gen) < 0.0); + double val = dis(gen); + if (is_negative) val = -val; + in_data_h[i] = static_cast(val); + } + for (size_t i = 0; i < total_scales; ++i) { + in_scales_h[i] = int_dis(gen); + } + + // Allocate device memory + const size_t in_data_size = elts_num * sizeof(InputType); + const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t scales_size = total_scales * sizeof(fp8e8m0); + const size_t first_dims_size = num_tensors * sizeof(size_t); + const size_t last_dims_size = num_tensors * sizeof(size_t); + const size_t offsets_size = (num_tensors + 1) * sizeof(size_t); + + InputType *in_data_d; + OutputType *out_grouped_d; + fp8e8m0 *in_scales_d; + size_t *first_dims_d; + size_t *last_dims_d; + size_t *offsets_d; + + cudaMalloc((void **)&in_data_d, in_data_size); + cudaMalloc((void **)&out_grouped_d, out_data_size); + cudaMalloc((void **)&in_scales_d, scales_size); + cudaMalloc((void **)&first_dims_d, first_dims_size); + cudaMalloc((void **)&last_dims_d, last_dims_size); + cudaMalloc((void **)&offsets_d, offsets_size); + + cudaMemcpy(in_data_d, in_data_h.data(), in_data_size, cudaMemcpyHostToDevice); + cudaMemcpy(in_scales_d, in_scales_h.data(), scales_size, cudaMemcpyHostToDevice); + cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); + cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); + + // Set up grouped input tensor + NVTEShape logical_shape = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + NVTEShape first_dims_shape; + NVTEShape last_dims_shape; + NVTEShape offsets_shape; + first_dims_shape.ndim = 1; + last_dims_shape.ndim = 1; + offsets_shape.ndim = 1; + first_dims_shape.data[0] = num_tensors; + last_dims_shape.data[0] = num_tensors; + offsets_shape.data[0] = num_tensors; + + // Data tensors must be 1D (flattened) + std::vector data_1d_shape = {elts_num}; + NVTEShape data_shape = nvte_make_shape(data_1d_shape.data(), data_1d_shape.size()); + + std::vector scales_1d_shape = {total_scales}; + NVTEShape scales_shape = nvte_make_shape(scales_1d_shape.data(), scales_1d_shape.size()); + + NVTEGroupedTensor in_group_tensor = + nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape); + + // Set input data (rowwise or columnwise) - data shape must be 1D + NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), data_shape}; + if (rowwise) { + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor, + sizeof(in_data_tensor)); + } else { + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, + &in_data_tensor, sizeof(in_data_tensor)); + } + + // Set scales + NVTEBasicTensor in_scales_tensor = {in_scales_d, NVTEDType::kNVTEFloat8E8M0, scales_shape}; + if (rowwise) { + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, + &in_scales_tensor, sizeof(in_scales_tensor)); + } else { + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, + &in_scales_tensor, sizeof(in_scales_tensor)); + } + + // Set shape arrays + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + } + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape}; + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor, + sizeof(last_dims_tensor)); + } + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(in_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + } + + // Set up grouped output tensor + NVTEGroupedTensor out_group_tensor = + nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape); + + NVTEBasicTensor out_data_tensor = {out_grouped_d, static_cast(otype), data_shape}; + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_tensor, + sizeof(out_data_tensor)); + + // Set shape arrays on output too + if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape}; + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + } + if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { + NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape}; + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor, + sizeof(last_dims_tensor)); + } + if (shape_rep != SAME_BOTH_DIMS) { + NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + } + + // Run grouped dequantize + nvte_group_dequantize(in_group_tensor, out_group_tensor, 0); + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Copy grouped output to host + std::vector out_grouped_h(elts_num); + cudaMemcpy(out_grouped_h.data(), out_grouped_d, out_data_size, cudaMemcpyDeviceToHost); + + // Now compute reference: run single-tensor nvte_dequantize for each tensor + std::vector out_ref_h(elts_num); + + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + const size_t data_offset = offsets_h[t]; + const size_t scales_offset = per_tensor_scales_offset[t]; + const size_t tensor_scales_count = + per_tensor_scales_first_dim[t] * per_tensor_scales_last_dim[t]; + + const size_t single_data_size = M * K * sizeof(InputType); + const size_t single_out_size = M * K * sizeof(OutputType); + const size_t single_scales_size = tensor_scales_count * sizeof(fp8e8m0); + + // Allocate per-tensor device memory + InputType *single_in_d; + OutputType *single_out_d; + fp8e8m0 *single_scales_d; + + cudaMalloc((void **)&single_in_d, single_data_size); + cudaMalloc((void **)&single_out_d, single_out_size); + cudaMalloc((void **)&single_scales_d, single_scales_size); + + cudaMemcpy(single_in_d, in_data_h.data() + data_offset, single_data_size, + cudaMemcpyHostToDevice); + cudaMemcpy(single_scales_d, in_scales_h.data() + scales_offset, single_scales_size, + cudaMemcpyHostToDevice); + cudaMemset(single_out_d, 0, single_out_size); + + // Build single-tensor NVTETensor using TensorWrapper directly + std::vector single_shape = {M, K}; + std::vector scale_shape_vec = {per_tensor_scales_first_dim[t], + per_tensor_scales_last_dim[t]}; + + TensorWrapper input_w(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_w.set_rowwise_data(single_in_d, itype, single_shape); + input_w.set_rowwise_scale_inv(single_scales_d, DType::kFloat8E8M0, scale_shape_vec); + } else { + input_w.set_columnwise_data(single_in_d, itype, single_shape); + input_w.set_columnwise_scale_inv(single_scales_d, DType::kFloat8E8M0, scale_shape_vec); + } + + TensorWrapper output_w; + output_w.set_rowwise_data(single_out_d, otype, single_shape); + + nvte_dequantize(input_w.data(), output_w.data(), 0); + cudaDeviceSynchronize(); + err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << "Single-tensor dequantize failed for tensor " << t << ": " + << cudaGetErrorString(err); + + // Copy reference output to host + cudaMemcpy(out_ref_h.data() + data_offset, single_out_d, single_out_size, + cudaMemcpyDeviceToHost); + + cudaFree(single_in_d); + cudaFree(single_out_d); + cudaFree(single_scales_d); + } + + // Bitwise comparison + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; + const size_t data_offset = offsets_h[t]; + const size_t tensor_elts = M * K; + + int result = memcmp(out_grouped_h.data() + data_offset, out_ref_h.data() + data_offset, + tensor_elts * sizeof(OutputType)); + if (result != 0) { + // Find first mismatch for error reporting + for (size_t i = 0; i < tensor_elts; ++i) { + if (out_grouped_h[data_offset + i] != out_ref_h[data_offset + i]) { + GTEST_FAIL() << "Bitwise mismatch at tensor " << t << " element " << i + << " (global offset " << (data_offset + i) << "): grouped=" + << static_cast(out_grouped_h[data_offset + i]) + << " vs reference=" << static_cast(out_ref_h[data_offset + i]); + } + } + } + } + + // Cleanup + cudaFree(in_data_d); + cudaFree(out_grouped_d); + cudaFree(in_scales_d); + cudaFree(first_dims_d); + cudaFree(last_dims_d); + cudaFree(offsets_d); +} + +// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} +std::vector> input_configs = { + {SAME_BOTH_DIMS, 1, 128, 128}, + {SAME_BOTH_DIMS, 2, 256, 128}, + {VARYING_FIRST_DIM, 2, 512, 128, 128, 384}, + {VARYING_FIRST_DIM, 2, 384, 128, 128, 256}, + {VARYING_FIRST_DIM, 5, 4096, 512, 128, 256, 384, 1024, 2304}, + {VARYING_LAST_DIM, 3, 256, 896, 128, 256, 512}, + {VARYING_BOTH_DIMS, 2, 1, (128 * 128) + (256 * 256), 128, 256, 128, 256}, + {VARYING_BOTH_DIMS, 2, 1, (256 * 128) + (512 * 640), 256, 512, 128, 640}, + // Non-128-aligned constant dimensions + {SAME_BOTH_DIMS, 1, 160, 192}, + {SAME_BOTH_DIMS, 2, 256, 96}, + {VARYING_FIRST_DIM, 2, 384, 160, 128, 256}, + {VARYING_FIRST_DIM, 3, 768, 96, 256, 256, 256}, + {VARYING_LAST_DIM, 2, 160, 384, 128, 256}, + {VARYING_LAST_DIM, 3, 96, 512, 128, 128, 256}, +}; + +std::vector scaling_directions = { + ScalingDirection::ROWWISE, + ScalingDirection::COLWISE, +}; + +} // namespace + +class GroupedDequantizeMXFP8TestSuite + : public ::testing::TestWithParam, // Config + transformer_engine::DType, // InputType + transformer_engine::DType // OutputType + >> {}; + +TEST_P(GroupedDequantizeMXFP8TestSuite, TestGroupedDequantizeMXFP8) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ScalingDirection scaling_direction = std::get<0>(GetParam()); + const std::vector config = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + + const ShapeRepresentation shape_rep = static_cast(config[0]); + const size_t num_tensors = config[1]; + const std::vector logical_shape = {config[2], config[3]}; + + const bool rowwise = (scaling_direction == ScalingDirection::ROWWISE); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + std::vector offsets(num_tensors + 1, 0); + + for (size_t t = 0; t < num_tensors; ++t) { + switch (shape_rep) { + case SAME_BOTH_DIMS: { + first_dims[t] = logical_shape[0] / num_tensors; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_FIRST_DIM: { + first_dims[t] = config[t + 4]; + last_dims[t] = logical_shape[1]; + break; + } + case VARYING_LAST_DIM: { + first_dims[t] = logical_shape[0]; + last_dims[t] = config[t + 4]; + break; + } + case VARYING_BOTH_DIMS: { + first_dims[t] = config[t + 4]; + last_dims[t] = config[t + (4 + num_tensors)]; + break; + } + } + offsets[t + 1] = offsets[t] + first_dims[t] * last_dims[t]; + + // Skip tests if varying dimensions are not 128-aligned + const bool first_dim_varies = + (shape_rep == VARYING_FIRST_DIM || shape_rep == VARYING_BOTH_DIMS); + const bool last_dim_varies = + (shape_rep == VARYING_LAST_DIM || shape_rep == VARYING_BOTH_DIMS); + if (first_dim_varies && (first_dims[t] % 128 != 0)) { + GTEST_SKIP(); + } + if (last_dim_varies && (last_dims[t] % 128 != 0)) { + GTEST_SKIP(); + } + // TMA requires last_dim * sizeof(FP8) to be 16-byte aligned + if (last_dims[t] % 16 != 0) { + GTEST_SKIP(); + } + // For colwise: first dim must be divisible by 32 + if (!rowwise && (first_dims[t] % 32 != 0)) { + GTEST_SKIP(); + } + // For rowwise: last dim must be divisible by 32 + if (rowwise && (last_dims[t] % 32 != 0)) { + GTEST_SKIP(); + } + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + output_type, OutputType, + performTest(shape_rep, num_tensors, logical_shape, first_dims, + last_dims, offsets, rowwise););); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, GroupedDequantizeMXFP8TestSuite, + ::testing::Combine(::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_configs), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo &info) { + std::string name; + switch (std::get<0>(info.param)) { + case ScalingDirection::ROWWISE: + name += "ROWWISE_"; + break; + case ScalingDirection::COLWISE: + name += "COLWISE_"; + break; + } + + const std::vector input = std::get<1>(info.param); + switch (static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: + name += "SAME_BOTH_DIMS"; + break; + case ShapeRepresentation::VARYING_FIRST_DIM: + name += "VARYING_FIRST_DIM"; + break; + case ShapeRepresentation::VARYING_LAST_DIM: + name += "VARYING_LAST_DIM"; + break; + case ShapeRepresentation::VARYING_BOTH_DIMS: + name += "VARYING_BOTH_DIMS"; + break; + } + + name += "_N_" + std::to_string(input[1]); + name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]); + name += "_" + test::typeName(std::get<2>(info.param)); + name += "_" + test::typeName(std::get<3>(info.param)); + return name; + }); diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 9dd965fa94..f02af0b6ff 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -442,6 +442,82 @@ def test_group_quantize_cudagraph_capturable(self) -> None: assert torch.equal(static_output.rowwise_data, expected.rowwise_data) assert torch.equal(static_output.scale_inv, expected.scale_inv) + @pytest.mark.parametrize( + "shape", + [[(512, 1024), (512, 1024)], [(256, 512), (512, 512), (768, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_group_dequantize(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped dequantization for MXFP8 back to BF16.""" + num_tensors = len(shape) + + # Create BF16 input tensors and quantize them with MXFP8. + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + grouped_input = torch.cat(input_tensors, dim=0) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device="cuda") + + # Quantize. + quantized = tex.group_quantize(grouped_input, quantizer, num_tensors, first_dims) + + # Dequantize. + dequantized = tex.group_dequantize(quantized, tex.DType.kBFloat16) + + # Verify output metadata. + assert dequantized.num_tensors == num_tensors + assert dequantized.logical_shape == quantized.logical_shape + assert torch.equal(dequantized.first_dims, quantized.first_dims) + assert torch.equal(dequantized.tensor_offsets, quantized.tensor_offsets) + + # Verify dequantized values are close to original. + dequantized_bf16 = dequantized.data.reshape(grouped_input.shape) + torch.testing.assert_close(dequantized_bf16, grouped_input, atol=0.125, rtol=0.1) + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_group_dequantize_cudagraph_capturable(self) -> None: + """Ensure group_dequantize is CUDA graph capturable.""" + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + grouped_input = torch.cat(input_tensors, dim=0) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + # Quantize to get MXFP8 grouped tensor. + quantized = tex.group_quantize(grouped_input, quantizer, num_tensors, first_dims) + + # Warmup dequantize. + torch.cuda.synchronize() + _ = tex.group_dequantize(quantized, tex.DType.kBFloat16) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_output = tex.group_dequantize(quantized, tex.DType.kBFloat16) + + # Replay with different input data. + fresh_input = torch.cat( + [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape], + dim=0, + ) + fresh_quantized = tex.group_quantize(fresh_input, quantizer, num_tensors, first_dims) + quantized.data.copy_(fresh_quantized.data) + quantized.scale_inv.copy_(fresh_quantized.scale_inv) + + graph.replay() + torch.cuda.synchronize() + + expected = tex.group_dequantize(quantized, tex.DType.kBFloat16) + assert torch.equal(static_output.data, expected.data) + def test_clear(self) -> None: """Test clear method""" num_tensors = 3 diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 57404ae8a5..07fd41006c 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -89,6 +89,14 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str stream); } +void nvte_group_dequantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dequantize); + using namespace transformer_engine; + dispatch::group_dequantize_helper(*convertNVTEGroupedTensorCheck(input), + convertNVTEGroupedTensorCheck(output), stream); +} + void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_configs, const size_t num_tensors, cudaStream_t stream) { diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 81304981d3..12787d609f 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -16,6 +16,7 @@ #include "../../common.h" #include "../fp8/dequantize_fp8.cuh" #include "../mxfp8/dequantize_mxfp8.cuh" +#include "../mxfp8/group_dequantize_mxfp8.cuh" #include "../nvfp4/dequantize_nvfp4.cuh" namespace transformer_engine { @@ -50,6 +51,26 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t } } +inline void group_dequantize_helper(const GroupedTensor &input, GroupedTensor *output, + cudaStream_t stream) { + CheckInputGroupedTensor(input, "group_dequantize_input"); + CheckOutputGroupedTensor(*output, "group_dequantize_output"); + + switch (input.scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::group_dequantize(&input, output, stream); + } else { + NVTE_ERROR("MXFP8 Grouped Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + default: + NVTE_ERROR("Grouped dequantize not implemented for scaling mode: " + + to_string(input.scaling_mode) + "."); + } +} + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/mxfp8/group_dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_dequantize_mxfp8.cuh new file mode 100644 index 0000000000..f603f39718 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/group_dequantize_mxfp8.cuh @@ -0,0 +1,469 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file group_dequantize_mxfp8.cuh + * \brief CUDA kernels to dequantize grouped tensors from MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_DEQUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GROUP_DEQUANTIZE_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "group_quantize_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace group_dequantize_kernel { + +constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; +__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +__device__ alignas(128) CUtensorMap g_tensor_maps_output[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; + +// Reuse helper types and functions from group_quantize_kernel namespace +using group_quantize_kernel::fence_acquire_tensormap; +using group_quantize_kernel::get_current_tensor_id; +using group_quantize_kernel::get_tensor_cols_num; +using group_quantize_kernel::get_tensor_rows_num; +using group_quantize_kernel::modify_base_tensor_map; +using group_quantize_kernel::ShapeRepresentation; + +template +__global__ void update_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_output, + const IType *const __restrict__ input_data_ptr, + const OType *const __restrict__ output_data_ptr, + const ShapeRepresentation shape_rep, + const size_t num_tensors, const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr) { + const bool leading_thread = (threadIdx.x == 0); + const size_t tensor_id = blockIdx.x; + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + const size_t offset_elts = offsets_ptr[tensor_id]; + + if (leading_thread && (tensor_id < num_tensors)) { + { + const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], + global_data_ptr, rows, cols, sizeof(IType)); + } + { + const uintptr_t global_data_ptr = reinterpret_cast(output_data_ptr + offset_elts); + modify_base_tensor_map(base_tensor_map_output, &g_tensor_maps_output[tensor_id], + global_data_ptr, rows, cols, sizeof(OType)); + } + } +} + +template +__global__ void __launch_bounds__(128) + group_dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_static, + const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, + const e8m0_t *const __restrict__ scales_ptr) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + constexpr size_t BUFFERS_NUM = 2; + constexpr size_t ELEMS_PER_THREAD = 16; + constexpr size_t BUFFER_DIM_Y = 16; + constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; + constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; + constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; + constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; + constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; + constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; + constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); + + // Group-awareness: determine which tensor this block belongs to + const bool is_single_tensor = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + + size_t tensor_id; + size_t block_id_Y, block_id_X; + + if (is_single_tensor) { + // SAME_BOTH_DIMS or VARYING_FIRST_DIM: simple 2D tiling over single logical tensor + const size_t chunks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + block_id_Y = blockIdx.x / chunks_X; + block_id_X = blockIdx.x % chunks_X; + const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK; + tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, block_id_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + } else if (shape_rep == ShapeRepresentation::VARYING_LAST_DIM) { + // Virtual 2D grid: DIVUP(R,128) row-tiles x (total_cols/128) col-tiles + const size_t chunks_X_total = last_logical_dim / CHUNK_DIM_X; + const size_t col_chunk_global = blockIdx.x % chunks_X_total; + block_id_Y = blockIdx.x / chunks_X_total; + // Search using column-based element offset (works with existing binary search) + const size_t search_offset = col_chunk_global * CHUNK_DIM_X * first_logical_dim; + tensor_id = get_current_tensor_id(shape_rep, num_tensors, search_offset, block_id_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t tensor_col_start = static_cast(offsets_ptr[tensor_id]) / first_logical_dim; + block_id_X = col_chunk_global - tensor_col_start / CHUNK_DIM_X; + } else { + // VARYING_BOTH_DIMS: 1D grid, element-offset-based (both dims 128-aligned) + const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK; + const size_t chunks_X_for_id = DIVUP(last_logical_dim, CHUNK_DIM_X); + tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset, + blockIdx.x / chunks_X_for_id, first_logical_dim, + last_logical_dim, offsets_ptr); + const size_t vb_tensor_base = static_cast(offsets_ptr[tensor_id]); + const size_t vb_cols = + get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t chunks_X = DIVUP(vb_cols, CHUNK_DIM_X); + const size_t block_id_in_tensor = blockIdx.x - vb_tensor_base / ELTS_PER_CHUNK; + block_id_Y = block_id_in_tensor / chunks_X; + block_id_X = block_id_in_tensor % chunks_X; + } + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + // Compute per-tensor scale stride from cols (matches group_quantize kernel) + const size_t scale_stride = USE_ROWWISE_SCALING + ? DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4) + : DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + + // Select TMA descriptors (static for single tensor, per-tensor for multi-tensor) + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_output = + is_single_tensor ? tensor_map_output_static : g_tensor_maps_output[tensor_id]; + + if (!is_single_tensor) { + fence_acquire_tensormap(&tensor_map_input); + fence_acquire_tensormap(&tensor_map_output); + } + + const int chunk_offset_Y = block_id_Y * CHUNK_DIM_Y; + const int chunk_offset_X = block_id_X * CHUNK_DIM_X; + + // Per-tensor scale offset + constexpr size_t SCALE_DIVISOR = USE_ROWWISE_SCALING ? SCALE_DIM_X : SCALE_DIM_Y; + size_t scales_base_offset; + if (is_single_tensor) { + scales_base_offset = 0; + } else if (shape_rep == ShapeRepresentation::VARYING_LAST_DIM) { + const size_t sum_prev_cols = tensor_base / first_logical_dim; + if constexpr (USE_ROWWISE_SCALING) { + // Scale layout: DIVUP_TO_MULTIPLE(R, 128) rows x (Ki/32) cols per tensor + const size_t padded_rows = DIVUP_TO_MULTIPLE(first_logical_dim, static_cast(128)); + scales_base_offset = (padded_rows / SCALE_DIM_X) * sum_prev_cols; + } else { + // Scale layout: DIVUP_TO_MULTIPLE(ceil(R/32), 4) rows x Ki cols per tensor + const size_t padded_scale_rows = DIVUP_TO_MULTIPLE( + DIVUP(first_logical_dim, static_cast(SCALE_DIM_Y)), static_cast(4)); + scales_base_offset = padded_scale_rows * sum_prev_cols; + } + } else { + // VARYING_BOTH_DIMS: both dims 128-padded, original formula is exact + scales_base_offset = tensor_base / SCALE_DIVISOR; + } + const e8m0_t *const tensor_scales_ptr = scales_ptr + scales_base_offset; + + const int scales_rowwise_chunk_offset_Y = block_id_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = block_id_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = block_id_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = block_id_X * SCALES_COLWISE_PER_CHUNK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + + // Static shared memory (matching single-tensor dequantize) + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + int parity = 0; + constexpr int iteration_zero = 0; + constexpr int buffer_zero = 0; + if (is_master_thread) { + const int chunk_stage_offset_Y = chunk_offset_Y; + const int chunk_stage_offset_X = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buffer_zero]), + reinterpret_cast(&tensor_map_input), chunk_stage_offset_X, + chunk_stage_offset_Y, &mbar[iteration_zero]); + + ptx::mbarrier_arrive_expect_tx(&mbar[iteration_zero], transaction_size); + } else { + ptx::mbarrier_arrive(&mbar[iteration_zero]); + } + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % BUFFERS_NUM; + const int next_iter = iter + 1; + if (next_iter < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_iter % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_buff]), + reinterpret_cast(&tensor_map_input), chunk_it_offset_x, + chunk_it_offset_y, &mbar[next_iter]); + + ptx::mbarrier_arrive_expect_tx(&mbar[next_iter], transaction_size); + } else { + ptx::mbarrier_arrive(&mbar[next_iter]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + const int scale_offset_Y = + USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) + : (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y); + + const int scale_offset_X = + USE_ROWWISE_SCALING + ? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE) + : (scales_colwise_chunk_offset_X + tid_colwise_X); + + const int scale_idx = scale_offset_Y * scale_stride + scale_offset_X; + const e8m0_t biased_exponent = tensor_scales_ptr[scale_idx]; + const float block_scale = ptx::exp2f(biased_exponent); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out; + + const int shmem_offset_y = thread_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + } + out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]); + } else { +#pragma unroll + for (int i = 0; i < BUFFER_DIM_Y; ++i) { + const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + ptx::cp_async_bulk_commit_group(); + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace group_dequantize_kernel + +inline void group_dequantize(const GroupedTensor *input, GroupedTensor *output, + cudaStream_t stream) { + using namespace group_dequantize_kernel; + using group_quantize_kernel::ShapeRepresentation; + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + constexpr size_t SHMEM_DIM_Y = 16; + constexpr size_t SHMEM_DIM_X = 128; + + checkCuDriverContext(stream); + + const bool use_rowwise_scaling = input->has_data(); + const bool use_colwise_scaling = input->has_columnwise_data(); + NVTE_CHECK(use_rowwise_scaling || use_colwise_scaling, + "Input tensor must have either rowwise or columnwise data."); + NVTE_CHECK(!(use_rowwise_scaling && use_colwise_scaling), + "Dequantize only supports rowwise or columnwise scaling, not both simultaneously."); + + NVTE_CHECK(!input->with_gemm_swizzled_scales, "Input must have scales in compact format."); + NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input must have FP8 type."); + + NVTE_CHECK(input->num_tensors == output->num_tensors, + "Number of input and output tensors must be same."); + + ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + if (input->all_same_shape()) { + shape_rep = ShapeRepresentation::SAME_BOTH_DIMS; + } else if (input->all_same_first_dim()) { + shape_rep = ShapeRepresentation::VARYING_LAST_DIM; + } else if (input->all_same_last_dim()) { + shape_rep = ShapeRepresentation::VARYING_FIRST_DIM; + } else if (input->varying_both_dims()) { + shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; + } + + const bool is_single_tensor = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || + shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); + + const size_t first_logical_dim = input->logical_shape.data[0]; + const size_t last_logical_dim = input->logical_shape.data[1]; + const size_t elts_total = first_logical_dim * last_logical_dim; + + const size_t num_tensors = input->num_tensors; + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + size_t blocks = 0; + if (is_single_tensor) { + const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + blocks = blocks_Y * blocks_X; + } else { + NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + "Number of tensors in a group is larger than " + "the MAX number of supported descriptors (64)."); + NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0, + "Last dimension of a grouped tensor should be divisible by 128."); + if (shape_rep == ShapeRepresentation::VARYING_LAST_DIM) { + blocks = DIVUP(first_logical_dim, CHUNK_DIM_Y) * (last_logical_dim / CHUNK_DIM_X); + } else { + blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + } + } + + const dim3 grid(blocks); + const dim3 block(THREADS_PER_CHUNK); + + const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + + const e8m0_t *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input->scale_inv.dptr) + : reinterpret_cast(input->columnwise_scale_inv.dptr); + + const SimpleTensor &input_data = use_rowwise_scaling ? input->data : input->columnwise_data; + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input_data, first_logical_dim, + last_logical_dim, SHMEM_DIM_Y, SHMEM_DIM_X, last_logical_dim, + 0, typeToNumBits(input->dtype())); + create_2D_tensor_map(tensor_map_output, output->data, first_logical_dim, + last_logical_dim, SHMEM_DIM_Y, SHMEM_DIM_X, last_logical_dim, + 0, typeToNumBits(output->dtype())); + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = + reinterpret_cast(input_data.dptr); + OType *const output_dptr = reinterpret_cast(output->data.dptr); + + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_output, input_dptr, output_dptr, shape_rep, + num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr); + } + + group_dequantize_mxfp8_kernel + <<>>(tensor_map_input, tensor_map_output, shape_rep, + num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, + scales_ptr);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 04712d3003..9d75c077a5 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -394,8 +394,6 @@ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. - * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, - * the block dequantization (MXFP8) of the specified shape of the block will be used. * In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise * data of the output tensor, regardless of whether the row- or columnwise scaling is used. * @@ -405,6 +403,17 @@ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, */ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Casts input grouped tensor from reduced to higher precision. + * In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise + * data of the output tensor, regardless of whether the row- or columnwise scaling is used. + * + * \param[in] input Input grouped FP8/MXFP8 tensor to be cast. + * \param[in,out] output Output grouped tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_dequantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + /*! \brief Casts multiple input tensors to quantized output tensors. * * \param[in] inputs List of input tensors to be cast. diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4d4e5094c..c9ff8071a3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -290,6 +290,8 @@ py::object dequantize(const py::handle &input, DType otype); py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); +py::object group_dequantize(const py::handle &input, DType otype); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..befa765c4f 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -251,6 +251,80 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) return out; } +py::object group_dequantize(const py::handle &input, transformer_engine::DType otype) { + using namespace pybind11::literals; + init_extension(); + + // Extract fields from the Python GroupedTensor. + const auto num_tensors = input.attr("num_tensors").cast(); + const auto logical_shape_py = input.attr("logical_shape").cast(); + const auto logical_first_dim = logical_shape_py[0].cast(); + const auto logical_last_dim = logical_shape_py[1].cast(); + const std::vector logical_shape = {logical_first_dim, logical_last_dim}; + const auto &quantizer = convert_quantizer(input.attr("quantizer")); + + // Extract optional tensor attributes. + auto get_optional_tensor = [&input](const char *name) -> std::optional { + auto attr = input.attr(name); + if (attr.is_none()) return std::nullopt; + return attr.cast(); + }; + auto rowwise_data = get_optional_tensor("data"); + auto columnwise_data = get_optional_tensor("columnwise_data"); + auto rowwise_scale_inv = get_optional_tensor("scale_inv"); + auto columnwise_scale_inv = get_optional_tensor("columnwise_scale_inv"); + auto first_dims = get_optional_tensor("first_dims"); + auto tensor_offsets = get_optional_tensor("tensor_offsets"); + + // Early-return for empty input. + if (logical_first_dim == 0 || logical_last_dim == 0) { + NoneQuantizer q{py::none()}; + auto [out_cpp, out_py] = + q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(), first_dims, + logical_first_dim, logical_last_dim); + return py::reinterpret_borrow(out_py); + } + + // Build input GroupedTensorWrapper. + auto input_cpp = GroupedTensorWrapper(num_tensors, logical_shape, quantizer->get_scaling_mode()); + if (rowwise_data.has_value()) { + input_cpp.set_rowwise_data(rowwise_data->data_ptr(), + GetTransformerEngineDType(rowwise_data->scalar_type()), + getTensorShape(*rowwise_data)); + if (rowwise_scale_inv.has_value()) { + input_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + } + if (columnwise_data.has_value()) { + input_cpp.set_columnwise_data(columnwise_data->data_ptr(), + GetTransformerEngineDType(columnwise_data->scalar_type()), + getTensorShape(*columnwise_data)); + if (columnwise_scale_inv.has_value()) { + input_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + } + if (first_dims.has_value()) { + input_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + input_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + // Create output GroupedTensor using NoneQuantizer. + NoneQuantizer q{py::none()}; + auto [out_cpp, out_py] = q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(), + first_dims, logical_first_dim, logical_last_dim); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_dequantize(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(out_py); +} + namespace { void multi_tensor_quantize_impl(const std::vector &input_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8302a13010..369e0a3b0d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -141,6 +141,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("otype")); m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + m.def("group_dequantize", transformer_engine::pytorch::group_dequantize, + "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",