Skip to content

[None][feat] Optimize super-v3 nvfp4 for better perf#11273

Merged
Wanli-Jiang merged 16 commits intoNVIDIA:mainfrom
Wanli-Jiang:user/williamj/super-v3-opt2
Feb 12, 2026
Merged

[None][feat] Optimize super-v3 nvfp4 for better perf#11273
Wanli-Jiang merged 16 commits intoNVIDIA:mainfrom
Wanli-Jiang:user/williamj/super-v3-opt2

Conversation

@Wanli-Jiang
Copy link
Collaborator

@Wanli-Jiang Wanli-Jiang commented Feb 4, 2026

Detailed Kernel Analysis

1. Kernels Completely Removed (Optimized Out)

These kernels were present in the base profile but are absent from the optimized version:

Kernel Instances (Base) Total Time (Base) Notes
elementwise_kernel (direct_copy BFloat16) 960 860.73 ms Memory copy operations eliminated
cutlass FP4 GEMM (4_4_1_2SM, 256x256x256) 320 113.77 ms Kernel variant consolidated
cutlass3x_sm100 bias_bf16_relu 320 164.30 ms Fused into other operations
flashinfer RMSNormKernel 712 162.15 ms Replaced by fused variant
vectorized_elementwise_kernel (clamp) 320 57.09 ms Fused into other operations
vectorized_elementwise_kernel (pow) 320 52.72 ms Fused into other operations

Total time saved from removed kernels: ~1,410.76 ms

2. Kernels Added (New in Optimized Version)

These kernels appear only in the optimized profile:

Kernel Instances Total Time Purpose
_fused_conv_output_transpose_kernel 320 102.86 ms Fused convolution + transpose
_extract_transpose_prefill_kernel 320 100.99 ms Optimized transpose for prefill
warpSpecializedInvoker (FP4AddBiasResidualPreLayerNorm, variant 1) 320 98.01 ms Fused LN + bias + residual
warpSpecializedInvoker (FP4AddBiasResidualPreLayerNorm, variant 2) 320 87.52 ms Fused LN + bias + residual
fusedRelu2QuantizeKernel 320 69.22 ms Fused ReLU + quantize
FusedAddRMSNormKernel 8 2.45 ms Fused add + RMS norm
FillFunctor 8 0.93 ms Memory initialization

Total time for new kernels: ~462.00 ms

Net savings from kernel fusion: ~948.76 ms

3. Kernels with Significant Time Changes

Kernel Base (ms) Opt (ms) Change (ms) % Change Instances
_chunk_scan_fwd_kernel 670.56 476.10 -194.46 -29.0% 320
quantize_with_block_size 398.67 134.05 -264.62 -66.4% 2304/960
vectorized_elementwise (add) 183.62 58.38 -125.24 -68.2% 1024/320
_layer_norm_fwd_1pass_kernel (codes not changed) 164.13 171.13 +6.99 +4.3% 320
_state_passing_fwd_kernel 127.91 128.88 +0.97 +0.8% 320
_chunk_state_fwd_kernel 323.61 312.06 -11.55 -3.6% 320
routingMainKernel (DeepSeek) (codes not changed) 225.77 237.02 +11.25 +5.0% 320
bmm_Bfloat16_E2m1E2m1 (codes not changed) 413.52 424.28 +10.76 +2.6% 320
causal_conv1d_fwd_kernel 321.42 317.19 -4.24 -1.3% 320
finalizeKernelVecLoad (codes not changed) 139.88 142.41 +2.52 +1.8% 320
nvjet_sm100_tss (codes not changed) 48.17 50.06 +1.89 +3.9% 320
_bmm_chunk_fwd_kernel 31.09 38.20 +7.11 +22.9% 320

4. Core GEMM Operations (We did not touch it, the regression might be from run variances.)

Kernel Base (ms) Opt (ms) Change Notes
bmm_E2m1_E2m1E2m1 (Mamba BMM) 865.87 910.97 +5.2% Slight regression
fmhaSm100f (Flash Attention) 775.96 818.86 +5.5% Slight regression
FP4 GEMM 2_1_1_2SM 568.28 571.95 +0.6% Stable

Summary by CodeRabbit

Release Notes

  • New Features

    • Added fused ReLU2 activation with FP4 quantization for efficient inference.
    • Extended layer normalization with optional high-precision output mode.
  • Performance Improvements

    • Optimized causal convolution with CUDA enhancements and warp-level operations.
    • Improved Mamba2 prefill with fused elementwise operations and Triton acceleration.
    • Enhanced Triton kernel configurations for better parallelism and reduced register pressure.
    • Nemotron model now supports FP4 fused quantization paths.
  • Tests

    • Added comprehensive test coverage for causal convolution, fused activation quantization, and layer normalization features.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@Wanli-Jiang
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34754 [ run ] triggered by Bot. Commit: 81b7dcc

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34820 [ run ] triggered by Bot. Commit: 81b7dcc

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34820 [ run ] completed with state SUCCESS. Commit: 81b7dcc
/LLM/main/L0_MergeRequest_PR pipeline #26860 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@Wanli-Jiang
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34964 [ run ] triggered by Bot. Commit: 082ba85

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34964 [ run ] completed with state SUCCESS. Commit: 082ba85
/LLM/main/L0_MergeRequest_PR pipeline #26974 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/super-v3-opt2 branch 3 times, most recently from 15c7bf2 to 6cd2cae Compare February 6, 2026 08:38
@Wanli-Jiang Wanli-Jiang marked this pull request as ready for review February 6, 2026 08:42
@Wanli-Jiang Wanli-Jiang requested review from a team as code owners February 6, 2026 08:42
@Wanli-Jiang
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35091 [ run ] triggered by Bot. Commit: 6cd2cae

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

This pull request introduces CUDA kernel optimizations for causal convolution, implements fused ReLU-2 activation with FP4 quantization, extends layernorm kernels with high-precision output handling, and integrates FP4 quantization support into PyTorch modules and the Nemotron model. Includes comprehensive test coverage for new and modified operations.

Changes

Cohort / File(s) Summary
CUDA Kernel Optimizations
cpp/tensorrt_llm/kernels/causalConv1d.cu
Optimizes forward kernel with warp-shuffle-based data exchange, per-iteration dual output processing using fused multiply-add, and fast math SiLU activation (__expf, __frcp_rn) for two outputs simultaneously. Adds static_assert constraints on thread count and vector sizes.
Fused Activation Quantization Kernel
cpp/tensorrt_llm/kernels/fusedActivationQuant.cu
Introduces new CUDA kernel and host launcher for fused ReLU-2 activation with FP4/SF quantization, including per-row processing, warp-level reductions, SF scale calculation, and conditional FP4/SF output writing with swizzled layout.
Layernorm Parameter & Header Updates
cpp/tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h, low_latency_layernorm.cuh, ws_layernorm.cuh, ws_layernorm.h
Adds high_precision_normed_output field to param struct; adjusts storage type from PackType<AccumulatorType> to PackType<InputType> across low-latency and warp-specialized kernels; adds output_hp_norm parameter to invokeWSLayerNorm signature.
Layernorm Trait & Implementation
cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu
Extends FP4AddBiasResidualPreLayerNormTraits with _HIGH_PRECISION_NORMED_OUTPUT template parameter; updates invokeWSLayerNormImpl and public invokeWSLayerNorm specializations to accept output_hp_norm and branch kernel variants accordingly.
TorchScript Extension Integration
cpp/tensorrt_llm/thop/CMakeLists.txt, fusedActivationQuant.cpp
Adds fusedActivationQuant.cpp to build; implements fused_relu2_quantize TorchScript op with FP16/BF16 support, allocating FP4 and SF output tensors with proper size calculations and guarding BF16 path.
RMSNorm Quantization Integration
cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp
Extends fused_add_rms_norm_quant with optional output_hp_norm parameter and high_precision_normed_output tensor return; updates kernel launcher signature and allocates high-precision output conditionally.
PyTorch Custom Ops
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Updates fused_add_rms_norm_quant signature to include output_hp_norm parameter and hp_output return; introduces new fused_relu2_quantize op with FP4/SF tensor outputs.
Mamba2 Prefill Fusion
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
Introduces two Triton-based kernels and Python wrappers for extracting/transposing XBC data and fused split/rearrange after convolution, with grid/block sizing for optimal prefill performance.
Mamba2 Metadata Optimization
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py
Adds Triton-accelerated cu_seqlens_to_chunk_indices_offsets path; introduces pre-allocated buffers (_arange_buffer, _arange_buffer_long, _cu_seqlens_long) for reduced host-device allocations in prepare().
Mamba2 Mixer Fused Ops
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Replaces explicit split/transpose sequences with fused workflows using extract_transpose_xbc_prefill and fused_split_rearrange_after_conv1d; adjusts preallocation dtype to match zxbcdt.
Mamba2 Autotune Configs
tensorrt_llm/_torch/modules/mamba/ssd_bmm.py, ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py
Expands Triton autotune configuration space with smaller block sizes (32/32/32), reduced register pressure configs, and mixed-size coverage; ssd_chunk_scan.py replaces exp() with exp2() for faster computation.
MLP Fused ReLU2 Path
tensorrt_llm/_torch/modules/mlp.py
Introduces fused_relu2_quantize path scaffolding via _fused_checked flag; conditionally dispatches activation through fused kernel when enabled, wrapping result in Fp4QuantizedTensor.
RMSNorm High-Precision Output
tensorrt_llm/_torch/modules/rms_norm.py
Adds return_hp_output parameter to RMSNorm; captures fused results tuple and conditionally extracts/returns high-precision normed output alongside residual when enabled.
Nemotron Model FP4 Support
tensorrt_llm/_torch/models/modeling_nemotron_h.py
Integrates NVFP4 fused paths with scale attachment logic; extends forward signatures to accept FP4-quantized tensors; propagates residuals through layers returning (hidden_states, residual) tuples; routes MoE logits through high-precision inputs.
Benchmarking Tool
tensorrt_llm/tools/layer_wise_benchmarks/runner.py
Updates non-residual-fusion forward loop to pass residual kwarg and handle dual return types (tuple or single value) from layers.
Causal Conv1D Tests
tests/unittest/_torch/modules/mamba/test_causal_conv1d.py
Adds comprehensive test suite validating causal_conv1d_fwd against PyTorch reference, covering correctness, batch sizes, kernel widths, and initial state handling.
Fused Elementwise Ops Tests
tests/unittest/_torch/modules/mamba/test_fuse_elementwise_ops.py
Validates extract_transpose_xbc_prefill and fused_split_rearrange_after_conv1d against references across multiple dimension configurations and data types.
Mamba2 Metadata Tests
tests/unittest/_torch/modules/mamba/test_mamba2_metadata.py
Tests cu_seqlens_to_chunk_indices_offsets Triton implementation against reference for empty, single, multiple sequences, and various chunk sizes.
Fused Activation Quantization Tests
tests/unittest/_torch/modules/test_fused_activation_quant.py
Validates fused_relu2_quantize kernel against separate relu2 + fp4_quantize operations; includes zero-input edge cases and scale sensitivity testing.
Fused RMSNorm Quantization Tests
tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py
Comprehensive tests for fused_add_rms_norm_quant with/without high-precision output; validates shapes, dtypes, and numerical correctness across configurations and batch sizes.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.65% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning PR description is incomplete; only contains kernel profiling data table without addressing required template sections for summary, test coverage, or checklist completion. Add clear Description section explaining what changes were made and why. Explicitly list Test Coverage confirming existing/new tests validate changes. Complete PR Checklist by confirming guidelines, test cases, dependencies, CODEOWNERS, and documentation updates.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title describes a feature related to optimizing nvfp4 for super-v3 models and aligns with actual changes across multiple files (quantization kernels, fused operations, MLP/RMSNorm modules).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py (1)

4-4: ⚠️ Potential issue | 🟡 Minor

Copyright year should be updated to 2026.

This file (and the other three files in this review) has 2022-2024 in the NVIDIA copyright header, but the files are being meaningfully modified in 2026. As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification."

Update to 2022-2026 across all four modified files.

tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Copyright year should be updated to 2026.

The file has meaningful modifications in this PR. As per coding guidelines, source files should "contain an NVIDIA copyright header with the year of latest meaningful modification."

cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu (1)

1-2: ⚠️ Potential issue | 🟡 Minor

Copyright year should be updated to 2026.

This file has meaningful modifications (new template parameter, new branching logic). As per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.

tensorrt_llm/_torch/models/modeling_nemotron_h.py (2)

707-709: ⚠️ Potential issue | 🟡 Minor

Python 3.10+ type syntax breaks 3.8 compatibility.

Lines 707 and 709 use torch.Tensor | None and lowercase tuple[...], which require Python 3.10+ at runtime. The rest of the file (and the imports at line 17) consistently uses Optional[...] and Tuple[...] from typing. There is no from __future__ import annotations to enable deferred evaluation.

Proposed fix
-        residual: torch.Tensor | None = None,
+        residual: Optional[torch.Tensor] = None,
         attn_metadata: Optional[AttentionMetadata] = None,
-    ) -> tuple[torch.Tensor, torch.Tensor | None]:
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

As per coding guidelines, "The code developed for TensorRT-LLM should conform to Python 3.8+."


623-652: ⚠️ Potential issue | 🔴 Critical

Remove dead code that references non-existent attributes.

Lines 623–638 contain code copied from NemotronHForCausalLM.__init__ that references self.model, self.draft_model, self.epilogue, and self.spec_worker—none of which exist in NemotronHMTPDecoderLayer (which inherits from NemotronHLayerDecoderLayer).

This block is currently unreachable because sublayer_model_config (lines 785–791) does not include spec_config, so model_config.spec_config is always None. However, if spec_config were ever passed to the sublayer config, this code would crash with AttributeError. Remove this block to prevent confusion and latent bugs.

🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py`:
- Around line 253-261: The branch that handles non-tuple hidden_states can
receive an Fp4QuantizedTensor which lacks .view(), so update the else branch in
the block around hidden_states / hidden_states_hp (and before creating
hidden_states_hp_2d) to either tighten the type union or detect and unwrap the
quantized wrapper: if isinstance(hidden_states_hp, Fp4QuantizedTensor) (or check
hasattr(hidden_states_hp, "tensor"/"values"/"to_torch")), assign
hidden_states_hp = hidden_states_hp.tensor (or the correct backing
attribute/method) to get a regular torch.Tensor, then assert shape and call
.view(-1, self.hidden_dim) to produce hidden_states_hp_2d; alternatively, remove
Fp4QuantizedTensor from the accepted non-tuple type so only tuple paths reach
this code.

In `@tensorrt_llm/_torch/modules/mlp.py`:
- Around line 127-142: In _fused_relu2_quant, the code always casts inputs to
torch.bfloat16 which can change precision for float16 models and add overhead;
update the logic in _fused_relu2_quant to preserve the input dtype when it is
torch.float16 or torch.bfloat16 (only convert if dtype is neither), then call
torch.ops.trtllm.fused_relu2_quantize with that preserved dtype path (the C++
invokeFusedRelu2Quantize supports both fp16 and bf16), and optionally emit a
debug log when an unexpected dtype is converted; reference _fused_relu2_quant,
torch.ops.trtllm.fused_relu2_quantize, and self.down_proj.input_scale to locate
where to adjust the cast and add the log.

In `@tensorrt_llm/_torch/modules/rms_norm.py`:
- Around line 140-152: Update the method's return type annotation to reflect the
three actual possible signatures: Fp4QuantizedTensor, (Fp4QuantizedTensor,
Tensor) and (Fp4QuantizedTensor, Tensor, Tensor) — i.e. include the 3-tuple
variant returned when return_hp_output is True and has_residual is True;
reference the symbols return_hp_output, has_residual, hidden_states_fused,
residual_out and high_precision_normed_output to locate the forward method's
annotation and remove the impossible (no-residual + return_hp_output) variant
from the union.

In `@tests/unittest/_torch/modules/mamba/test_causal_conv1d.py`:
- Around line 44-47: Remove the duplicated assertion that repeats
conv_weight.shape[0] == dim; in the test where conv_weight is validated
(variable conv_weight in test_causal_conv1d), delete the redundant assert on
conv_weight.shape[0] so only the unique shape checks remain (assert
conv_weight.shape[0] == dim, assert conv_weight.shape[1] == 1, assert
conv_weight.shape[2] == dconv).

In `@tests/unittest/_torch/modules/test_fused_activation_quant.py`:
- Around line 72-87: The helper dequantize_fp4 function currently ignores its
scale and scale2 parameters and only unpacks FP4 indices via E2M1_VALUES,
producing incorrect values and is unused; either remove dequantize_fp4 or
implement proper dequantization by applying the scales to the unpacked values
(use scale and scale2 to reconstruct floating outputs for the two interleaved
channels) and ensure tests call it; locate the function dequantize_fp4 and
E2M1_VALUES in this test file and either delete the unused function or update it
to map low/high indices through E2M1_VALUES and then multiply/add the
appropriate scale/scale2 per element so the returned tensor reflects correct
dequantized floats.
🧹 Nitpick comments (12)
tensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py (1)

375-457: Autotune expansion for _chunk_state_varlen_kernel is consistent with the other kernels.

Same config structure as _chunk_state_fwd_kernel. No issues.

Note: the tl.exp() calls in this file (e.g., lines 344, 349, 564, 611) were not converted to tl.math.exp2(... * 1.4426950408889634) like in ssd_chunk_scan.py. If the exp2 optimization yields measurable gains in ssd_chunk_scan, it may be worth applying here as well for consistency.

cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh (1)

1-2: Copyright year may need updating.

The copyright header says 2024, but this file contains meaningful modifications for 2026. As per coding guidelines, source files should have the year of latest meaningful modification in the copyright header.

Suggested fix
-* Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+* Copyright (c) 2024-2026, NVIDIA CORPORATION.  All rights reserved.
tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py (1)

34-63: tuple[...] syntax requires Python 3.9+.

Lines 39 and 66 use lowercase tuple[...] for type hints. Per coding guidelines, Python 3.8+ compatibility is expected. Use Tuple from typing for consistency with the rest of the codebase (e.g., cpp_custom_ops.py uses Tuple).

Proposed fix
+from typing import Tuple
+
 def rms_norm_ref(
     hidden_states: torch.Tensor,
     residual: torch.Tensor,
     gamma: torch.Tensor,
     eps: float,
-) -> tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor]:

And similarly for get_swizzled_sf_indices:

-def get_swizzled_sf_indices(m: int, n: int, sf_vec_size: int = 16) -> list[int]:
+def get_swizzled_sf_indices(m: int, n: int, sf_vec_size: int = 16) -> List[int]:

As per coding guidelines, "The code developed for TensorRT-LLM should conform to Python 3.8+".

cpp/tensorrt_llm/thop/fusedActivationQuant.cpp (1)

42-88: int64_tint narrowing cast on dimensions m and n.

Lines 69–70 (and 76–77) cast m and n from int64_t to int via static_cast<int>(). If a tensor dimension ever exceeds INT_MAX, this silently truncates. While unlikely for FP4 quantization workloads, it's worth adding a guard — similar to the existing TORCH_CHECK assertions — to reject oversized inputs early.

Proposed guard
     TORCH_CHECK(n % sf_vec_size == 0, "N must be divisible by sf_vec_size.");
+    TORCH_CHECK(m <= INT_MAX && n <= INT_MAX, "Dimensions exceed int range.");
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py (1)

146-146: Unused unpacked variable conv_dim.

Ruff flags conv_dim as unused (RUF059). Prefix with _ to clarify intent.

Proposed fix
-    conv_dim, num_prefill_tokens = xbc.shape
+    _conv_dim, num_prefill_tokens = xbc.shape
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (2)

26-60: num_seqs as tl.constexpr will trigger Triton recompilation for every distinct batch size.

Since num_seqs is declared tl.constexpr, the loop for seq_idx in range(num_seqs - 1) is unrolled at compile time. Every time the batch size changes, Triton must re-specialize and recompile this kernel. If the serving workload sees variable batch sizes (common in production), the recompilation overhead could be noticeable on the first occurrence of each batch size.

Consider whether num_seqs can be a runtime parameter instead, with a tl.range-based loop (or a while loop with a dynamic bound). Alternatively, if the max batch size is small and bounded, the current approach is fine — just be aware of the trade-off.


225-233: Per-element Python-loop access to num_cached_tokens_per_seq causes host-device syncs.

Lines 225–227 index into num_cached_tokens_per_seq (a CUDA tensor) in a Python for loop, triggering a host-device synchronization on each iteration. For small num_contexts this is typically acceptable, but it partially negates the benefit of the Triton-based chunk computation.

If num_cached_tokens_per_seq is a CUDA tensor, a vectorized approach avoids the per-element syncs:

Suggested vectorized path
-            initial_states = [
-                num_cached_tokens_per_seq[i] > 0 for i in range(num_contexts)
-            ]
-            self.use_initial_states = any(initial_states)
+            initial_states_tensor = num_cached_tokens_per_seq[:num_contexts] > 0
+            self.use_initial_states = initial_states_tensor.any().item()
             if self.use_initial_states:
-                self.has_initial_states[:num_contexts] = torch.tensor(
-                    initial_states, dtype=torch.bool)
+                self.has_initial_states[:num_contexts] = initial_states_tensor
cpp/tensorrt_llm/kernels/fusedActivationQuant.cu (2)

42-50: Missing braces and const qualifiers on local variables.

Per coding guidelines:

  • Line 49–50: The if body must be a compound (brace-delimited) statement.
  • Lines 44–47: SFScaleVal, numColThreads, numColVecs, rowIdx are never modified after initialization and should be const.
  • SFScaleVal and other identifiers starting with uppercase violate the camelCase-with-lowercase-first convention for local variables (e.g., sfScaleVal).
Proposed fix for the early-return guard
-    if (rowIdx >= m)
-        return;
+    if (rowIdx >= m)
+    {
+        return;
+    }

As per coding guidelines: "The statement forming the body of a switch, while, do..while, or for statement must be a compound statement", "If and else statements should always be followed by brace-delimited statements", and "A variable that is not modified after its initialization should be declared as const".


70-71: Warp shuffle reduction only reduces across lane pairs — verify this is intentional.

__shfl_xor_sync(0xffffffff, localMax, 1) exchanges with exactly the adjacent lane (XOR with 1). Since kEltsPerThread = 8 and kSfVecSize = 16, each pair of consecutive threads covers one 16-element SF vector, so a single XOR-1 shuffle is sufficient. This is correct but subtle — a brief comment noting the pairing would help future readers.

cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu (1)

122-148: Four-way branching is correct but could be simplified.

The {rms, non-rms} × {hp, non-hp} matrix produces 4 nearly identical _invoke(...) calls differing only in two bool template args. A small helper template or constexpr dispatch could deduplicate this, but since template instantiation requires compile-time constants, the current explicit branching is functionally correct and not unusual for CUDA kernel selection code.

tensorrt_llm/_torch/models/modeling_nemotron_h.py (2)

1-1: Copyright year is outdated.

The header says 2022-2024 but this file has meaningful modifications in 2026. As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification."

Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

421-422: Redundant zeros_like allocation — dead code in practice.

NemotronHModel.forward (line 527) already initializes residual = torch.zeros_like(hidden_states) before the layer loop and always passes it, so this fallback is never reached in the normal call path. The NemotronHMTPDecoderLayer also handles None residual in its own forward without calling super().forward().

If this is kept as a defensive guard, consider a cheaper sentinel or an assertion instead, to avoid silently masking caller bugs:

Suggested alternative
-        if residual is None:
-            residual = torch.zeros_like(hidden_states)
+        assert residual is not None, (
+            "NemotronHLayer.forward expects a residual tensor; "
+            "caller must initialize it (e.g. torch.zeros_like(hidden_states))."
+        )

@nv-guomingz nv-guomingz requested a review from JadoTu February 6, 2026 09:02
@tensorrt-cicd
Copy link
Collaborator

PR_Github #35091 [ run ] completed with state SUCCESS. Commit: 6cd2cae
/LLM/main/L0_MergeRequest_PR pipeline #27086 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35691 [ run ] triggered by Bot. Commit: d397bbc

@Wanli-Jiang
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35691 [ run ] completed with state SUCCESS. Commit: d397bbc
/LLM/main/L0_MergeRequest_PR pipeline #27567 (Partly Tested) completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35735 [ run ] triggered by Bot. Commit: d397bbc

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/super-v3-opt2 branch from d397bbc to 66f64d0 Compare February 12, 2026 06:32
@Wanli-Jiang
Copy link
Collaborator Author

/bot run --stage-list "A30-AutoDeploy-1,RTX5090-PackageSanityCheck-PY312-UB2404,A10-PackageSanityCheck-PY310-UB2204,GB200-4_GPUs-PyTorch-1,DGX_B200-4_GPUs-PyTorch-1,DGX_H100-4_GPUs-PyTorch-Others-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35746 [ run ] triggered by Bot. Commit: 66f64d0

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35746 [ run ] completed with state SUCCESS. Commit: 66f64d0
/LLM/main/L0_MergeRequest_PR pipeline #27610 (Partly Tested) completed with status: 'SUCCESS'

@Wanli-Jiang
Copy link
Collaborator Author

/bot skip --comment "all tests are passed within different runs and passed at local runs"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35787 [ skip ] triggered by Bot. Commit: 66f64d0

@tensorrt-cicd
Copy link
Collaborator

PR_Github #35787 [ skip ] completed with state SUCCESS. Commit: 66f64d0
Skipping testing for commit 66f64d0

@Wanli-Jiang Wanli-Jiang merged commit 421eb9e into NVIDIA:main Feb 12, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants