[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes#2748
[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes#2748jberchtold-nvidia wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR extends the TransformerEngine grouped GEMM implementation with two features: (1) MXFP8 grouped GEMM support via cuBLAS 13.3+, including MXFP8 operand layout selection ( Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant nvte_grouped_gemm
participant select_grouped_operand
participant setup_kernel as setup_grouped_gemm_kernel (GPU)
participant set_fp8_scale_pointers
participant cuBLAS as cublasLtMatmul (GPU)
Caller->>nvte_grouped_gemm: A (MXFP8/FP8), B (MXFP8/FP8), transa, transb
nvte_grouped_gemm->>select_grouped_operand: select rowwise/columnwise data for A
select_grouped_operand-->>nvte_grouped_gemm: A_sel (dptr, scale_inv, shape, scaling_mode)
nvte_grouped_gemm->>select_grouped_operand: select rowwise/columnwise data for B
select_grouped_operand-->>nvte_grouped_gemm: B_sel (dptr, scale_inv, shape, scaling_mode)
nvte_grouped_gemm->>nvte_grouped_gemm: GroupedGemmSetupWorkspace::from_buffers()<br/>(16-byte aligned ptr arrays for 8 arrays)
nvte_grouped_gemm->>nvte_grouped_gemm: A_sel.scale_inv_ptrs = ws.a_scale_inv_ptrs<br/>B_sel.scale_inv_ptrs = ws.b_scale_inv_ptrs
nvte_grouped_gemm->>setup_kernel: launch (A_ptrs, B_ptrs, ..., a_scale_inv_ptrs,<br/>b_scale_inv_ptrs, a_tensor_scale/a_mxfp8_base, ...)
Note over setup_kernel: For each tensor idx:<br/> A_ptrs[idx] = a_base + a_offset * elem_size<br/> For MXFP8: scale_ptrs[idx] = scale_base + a_offset/32<br/> For tensor-scaled: scale_ptrs[idx] = float_array + idx
nvte_grouped_gemm->>set_fp8_scale_pointers: matmulDesc, A_sel, B_sel
Note over set_fp8_scale_pointers: MXFP8: set VEC32_UE8M0 scale mode<br/>Tensor-scaled: set PER_BATCH_SCALAR_32F<br/>Both: set SCALE_POINTER to ptr array
set_fp8_scale_pointers-->>nvte_grouped_gemm: matmulDesc configured
nvte_grouped_gemm->>cuBLAS: cublasLtMatmul(handle, matmulDesc,<br/>alpha_ptrs, A_ptrs, descA, B_ptrs, descB, ...)
Note over cuBLAS: Reads per-tensor scale pointers<br/>from ws.a/b_scale_inv_ptrs<br/>(filled by setup_kernel on same stream)
cuBLAS-->>Caller: D (grouped output)
|
|
/te-ci |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci |
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci |
| for (size_t i = 0; i < num_tensors; ++i) { | ||
| tensors[i]->to_cpu(); | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i]; | ||
| const void* src = get_cpu_ptr_fn(tensors[i]); | ||
| NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice)); | ||
| } |
There was a problem hiding this comment.
Redundant CPU sync for swizzled MXFP8 scales.
The loop calls tensors[i]->to_cpu() on line 1263, then immediately passes the tensor to get_cpu_ptr_fn(tensors[i]) on line 1267. However, both rowwise_cpu_scale_inv_ptr<uint8_t>() and columnwise_cpu_scale_inv_ptr<uint8_t>() internally call to_cpu() themselves (test_common.h lines 249 and 264), making the explicit call on line 1263 redundant.
Additionally, the GPU pointers are available directly via get_rowwise_scale_inv().data_ptr and get_columnwise_scale_inv().data_ptr, allowing a device-to-device copy that avoids the round-trip entirely:
| for (size_t i = 0; i < num_tensors; ++i) { | |
| tensors[i]->to_cpu(); | |
| NVTE_CHECK_CUDA(cudaGetLastError()); | |
| void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i]; | |
| const void* src = get_cpu_ptr_fn(tensors[i]); | |
| NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice)); | |
| } | |
| NVTE_CHECK_CUDA(cudaMemcpy(dst, | |
| has_rowwise ? tensors[i]->tensor_.get_rowwise_scale_inv().data_ptr | |
| : tensors[i]->tensor_.get_columnwise_scale_inv().data_ptr, | |
| numels[i], | |
| cudaMemcpyDeviceToDevice)); |
This improves both clarity and efficiency in test code.
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Add documentation for scaling factors in common.h Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci |
1 similar comment
|
/te-ci |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: