Skip to content

[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes#2748

Open
jberchtold-nvidia wants to merge 10 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-common-mxfp8-grouped-gemm-plus-fixes
Open

[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes#2748
jberchtold-nvidia wants to merge 10 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-common-mxfp8-grouped-gemm-plus-fixes

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 2 commits March 9, 2026 15:47
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR extends the TransformerEngine grouped GEMM implementation with two features: (1) MXFP8 grouped GEMM support via cuBLAS 13.3+, including MXFP8 operand layout selection (VEC32_UE8M0 scale mode, swizzled scale pointer arrays), and (2) tensor-scaled FP8 grouped GEMM fixes that correctly set PER_BATCH_SCALAR_32F scale mode and switch from a single scale array pointer to a per-tensor pointer array required by the grouped GEMM API.

Key changes:

  • cublaslt_grouped_gemm.cu: Expands setup workspace to 8 pointer arrays (adds a_scale_inv_ptrs / b_scale_inv_ptrs), applies 16-byte alignment per array as required by cuBLAS, adds MXFP8 operand selection logic, and fixes tensor-scaled FP8 by explicitly setting CUBLASLT_MATMUL_MATRIX_SCALE_PER_BATCH_SCALAR_32F. Also unconditionally enables split accumulator (FAST_ACCUM=0) for all grouped GEMMs (separate tracking PR Pytorch binding for cublas grouped gemm #2669).
  • test_common.cu / test_common.h: Refactors build_grouped_tensor to handle columnwise-only tensors (needed for MXFP8 non-transposed A) and adds MXFP8 scale gathering; disables random padding for MXFP8 (avoids a_offset/32 alignment issues).
  • transformer_engine.h / transformer_engine.cpp: Adds experimental nvte_set/get_grouped_tensor_swizzled_scales API.

Issues found:

  • No cross-validation that A and B use the same FP8 scaling mode — passing MXFP8 A with tensor-scaled B silently sets conflicting cuBLAS scale mode attributes.
  • The runtime cuBLAS version check error message prints the raw version integer (130300) rather than a human-readable string (13.3.0).
  • make_mxfp8_operand uses cudaDeviceSynchronize() where a targeted cudaStreamSynchronize(0) suffices.
  • gather_scales performs an unnecessary GPU→CPU→GPU round-trip for scale data that could use direct device-to-device copies.

Confidence Score: 3/5

  • Merging is risky until the missing A/B scaling-mode cross-validation is addressed; the remaining issues are test-code quality concerns.
  • The core cuBLAS grouped GEMM plumbing (workspace alignment, pointer array filling, scale mode attributes) is well-structured and the logic for selecting MXFP8 vs tensor-scaled paths per operand is correct. However, the absence of a guard against mixing MXFP8 and tensor-scaled FP8 operands in the same call is a latent production correctness bug. The test shapes are carefully chosen to sidestep the known a_offset/32 divisibility constraint but don't cover the general case. Test code also has inefficiencies (round-trip copies, over-broad sync) that reduce confidence in test reliability at scale.
  • Pay close attention to transformer_engine/common/gemm/cublaslt_grouped_gemm.cu (mixed scaling mode validation) and tests/cpp/test_common.cu (gather_scales offset consistency).

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core MXFP8 grouped GEMM implementation: adds scale pointer arrays to workspace, MXFP8 operand selection logic, VEC32_UE8M0/PER_BATCH_SCALAR_32F scale mode setup, and split-accumulator mode. Missing cross-validation that A and B use the same scaling mode; version code in error messages is not human-readable.
tests/cpp/operator/test_grouped_gemm.cu Adds kMXFP8 test case, make_mxfp8_operand helper, and updated test shapes (all multiples of 128 which avoids the a_offset/32 divisibility edge case). Test workspace size helper still uses unaligned calculation vs production's aligned calculation. Broad cudaDeviceSynchronize used where stream-specific sync suffices.
tests/cpp/test_common.cu Refactors build_grouped_tensor to support columnwise-only tensors and adds MXFP8 scale gathering. The gather_scales lambda does a GPU→CPU→GPU round-trip that could be replaced with a direct device-to-device copy. Scale offset consistency between gather_scales and the kernel depends on exact data_numel/32 divisibility (only safe for the chosen test shapes).
tests/cpp/test_common.h Adds columnwise_scale_inv field to GroupedBuffers struct to hold the separate columnwise MXFP8 scale buffer. Straightforward addition with no issues.
transformer_engine/common/common.h Adds with_gemm_swizzled_scales field initialization to GroupedTensor constructor. Minimal, correct change.
transformer_engine/common/include/transformer_engine/transformer_engine.h Adds nvte_set_grouped_tensor_swizzled_scales / nvte_get_grouped_tensor_swizzled_scales API declarations, clearly marked EXPERIMENTAL. No issues.
transformer_engine/common/transformer_engine.cpp Implements the two new swizzled-scales getter/setter functions with null-pointer guards consistent with the rest of the file. Also handles kNVTEGroupedWithGEMMSwizzledScales in get/set param dispatch. Clean implementation.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant nvte_grouped_gemm
    participant select_grouped_operand
    participant setup_kernel as setup_grouped_gemm_kernel (GPU)
    participant set_fp8_scale_pointers
    participant cuBLAS as cublasLtMatmul (GPU)

    Caller->>nvte_grouped_gemm: A (MXFP8/FP8), B (MXFP8/FP8), transa, transb
    nvte_grouped_gemm->>select_grouped_operand: select rowwise/columnwise data for A
    select_grouped_operand-->>nvte_grouped_gemm: A_sel (dptr, scale_inv, shape, scaling_mode)
    nvte_grouped_gemm->>select_grouped_operand: select rowwise/columnwise data for B
    select_grouped_operand-->>nvte_grouped_gemm: B_sel (dptr, scale_inv, shape, scaling_mode)

    nvte_grouped_gemm->>nvte_grouped_gemm: GroupedGemmSetupWorkspace::from_buffers()<br/>(16-byte aligned ptr arrays for 8 arrays)
    nvte_grouped_gemm->>nvte_grouped_gemm: A_sel.scale_inv_ptrs = ws.a_scale_inv_ptrs<br/>B_sel.scale_inv_ptrs = ws.b_scale_inv_ptrs

    nvte_grouped_gemm->>setup_kernel: launch (A_ptrs, B_ptrs, ..., a_scale_inv_ptrs,<br/>b_scale_inv_ptrs, a_tensor_scale/a_mxfp8_base, ...)
    Note over setup_kernel: For each tensor idx:<br/>  A_ptrs[idx] = a_base + a_offset * elem_size<br/>  For MXFP8: scale_ptrs[idx] = scale_base + a_offset/32<br/>  For tensor-scaled: scale_ptrs[idx] = float_array + idx

    nvte_grouped_gemm->>set_fp8_scale_pointers: matmulDesc, A_sel, B_sel
    Note over set_fp8_scale_pointers: MXFP8: set VEC32_UE8M0 scale mode<br/>Tensor-scaled: set PER_BATCH_SCALAR_32F<br/>Both: set SCALE_POINTER to ptr array
    set_fp8_scale_pointers-->>nvte_grouped_gemm: matmulDesc configured

    nvte_grouped_gemm->>cuBLAS: cublasLtMatmul(handle, matmulDesc,<br/>alpha_ptrs, A_ptrs, descA, B_ptrs, descB, ...)
    Note over cuBLAS: Reads per-tensor scale pointers<br/>from ws.a/b_scale_inv_ptrs<br/>(filled by setup_kernel on same stream)
    cuBLAS-->>Caller: D (grouped output)
Loading

Comments Outside Diff (3)

  1. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 464-476 (link)

    Missing validation: mixed MXFP8 / tensor-scaled operands

    set_fp8_scale_pointers validates each operand's scale mode independently but never checks that A and B use the same mode. If a caller passes MXFP8 A with tensor-scaled B (or vice versa), the function will silently set conflicting A_SCALE_MODE / B_SCALE_MODE attributes (VEC32_UE8M0 vs PER_BATCH_SCALAR_32F) on the same matmulDesc. cuBLAS does not document support for mixed scale modes in a single grouped GEMM, so the resulting computation may silently produce incorrect values or crash at runtime with a cuBLAS error.

    A defensive check before the existing per-operand branches would catch this early:

    if ((mxfp8_a && !mxfp8_b && is_fp8_b) || (!mxfp8_a && is_fp8_a && mxfp8_b)) {
        NVTE_CHECK(false,
            "Grouped GEMM: A and B must use the same FP8 scaling mode. "
            "Mixing MXFP8 and tensor-scaled FP8 is not supported.");
    }
  2. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 589-591 (link)

    Runtime cuBLAS version value is not human-readable in the error message

    CUBLAS_MXFP8_GROUPED_GEMM_VERSION is 130300, so the error message the user sees will read "MXFP8 grouped GEMM requires cuBLAS 130300+, but run-time cuBLAS version is 130100", which is opaque. Most developers expect a major.minor.patch string (e.g., "13.3.0").

    Consider a small helper or a more descriptive literal in the message:

    NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_MXFP8_GROUPED_GEMM_VERSION,
               "MXFP8 grouped GEMM requires cuBLAS 13.3.0+ (version code ",
               CUBLAS_MXFP8_GROUPED_GEMM_VERSION, "), but run-time cuBLAS version code is ",
               transformer_engine::cuda::cublas_version());
  3. tests/cpp/operator/test_grouped_gemm.cu, line 110-112 (link)

    cudaDeviceSynchronize blocks all streams unnecessarily

    Both nvte_quantize and nvte_swizzle_scaling_factors are issued on stream 0. cudaDeviceSynchronize() blocks every CUDA stream, which is unnecessarily broad (and slow in test suites that parallelize across streams). A targeted sync on stream 0 is sufficient:

Last reviewed commit: 853d093

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

vthumbe1503
vthumbe1503 previously approved these changes Mar 9, 2026
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator

/te-ci

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator

/te-ci

Comment on lines +1262 to +1268
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant CPU sync for swizzled MXFP8 scales.

The loop calls tensors[i]->to_cpu() on line 1263, then immediately passes the tensor to get_cpu_ptr_fn(tensors[i]) on line 1267. However, both rowwise_cpu_scale_inv_ptr<uint8_t>() and columnwise_cpu_scale_inv_ptr<uint8_t>() internally call to_cpu() themselves (test_common.h lines 249 and 264), making the explicit call on line 1263 redundant.

Additionally, the GPU pointers are available directly via get_rowwise_scale_inv().data_ptr and get_columnwise_scale_inv().data_ptr, allowing a device-to-device copy that avoids the round-trip entirely:

Suggested change
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
NVTE_CHECK_CUDA(cudaMemcpy(dst,
has_rowwise ? tensors[i]->tensor_.get_rowwise_scale_inv().data_ptr
: tensors[i]->tensor_.get_columnwise_scale_inv().data_ptr,
numels[i],
cudaMemcpyDeviceToDevice));

This improves both clarity and efficiency in test code.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Add documentation for scaling factors in common.h

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

1 similar comment
@vthumbe1503
Copy link
Collaborator

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants