Skip to content
Open
110 changes: 101 additions & 9 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>

#include "../test_common.h"
Expand All @@ -32,6 +33,7 @@ namespace {
enum class InputCase {
kFP8Current,
kBF16,
kMXFP8,
};

enum class ShapeCase {
Expand All @@ -44,16 +46,29 @@ enum class ShapeCase {
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
// Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays
size_t size = 8 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}

Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);

const size_t numel = shape[0] * shape[1];
std::vector<float> data(numel);
std::mt19937 gen(std::hash<std::string>{}(name));
// Random mean and stddev -> different amax per tensor -> different scales
std::uniform_real_distribution<float> param_dis(0.1f, 10.0f);
float mean = param_dis(gen);
float stddev = param_dis(gen);
std::normal_distribution<float> dis(mean, stddev);
for (size_t i = 0; i < numel; ++i) {
data[i] = dis(gen);
}
NVTE_CHECK_CUDA(cudaMemcpy(input_fp32.rowwise_dptr(), data.data(),
numel * sizeof(float), cudaMemcpyHostToDevice));

Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);

Expand All @@ -73,6 +88,63 @@ Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& sha
return t;
}

// Creates an MXFP8 operand with the correct data layout for GEMM.
// MXFP8 GEMM requirements (scales are along K dimension):
// A transposed -> needs rowwise data/scales
// A non-transposed -> needs columnwise data/scales
// B transposed -> needs columnwise data/scales
// B non-transposed -> needs rowwise data/scales
Tensor make_mxfp8_operand(const std::string& name, const std::vector<size_t>& shape,
bool is_A, bool transposed) {
// Determine which data layout we need
bool use_rowwise, use_colwise;
if (is_A) {
// A: transposed -> rowwise, non-transposed -> columnwise
use_rowwise = transposed;
use_colwise = !transposed;
} else {
// B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!)
use_rowwise = !transposed;
use_colwise = transposed;
}

// Create BF16 input with random data
Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16);
fillUniform(&input_bf16);

// Create MXFP8 tensor with only the required data layout
Tensor mxfp8(name, shape, TypeInfo<fp8e4m3>::dtype, use_rowwise, use_colwise,
NVTE_MXFP8_1D_SCALING);

// Quantize BF16 -> MXFP8
nvte_quantize(input_bf16.data(), mxfp8.data(), 0);

// Create output tensor for swizzled scales (same data shape, same layout)
Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo<fp8e4m3>::dtype,
use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING);
mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call

// Copy quantized data from mxfp8 to mxfp8_swizzled
if (use_rowwise) {
size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
}
if (use_colwise) {
size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
}

// Swizzle scales for GEMM
nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0);

// Sync to ensure operations are complete
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

return mxfp8_swizzled;
}

struct TestParams {
InputCase input_case;
bool transa;
Expand All @@ -88,16 +160,16 @@ struct TestParams {
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
}
}

Expand Down Expand Up @@ -138,6 +210,13 @@ void run_grouped_gemm_case(const TestParams& params) {
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kMXFP8: {
A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape,
/*is_A=*/true, params.transa));
B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape,
/*is_A=*/false, params.transb));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
Expand Down Expand Up @@ -246,7 +325,9 @@ void run_grouped_gemm_case(const TestParams& params) {
cublas_ws.data(),
nullptr, // config (use defaults)
0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Compare results
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
Expand Down Expand Up @@ -277,7 +358,7 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
}

std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
Expand All @@ -288,16 +369,27 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest

// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
// FP8 tests (each tensor has random mean/stddev -> different scales)
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
// BF16 tests
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
// MXFP8 tests
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false},
// MXFP8 with NULL C
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true},
};

INSTANTIATE_TEST_SUITE_P(OperatorTest,
Expand Down
114 changes: 96 additions & 18 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,14 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode) {
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
const NVTEShape shape = tensors[0]->rowwise_shape();

// Check which data layouts are available (all tensors must have the same)
const bool has_rowwise = tensors[0]->rowwise();
const bool has_columnwise = tensors[0]->columnwise();
NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout.");

const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape()
: tensors[0]->columnwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
Expand All @@ -1076,7 +1083,8 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = tensors[i]->rowwise_shape();
const auto s = has_rowwise ? tensors[i]->rowwise_shape()
: tensors[i]->columnwise_shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
Expand Down Expand Up @@ -1105,10 +1113,11 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
};

