From 58906ba21c4b1de0ffe55d13fe8a04f81e28dba9 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 9 Mar 2026 15:47:00 -0700 Subject: [PATCH 1/9] MXFP8 grouped GEMM + tensor-scaled FP8 fixes Signed-off-by: Jeremy Berchtold --- tests/cpp/operator/test_grouped_gemm.cu | 110 +++++++++- tests/cpp/test_common.cu | 111 ++++++++-- tests/cpp/test_common.h | 1 + transformer_engine/common/common.h | 4 +- .../common/gemm/cublaslt_grouped_gemm.cu | 204 ++++++++++++++++-- .../transformer_engine/transformer_engine.h | 17 ++ .../common/transformer_engine.cpp | 16 ++ 7 files changed, 415 insertions(+), 48 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index a7aabbbcb6..cfa6b4c431 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "../test_common.h" @@ -32,6 +33,7 @@ namespace { enum class InputCase { kFP8Current, kBF16, + kMXFP8, }; enum class ShapeCase { @@ -44,8 +46,8 @@ enum class ShapeCase { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) - size_t size = 6 * ptr_bytes + 6 * int_bytes; + // Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays + size_t size = 8 * ptr_bytes + 6 * int_bytes; const size_t alignment = 256; size = ((size + alignment - 1) / alignment) * alignment; return size; @@ -53,7 +55,20 @@ size_t grouped_setup_workspace_size(const size_t num_tensors) { Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); - fillUniform(&input_fp32); + + const size_t numel = shape[0] * shape[1]; + std::vector data(numel); + std::mt19937 gen(std::hash{}(name)); + // Random mean and stddev -> different amax per tensor -> different scales + std::uniform_real_distribution param_dis(0.1f, 10.0f); + float mean = param_dis(gen); + float stddev = param_dis(gen); + std::normal_distribution dis(mean, stddev); + for (size_t i = 0; i < numel; ++i) { + data[i] = dis(gen); + } + NVTE_CHECK_CUDA(cudaMemcpy(input_fp32.rowwise_dptr(), data.data(), + numel * sizeof(float), cudaMemcpyHostToDevice)); Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); @@ -73,6 +88,63 @@ Tensor make_bf16_operand(const std::string& name, const std::vector& sha return t; } +// Creates an MXFP8 operand with the correct data layout for GEMM. +// MXFP8 GEMM requirements (scales are along K dimension): +// A transposed -> needs rowwise data/scales +// A non-transposed -> needs columnwise data/scales +// B transposed -> needs columnwise data/scales +// B non-transposed -> needs rowwise data/scales +Tensor make_mxfp8_operand(const std::string& name, const std::vector& shape, + bool is_A, bool transposed) { + // Determine which data layout we need + bool use_rowwise, use_colwise; + if (is_A) { + // A: transposed -> rowwise, non-transposed -> columnwise + use_rowwise = transposed; + use_colwise = !transposed; + } else { + // B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!) + use_rowwise = !transposed; + use_colwise = transposed; + } + + // Create BF16 input with random data + Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16); + fillUniform(&input_bf16); + + // Create MXFP8 tensor with only the required data layout + Tensor mxfp8(name, shape, TypeInfo::dtype, use_rowwise, use_colwise, + NVTE_MXFP8_1D_SCALING); + + // Quantize BF16 -> MXFP8 + nvte_quantize(input_bf16.data(), mxfp8.data(), 0); + + // Create output tensor for swizzled scales (same data shape, same layout) + Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo::dtype, + use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING); + mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call + + // Copy quantized data from mxfp8 to mxfp8_swizzled + if (use_rowwise) { + size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype()); + NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice)); + } + if (use_colwise) { + size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype()); + NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice)); + } + + // Swizzle scales for GEMM + nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0); + + // Sync to ensure operations are complete + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + return mxfp8_swizzled; +} + struct TestParams { InputCase input_case; bool transa; @@ -88,16 +160,16 @@ struct TestParams { std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: - return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; case ShapeCase::kSameFirst: // Same M (first dim), varying N and K - return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; + return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; case ShapeCase::kSameLast: // Same N (last dim), varying M and K - return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; + return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; + return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; } } @@ -138,6 +210,13 @@ void run_grouped_gemm_case(const TestParams& params) { B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); break; } + case InputCase::kMXFP8: { + A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + /*is_A=*/true, params.transa)); + B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + /*is_A=*/false, params.transb)); + break; + } } D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), std::vector{M, N}, @@ -246,7 +325,9 @@ void run_grouped_gemm_case(const TestParams& params) { cublas_ws.data(), nullptr, // config (use defaults) 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + // Compare results for (size_t i = 0; i < num_gemms; ++i) { Tensor grouped_split("grouped_D" + std::to_string(i), std::vector{static_cast(std::get<0>(shapes[i])), @@ -277,7 +358,7 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { } std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { - constexpr const char* kInputNames[] = {"FP8Current", "BF16"}; + constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"}; constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + "tb" + (info.param.transb ? "T" : "N"); @@ -288,16 +369,27 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo kTestParams = { - // Basic tests + // FP8 tests (each tensor has random mean/stddev -> different scales) {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + // BF16 tests {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, // Test NULL C (valid when beta=0) {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, + // MXFP8 tests + {InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false}, + {InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false}, + {InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false}, + {InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false}, + {InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false}, + // MXFP8 with NULL C + {InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index b64ae24131..32338ee801 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1061,7 +1061,14 @@ std::array get_scale_tensor_dims(const size_t rows, GroupedBuffers build_grouped_tensor(const std::vector& tensors, const NVTEScalingMode scaling_mode) { NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); - const NVTEShape shape = tensors[0]->rowwise_shape(); + + // Check which data layouts are available (all tensors must have the same) + const bool has_rowwise = tensors[0]->rowwise(); + const bool has_columnwise = tensors[0]->columnwise(); + NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout."); + + const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape() + : tensors[0]->columnwise_shape(); const DType dtype = tensors[0]->dtype(); const size_t num_tensors = tensors.size(); const size_t elem_size = typeToNumBits(dtype) / 8; @@ -1076,7 +1083,8 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, std::vector first_dims(num_tensors); std::vector last_dims(num_tensors); for (size_t i = 0; i < num_tensors; ++i) { - const auto s = tensors[i]->rowwise_shape(); + const auto s = has_rowwise ? tensors[i]->rowwise_shape() + : tensors[i]->columnwise_shape(); NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors."); first_dims[i] = static_cast(s.data[0]); last_dims[i] = static_cast(s.data[1]); @@ -1146,21 +1154,24 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, : (logical_first * logical_last); const size_t total_bytes = static_cast(total_elems) * elem_size; - grouped.data = cuda_alloc(total_bytes); - for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, - tensors[i]->rowwise_dptr(), - grouped.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - } - - NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; NVTEGroupedTensor h = grouped.handle.get(); - nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor)); - const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); - if (include_columnwise) { + // Copy rowwise data if available + if (has_rowwise) { + grouped.data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor)); + } + + // Copy columnwise data if available + if (has_columnwise) { grouped.columnwise_data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; @@ -1202,11 +1213,17 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); } - if (isFp8Type(dtype)) { + if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + // FP8 tensor scaling: one float scale_inv per tensor + // For delayed scaling, rowwise and columnwise share the same scale std::vector scale_inv_cpu(num_tensors, 1.f); for (size_t i = 0; i < num_tensors; ++i) { tensors[i]->to_cpu(); - scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + if (has_rowwise) { + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } else { + scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr()[0]; + } } grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), @@ -1217,6 +1234,66 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, sizeof(scale_tensor)); nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, sizeof(scale_tensor)); + } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + // MXFP8: E8M0 scale_inv per block of 32 elements + // Helper to gather scale_inv from individual tensors into a contiguous buffer + auto gather_scales = [&]( + auto get_shape_fn, + auto get_cpu_ptr_fn) -> std::pair, size_t> { + // Compute total size and offsets + size_t total_bytes = 0; + std::vector scale_offsets(num_tensors); + std::vector numels(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + scale_offsets[i] = total_bytes; + const NVTEShape shape = get_shape_fn(tensors[i]); + size_t numel = 1; + for (size_t d = 0; d < shape.ndim; ++d) { + numel *= shape.data[d]; + } + numels[i] = numel; + total_bytes += numel; // E8M0 is 1 byte per element + } + + // Allocate and copy + CudaPtr<> buffer = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + NVTE_CHECK_CUDA(cudaGetLastError()); + void* dst = static_cast(buffer.get()) + scale_offsets[i]; + const void* src = get_cpu_ptr_fn(tensors[i]); + NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice)); + } + return {std::move(buffer), total_bytes}; + }; + + // Gather rowwise scale_inv if available + if (has_rowwise) { + auto [row_buffer, row_total] = gather_scales( + [](Tensor* t) { return t->rowwise_scale_inv_shape(); }, + [](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr(); }); + grouped.scale_inv = std::move(row_buffer); + + NVTEShape row_shape = nvte_make_shape(&row_total, 1); + NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor)); + } + + // Gather columnwise scale_inv if available + if (has_columnwise) { + auto [col_buffer, col_total] = gather_scales( + [](Tensor* t) { return t->columnwise_scale_inv_shape(); }, + [](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr(); }); + grouped.columnwise_scale_inv = std::move(col_buffer); + + NVTEShape col_shape = nvte_make_shape(&col_total, 1); + NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor)); + } + + // Mark as having swizzled scales (required for GEMM) + nvte_set_grouped_tensor_swizzled_scales(h, 1); } return grouped; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 5bb6400629..927407f478 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -535,6 +535,7 @@ struct GroupedBuffers { GroupedTensorHandle handle; CudaPtr<> data; CudaPtr<> scale_inv; + CudaPtr<> columnwise_scale_inv; CudaPtr first_dims_dev; CudaPtr last_dims_dev; CudaPtr offsets_dev; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 1749b5734a..3de661a0ad 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -334,7 +334,6 @@ struct GroupedTensor { NVTEShape logical_shape; NVTEGroupedTensor nvte_tensor; - /*! \brief Whether scaling factors are in format expected by GEMM * * Only meaningful for MXFP8 and NVFP4. @@ -370,7 +369,8 @@ struct GroupedTensor { last_dims(nullptr, std::vector{0}, DType::kInt64), tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), - nvte_tensor(0) {} + nvte_tensor(0), + with_gemm_swizzled_scales(false) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index dc4757ab90..a7dae5176a 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include +#include #include "../common.h" #include "../util/cuda_runtime.h" @@ -26,6 +27,9 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { } // namespace +// MXFP8 support for grouped GEMM requires cuBLAS 13.3+ +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130100 + #if CUBLAS_VERSION >= 130200 namespace { @@ -109,6 +113,8 @@ struct GroupedGemmSetupWorkspace { void **D_ptrs; float **alpha_ptrs; float **beta_ptrs; + void **a_scale_inv_ptrs; // Per-tensor FP8 scale pointers for A (float* for tensor scaling, E8M0* for MXFP8) + void **b_scale_inv_ptrs; // Per-tensor FP8 scale pointers for B (float* for tensor scaling, E8M0* for MXFP8) // Storage dimensions for cuBLAS matrix layouts int *a_rows; int *a_cols; @@ -118,28 +124,47 @@ struct GroupedGemmSetupWorkspace { int *d_cols; // N (last dim) - also used for C // Initialize from workspace buffer - // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) + // Layout: all pointer arrays first (16-byte aligned for cuBLAS), then int arrays static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); + constexpr size_t kPtrAlignment = 16; // cuBLAS requires 16-byte alignment for pointer arrays + + // Helper to align offset to kPtrAlignment + auto align_offset = [&]() { + offset = (offset + kPtrAlignment - 1) / kPtrAlignment * kPtrAlignment; + }; - // Pointer arrays first (all 8-byte aligned) + // Pointer arrays first (all 16-byte aligned for cuBLAS grouped GEMM) + align_offset(); ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); + ws.a_scale_inv_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + align_offset(); + ws.b_scale_inv_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; - // Int arrays for storage dimensions (4-byte aligned) + // Int arrays for storage dimensions (4-byte aligned is fine) + align_offset(); ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); @@ -159,8 +184,12 @@ struct GroupedGemmSetupWorkspace { static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 6 int arrays - size_t size = 6 * ptr_size + 6 * int_size; + constexpr size_t kPtrAlignment = 16; // Must match from_buffers + + // Layout: 8 ptr arrays (each 16-byte aligned), then 6 int arrays + // Each ptr array takes ptr_size bytes but needs to start at 16-byte boundary + auto aligned_ptr_size = ((ptr_size + kPtrAlignment - 1) / kPtrAlignment) * kPtrAlignment; + size_t size = 8 * aligned_ptr_size + 6 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -226,8 +255,11 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor struct GroupedOperandSelection { TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed char *dptr = nullptr; - void *scale_inv = nullptr; + void *scale_inv = nullptr; // Contiguous array of scales (input) + void **scale_inv_ptrs = nullptr; // Array of pointers to scales (output, for cuBLAS) transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + bool with_gemm_swizzled_scales = false; bool trans = false; }; @@ -266,15 +298,19 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); - // Currently only unquantized data and tensor-scaled FP8 are supported. const auto sm = t->scaling_mode; - NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING, - "Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"); + const bool mxfp8 = is_mxfp_scaling(sm); + + // Validate scaling mode + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING || mxfp8, + "Grouped GEMM is only supported with tensor scaling and MXFP8"); const DType row_dtype = t->data.dtype; const DType col_dtype = t->columnwise_data.dtype; GroupedOperandSelection sel; sel.trans = trans; + sel.scaling_mode = sm; + sel.with_gemm_swizzled_scales = t->with_gemm_swizzled_scales; const DType rep_dtype = has_row ? row_dtype : col_dtype; const bool is_fp8 = is_fp8_dtype(rep_dtype); @@ -296,6 +332,36 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: sel.shape = create_shape_info(t, /*swap_dims=*/false); }; + // MXFP8: Row-wise and column-wise data are scaled along different dimensions. + if (mxfp8) { + if (is_A) { + if (trans) { + NVTE_CHECK(has_row, "Grouped GEMM: MXFP8 transposed A is missing row-wise data"); + use_rowwise(); + } else { + NVTE_CHECK(has_col, "Grouped GEMM: MXFP8 non-transposed A is missing column-wise data"); + // Use columnwise data/scales but keep dims un-swapped and trans unchanged. + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.scale_inv = t->columnwise_scale_inv.dptr; + sel.dtype = col_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/false); + } + } else { // B + if (trans) { + NVTE_CHECK(has_col, "Grouped GEMM: MXFP8 transposed B is missing column-wise data"); + // Use columnwise data/scales but keep dims un-swapped and trans unchanged. + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.scale_inv = t->columnwise_scale_inv.dptr; + sel.dtype = col_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/false); + } else { + NVTE_CHECK(has_row, "Grouped GEMM: MXFP8 non-transposed B is missing row-wise data"); + use_rowwise(); + } + } + return sel; + } + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. if (is_fp8 && !non_tn_fp8_ok) { if (is_A) { @@ -383,6 +449,11 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); + + // Fast accumulation mode: 0 = split accumulator (more accurate), 1 = fast accumulator + int8_t fastAccuMode = 0; // Use split accumulator for accuracy + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); } inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, @@ -392,17 +463,62 @@ inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); if (!is_fp8_a && !is_fp8_b) return; + const bool mxfp8_a = transformer_engine::is_mxfp_scaling(A_sel.scaling_mode); + const bool mxfp8_b = transformer_engine::is_mxfp_scaling(B_sel.scaling_mode); + +#if CUBLAS_VERSION >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION + // For MXFP8, verify scales are swizzled and set scale mode + if (mxfp8_a || mxfp8_b) { + NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION, + "MXFP8 grouped GEMM requires cuBLAS ", CUBLAS_MXFP8_GROUPED_GEMM_VERSION, + "+, but run-time cuBLAS version is ", transformer_engine::cuda::cublas_version()); + } + + if (mxfp8_a) { + NVTE_CHECK(A_sel.with_gemm_swizzled_scales, + "MXFP8 grouped GEMM: A scales must be swizzled for GEMM"); + cublasLtMatmulMatrixScale_t scale_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scale_mode_a, sizeof(scale_mode_a))); + } + if (mxfp8_b) { + NVTE_CHECK(B_sel.with_gemm_swizzled_scales, + "MXFP8 grouped GEMM: B scales must be swizzled for GEMM"); + cublasLtMatmulMatrixScale_t scale_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scale_mode_b, sizeof(scale_mode_b))); + } +#else + NVTE_CHECK(!mxfp8_a && !mxfp8_b, + "MXFP8 grouped GEMM requires cuBLAS ", CUBLAS_MXFP8_GROUPED_GEMM_VERSION, + "+, but compile-time cuBLAS version is ", CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION + if (is_fp8_a) { - void *a_scale_inv = A_sel.scale_inv; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK(A_sel.scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK(A_sel.scale_inv_ptrs != nullptr, "FP8 grouped GEMM: A scale_inv_ptrs is required"); + if (!mxfp8_a) { + // Tensor scaling: PER_BATCH_SCALAR_32F for grouped GEMM with float** pointer array + cublasLtMatmulMatrixScale_t scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_PER_BATCH_SCALAR_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scale_mode, sizeof(scale_mode))); + } + void *a_scale_ptrs = A_sel.scale_inv_ptrs; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_ptrs, sizeof(a_scale_ptrs))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.scale_inv; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK(B_sel.scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK(B_sel.scale_inv_ptrs != nullptr, "FP8 grouped GEMM: B scale_inv_ptrs is required"); + if (!mxfp8_b) { + // Tensor scaling: PER_BATCH_SCALAR_32F for grouped GEMM with float** pointer array + cublasLtMatmulMatrixScale_t scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_PER_BATCH_SCALAR_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scale_mode, sizeof(scale_mode))); + } + void *b_scale_ptrs = B_sel.scale_inv_ptrs; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_ptrs, sizeof(b_scale_ptrs))); } } @@ -471,11 +587,15 @@ __global__ void setup_grouped_gemm_kernel( // Output arrays void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols, int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs, + void **a_scale_inv_ptrs, void **b_scale_inv_ptrs, // Inputs char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, - size_t num_tensors) { + // Scale inputs: for tensor scaling, pass float* and set mxfp8_base to nullptr + // For MXFP8, pass nullptr for tensor_scale and set mxfp8_base + float *a_tensor_scale, float *b_tensor_scale, + char *a_mxfp8_scale_base, char *b_mxfp8_scale_base, size_t num_tensors) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -506,12 +626,29 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); + // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). d_rows[idx] = static_cast(d_last); d_cols[idx] = static_cast(d_first); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; beta_ptrs[idx] = beta_ptr + idx; + + // Fill FP8 scale pointers (per-matrix) + // For tensor scaling: one float per tensor, indexed by tensor index + // For MXFP8: E8M0 blocks, offset computed from data offset (1 scale byte per 32 elements) + if (a_tensor_scale) { + a_scale_inv_ptrs[idx] = a_tensor_scale + idx; + } else if (a_mxfp8_scale_base) { + int64_t a_scale_offset = a_offset / 32; + a_scale_inv_ptrs[idx] = a_mxfp8_scale_base + a_scale_offset; + } + if (b_tensor_scale) { + b_scale_inv_ptrs[idx] = b_tensor_scale + idx; + } else if (b_mxfp8_scale_base) { + int64_t b_scale_offset = b_offset / 32; + b_scale_inv_ptrs[idx] = b_mxfp8_scale_base + b_scale_offset; + } } // Launch the setup kernel to populate workspace arrays @@ -537,12 +674,34 @@ inline void launch_grouped_gemm_setup( const int threads_per_block = 256; const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + // Get scale pointers for FP8 + // For tensor scaling: float* array indexed by tensor + // For MXFP8: char* base, kernel computes offsets from data offsets + float *a_tensor_scale = nullptr; + float *b_tensor_scale = nullptr; + char *a_mxfp8_scale_base = nullptr; + char *b_mxfp8_scale_base = nullptr; + + if (transformer_engine::is_mxfp_scaling(A_sel.scaling_mode)) { + a_mxfp8_scale_base = static_cast(A_sel.scale_inv); + } else if (A_sel.scale_inv) { + a_tensor_scale = static_cast(A_sel.scale_inv); + } + if (transformer_engine::is_mxfp_scaling(B_sel.scaling_mode)) { + b_mxfp8_scale_base = static_cast(B_sel.scale_inv); + } else if (B_sel.scale_inv) { + b_tensor_scale = static_cast(B_sel.scale_inv); + } + setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, - ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base, + ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, + ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs, + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), static_cast(beta_tensor->data.dptr), - num_tensors); + a_tensor_scale, b_tensor_scale, + a_mxfp8_scale_base, b_mxfp8_scale_base, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -599,8 +758,8 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT // Select operand storage (row-wise vs column-wise) and adjust transpose flags to // mirror the non-grouped GEMM logic for FP8 layout constraints. - const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); // Workspaces: setup (pointer arrays) and cuBLAS const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); @@ -613,6 +772,11 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( static_cast(setup_workspace_ptr), num_tensors); + + // Set scale_inv_ptrs from workspace (kernel will fill these arrays for both tensor scaling and MXFP8) + A_sel.scale_inv_ptrs = setup_workspace.a_scale_inv_ptrs; + B_sel.scale_inv_ptrs = setup_workspace.b_scale_inv_ptrs; + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, beta_tensor, num_tensors, stream); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e316f8be8c..aed3ecc1c7 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -542,6 +542,23 @@ NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) */ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Set whether the grouped tensor has GEMM-swizzled scales. + * + * \param[in] tensor Grouped tensor. + * \param[in] val 1 if scales are swizzled, 0 otherwise. + */ +void nvte_set_grouped_tensor_swizzled_scales(NVTEGroupedTensor tensor, uint8_t val); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get whether the grouped tensor has GEMM-swizzled scales. + * + * \param[in] tensor Grouped tensor. + * + * \return 1 if scales are swizzled, 0 otherwise. + */ +uint8_t nvte_get_grouped_tensor_swizzled_scales(const NVTEGroupedTensor tensor); + #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index cd02074fbd..d9708fd5e8 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1358,3 +1358,19 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); return t.logical_shape; } + +void nvte_set_grouped_tensor_swizzled_scales(NVTEGroupedTensor tensor, uint8_t val) { + if (tensor == nullptr) { + return; + } + auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + t.with_gemm_swizzled_scales = (val != 0); +} + +uint8_t nvte_get_grouped_tensor_swizzled_scales(const NVTEGroupedTensor tensor) { + if (tensor == nullptr) { + return 0; + } + const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + return t.with_gemm_swizzled_scales ? 1 : 0; +} From 3ec6decf1901e0323e9d63ca7d4655cf734dcf55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:50:20 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index a7dae5176a..b154307772 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -113,8 +113,10 @@ struct GroupedGemmSetupWorkspace { void **D_ptrs; float **alpha_ptrs; float **beta_ptrs; - void **a_scale_inv_ptrs; // Per-tensor FP8 scale pointers for A (float* for tensor scaling, E8M0* for MXFP8) - void **b_scale_inv_ptrs; // Per-tensor FP8 scale pointers for B (float* for tensor scaling, E8M0* for MXFP8) + void ** + a_scale_inv_ptrs; // Per-tensor FP8 scale pointers for A (float* for tensor scaling, E8M0* for MXFP8) + void ** + b_scale_inv_ptrs; // Per-tensor FP8 scale pointers for B (float* for tensor scaling, E8M0* for MXFP8) // Storage dimensions for cuBLAS matrix layouts int *a_rows; int *a_cols; @@ -255,7 +257,7 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor struct GroupedOperandSelection { TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed char *dptr = nullptr; - void *scale_inv = nullptr; // Contiguous array of scales (input) + void *scale_inv = nullptr; // Contiguous array of scales (input) void **scale_inv_ptrs = nullptr; // Array of pointers to scales (output, for cuBLAS) transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; @@ -478,20 +480,20 @@ inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, NVTE_CHECK(A_sel.with_gemm_swizzled_scales, "MXFP8 grouped GEMM: A scales must be swizzled for GEMM"); cublasLtMatmulMatrixScale_t scale_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scale_mode_a, sizeof(scale_mode_a))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &scale_mode_a, sizeof(scale_mode_a))); } if (mxfp8_b) { NVTE_CHECK(B_sel.with_gemm_swizzled_scales, "MXFP8 grouped GEMM: B scales must be swizzled for GEMM"); cublasLtMatmulMatrixScale_t scale_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scale_mode_b, sizeof(scale_mode_b))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &scale_mode_b, sizeof(scale_mode_b))); } #else - NVTE_CHECK(!mxfp8_a && !mxfp8_b, - "MXFP8 grouped GEMM requires cuBLAS ", CUBLAS_MXFP8_GROUPED_GEMM_VERSION, - "+, but compile-time cuBLAS version is ", CUBLAS_VERSION); + NVTE_CHECK(!mxfp8_a && !mxfp8_b, "MXFP8 grouped GEMM requires cuBLAS ", + CUBLAS_MXFP8_GROUPED_GEMM_VERSION, "+, but compile-time cuBLAS version is ", + CUBLAS_VERSION); #endif // CUBLAS_VERSION >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION if (is_fp8_a) { @@ -594,8 +596,8 @@ __global__ void setup_grouped_gemm_kernel( size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, // Scale inputs: for tensor scaling, pass float* and set mxfp8_base to nullptr // For MXFP8, pass nullptr for tensor_scale and set mxfp8_base - float *a_tensor_scale, float *b_tensor_scale, - char *a_mxfp8_scale_base, char *b_mxfp8_scale_base, size_t num_tensors) { + float *a_tensor_scale, float *b_tensor_scale, char *a_mxfp8_scale_base, + char *b_mxfp8_scale_base, size_t num_tensors) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -695,12 +697,10 @@ inline void launch_grouped_gemm_setup( setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, - ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, - ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs, - A_sel.dptr, B_sel.dptr, c_base, d_base, - A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, - static_cast(alpha_tensor->data.dptr), static_cast(beta_tensor->data.dptr), - a_tensor_scale, b_tensor_scale, + ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs, + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), a_tensor_scale, b_tensor_scale, a_mxfp8_scale_base, b_mxfp8_scale_base, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); From d1e42bbdf5a35d9da3e56f27e300a6a93a29e3ad Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 9 Mar 2026 16:28:03 -0700 Subject: [PATCH 3/9] Change version to 13.3 Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b154307772..90dafc5733 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -28,7 +28,7 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { } // namespace // MXFP8 support for grouped GEMM requires cuBLAS 13.3+ -#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130100 +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 #if CUBLAS_VERSION >= 130200 From d5eaa1f23411ad4b0fad1f1d40fcb9589bb06724 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 9 Mar 2026 19:40:44 -0700 Subject: [PATCH 4/9] Random padding condition shouldnt be done for mxfp8 Signed-off-by: vthumbe1503 --- tests/cpp/test_common.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 32338ee801..09d33eb5ee 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1113,10 +1113,11 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, }; const bool need_offsets = !same_first || !same_last; + const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING; if (need_offsets) { offsets[0] = 0; for (size_t i = 1; i < num_tensors; ++i) { - offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); + offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0); } } else { for (size_t i = 0; i < num_tensors; ++i) { From 62c21c85c2f2ad71cda42bd467ba145d7f0ed42d Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 9 Mar 2026 19:46:26 -0700 Subject: [PATCH 5/9] Remove incorrect comment Signed-off-by: vthumbe1503 --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 90dafc5733..5afe1858f7 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -628,7 +628,6 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). d_rows[idx] = static_cast(d_last); d_cols[idx] = static_cast(d_first); From 0a06edb96fee16f5776c3044f4008412705750ec Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 9 Mar 2026 20:51:30 -0700 Subject: [PATCH 6/9] CUBLAS > 13.2 is enough Signed-off-by: vthumbe1503 --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 5afe1858f7..e167283b9c 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -27,8 +27,8 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { } // namespace -// MXFP8 support for grouped GEMM requires cuBLAS 13.3+ -#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 +// MXFP8 support for grouped GEMM requires cuBLAS 13.2+ +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130200 #if CUBLAS_VERSION >= 130200 From d50b71132a864e55b2c51a2ddf6c21c21a327662 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 9 Mar 2026 23:43:29 -0700 Subject: [PATCH 7/9] CUBLAS version needed for MXFP8 indeed seems to be 13.3 Signed-off-by: vthumbe1503 --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index e167283b9c..5afe1858f7 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -27,8 +27,8 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { } // namespace -// MXFP8 support for grouped GEMM requires cuBLAS 13.2+ -#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130200 +// MXFP8 support for grouped GEMM requires cuBLAS 13.3+ +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 #if CUBLAS_VERSION >= 130200 From 853d093ad6660f453d7426939c9c63fe1d4aa26b Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 10 Mar 2026 10:01:29 -0700 Subject: [PATCH 8/9] Accidental line removal added back. Plus need changes ci t trigger Add documentation for scaling factors in common.h Signed-off-by: vthumbe1503 --- transformer_engine/common/common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 3de661a0ad..09623aafe7 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -334,6 +334,7 @@ struct GroupedTensor { NVTEShape logical_shape; NVTEGroupedTensor nvte_tensor; + /*! \brief Whether scaling factors are in format expected by GEMM * * Only meaningful for MXFP8 and NVFP4. From dfd85b79e493088aad741b60d64c4bdf9b39fde8 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 10 Mar 2026 23:11:48 -0700 Subject: [PATCH 9/9] Update cuBLAS version requirement for MXFP8 support Signed-off-by: vthumbe1503 --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 5afe1858f7..e167283b9c 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -27,8 +27,8 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { } // namespace -// MXFP8 support for grouped GEMM requires cuBLAS 13.3+ -#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 +// MXFP8 support for grouped GEMM requires cuBLAS 13.2+ +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130200 #if CUBLAS_VERSION >= 130200