[WIP] Enable MultiCastTranspose for expert weights#628
Draft
sudhu2k wants to merge 3 commits into
Draft
Conversation
This new method allows for efficient workspace management of multiple weights, enabling fused operations for delayed-scaling FP8. It enhances performance by reducing the number of quantization calls and supports caching of workspaces. The grouped_linear module has been updated to utilize this method.
…ed workspace management This update enhances the quantization process by optimizing workspace allocation and handling cache misses more effectively. It introduces a streamlined approach for both cache hits and misses, ensuring efficient in-place quantization and reducing unnecessary memory reallocations. The changes aim to improve performance during FP8 operations while maintaining compatibility with existing functionality.
This update modifies the multi_tensor_quantize function to accept an optional outputs parameter, allowing for in-place quantization when cached workspaces are provided. The changes improve memory efficiency and performance during FP8 operations by reducing unnecessary tensor reallocations.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR optimizes FP8 weight quantization for
GroupedLinearby replacing the per-weight quantize loop with a single fusedmulti_cast_transposekernel for delayed-scaling FP8. Previously each weight in a group was cast and transposed with its ownquantizecall; now the whole group is processed in one fused call, reducing kernel-launch overhead and unnecessary tensor reallocations during FP8 operations.To make this work efficiently with the existing weight-caching mechanism,
multi_tensor_quantizenow accepts an optionaloutputsargument so the fused kernel can write directly into cached workspace buffers in place rather than allocating fresh tensors on every update.Fixes #16929
Type of change
Changes
Please list the changes introduced in this PR:
TransformerEngineBaseModule.get_multi_weight_workspace, a group-aware analog ofget_weight_workspacethat casts and transposes a whole group of weights with a single fusedmulti_tensor_quantizekernel for delayed-scaling FP8, with caching support. Falls back to the per-tensorget_weight_workspacepath when fusion is not applicable (other recipes, rowwise-only usage, already-quantized weights, or CUDA-graph weight caching)._GroupedLinearto useget_multi_weight_workspaceinstead of looping overget_weight_workspaceper GEMM.multi_tensor_quantizeC++ extension (extensions.h,cast.cpp,pybind.cpp) with an optionaloutputsparameter, enabling in-place quantization into pre-allocated/cached workspace buffers and avoiding redundant tensor allocations.update_workspacereuses cached buffers via the newoutputsargument. Quantizers are temporarily forced tointernal=Falseso cached workspaces surviveprepare_for_saving.Checklist: