Skip to content

[None][feat] Optimize mamba2 _chunk_scan_fwd_kernel#11345

Open
JadoTu wants to merge 1 commit intoNVIDIA:mainfrom
JadoTu:mamba2_prefill_kernel_tune
Open

[None][feat] Optimize mamba2 _chunk_scan_fwd_kernel#11345
JadoTu wants to merge 1 commit intoNVIDIA:mainfrom
JadoTu:mamba2_prefill_kernel_tune

Conversation

@JadoTu
Copy link
Collaborator

@JadoTu JadoTu commented Feb 6, 2026

Summary by CodeRabbit

  • Refactor
    • Optimized internal kernel configuration management for improved code organization and maintainability.

Description

  1. The insight comes from the suggestion "The 2.00 theoretical warps per scheduler this kernel can issue according to its occupancy are below the hardware maximum of 16. This kernel's theoretical occupancy (12.5%) is limited by the number of required registers, and the required amount of shared memory." from ncu.
  2. Optimize the triton _chunk_scan_fwd_kernel by adding BLOCK_SIZE_DSTATE into autotune configs. This variable will be 128 in current Nemotron SuperV3 model but is not the best in long context scene (50k 2k).
  3. About 30% perf gain on this kernel with this PR.
image image

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
@JadoTu JadoTu requested review from a team as code owners February 6, 2026 09:38
@nv-guomingz
Copy link
Collaborator

Let's merge #11273 firstly

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

Refactored Triton autotune configuration in SSD chunk scan module from static explicit configs to a base configs list that dynamically generates configurations for multiple DSTATE values (32, 64, 128), replacing hardcoded BLOCK_SIZE_DSTATE values.

Changes

Cohort / File(s) Summary
Triton Configuration Refactoring
tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py
Introduced centralized _BASE_CONFIGS list with base Triton configuration tuples and replaced static autotune decorator configs with dynamic generator that expands configurations for each DSTATE value. Removed inline BLOCK_SIZE_DSTATE from kernel launch, delegating to dynamically generated config.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically identifies the main change: optimizing the mamba2 _chunk_scan_fwd_kernel, which aligns with the actual code changes and performance improvements shown in the PR.
Description check ✅ Passed The PR description explains the motivation (NCU profiler suggestion), the change (BLOCK_SIZE_DSTATE autotune), and includes performance results (~30% gain), but is missing explicit test coverage information.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py (1)

4-4: ⚠️ Potential issue | 🟡 Minor

Update NVIDIA copyright year to 2026.

This PR makes a meaningful modification to the file, so the copyright end-year should be updated. As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification."

Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py`:
- Around line 129-137: The autotune config generator currently yields
BLOCK_SIZE_DSTATE values (d) that can be smaller than the kernel's runtime
dstate, causing silent incorrect results; update the triton.autotune configs
comprehension (the triton.autotune decorator that iterates over _BASE_CONFIGS
and d in [32,64,128]) to filter out any config where BLOCK_SIZE_DSTATE (d) < the
kernel's dstate (available in each base config's kwargs, e.g.,
cfg.kwargs['dstate'] or cfg.kwargs.get('dstate')), so only configs with d >=
dstate are included; modify the list comprehension to check that condition
before constructing triton.Config to ensure correctness.

Comment on lines 129 to 137
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=2,
),
triton.Config({
**cfg.kwargs, 'BLOCK_SIZE_DSTATE': d
},
num_stages=cfg.num_stages,
num_warps=cfg.num_warps) for cfg in _BASE_CONFIGS
for d in [32, 64, 128]
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py | head -200

Repository: NVIDIA/TensorRT-LLM

Length of output: 6043


🏁 Script executed:

cat -n tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py | tail -250 | head -200

Repository: NVIDIA/TensorRT-LLM

Length of output: 8578


🌐 Web query:

Triton autotune prune_configs_by API documentation

💡 Result:

In Triton, prune_configs_by is not a standalone API—it’s an optional parameter to the triton.autotune(...) decorator.

Where it’s documented

It’s documented on the official triton.autotune page. (triton-lang.org)

prune_configs_by parameter (what Triton expects)

prune_configs_by should be a dict that can contain:

  • "perf_model": a performance model used to predict runtime for configs (returns predicted runtime).
  • "top_k": how many configs to keep for actual benchmarking.
  • "early_config_prune": an optional function to prune configs early.

The pruning function signature is:

  • prune_configs_by(configs: List[triton.Config], named_args: Dict[str, Any], **kwargs) -> List[triton.Config]

and it must return at least one config. (triton-lang.org)

Citations:


🏁 Script executed:

rg "prune_configs_by" --type py -A 3 -B 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

sed -n '340,420p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 4120


🏁 Script executed:

sed -n '280,360p' tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py | cat -n

Repository: NVIDIA/TensorRT-LLM

Length of output: 4102


🏁 Script executed:

grep -r "triton" setup.py requirements.txt pyproject.toml 2>/dev/null | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1613


🏁 Script executed:

grep -A 5 -B 5 "version.parse.*triton" tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 302


Autotuner can select a BLOCK_SIZE_DSTATE < dstate config, producing silently incorrect results.

When BLOCK_SIZE_DSTATE <= 128 (true for all 33 generated configs), the kernel takes the single-tile branch (lines 36–51) which loads and accumulates only [0, BLOCK_SIZE_DSTATE) elements along the dstate axis—there is no loop. If the autotuner selects, e.g., BLOCK_SIZE_DSTATE=32 for a model with dstate=128, the dot product accumulates only 32 of 128 elements, silently producing wrong output. Triton's autotuner benchmarks latency only—it does not validate correctness—so the smaller (faster, but wrong) tile can win even with per-dstate caching, since initial tuning for a given dstate could select an incompatible BLOCK_SIZE_DSTATE.

Add config pruning to discard configs where BLOCK_SIZE_DSTATE < dstate:

Proposed fix
+def _prune_chunk_scan_configs(configs, named_args, **kwargs):
+    dstate = named_args['dstate']
+    return [c for c in configs if c.kwargs['BLOCK_SIZE_DSTATE'] >= dstate]
+
+
 `@triton.autotune`(
     configs=[
         triton.Config({
             **cfg.kwargs, 'BLOCK_SIZE_DSTATE': d
         },
                       num_stages=cfg.num_stages,
                       num_warps=cfg.num_warps) for cfg in _BASE_CONFIGS
         for d in [32, 64, 128]
     ],
     key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
+    prune_configs_by={'early_config_prune': _prune_chunk_scan_configs},
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32
},
num_stages=4,
num_warps=2,
),
triton.Config({
**cfg.kwargs, 'BLOCK_SIZE_DSTATE': d
},
num_stages=cfg.num_stages,
num_warps=cfg.num_warps) for cfg in _BASE_CONFIGS
for d in [32, 64, 128]
],
def _prune_chunk_scan_configs(configs, named_args, **kwargs):
dstate = named_args['dstate']
return [c for c in configs if c.kwargs['BLOCK_SIZE_DSTATE'] >= dstate]
`@triton.autotune`(
configs=[
triton.Config({
**cfg.kwargs, 'BLOCK_SIZE_DSTATE': d
},
num_stages=cfg.num_stages,
num_warps=cfg.num_warps) for cfg in _BASE_CONFIGS
for d in [32, 64, 128]
],
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
prune_configs_by={'early_config_prune': _prune_chunk_scan_configs},
)
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py` around lines 129 - 137,
The autotune config generator currently yields BLOCK_SIZE_DSTATE values (d) that
can be smaller than the kernel's runtime dstate, causing silent incorrect
results; update the triton.autotune configs comprehension (the triton.autotune
decorator that iterates over _BASE_CONFIGS and d in [32,64,128]) to filter out
any config where BLOCK_SIZE_DSTATE (d) < the kernel's dstate (available in each
base config's kwargs, e.g., cfg.kwargs['dstate'] or cfg.kwargs.get('dstate')),
so only configs with d >= dstate are included; modify the list comprehension to
check that condition before constructing triton.Config to ensure correctness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants