[#11932][fix] Filter CUTLASS MoE GEMM tile configs by device shared memory on SM121#12704
[#11932][fix] Filter CUTLASS MoE GEMM tile configs by device shared memory on SM121#12704mihai-chiorean wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughTwo runtime shared memory validation checks were added: one filters CUTLASS GEMM candidate configurations when max shared memory is below 120KB, and another validates kernel shared memory requirements against device limits before kernel launch. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Around line 561-568: The SM<120 KiB pruning only removes CtaShape128x128x128B
but leaves 64B K-tile configs that the FP4 dispatcher
(dispatchNVFP4xNVFP4GemmCTAShapeSm120) doesn't handle, causing an invalid-config
throw; update the pruning logic around candidate_configs and
CutlassGemmConfig::tile_config_sm120 to either remove the additional 64B shapes
(CtaShape128x128x64B, CtaShape128x256x64B, CtaShape256x128x64B) or, better,
whitelist only the shapes the dispatcher supports (CtaShape128x128x128B,
CtaShape128x128x256B, CtaShape256x128x128B) so candidate_configs contains only
dispatchable CutlassTileConfigSM120 values and the autotuner no longer hits the
default invalid-config path.
- Around line 554-559: The static kMaxSmem initializer currently queries device
0; change it to call cudaGetDevice() to obtain the current device and pass that
device ID into cudaDeviceGetAttribute(), adding explicit error checks for both
cudaGetDevice and cudaDeviceGetAttribute and falling back or logging on error.
Replace the literal 120 * 1024 with a named constant (e.g., kSmemThresholdBytes)
and use that constant where the literal appears. Declare the iterator variable
named it as const (e.g., const auto it) where it is defined. Also update the
file copyright year to 2026.
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl`:
- Around line 673-682: Replace the current function-static kMaxSmem that calls
cudaDeviceGetAttribute(..., 0) with a runtime query that calls cudaGetDevice()
to obtain the current device id and then calls cudaDeviceGetAttribute(&val,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device) so the shared-memory check for
GemmGrouped/GemmKernel_ uses the active GPU; remove the lambda-cached static and
compute smem limit per invocation and keep the TLLM_CHECK_WITH_INFO(smem_size <=
kMaxSmem, ...) guard unchanged except that kMaxSmem is now the per-call value.
Also add the same per-device SMEM preflight guard (compute current device via
cudaGetDevice and compare sizeof(typename GemmKernel_::SharedStorage) against
that device's limit) into moe_gemm_tma_ws_mixed_input_launcher.inl before
calling gemm.can_implement().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ee34cc95-0326-4ec6-978d-b97646526a47
📒 Files selected for processing (2)
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl
| static int const kMaxSmem = []() | ||
| { | ||
| int val = 0; | ||
| cudaDeviceGetAttribute(&val, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0); | ||
| return val; | ||
| }(); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '550,570p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1211
🏁 Script executed:
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '1,50p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1819
🏁 Script executed:
# Get line count first
wc -l cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 129
🏁 Script executed:
# Check context around the lambda - get more lines to understand the function
cat -n cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '500,580p'Repository: NVIDIA/TensorRT-LLM
Length of output: 4332
🏁 Script executed:
# Check the copyright year in the file
head -5 cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 273
🏁 Script executed:
# Look for uses of cudaGetDevice or device context management
rg "cudaGetDevice|cudaSetDevice|getDevice" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search for any comments or code about device detection in this file
rg "device|Device" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 220
🏁 Script executed:
# Check if there are any other static cached CUDA attributes in the file
rg "static.*const|static.*=.*\[\]" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 149
🏁 Script executed:
# Check if there's a way to get current device in this codebase
rg "getCurrentDevice|current.*device" cpp/tensorrt_llm --type cpp --type hpp | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 90
🏁 Script executed:
# Check how other parts of cutlass_heuristic.cpp handle device selection
grep -n "cudaDevice\|device\|Device" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 228
🏁 Script executed:
# Check the git log to see when this file was last modified
git log --oneline -n 20 -- cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 155
🏁 Script executed:
# Check the function signature and see where/how it's called
grep -n "get_candidate_configs_sm120" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 250
🏁 Script executed:
# Look for callers of this function
rg "get_candidate_configs_sm120" cpp/tensorrt_llm --type cpp -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 1387
🏁 Script executed:
# Check the header file to see the function declaration
cat cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h | grep -A 10 "get_candidate_configs_sm120"Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check what file contains the declaration/definition
find cpp -name "*.h" -o -name "*.hpp" | xargs grep -l "get_candidate_configs_sm120" 2>/dev/nullRepository: NVIDIA/TensorRT-LLM
Length of output: 47
🏁 Script executed:
# Look at the context where get_candidate_configs_sm120 is called
sed -n '600,630p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1230
🏁 Script executed:
# Check for cudaGetDevice calls to see if there's device context management
rg "cudaGetDevice|cudaSetDevice" cpp/tensorrt_llm --type-list all 2>/dev/null | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check if there are patterns for getting the current device
rg "cudaGetDevice\(\s*&" cpp/tensorrt_llm --type cpp | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 1210
🏁 Script executed:
# Look at the full select_kernels function to understand the context better
sed -n '595,625p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1283
🏁 Script executed:
# Check if there's error handling for CUDA calls elsewhere in this file
rg "TLLM_CHECK|TLLM_THROW|cudaError" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 477
🏁 Script executed:
# Check the copyright year requirement - when was this file last meaningfully modified
git log -1 --format="%ai" -- cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 88
🏁 Script executed:
# Check if there's a way to pass device context through the call chain
sed -n '590,630p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1839
🏁 Script executed:
# Verify the signature of select_kernels or the entry point to understand device context
rg "select_kernels\|get_candidate_configs\(" cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp -B 2 -A 5 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check if cudaDeviceGetAttribute needs error handling
rg "cudaDeviceGetAttribute" cpp/tensorrt_llm --type cpp -B 2 -A 2 | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1924
🏁 Script executed:
# Check if get_candidate_configs is ever called with device context info available
rg "get_candidate_configs\(" cpp/tensorrt_llm --type cpp -B 5 -A 2 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 6940
🏁 Script executed:
# Look at how other code paths handle SMEM queries
rg "MaxSharedMemory|maxSharedMemory" cpp/tensorrt_llm --type cpp -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 20419
🏁 Script executed:
# Check the copyright requirement - what should the year be
git log --all --oneline -- cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 425
🏁 Script executed:
# Check the current copyright year in the file
head -3 cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 138
🏁 Script executed:
# Check if 120 * 1024 literal is used elsewhere and how similar constants are named
rg "1024\)|kMaxSmem|kSmem" cpp/tensorrt_llm --type cpp -B 1 -A 1 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 4764
🏁 Script executed:
# Verify the const issue with iterator
sed -n '565,568p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 404
🏁 Script executed:
# Check the pattern in other similar SMEM query code - how do they handle current device
cat cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h | grep -A 10 "cudaGetDevice"Repository: NVIDIA/TensorRT-LLM
Length of output: 612
Query the current CUDA device instead of hardcoding device 0.
Line 554–559 caches the SMEM limit in a process-wide static variable using device 0. On multi-GPU hosts, tuning or running on any non-zero GPU will use the wrong capacity threshold, either keeping invalid tiles or pruning valid ones. Call cudaGetDevice() and pass the current device to cudaDeviceGetAttribute(), with explicit error handling.
Additionally, extract the literal 120 * 1024 on line 561 to a named constant (e.g., kSmemThresholdBytes), declare the iterator it on line 565 as const, and update the copyright year to 2026.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` around lines
554 - 559, The static kMaxSmem initializer currently queries device 0; change it
to call cudaGetDevice() to obtain the current device and pass that device ID
into cudaDeviceGetAttribute(), adding explicit error checks for both
cudaGetDevice and cudaDeviceGetAttribute and falling back or logging on error.
Replace the literal 120 * 1024 with a named constant (e.g., kSmemThresholdBytes)
and use that constant where the literal appears. Declare the iterator variable
named it as const (e.g., const auto it) where it is defined. Also update the
file copyright year to 2026.
| if (kMaxSmem < 120 * 1024) | ||
| { | ||
| // Remove 128B K-tile configs that exceed 99 KiB at typical stage counts. | ||
| // Keep only 64B K-tile configs which fit within 99 KiB. | ||
| auto it = std::remove_if(candidate_configs.begin(), candidate_configs.end(), | ||
| [](CutlassGemmConfig const& config) | ||
| { return config.tile_config_sm120 == CutlassTileConfigSM120::CtaShape128x128x128B; }); | ||
| candidate_configs.erase(it, candidate_configs.end()); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file and understand the context around lines 561-568
wc -l cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 129
🏁 Script executed:
# Read the relevant section with more context
sed -n '540,600p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 3159
🏁 Script executed:
# Search for SM121 related configuration constraints and allowlists
rg -A 10 -B 5 'SM121' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 838
🏁 Script executed:
# Search for CtaShape128x256x64B and CtaShape256x128x64B to see where they're used
rg '(CtaShape128x256x64B|CtaShape256x128x64B|CtaShape128x128x64B)' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 908
🏁 Script executed:
# Search for launcher code or rejection logic related to these tile configs
rg -B 5 -A 5 'CtaShape128x256x64B|CtaShape256x128x64B' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1921
🏁 Script executed:
# Search for "failed" or rejection related to tactics/configs
rg -i 'failed.*tactic|reject.*config|smem.*exceed' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 14886
🏁 Script executed:
# Look for any comments or code about SM121 or GB10 constraints
rg -i 'sm121|gb10|99.*kib|99.*1024' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp -B 5 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 7298
🏁 Script executed:
# Check if there's a cutlass launcher or groupgemm launcher that validates configs
fd -e cpp -e h 'launcher|gemm' cpp/tensorrt_llm/kernels/cutlass_kernels/ | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1770
🏁 Script executed:
# Look at the groupgemm launcher for SM120/SM121 to see validation logic
fd -e h -e inl 'group.*gemm.*launcher|gemm.*group.*launcher' cpp/tensorrt_llm/kernels/cutlass_kernels/Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search for the SM120 group GEMM launcher implementation
rg -l 'group.*gemm.*sm120' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type hRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look for the actual kernel implementation files that would have the shared memory calculations
rg -B 10 -A 10 'CtaShape128x256x64B.*SharedStorage|CtaShape256x128x64B.*SharedStorage' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type hRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check if there's a GroupGemm launcher for SM120
rg -l 'group_gemm.*sm120|GroupGemm.*SM120' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look for where SM120 group GEMM launcher code is
find cpp/tensorrt_llm/kernels/cutlass_kernels -name "*group*gemm*" -o -name "*sm120*group*"Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search more broadly for group gemm and SM120
rg 'group.*gemm|GroupGemm' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h -l | grep -i sm120Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search for the launcher that uses these configs
rg 'candidate_configs.*fp4.*fp4|FP4.*FP4' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h -B 5 -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 49321
🏁 Script executed:
# Look for the get_candidate_configs function and see where it's called
rg 'get_candidate_configs.*sm120|get_candidate_configs.*fp4' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp --type h -A 5 -B 5Repository: NVIDIA/TensorRT-LLM
Length of output: 2166
🏁 Script executed:
# Search for the entire cutlass_heuristic.cpp file to see the full context
wc -l cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 129
🏁 Script executed:
# Read around the get_candidate_configs function that returns these configs
sed -n '500,630p' cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 5826
🏁 Script executed:
# Look for autotuner or launcher code that would process these configs
rg -l 'autotuner|launcher' cpp/tensorrt_llm/kernels/cutlass_kernels/ --type cpp | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 496
🏁 Script executed:
# Look at the complete switch statement in dispatchNVFP4xNVFP4GemmCTAShapeSm120
sed -n '200,350p' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h | grep -A 100 'dispatchNVFP4xNVFP4GemmCTAShapeSm120'Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Get the exact cases handled in the tile config switch for SM120
rg -A 50 'dispatchNVFP4xNVFP4GemmCTAShapeSm120.*T.*D.*A.*B' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 3967
🏁 Script executed:
# Search for all CtaShape cases handled in SM120
rg 'CtaShape.*SM120.*:$' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.hRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Get a better view of the switch cases
rg 'case.*CtaShape.*SM120' cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 45
The low-SMEM filter for SM121 is incomplete.
The <120 KiB filter removes only CtaShape128x128x128B, leaving CtaShape128x128x64B, CtaShape128x256x64B, and CtaShape256x128x64B as candidates. However, the FP4 GEMM dispatcher (dispatchNVFP4xNVFP4GemmCTAShapeSm120) only handles CtaShape128x128x128B, CtaShape128x128x256B, and CtaShape256x128x128B. The CtaShape128x256x64B and CtaShape256x128x64B configs will hit the default case and throw "Config is invalid for FP4 GEMM", preventing the autotuner from reaching a "0 failed tactics" state on SM121.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` around lines
561 - 568, The SM<120 KiB pruning only removes CtaShape128x128x128B but leaves
64B K-tile configs that the FP4 dispatcher
(dispatchNVFP4xNVFP4GemmCTAShapeSm120) doesn't handle, causing an invalid-config
throw; update the pruning logic around candidate_configs and
CutlassGemmConfig::tile_config_sm120 to either remove the additional 64B shapes
(CtaShape128x128x64B, CtaShape128x256x64B, CtaShape256x128x64B) or, better,
whitelist only the shapes the dispatcher supports (CtaShape128x128x128B,
CtaShape128x128x256B, CtaShape256x128x128B) so candidate_configs contains only
dispatchable CutlassTileConfigSM120 values and the autotuner no longer hits the
default invalid-config path.
| using GemmKernel_ = typename GemmGrouped::GemmKernel; \ | ||
| int smem_size = static_cast<int>(sizeof(typename GemmKernel_::SharedStorage)); \ | ||
| static int const kMaxSmem = []() \ | ||
| { \ | ||
| int val = 0; \ | ||
| cudaDeviceGetAttribute(&val, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0); \ | ||
| return val; \ | ||
| }(); \ | ||
| TLLM_CHECK_WITH_INFO(smem_size <= kMaxSmem, \ | ||
| "MoE grouped GEMM requires %d bytes shared memory but device supports %d", smem_size, kMaxSmem); \ |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "moe_gemm_tma_ws_launcher.inl" -o -name "fused_moe_gemm_launcher_sm80.inl" -o -name "moe_gemm_tma_ws_mixed_input_launcher.inl"Repository: NVIDIA/TensorRT-LLM
Length of output: 351
🏁 Script executed:
sed -n '673,682p' ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inlRepository: NVIDIA/TensorRT-LLM
Length of output: 3242
🏁 Script executed:
sed -n '53,68p' ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inlRepository: NVIDIA/TensorRT-LLM
Length of output: 1033
🏁 Script executed:
grep -n "cudaDeviceGetAttribute\|cudaDevAttrMaxSharedMemoryPerBlockOptin\|kMaxSmem" ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inlRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check the context around lines where such a guard would normally appear
# Look for where gemm.can_implement() or similar checks occur
grep -n "can_implement\|initialize()" ./cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 248
This guard still reads the wrong GPU on multi-device hosts.
The hardcoded device 0 in cudaDeviceGetAttribute() and function-static caching of kMaxSmem means a launch on a non-zero device can compare its kernel requirements against the wrong GPU's SMEM limit. This allows unsupported kernels to reach gemm.initialize() or valid kernels to be incorrectly rejected. Use the pattern from fused_moe_gemm_launcher_sm80.inl:53–68, which calls cudaGetDevice() to query the current device's actual limit.
Additionally, moe_gemm_tma_ws_mixed_input_launcher.inl also lacks an equivalent SMEM preflight guard before its gemm.can_implement() call and may encounter the same issue.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl`
around lines 673 - 682, Replace the current function-static kMaxSmem that calls
cudaDeviceGetAttribute(..., 0) with a runtime query that calls cudaGetDevice()
to obtain the current device id and then calls cudaDeviceGetAttribute(&val,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device) so the shared-memory check for
GemmGrouped/GemmKernel_ uses the active GPU; remove the lambda-cached static and
compute smem limit per invocation and keep the TLLM_CHECK_WITH_INFO(smem_size <=
kMaxSmem, ...) guard unchanged except that kMaxSmem is now the per-call value.
Also add the same per-device SMEM preflight guard (compute current device via
cudaGetDevice and compare sizeof(typename GemmKernel_::SharedStorage) against
that device's limit) into moe_gemm_tma_ws_mixed_input_launcher.inl before
calling gemm.can_implement().
93b9c60 to
556d505
Compare
…ared memory on SM121 SM121 (DGX Spark GB10) has 99 KiB shared memory per block vs 228 KiB on SM120 (B200). CUTLASS StageCountAutoCarveout computes pipeline stages assuming 228 KiB, causing TMA warp-specialized MoE grouped GEMM tactics to fail with opaque "Error Internal" on gemm.initialize(). This patch adds two fixes: 1. Runtime SMEM guard in moe_gemm_tma_ws_launcher.inl: checks kernel SharedStorage size against cudaDevAttrMaxSharedMemoryPerBlockOptin before launch, converting opaque CUTLASS errors into clear diagnostics that the autotuner can skip. 2. Heuristic filter in get_candidate_configs_sm120(): on devices with < 120 KiB SMEM, removes tile configs whose SharedStorage exceeds the device limit. Keeps only CtaShape128x128x64B which fits within 99 KiB including FINALIZE epilogue overhead (~80 KiB total). Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
556d505 to
b1f4372
Compare
Summary
SM121 (DGX Spark GB10) has 99 KiB shared memory per block vs 228 KiB on SM120 (B200). CUTLASS
StageCountAutoCarveoutcomputes pipeline stages assuming 228 KiB, causing TMA warp-specialized MoE grouped GEMM tactics to fail with opaque "Error Internal" ongemm.initialize().This patch adds:
Runtime SMEM guard in
moe_gemm_tma_ws_launcher.inl: checks kernelSharedStoragesize againstcudaDevAttrMaxSharedMemoryPerBlockOptinbefore launch, converting opaque CUTLASS errors into clear diagnostics that the autotuner can skip.Heuristic filter in
get_candidate_configs_sm120(): on devices with < 120 KiB SMEM, removes tile configs whose SharedStorage exceeds the device limit. Keeps onlyCtaShape128x128x64Bwhich fits within 99 KiB including FINALIZE epilogue overhead.Impact
TRTLLM_MOE_BACKEND=TRITON(32-40 tok/s vs ~4.8 tok/s with CUTLASS), as the Triton backend JIT-compiles kernels adapted to SM121 constraints. This fix ensures CUTLASS does not crash when used, but TRITON may be the better path for SM121 MoE workloads.Test plan