[None][feat] Optimize MoE DenseGEMM FC2 kernel#12686
Draft
mingyangHao wants to merge 1 commit intoNVIDIA:mainfrom
Draft
[None][feat] Optimize MoE DenseGEMM FC2 kernel#12686mingyangHao wants to merge 1 commit intoNVIDIA:mainfrom
mingyangHao wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
@coderabbitai summary
Description
Optimizes the MoE DenseGEMM FC2 CuTe DSL kernel (
fc2.py) by removing the dedicated alpha-scale pipeline and enabling 2CTA MMA configuration in the autotuner.Kernel changes (
fc2.py):PipelineCpAsyncalpha-scale pipeline, its dedicated warp (warp 6), dummy warp (warp 7), and all associated SMEM buffers. Alpha values are now loaded directly from GMEM in the epilogue with a prefetch pattern, reducing CTA thread count from 256 to 192.tiles_per_expertk-tiles (2 tiles per expert) within a single accumulator buffer before committing, halving the acc pipeline round-trips (from 512 to 256 for the target shape).num_expertsinstead ofk_tile_cnt, loading alpha once per expert instead of once per k-tile, reducing redundant GMEM alpha loads by 2x.num_acc_stagefrom 2 to 3: Allows MMA to run further ahead of the epilogue, improving pipeline overlap.(256,128) cluster=(2,1)2CTA MMA configuration viable (previously crashed due to SMEM overflow).Autotuner changes (
cute_dsl_custom_ops.py):(256, 128)tomma_tiler_mn_candidatesfor the FC2 runner.(2, 1)tocluster_shape_mn_candidatesto enable B-multicast via 2CTA clustering.Performance (CUPTI, B200, cold L2, warmup=10, iter=50):
Config
(128,128) cluster=(1,1):Config
(128,64) cluster=(1,1):Config
(256,128) cluster=(2,1)— NEW, only works after this change:The previous code crashes with illegal memory access on
(256,128) cluster=(2,1)because thePipelineCpAsyncalpha pipeline does not support 2CTA cluster barrier semantics.Test Coverage
tests/unittest/_torch/thop/parallel/test_moe_densegemm.pycover functional correctness.tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.pycovers performance benchmarking across M=32..384 with configurable tile/cluster settings.can_implementguard infc2.pyalready validates(256,128)+(2,1)as a legal tactic; no new validation code needed.PR Checklist