const bool need_offsets = !same_first || !same_last;
const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0);
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -1146,21 +1155,24 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;

grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}

NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));

const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
// Copy rowwise data if available
if (has_rowwise) {
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));
}

// Copy columnwise data if available
if (has_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
Expand Down Expand Up @@ -1202,11 +1214,17 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
}

if (isFp8Type(dtype)) {
if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
// FP8 tensor scaling: one float scale_inv per tensor
// For delayed scaling, rowwise and columnwise share the same scale
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
if (has_rowwise) {
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
} else {
scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr<float>()[0];
}
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
Expand All @@ -1217,6 +1235,66 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
sizeof(scale_tensor));
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
// MXFP8: E8M0 scale_inv per block of 32 elements
// Helper to gather scale_inv from individual tensors into a contiguous buffer
auto gather_scales = [&](
auto get_shape_fn,
auto get_cpu_ptr_fn) -> std::pair<CudaPtr<>, size_t> {
// Compute total size and offsets
size_t total_bytes = 0;
std::vector<size_t> scale_offsets(num_tensors);
std::vector<size_t> numels(num_tensors);

for (size_t i = 0; i < num_tensors; ++i) {
scale_offsets[i] = total_bytes;
const NVTEShape shape = get_shape_fn(tensors[i]);
size_t numel = 1;
for (size_t d = 0; d < shape.ndim; ++d) {
numel *= shape.data[d];
}
numels[i] = numel;
total_bytes += numel; // E8M0 is 1 byte per element
}

// Allocate and copy
CudaPtr<> buffer = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
Comment on lines +1262 to +1268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant CPU sync for swizzled MXFP8 scales.

The loop calls tensors[i]->to_cpu() on line 1263, then immediately passes the tensor to get_cpu_ptr_fn(tensors[i]) on line 1267. However, both rowwise_cpu_scale_inv_ptr<uint8_t>() and columnwise_cpu_scale_inv_ptr<uint8_t>() internally call to_cpu() themselves (test_common.h lines 249 and 264), making the explicit call on line 1263 redundant.

Additionally, the GPU pointers are available directly via get_rowwise_scale_inv().data_ptr and get_columnwise_scale_inv().data_ptr, allowing a device-to-device copy that avoids the round-trip entirely:

Suggested change
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
NVTE_CHECK_CUDA(cudaMemcpy(dst,
has_rowwise ? tensors[i]->tensor_.get_rowwise_scale_inv().data_ptr
: tensors[i]->tensor_.get_columnwise_scale_inv().data_ptr,
numels[i],
cudaMemcpyDeviceToDevice));

This improves both clarity and efficiency in test code.

return {std::move(buffer), total_bytes};
};

// Gather rowwise scale_inv if available
if (has_rowwise) {
auto [row_buffer, row_total] = gather_scales(
[](Tensor* t) { return t->rowwise_scale_inv_shape(); },
[](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr<uint8_t>(); });
grouped.scale_inv = std::move(row_buffer);

NVTEShape row_shape = nvte_make_shape(&row_total, 1);
NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor));
}

// Gather columnwise scale_inv if available
if (has_columnwise) {
auto [col_buffer, col_total] = gather_scales(
[](Tensor* t) { return t->columnwise_scale_inv_shape(); },
[](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr<uint8_t>(); });
grouped.columnwise_scale_inv = std::move(col_buffer);

NVTEShape col_shape = nvte_make_shape(&col_total, 1);
NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor));
}

// Mark as having swizzled scales (required for GEMM)
nvte_set_grouped_tensor_swizzled_scales(h, 1);
}

return grouped;
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ struct GroupedBuffers {
GroupedTensorHandle handle;
CudaPtr<> data;
CudaPtr<> scale_inv;
CudaPtr<> columnwise_scale_inv;
CudaPtr<int64_t> first_dims_dev;
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ struct GroupedTensor {
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)),
nvte_tensor(0) {}
nvte_tensor(0),
with_gemm_swizzled_scales(false) {}

explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }

Expand Down
Loading
Loading