Skip to content

[None][feat] Optimize MoE DenseGEMM FC2 kernel#12686

Draft
mingyangHao wants to merge 1 commit intoNVIDIA:mainfrom
mingyangHao:opt-fc2
Draft

[None][feat] Optimize MoE DenseGEMM FC2 kernel#12686
mingyangHao wants to merge 1 commit intoNVIDIA:mainfrom
mingyangHao:opt-fc2

Conversation

@mingyangHao
Copy link
Copy Markdown

@coderabbitai summary

Description

Optimizes the MoE DenseGEMM FC2 CuTe DSL kernel (fc2.py) by removing the dedicated alpha-scale pipeline and enabling 2CTA MMA configuration in the autotuner.

Kernel changes (fc2.py):

  1. Remove alpha-scale pipeline: Eliminates the 10-stage PipelineCpAsync alpha-scale pipeline, its dedicated warp (warp 6), dummy warp (warp 7), and all associated SMEM buffers. Alpha values are now loaded directly from GMEM in the epilogue with a prefetch pattern, reducing CTA thread count from 256 to 192.
  2. Expert-grouped MMA accumulation: The MMA mainloop now accumulates tiles_per_expert k-tiles (2 tiles per expert) within a single accumulator buffer before committing, halving the acc pipeline round-trips (from 512 to 256 for the target shape).
  3. Expert-iterated epilogue: The epilogue loop iterates over num_experts instead of k_tile_cnt, loading alpha once per expert instead of once per k-tile, reducing redundant GMEM alpha loads by 2x.
  4. Increase num_acc_stage from 2 to 3: Allows MMA to run further ahead of the epilogue, improving pipeline overlap.
  5. Free SMEM budget: Removing the alpha pipeline frees ~5KB SMEM, enabling more A/B pipeline stages and — critically — making the (256,128) cluster=(2,1) 2CTA MMA configuration viable (previously crashed due to SMEM overflow).

Autotuner changes (cute_dsl_custom_ops.py):

  • Add (256, 128) to mma_tiler_mn_candidates for the FC2 runner.
  • Add (2, 1) to cluster_shape_mn_candidates to enable B-multicast via 2CTA clustering.
  • This enables the 2CTA MMA + B-multicast config, which is the optimal configuration for M≥128 in the FC2 shape (A[M,65536] × B[7168,65536]).

Performance (CUPTI, B200, cold L2, warmup=10, iter=50):

Config (128,128) cluster=(1,1):

M Before (us) After (us) Speedup
128 76.54 75.94 1.01x
256 80.59 76.48 1.05x
288 151.40 138.26 1.09x
384 152.01 142.82 1.06x

Config (128,64) cluster=(1,1):

M Before (us) After (us) Speedup
128 65.80 62.63 1.05x
192 127.81 111.89 1.14x
288 190.41 162.79 1.17x
384 190.97 169.95 1.12x

Config (256,128) cluster=(2,1) — NEW, only works after this change:

M After (us) TFLOPS
128 70.22 1712.7
256 72.43 3320.5
288 133.01 2034.2
384 135.64 2659.9

The previous code crashes with illegal memory access on (256,128) cluster=(2,1) because the PipelineCpAsync alpha pipeline does not support 2CTA cluster barrier semantics.

Test Coverage

  • Existing unit tests in tests/unittest/_torch/thop/parallel/test_moe_densegemm.py cover functional correctness.
  • Existing test script tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py covers performance benchmarking across M=32..384 with configurable tile/cluster settings.
  • The can_implement guard in fc2.py already validates (256,128)+(2,1) as a legal tactic; no new validation code needed.

PR Checklist

  • Please check this after reviewing the above items as appropriate for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant