Conversation
- Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format. - Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels. - Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively. These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
These enhancements tests the changes introduced for unswizzling Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format. - Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes. - Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization. - Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format. - Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality. - Updated the launch function to handle multiple tensor unswizzling operations efficiently. These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds unswizzle (GEMM-swizzled → compact) support for MXFP8/NVFP4 1D scaling factors, mirroring the existing swizzle infrastructure. It introduces Key findings from this review:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant nvte_unswizzle as nvte_unswizzle_scaling_factors
participant unswizzle as unswizzle_scaling_factors
participant row_kernel as unswizzle_row_scaling_kernel
participant col_kernel as unswizzle_col_scaling_kernel
Caller->>nvte_unswizzle: (swizzled_tensor, compact_tensor, stream)
nvte_unswizzle->>unswizzle: validate scaling_mode, dtype, flags
unswizzle->>unswizzle: derive m, k from swizzled input shape
unswizzle->>unswizzle: choose rowwise_unswizzle / columnwise_unswizzle
alt rowwise_unswizzle
unswizzle->>row_kernel: launch<<<(K/tiles,M_tiles), (32,32), slm>>>
row_kernel->>row_kernel: load swizzled tiles into SLM
row_kernel->>row_kernel: regs_unshuffle (inverse of regs_shuffle)
row_kernel->>row_kernel: write compact bytes (bounds-checked)
end
alt columnwise_unswizzle
unswizzle->>col_kernel: launch<<<(K_tiles,M/tiles), (32,32), slm>>>
col_kernel->>col_kernel: load swizzled tiles into SLM
col_kernel->>col_kernel: regs_unshuffle_with_bit_shifts
col_kernel->>col_kernel: write compact bytes (bounds-checked)
end
unswizzle-->>Caller: output compact tensor populated
Last reviewed commit: d7b6d2d |
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
85ea04b to
17dbb33
Compare
for more information, see https://pre-commit.ci
…ather than casting Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
| NVTE_CHECK(static_cast<size_t>(original_M) * original_K == output->scale_inv.numel(), | ||
| "Expected output tensor to have ", static_cast<size_t>(original_M) * original_K, | ||
| " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); | ||
| break; |
There was a problem hiding this comment.
Output size validation uses unpadded dimensions instead of padded
unswizzle_scaling_factors validates the rowwise output size with original_M * original_K, where original_M = output->flat_first_dim() is the actual data-matrix row count—not the padded scale row count m. This check will incorrectly reject a perfectly valid compact output tensor whenever M is not already a multiple of 128.
For example, with a matrix of shape [200, 4096] (M=200, K=4096):
m = ceil(200/128)*128 = 256(required by the swizzle padding constraint)original_M = 200,original_K = 128output->scale_inv.numel() = 256 * 128 = 32768(padded compact tensor)- But this check would require
200 * 128 = 25600— and fail.
The equivalent check in swizzle_scaling_factors correctly uses m * k (see line 672-673). The corresponding check in multi_tensor_unswizzle_scaling_factors also uses m * k (line 1463), making this single-tensor path the outlier.
| NVTE_CHECK(static_cast<size_t>(original_M) * original_K == output->scale_inv.numel(), | |
| "Expected output tensor to have ", static_cast<size_t>(original_M) * original_K, | |
| " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); | |
| break; | |
| NVTE_CHECK(static_cast<size_t>(m) * k == output->scale_inv.numel(), | |
| "Expected output tensor to have ", static_cast<size_t>(m) * k, | |
| " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); |
| NVTE_CHECK(static_cast<size_t>(original_M) * original_K == output->columnwise_scale_inv.numel(), | ||
| "Expected output tensor to have ", static_cast<size_t>(original_M) * original_K, | ||
| " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, | ||
| "."); |
There was a problem hiding this comment.
Same output-size validation bug in the columnwise unswizzle path
The same issue as in the rowwise path above: the check uses original_M * original_K (where original_M = output->flat_last_dim()) instead of the padded scale dimensions m * k. For any tensor where the column-wise scale M dimension is not already a multiple of 128, this will incorrectly fail even though the output buffer is correctly sized.
The equivalent check in multi_tensor_unswizzle_scaling_factors (line 1522) correctly accumulates the output shape, and swizzle_scaling_factors (line 678-681) uses m * k. This path should follow the same pattern:
| NVTE_CHECK(static_cast<size_t>(original_M) * original_K == output->columnwise_scale_inv.numel(), | |
| "Expected output tensor to have ", static_cast<size_t>(original_M) * original_K, | |
| " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, | |
| "."); | |
| NVTE_CHECK(static_cast<size_t>(m) * k == output->columnwise_scale_inv.numel(), | |
| "Expected output tensor to have ", static_cast<size_t>(m) * k, | |
| " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, | |
| "."); |
| void multi_tensor_unswizzle_scaling_factors(const std::vector<Tensor*>& input, | ||
| const std::vector<Tensor*>& output, | ||
| cudaStream_t stream) { |
There was a problem hiding this comment.
Inconsistent const qualifier on output parameter
The output parameter here is const std::vector<Tensor*>&, but the analogous multi_tensor_swizzle_scaling_factors (line 990) uses std::vector<Tensor*>& (non-const). While adding const to the vector reference doesn't prevent mutating the pointed-to Tensor objects—so it has no effect on correctness—the inconsistency is surprising to callers and departs from the established pattern in this file. Consider aligning the signatures:
| void multi_tensor_unswizzle_scaling_factors(const std::vector<Tensor*>& input, | |
| const std::vector<Tensor*>& output, | |
| cudaStream_t stream) { | |
| void multi_tensor_unswizzle_scaling_factors(const std::vector<Tensor*>& input, | |
| std::vector<Tensor*>& output, | |
| cudaStream_t stream) { |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR adds unswizzle support for scaling factors and extends the swizzle module so scaling tensors can be converted from GEMM-swizzled layout back to compact layout, including multi-tensor paths. It also adds round-trip and standalone tests to validate unswizzle correctness.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_engine/common/swizzle/swizzle.cuand declarations intransformer_engine/common/include/transformer_engine/swizzle.htests/cpp/operator/test_swizzle.cu, including standalone unswizzle and swizzle→unswizzle round-trip coverageChecklist: