Skip to content

[#11932][fix] Filter CUTLASS MoE GEMM tile configs by device shared memory on SM121#12704

Open
mihai-chiorean wants to merge 2 commits intoNVIDIA:mainfrom
mihai-chiorean:fix/cutlass-moe-smem-sm121
Open

[#11932][fix] Filter CUTLASS MoE GEMM tile configs by device shared memory on SM121#12704
mihai-chiorean wants to merge 2 commits intoNVIDIA:mainfrom
mihai-chiorean:fix/cutlass-moe-smem-sm121

Conversation

@mihai-chiorean
Copy link
Copy Markdown
Contributor

@mihai-chiorean mihai-chiorean commented Apr 2, 2026

Summary

SM121 (DGX Spark GB10) has 99 KiB shared memory per block vs 228 KiB on SM120 (B200). CUTLASS StageCountAutoCarveout computes pipeline stages assuming 228 KiB, causing TMA warp-specialized MoE grouped GEMM tactics to fail with opaque "Error Internal" on gemm.initialize().

This patch adds:

  1. Runtime SMEM guard in moe_gemm_tma_ws_launcher.inl: checks kernel SharedStorage size against cudaDevAttrMaxSharedMemoryPerBlockOptin before launch, converting opaque CUTLASS errors into clear diagnostics that the autotuner can skip.

  2. Heuristic filter in get_candidate_configs_sm120(): on devices with < 120 KiB SMEM, removes tile configs whose SharedStorage exceeds the device limit. Keeps only CtaShape128x128x64B which fits within 99 KiB including FINALIZE epilogue overhead.

Impact

  • Without this fix: 13/16 MoE GEMM tactics fail on SM121, falling back to a slow tactic
  • With this fix: 0 failed tactics, autotuner selects optimal configs for 99 KiB SMEM
  • In our testing on DGX Spark, we have seen significantly better MoE performance with TRTLLM_MOE_BACKEND=TRITON (32-40 tok/s vs ~4.8 tok/s with CUTLASS), as the Triton backend JIT-compiles kernels adapted to SM121 constraints. This fix ensures CUTLASS does not crash when used, but TRITON may be the better path for SM121 MoE workloads.

Test plan

  • Verified on DGX Spark GB10 (SM121, 128GB UMA) with Qwen3-30B-A3B-NVFP4
  • Autotuner reports 0 failed MoE GEMM tactics (was 13)
  • Runtime SMEM guard produces clear diagnostic: "requires N bytes shared memory but device supports M"

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

📝 Walkthrough

Walkthrough

Two runtime shared memory validation checks were added: one filters CUTLASS GEMM candidate configurations when max shared memory is below 120KB, and another validates kernel shared memory requirements against device limits before kernel launch.

Changes

Cohort / File(s) Summary
CUTLASS Config Filtering
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
Added <algorithm> include and runtime device query for cudaDevAttrMaxSharedMemoryPerBlockOptin in get_candidate_configs_sm120. When max shared memory is below 120KB, candidate configs with CtaShape128x128x128B tile are filtered out, retaining only 64B K-tile variants.
MOE GEMM Launcher Validation
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl
Added pre-launch validation check verifying kernel's required shared memory (sizeof(GemmKernel::SharedStorage)) does not exceed device's maximum shared memory per block with opt-in capability. Triggered before gemm.can_implement()/gemm.initialize() calls.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description provides a clear problem statement, technical solution with two specific changes, and test results, but is missing a formal PR title, Test Coverage section, and PR Checklist per the template. Add PR title following template format [ticket][type], include Test Coverage section listing relevant tests, and complete the PR Checklist with checkbox acknowledgment.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: filtering CUTLASS MoE GEMM tile configs based on device shared memory constraints for SM121, which directly addresses the core issue described in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Around line 561-568: The SM<120 KiB pruning only removes CtaShape128x128x128B
but leaves 64B K-tile configs that the FP4 dispatcher
(dispatchNVFP4xNVFP4GemmCTAShapeSm120) doesn't handle, causing an invalid-config
throw; update the pruning logic around candidate_configs and
CutlassGemmConfig::tile_config_sm120 to either remove the additional 64B shapes
(CtaShape128x128x64B, CtaShape128x256x64B, CtaShape256x128x64B) or, better,
whitelist only the shapes the dispatcher supports (CtaShape128x128x128B,
CtaShape128x128x256B, CtaShape256x128x128B) so candidate_configs contains only
dispatchable CutlassTileConfigSM120 values and the autotuner no longer hits the
default invalid-config path.
- Around line 554-559: The static kMaxSmem initializer currently queries device
0; change it to call cudaGetDevice() to obtain the current device and pass that
device ID into cudaDeviceGetAttribute(), adding explicit error checks for both
cudaGetDevice and cudaDeviceGetAttribute and falling back or logging on error.
Replace the literal 120 * 1024 with a named constant (e.g., kSmemThresholdBytes)
and use that constant where the literal appears. Declare the iterator variable
named it as const (e.g., const auto it) where it is defined. Also update the
file copyright year to 2026.

In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl`:
- Around line 673-682: Replace the current function-static kMaxSmem that calls
cudaDeviceGetAttribute(..., 0) with a runtime query that calls cudaGetDevice()
to obtain the current device id and then calls cudaDeviceGetAttribute(&val,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device) so the shared-memory check for
GemmGrouped/GemmKernel_ uses the active GPU; remove the lambda-cached static and
compute smem limit per invocation and keep the TLLM_CHECK_WITH_INFO(smem_size <=
kMaxSmem, ...) guard unchanged except that kMaxSmem is now the per-call value.
Also add the same per-device SMEM preflight guard (compute current device via
cudaGetDevice and compare sizeof(typename GemmKernel_::SharedStorage) against
that device's limit) into moe_gemm_tma_ws_mixed_input_launcher.inl before
calling gemm.can_implement().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ee34cc95-0326-4ec6-978d-b97646526a47

📥 Commits

Reviewing files that changed from the base of the PR and between 11c40bb and 93b9c60.

📒 Files selected for processing (2)
  • cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl

Comment on lines +554 to +559
static int const kMaxSmem = []()
{
int val = 0;
cudaDeviceGetAttribute(&val, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0);
return val;
}();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '550,570p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1211


🏁 Script executed:

cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '1,50p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1819


🏁 Script executed:

# Get line count first
wc -l cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 129


🏁 Script executed:

# Check context around the lambda - get more lines to understand the function
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '500,580p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 4332


🏁 Script executed:

# Check the copyright year in the file
head -5 cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 273


🏁 Script executed:

# Look for uses of cudaGetDevice or device context management
rg "cudaGetDevice|cudaSetDevice|getDevice" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for any comments or code about device detection in this file
rg "device|Device" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 220


🏁 Script executed:

# Check if there are any other static cached CUDA attributes in the file
rg "static.*const|static.*=.*\[\]" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 149


🏁 Script executed:

# Check if there's a way to get current device in this codebase
rg "getCurrentDevice|current.*device" cpp/tensorrt_llm --type cpp --type hpp | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 90


🏁 Script executed:

# Check how other parts of cutlass_heuristic.cpp handle device selection
grep -n "cudaDevice\|device\|Device" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 228


🏁 Script executed:

# Check the git log to see when this file was last modified
git log --oneline -n 20 -- cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 155


🏁 Script executed:

# Check the function signature and see where/how it's called
grep -n "get_candidate_configs_sm120" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 250


🏁 Script executed:

# Look for callers of this function
rg "get_candidate_configs_sm120" cpp/tensorrt_llm --type cpp -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 1387


🏁 Script executed:

# Check the header file to see the function declaration
cat cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h | grep -A 10 "get_candidate_configs_sm120"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check what file contains the declaration/definition
find cpp -name "*.h" -o -name "*.hpp" | xargs grep -l "get_candidate_configs_sm120" 2>/dev/null

Repository: NVIDIA/TensorRT-LLM

Length of output: 47


🏁 Script executed:

# Look at the context where get_candidate_configs_sm120 is called
sed -n '600,630p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 1230


🏁 Script executed:

# Check for cudaGetDevice calls to see if there's device context management
rg "cudaGetDevice|cudaSetDevice" cpp/tensorrt_llm --type-list all 2>/dev/null | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there are patterns for getting the current device
rg "cudaGetDevice\(\s*&" cpp/tensorrt_llm --type cpp | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 1210


🏁 Script executed:

# Look at the full select_kernels function to understand the context better
sed -n '595,625p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 1283


🏁 Script executed:

# Check if there's error handling for CUDA calls elsewhere in this file
rg "TLLM_CHECK|TLLM_THROW|cudaError" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 477


🏁 Script executed:

# Check the copyright year requirement - when was this file last meaningfully modified
git log -1 --format="%ai" -- cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 88


🏁 Script executed:

# Check if there's a way to pass device context through the call chain
sed -n '590,630p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 1839


🏁 Script executed:

# Verify the signature of select_kernels or the entry point to understand device context
rg "select_kernels\|get_candidate_configs\(" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp -B 2 -A 5 | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if cudaDeviceGetAttribute needs error handling
rg "cudaDeviceGetAttribute" cpp/tensorrt_llm --type cpp -B 2 -A 2 | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1924


🏁 Script executed:

# Check if get_candidate_configs is ever called with device context info available
rg "get_candidate_configs\(" cpp/tensorrt_llm --type cpp -B 5 -A 2 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 6940


🏁 Script executed:

# Look at how other code paths handle SMEM queries
rg "MaxSharedMemory|maxSharedMemory" cpp/tensorrt_llm --type cpp -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 20419


🏁 Script executed:

# Check the copyright requirement - what should the year be
git log --all --oneline -- cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 425


🏁 Script executed:

# Check the current copyright year in the file
head -3 cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 138


🏁 Script executed:

# Check if 120 * 1024 literal is used elsewhere and how similar constants are named
rg "1024\)|kMaxSmem|kSmem" cpp/tensorrt_llm --type cpp -B 1 -A 1 | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 4764


🏁 Script executed:

# Verify the const issue with iterator
sed -n '565,568p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 404


🏁 Script executed:

# Check the pattern in other similar SMEM query code - how do they handle current device
cat cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h | grep -A 10 "cudaGetDevice"

Repository: NVIDIA/TensorRT-LLM

Length of output: 612


Query the current CUDA device instead of hardcoding device 0.

Line 554–559 caches the SMEM limit in a process-wide static variable using device 0. On multi-GPU hosts, tuning or running on any non-zero GPU will use the wrong capacity threshold, either keeping invalid tiles or pruning valid ones. Call cudaGetDevice() and pass the current device to cudaDeviceGetAttribute(), with explicit error handling.

Additionally, extract the literal 120 * 1024 on line 561 to a named constant (e.g., kSmemThresholdBytes), declare the iterator it on line 565 as const, and update the copyright year to 2026.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` around lines
554 - 559, The static kMaxSmem initializer currently queries device 0; change it
to call cudaGetDevice() to obtain the current device and pass that device ID
into cudaDeviceGetAttribute(), adding explicit error checks for both
cudaGetDevice and cudaDeviceGetAttribute and falling back or logging on error.
Replace the literal 120 * 1024 with a named constant (e.g., kSmemThresholdBytes)
and use that constant where the literal appears. Declare the iterator variable
named it as const (e.g., const auto it) where it is defined. Also update the
file copyright year to 2026.

Comment on lines +561 to +568
if (kMaxSmem < 120 * 1024)
{
// Remove 128B K-tile configs that exceed 99 KiB at typical stage counts.
// Keep only 64B K-tile configs which fit within 99 KiB.
auto it = std::remove_if(candidate_configs.begin(), candidate_configs.end(),
[](CutlassGemmConfig const& config)
{ return config.tile_config_sm120 == CutlassTileConfigSM120::CtaShape128x128x128B; });
candidate_configs.erase(it, candidate_configs.end());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file and understand the context around lines 561-568
wc -l cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 129


🏁 Script executed:

# Read the relevant section with more context
sed -n '540,600p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 3159


🏁 Script executed:

# Search for SM121 related configuration constraints and allowlists
rg -A 10 -B 5 'SM121' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 838


🏁 Script executed:

# Search for CtaShape128x256x64B and CtaShape256x128x64B to see where they're used
rg '(CtaShape128x256x64B|CtaShape256x128x64B|CtaShape128x128x64B)' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 908


🏁 Script executed:

# Search for launcher code or rejection logic related to these tile configs
rg -B 5 -A 5 'CtaShape128x256x64B|CtaShape256x128x64B' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 1921


🏁 Script executed:

# Search for "failed" or rejection related to tactics/configs
rg -i 'failed.*tactic|reject.*config|smem.*exceed' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 14886


🏁 Script executed:

# Look for any comments or code about SM121 or GB10 constraints
rg -i 'sm121|gb10|99.*kib|99.*1024' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 7298


🏁 Script executed:

# Check if there's a cutlass launcher or groupgemm launcher that validates configs
fd -e cpp -e h 'launcher|gemm' cpp/tensorrt_llm/kernels/cutlass_kernels/ | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1770


🏁 Script executed:

# Look at the groupgemm launcher for SM120/SM121 to see validation logic
fd -e h -e inl 'group.*gemm.*launcher|gemm.*group.*launcher' cpp/tensorrt_llm/kernels/cutlass_kernels/

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for the SM120 group GEMM launcher implementation
rg -l 'group.*gemm.*sm120' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for the actual kernel implementation files that would have the shared memory calculations
rg -B 10 -A 10 'CtaShape128x256x64B.*SharedStorage|CtaShape256x128x64B.*SharedStorage' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check if there's a GroupGemm launcher for SM120
rg -l 'group_gemm.*sm120|GroupGemm.*SM120' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look for where SM120 group GEMM launcher code is
find cpp/tensorrt_llm/kernels/cutlass_kernels -name "*group*gemm*" -o -name "*sm120*group*"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search more broadly for group gemm and SM120
rg 'group.*gemm|GroupGemm' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h -l | grep -i sm120

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for the launcher that uses these configs
rg 'candidate_configs.*fp4.*fp4|FP4.*FP4' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h -B 5 -A 10

Repository: NVIDIA/TensorRT-LLM

Length of output: 49321


🏁 Script executed:

# Look for the get_candidate_configs function and see where it's called
rg 'get_candidate_configs.*sm120|get_candidate_configs.*fp4' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h -A 5 -B 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 2166


🏁 Script executed:

# Search for the entire cutlass_heuristic.cpp file to see the full context
wc -l cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 129


🏁 Script executed:

# Read around the get_candidate_configs function that returns these configs
sed -n '500,630p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 5826


🏁 Script executed:

# Look for autotuner or launcher code that would process these configs
rg -l 'autotuner|launcher' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 496


🏁 Script executed:

# Look at the complete switch statement in dispatchNVFP4xNVFP4GemmCTAShapeSm120
sed -n '200,350p' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h | grep -A 100 'dispatchNVFP4xNVFP4GemmCTAShapeSm120'

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Get the exact cases handled in the tile config switch for SM120
rg -A 50 'dispatchNVFP4xNVFP4GemmCTAShapeSm120.*T.*D.*A.*B' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 3967


🏁 Script executed:

# Search for all CtaShape cases handled in SM120
rg 'CtaShape.*SM120.*:$' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Get a better view of the switch cases
rg 'case.*CtaShape.*SM120' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


The low-SMEM filter for SM121 is incomplete.

The <120 KiB filter removes only CtaShape128x128x128B, leaving CtaShape128x128x64B, CtaShape128x256x64B, and CtaShape256x128x64B as candidates. However, the FP4 GEMM dispatcher (dispatchNVFP4xNVFP4GemmCTAShapeSm120) only handles CtaShape128x128x128B, CtaShape128x128x256B, and CtaShape256x128x128B. The CtaShape128x256x64B and CtaShape256x128x64B configs will hit the default case and throw "Config is invalid for FP4 GEMM", preventing the autotuner from reaching a "0 failed tactics" state on SM121.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` around lines
561 - 568, The SM<120 KiB pruning only removes CtaShape128x128x128B but leaves
64B K-tile configs that the FP4 dispatcher
(dispatchNVFP4xNVFP4GemmCTAShapeSm120) doesn't handle, causing an invalid-config
throw; update the pruning logic around candidate_configs and
CutlassGemmConfig::tile_config_sm120 to either remove the additional 64B shapes
(CtaShape128x128x64B, CtaShape128x256x64B, CtaShape256x128x64B) or, better,
whitelist only the shapes the dispatcher supports (CtaShape128x128x128B,
CtaShape128x128x256B, CtaShape256x128x128B) so candidate_configs contains only
dispatchable CutlassTileConfigSM120 values and the autotuner no longer hits the
default invalid-config path.

Comment on lines +673 to +682
using GemmKernel_ = typename GemmGrouped::GemmKernel; \
int smem_size = static_cast<int>(sizeof(typename GemmKernel_::SharedStorage)); \
static int const kMaxSmem = []() \
{ \
int val = 0; \
cudaDeviceGetAttribute(&val, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0); \
return val; \
}(); \
TLLM_CHECK_WITH_INFO(smem_size <= kMaxSmem, \
"MoE grouped GEMM requires %d bytes shared memory but device supports %d", smem_size, kMaxSmem); \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "moe_gemm_tma_ws_launcher.inl" -o -name "fused_moe_gemm_launcher_sm80.inl" -o -name "moe_gemm_tma_ws_mixed_input_launcher.inl"

Repository: NVIDIA/TensorRT-LLM

Length of output: 351


🏁 Script executed:

sed -n '673,682p' ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl

Repository: NVIDIA/TensorRT-LLM

Length of output: 3242


🏁 Script executed:

sed -n '53,68p' ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl

Repository: NVIDIA/TensorRT-LLM

Length of output: 1033


🏁 Script executed:

grep -n "cudaDeviceGetAttribute\|cudaDevAttrMaxSharedMemoryPerBlockOptin\|kMaxSmem" ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check the context around lines where such a guard would normally appear
# Look for where gemm.can_implement() or similar checks occur
grep -n "can_implement\|initialize()" ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 248


This guard still reads the wrong GPU on multi-device hosts.

The hardcoded device 0 in cudaDeviceGetAttribute() and function-static caching of kMaxSmem means a launch on a non-zero device can compare its kernel requirements against the wrong GPU's SMEM limit. This allows unsupported kernels to reach gemm.initialize() or valid kernels to be incorrectly rejected. Use the pattern from fused_moe_gemm_launcher_sm80.inl:53–68, which calls cudaGetDevice() to query the current device's actual limit.

Additionally, moe_gemm_tma_ws_mixed_input_launcher.inl also lacks an equivalent SMEM preflight guard before its gemm.can_implement() call and may encounter the same issue.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl`
around lines 673 - 682, Replace the current function-static kMaxSmem that calls
cudaDeviceGetAttribute(..., 0) with a runtime query that calls cudaGetDevice()
to obtain the current device id and then calls cudaDeviceGetAttribute(&val,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device) so the shared-memory check for
GemmGrouped/GemmKernel_ uses the active GPU; remove the lambda-cached static and
compute smem limit per invocation and keep the TLLM_CHECK_WITH_INFO(smem_size <=
kMaxSmem, ...) guard unchanged except that kMaxSmem is now the per-call value.
Also add the same per-device SMEM preflight guard (compute current device via
cudaGetDevice and compare sizeof(typename GemmKernel_::SharedStorage) against
that device's limit) into moe_gemm_tma_ws_mixed_input_launcher.inl before
calling gemm.can_implement().

@mihai-chiorean mihai-chiorean force-pushed the fix/cutlass-moe-smem-sm121 branch from 93b9c60 to 556d505 Compare April 2, 2026 19:03
…ared memory on SM121

SM121 (DGX Spark GB10) has 99 KiB shared memory per block vs 228 KiB
on SM120 (B200). CUTLASS StageCountAutoCarveout computes pipeline stages
assuming 228 KiB, causing TMA warp-specialized MoE grouped GEMM tactics
to fail with opaque "Error Internal" on gemm.initialize().

This patch adds two fixes:

1. Runtime SMEM guard in moe_gemm_tma_ws_launcher.inl: checks kernel
   SharedStorage size against cudaDevAttrMaxSharedMemoryPerBlockOptin
   before launch, converting opaque CUTLASS errors into clear
   diagnostics that the autotuner can skip.

2. Heuristic filter in get_candidate_configs_sm120(): on devices with
   < 120 KiB SMEM, removes tile configs whose SharedStorage exceeds
   the device limit. Keeps only CtaShape128x128x64B which fits within
   99 KiB including FINALIZE epilogue overhead (~80 KiB total).

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
@mihai-chiorean mihai-chiorean force-pushed the fix/cutlass-moe-smem-sm121 branch from 556d505 to b1f4372 Compare April 2, 2026 19:10
@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 2, 2026
@mihai-chiorean mihai-chiorean marked this pull request as ready for review April 3, 2026 04:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants