-
Notifications
You must be signed in to change notification settings - Fork 658
[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes #2748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
58906ba
3ec6dec
d1e42bb
9dc9463
d5eaa1f
62c21c8
f132773
0a06edb
d50b711
853d093
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1061,7 +1061,14 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, | |||||||||||||||||||||||||
| GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors, | ||||||||||||||||||||||||||
| const NVTEScalingMode scaling_mode) { | ||||||||||||||||||||||||||
| NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); | ||||||||||||||||||||||||||
| const NVTEShape shape = tensors[0]->rowwise_shape(); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Check which data layouts are available (all tensors must have the same) | ||||||||||||||||||||||||||
| const bool has_rowwise = tensors[0]->rowwise(); | ||||||||||||||||||||||||||
| const bool has_columnwise = tensors[0]->columnwise(); | ||||||||||||||||||||||||||
| NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout."); | ||||||||||||||||||||||||||
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape() | ||||||||||||||||||||||||||
| : tensors[0]->columnwise_shape(); | ||||||||||||||||||||||||||
| const DType dtype = tensors[0]->dtype(); | ||||||||||||||||||||||||||
| const size_t num_tensors = tensors.size(); | ||||||||||||||||||||||||||
| const size_t elem_size = typeToNumBits(dtype) / 8; | ||||||||||||||||||||||||||
|
|
@@ -1076,7 +1083,8 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors, | |||||||||||||||||||||||||
| std::vector<int64_t> first_dims(num_tensors); | ||||||||||||||||||||||||||
| std::vector<int64_t> last_dims(num_tensors); | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| const auto s = tensors[i]->rowwise_shape(); | ||||||||||||||||||||||||||
| const auto s = has_rowwise ? tensors[i]->rowwise_shape() | ||||||||||||||||||||||||||
| : tensors[i]->columnwise_shape(); | ||||||||||||||||||||||||||
| NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors."); | ||||||||||||||||||||||||||
| first_dims[i] = static_cast<int64_t>(s.data[0]); | ||||||||||||||||||||||||||
| last_dims[i] = static_cast<int64_t>(s.data[1]); | ||||||||||||||||||||||||||
|
|
@@ -1105,10 +1113,11 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors, | |||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| const bool need_offsets = !same_first || !same_last; | ||||||||||||||||||||||||||
| const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING; | ||||||||||||||||||||||||||
| if (need_offsets) { | ||||||||||||||||||||||||||
| offsets[0] = 0; | ||||||||||||||||||||||||||
| for (size_t i = 1; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); | ||||||||||||||||||||||||||
| offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
|
|
@@ -1146,21 +1155,24 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors, | |||||||||||||||||||||||||
| : (logical_first * logical_last); | ||||||||||||||||||||||||||
| const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| grouped.data = cuda_alloc(total_bytes); | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size; | ||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes, | ||||||||||||||||||||||||||
| tensors[i]->rowwise_dptr(), | ||||||||||||||||||||||||||
| grouped.tensor_bytes[i], | ||||||||||||||||||||||||||
| cudaMemcpyDeviceToDevice)); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape}; | ||||||||||||||||||||||||||
| NVTEGroupedTensor h = grouped.handle.get(); | ||||||||||||||||||||||||||
| nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor)); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); | ||||||||||||||||||||||||||
| if (include_columnwise) { | ||||||||||||||||||||||||||
| // Copy rowwise data if available | ||||||||||||||||||||||||||
| if (has_rowwise) { | ||||||||||||||||||||||||||
| grouped.data = cuda_alloc(total_bytes); | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size; | ||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes, | ||||||||||||||||||||||||||
| tensors[i]->rowwise_dptr(), | ||||||||||||||||||||||||||
| grouped.tensor_bytes[i], | ||||||||||||||||||||||||||
| cudaMemcpyDeviceToDevice)); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape}; | ||||||||||||||||||||||||||
| nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor)); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Copy columnwise data if available | ||||||||||||||||||||||||||
| if (has_columnwise) { | ||||||||||||||||||||||||||
| grouped.columnwise_data = cuda_alloc(total_bytes); | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size; | ||||||||||||||||||||||||||
|
|
@@ -1202,11 +1214,17 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors, | |||||||||||||||||||||||||
| nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if (isFp8Type(dtype)) { | ||||||||||||||||||||||||||
| if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { | ||||||||||||||||||||||||||
| // FP8 tensor scaling: one float scale_inv per tensor | ||||||||||||||||||||||||||
| // For delayed scaling, rowwise and columnwise share the same scale | ||||||||||||||||||||||||||
| std::vector<float> scale_inv_cpu(num_tensors, 1.f); | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| tensors[i]->to_cpu(); | ||||||||||||||||||||||||||
| scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0]; | ||||||||||||||||||||||||||
| if (has_rowwise) { | ||||||||||||||||||||||||||
| scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0]; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr<float>()[0]; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); | ||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), | ||||||||||||||||||||||||||
|
|
@@ -1217,6 +1235,66 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors, | |||||||||||||||||||||||||
| sizeof(scale_tensor)); | ||||||||||||||||||||||||||
| nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, | ||||||||||||||||||||||||||
| sizeof(scale_tensor)); | ||||||||||||||||||||||||||
| } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { | ||||||||||||||||||||||||||
| // MXFP8: E8M0 scale_inv per block of 32 elements | ||||||||||||||||||||||||||
| // Helper to gather scale_inv from individual tensors into a contiguous buffer | ||||||||||||||||||||||||||
| auto gather_scales = [&]( | ||||||||||||||||||||||||||
| auto get_shape_fn, | ||||||||||||||||||||||||||
| auto get_cpu_ptr_fn) -> std::pair<CudaPtr<>, size_t> { | ||||||||||||||||||||||||||
| // Compute total size and offsets | ||||||||||||||||||||||||||
| size_t total_bytes = 0; | ||||||||||||||||||||||||||
| std::vector<size_t> scale_offsets(num_tensors); | ||||||||||||||||||||||||||
| std::vector<size_t> numels(num_tensors); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| scale_offsets[i] = total_bytes; | ||||||||||||||||||||||||||
| const NVTEShape shape = get_shape_fn(tensors[i]); | ||||||||||||||||||||||||||
| size_t numel = 1; | ||||||||||||||||||||||||||
| for (size_t d = 0; d < shape.ndim; ++d) { | ||||||||||||||||||||||||||
| numel *= shape.data[d]; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| numels[i] = numel; | ||||||||||||||||||||||||||
| total_bytes += numel; // E8M0 is 1 byte per element | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Allocate and copy | ||||||||||||||||||||||||||
| CudaPtr<> buffer = cuda_alloc(total_bytes); | ||||||||||||||||||||||||||
| for (size_t i = 0; i < num_tensors; ++i) { | ||||||||||||||||||||||||||
| tensors[i]->to_cpu(); | ||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||||||||||||||||||||||||||
| void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i]; | ||||||||||||||||||||||||||
| const void* src = get_cpu_ptr_fn(tensors[i]); | ||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice)); | ||||||||||||||||||||||||||
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
Comment on lines
+1262
to
+1268
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant CPU sync for swizzled MXFP8 scales. The loop calls Additionally, the GPU pointers are available directly via
Suggested change
This improves both clarity and efficiency in test code. |
||||||||||||||||||||||||||
| return {std::move(buffer), total_bytes}; | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Gather rowwise scale_inv if available | ||||||||||||||||||||||||||
| if (has_rowwise) { | ||||||||||||||||||||||||||
| auto [row_buffer, row_total] = gather_scales( | ||||||||||||||||||||||||||
| [](Tensor* t) { return t->rowwise_scale_inv_shape(); }, | ||||||||||||||||||||||||||
| [](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr<uint8_t>(); }); | ||||||||||||||||||||||||||
| grouped.scale_inv = std::move(row_buffer); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| NVTEShape row_shape = nvte_make_shape(&row_total, 1); | ||||||||||||||||||||||||||
| NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape}; | ||||||||||||||||||||||||||
| nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor)); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Gather columnwise scale_inv if available | ||||||||||||||||||||||||||
| if (has_columnwise) { | ||||||||||||||||||||||||||
| auto [col_buffer, col_total] = gather_scales( | ||||||||||||||||||||||||||
| [](Tensor* t) { return t->columnwise_scale_inv_shape(); }, | ||||||||||||||||||||||||||
| [](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr<uint8_t>(); }); | ||||||||||||||||||||||||||
| grouped.columnwise_scale_inv = std::move(col_buffer); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| NVTEShape col_shape = nvte_make_shape(&col_total, 1); | ||||||||||||||||||||||||||
| NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape}; | ||||||||||||||||||||||||||
| nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor)); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Mark as having swizzled scales (required for GEMM) | ||||||||||||||||||||||||||
| nvte_set_grouped_tensor_swizzled_scales(h, 1); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return grouped; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.