[None][feat] Optimize by fuse nvfp4_quant to layernorm_gated for mamba2_mixer#11473
Conversation
📝 WalkthroughWalkthroughThis pull request introduces NVIDIA FP4 (NVFP4) quantization support throughout TensorRT-LLM, including new fused CUDA kernels for quantization-aware operations (ReLU2, GatedRMSNorm), PyTorch bindings, integration with existing layernorm infrastructure for optional high-precision outputs, and performance optimizations for Mamba2 prefill via Triton-accelerated fusion kernels. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host/PyTorch
participant ToOp as Torch Op<br/>(fusedRelu2Quantize)
participant Kernel as CUDA Kernel<br/>(fusedRelu2QuantizeKernel)
participant Memory as GPU Memory
Host->>ToOp: fused_relu2_quantize(input, sfScale,<br/>sfVecSize)
Note over ToOp: Validate input shape (2D [M,N]),<br/>dtype (FP16/BF16),<br/>N divisible by sfVecSize
ToOp->>Memory: Allocate output_fp4 [M, N/4]<br/>(swizzled FP4 layout)
ToOp->>Memory: Allocate output_sf (swizzled<br/>scale-factor layout)
ToOp->>Kernel: Launch with grid/block<br/>CUDA stream
Kernel->>Kernel: Per-row processing:<br/>1. Pack & compute ReLU2(input)
Kernel->>Kernel: 2. Per-thread local max,<br/>cross-thread reduction
Kernel->>Kernel: 3. Compute SF (scale factor)<br/>quantize to FP8 (e4m3)
Kernel->>Memory: Write SF to output_sf
Kernel->>Kernel: 4. Quantize ReLU2 output<br/>to FP4 using SF scale
Kernel->>Memory: Write FP4-encoded output<br/>to output_fp4
ToOp->>Host: Return (output_fp4, output_sf)<br/>tuple
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (9)
tensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py (1)
4-4:⚠️ Potential issue | 🟡 MinorUpdate copyright year to include 2026.
The copyright header reads
2022-2024but this file is being modified in 2026. As per coding guidelines, "All source files must contain an NVIDIA copyright header with the year of latest meaningful modification."-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py (1)
4-4:⚠️ Potential issue | 🟡 MinorUpdate the copyright year to include 2026.
The NVIDIA copyright header shows
2022-2024but this file has meaningful modifications in 2026.-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "All source files must contain an NVIDIA copyright header with the year of latest meaningful modification."
tensorrt_llm/_torch/models/modeling_nemotron_h.py (2)
1-1:⚠️ Potential issue | 🟡 MinorUpdate copyright year to 2026.
The copyright header states
2022-2024but this file is being meaningfully modified in 2026. As per coding guidelines, "Include NVIDIA copyright header on ALL new files and update year on modified files."Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
243-299:⚠️ Potential issue | 🟡 MinorUpdate
MLP.forwardtype annotation to acceptFp4QuantizedTensor.
MLP.forward()(line 263:self.shared_experts(hidden_states)) receives a parameter with type annotationx: torch.Tensor, but whenhidden_statesis unpacked from the tuple (line 252), it can beFp4QuantizedTensorper the type annotation on line 245–246. While the code works at runtime becauseLinear.forward(used internally by MLP) acceptsUnion[torch.Tensor, Fp4QuantizedTensor], the type annotation onMLP.forwardshould be updated toUnion[torch.Tensor, Fp4QuantizedTensor]for type consistency. Similarly, verify thatGatedMLP.forwardhas the same annotation if used in similar contexts.tensorrt_llm/_torch/modules/rms_norm.py (1)
1-1:⚠️ Potential issue | 🟡 MinorCopyright year should be updated to 2026.
The header says
Copyright (c) 2025but this file is being meaningfully modified in 2026. As per coding guidelines, update the year on modified files.Fix copyright year
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
1-1:⚠️ Potential issue | 🟡 MinorCopyright year should be updated to include 2026.
The header says
Copyright (c) 2022-2024but this file is being meaningfully modified in 2026. As per coding guidelines, update the year on modified files.Fix copyright year
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.cpp/tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h (1)
1-2:⚠️ Potential issue | 🟡 MinorCopyright year should be updated to include 2026.
The header says
Copyright (c) 2024but this file is being meaningfully modified in 2026.Fix copyright year
-* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved.cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh (1)
1-2:⚠️ Potential issue | 🟡 MinorCopyright year should be updated to include 2026.
Fix copyright year
-* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved.cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu (1)
1-2:⚠️ Potential issue | 🟡 MinorUpdate copyright year to 2026.
The file has meaningful modifications (new template parameter, new branching logic) but the copyright header still says 2024.
-* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, "All source files must contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Fix all issues with AI agents
In `@cpp/tensorrt_llm/kernels/fusedGatedRMSNormQuant/CMakeLists.txt`:
- Around line 1-2: Update the SPDX header in CMakeLists.txt to reflect the
latest modification year by changing the copyright year range "2022-2024" to
include 2026 (e.g., "2022-2026" or simply "2026" per project convention) so the
new file uses the required NVIDIA copyright header with the correct latest year.
In `@cpp/tensorrt_llm/kernels/fusedGatedRMSNormQuant/fusedGatedRMSNormQuant.cu`:
- Around line 716-762: Add explicit validation that params.N is divisible by
params.groupSize (and that params.groupSize is compatible with the per-thread
element granularity) before selecting grouped kernels: in
invokeFusedGatedRMSNormQuant check if (params.N % params.groupSize != 0) and
fail fast (assert/log + return or throw) so ngroups = params.N /
params.groupSize cannot drop trailing elements; also add/ensure a check that
params.groupSize % ELTS_PER_THREAD == 0 (or equivalent constant used by
fusedGatedRMSNormQuantKernelGrouped) and reject invalid groupSize values so the
grouped kernel assumptions hold. Use the existing symbols params.N,
params.groupSize, invokeFusedGatedRMSNormQuant, and
fusedGatedRMSNormQuantKernelGrouped when locating where to insert the checks.
In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py`:
- Around line 1023-1035: The creation of output_sf incorrectly wraps scale_shape
in a tuple causing a shape mismatch; call fp4_utils.get_fp4_shape to get
scale_shape (already an int) and allocate output_sf with
input.new_empty(scale_shape, dtype=torch.uint8) instead of
input.new_empty((scale_shape,), ...). Update the allocation in the
fused_relu2_quantize registered function (trtllm::fused_relu2_quantize) to pass
scale_shape directly to new_empty.
In `@tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py`:
- Around line 751-753: The hardcoded BLOCK_SIZE_DSTATE=32 drops dstate
dimensions >32; instead compute BLOCK_SIZE_DSTATE based on dstate: if dstate <=
32 set BLOCK_SIZE_DSTATE=32 (to keep the optimization), otherwise compute
BLOCK_SIZE_DSTATE = max(triton.next_power_of_2(dstate), 16) (the original
expression) so larger dstate values are handled correctly; update the assignment
near the existing BLOCK_SIZE_DSTATE line and ensure any downstream logic (e.g.,
the single-load path guarded by "if BLOCK_SIZE_DSTATE <= 128" and uses
offs_k_dstate) will see the correct value at runtime.
In `@tests/unittest/_torch/modules/test_fused_activation_quant.py`:
- Around line 71-73: The type annotations in quantize_nvfp4_ref (and the other
signature at line 121) use the Python 3.9+ built-in generic syntax tuple[...],
which breaks 3.8 compatibility; fix by either adding "from __future__ import
annotations" at the top of the file (after imports) to defer evaluation of
annotations, or change the annotations to use typing.Tuple[torch.Tensor,
torch.Tensor] and add an import for typing as needed so quantize_nvfp4_ref and
the other function use typing.Tuple instead of tuple[...].
In `@tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py`:
- Line 39: Replace Python 3.9-style builtin generics with typing generics to
restore 3.8 compatibility: import Tuple and List from typing (or use
typing.Tuple / typing.List) and change the return annotation "->
tuple[torch.Tensor, torch.Tensor]" to "-> Tuple[torch.Tensor, torch.Tensor]" and
any "list[...]" occurrences (e.g., the usage at the location around line 66) to
"List[...]" so the function signatures and type hints use
typing.Tuple/typing.List instead of builtin tuple/list generics.
🧹 Nitpick comments (26)
tensorrt_llm/tools/layer_wise_benchmarks/runner.py (1)
445-456:inspect.signature()called on every forward pass — precompute outside the hot path.
inspect.signature(layer.forward)is evaluated on every invocation offorward(), but the result is invariant across calls. Since this is a benchmarking runner whereforwardis called in tight loops, the overhead accumulates unnecessarily.Precompute the per-layer residual fusion flag once during
__init__:♻️ Suggested refactor
+ residual_fusion_flags = { + idx: "residual" in inspect.signature(model.model.layers[idx].forward).parameters + for idx in layer_indices + } + def forward(position_ids, hidden_states, attn_metadata, residual, **kwargs): # TODO: to be more general, we should call DecoderModel.forward for layer_idx in layer_indices: layer = model.model.layers[layer_idx] - residual_fusion = "residual" in inspect.signature(layer.forward).parameters + residual_fusion = residual_fusion_flags[layer_idx] if residual_fusion: hidden_states, residual = layer( position_ids, hidden_states, attn_metadata, residual, **kwargs ) else: hidden_states = layer(position_ids, hidden_states, attn_metadata, **kwargs) return hidden_states, residualtensorrt_llm/_torch/modules/mamba/ssd_chunk_state.py (1)
373-457:_chunk_state_varlen_kernelis missing the(64, 128, 64)low-register-pressure config present in_chunk_state_fwd_kernel.
_chunk_state_fwd_kernelhas two configs under "Low register pressure" (lines 159–177):(64, 64, 64)and(64, 128, 64)both withnum_stages=2. The varlen variant only has(64, 64, 64)before jumping to "Original configs." If this was intentional, a comment explaining the difference would help; otherwise, consider adding the missing entry for consistency.Suggested addition after line 408
triton.Config( { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64 }, num_stages=2, num_warps=4), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64 + }, + num_stages=2, + num_warps=4), # Original configs for larger dimensionscpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu (1)
424-429: Consider using__ldgfor weight loads in the update kernel as well.The forward kernel (line 128) uses
__ldgfor read-only cache hints on weight loads, but the update kernel does a plain load. While the update kernel is typically on the latency-insensitive single-token path, applying__ldghere would be trivially consistent and can't hurt.♻️ Optional: add __ldg to update kernel weight loads
`#pragma` unroll for (int i = 0; i < kWidth; ++i) { - weight_vals[i] = float(weight[i * params.weight_width_stride]); + weight_vals[i] = float(__ldg(&weight[i * params.weight_width_stride])); }tests/unittest/_torch/modules/mamba/test_causal_conv1d.py (2)
40-40: Unused variableseq_lenflagged by static analysis.The unpacked
seq_lenis not referenced in the function body. Prefix it with_to signal it's intentionally unused.♻️ Proposed fix
- batch_size, dim, seq_len = x.shape + batch_size, dim, _seq_len = x.shape
74-109: Good parametric coverage for basic correctness.3 dtypes × 2 SiLU modes × 4 dimensions = 24 test configurations, with appropriate tolerance for mixed-precision comparison. One thing to consider: the CUDA kernel has two code paths — vectorized (when
seqlen % kNElts == 0) and non-vectorized. All tests here use sequence lengths that are multiples of 8, so the non-vectorized (kIsVecLoad = false) path is never exercised.Adding a test with an odd sequence length (e.g.,
seq_len = 37) would cover theBlockLoadT/BlockStoreT(warp-transpose) fallback path.tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py (1)
476-486: Consider extracting the magic constant1.4426950408889634(log₂(e)).This constant appears 4 times in the kernel. Defining it once at module level (e.g.,
LOG2_E = 1.4426950408889634) would improve readability and reduce duplication.The
exp→exp2transformation is mathematically correct.Suggested change
At module level:
LOG2_E = 1.4426950408889634Then in the kernel, e.g.:
- tl.math.exp2(dA_cs_m * 1.4426950408889634), + tl.math.exp2(dA_cs_m * LOG2_E),tensorrt_llm/_torch/models/modeling_nemotron_h.py (1)
375-409: Scale fallback disables NVFP4 correctly, but MoE scale approximation deserves a note.The logic is well-structured: attempt scale attachment from the downstream linear, and gracefully disable the fused path if unavailable.
For the MoE path (Line 403), using
shared_experts.up_proj.input_scaleas the representative scale for the entire fused RMSNorm+Quant is an approximation — routed experts could theoretically have different scale distributions. The inline comment at Line 394–395 partially explains this, but a brief note on why this approximation is acceptable (e.g., scales are calibrated to be similar, or routing absorbs the difference) would help future maintainers.tensorrt_llm/_torch/modules/rms_norm.py (2)
122-147:has_residualon line 142 is alwaysTruein this code path.The NVFP4 fused branch is entered only when
has_residualisTrue(line 91:self.is_nvfp4 and has_residual and ...). Theif has_residual:guard on line 142 is therefore redundant—it will never beFalsehere.Not a bug, but removing the check (or asserting) would make the control flow clearer.
Simplify output assembly
outputs = [hidden_states_fused] - if has_residual: - outputs.append(residual_out) + outputs.append(residual_out) # always present on NVFP4 fused path if self.return_hp_output: high_precision_normed_output = results[3].reshape(orig_shape) outputs.append(high_precision_normed_output) return outputs[0] if len(outputs) == 1 else tuple(outputs)
84-86: Return type annotation is complex and hard to maintain.The union of four return variants is getting unwieldy. Consider introducing a
TypeAliasor a lightweight data class to represent the various return shapes, which would also improve IDE support and self-documentation.tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py (2)
146-146: Prefix unused unpacked variable with underscore.Static analysis flags
conv_dimas unused. Since it's unpacked only for documentation, prefix it.Fix unused variable warning
- conv_dim, num_prefill_tokens = xbc.shape + _conv_dim, num_prefill_tokens = xbc.shape
130-176: No assertion thatd_inner == nheads * head_dimorconv_dim == d_inner + 2 * bc_size.The kernel silently produces wrong results if the caller passes inconsistent dimensions. A cheap assert at the top of
fused_split_rearrange_after_conv1dwould catch misuse early.Add dimension consistency checks
conv_dim, num_prefill_tokens = xbc.shape bc_size = n_groups * d_state + assert d_inner == nheads * head_dim, ( + f"d_inner ({d_inner}) != nheads * head_dim ({nheads} * {head_dim})" + ) + assert conv_dim == d_inner + 2 * bc_size, ( + f"conv_dim ({conv_dim}) != d_inner + 2 * bc_size ({d_inner} + 2 * {bc_size})" + )tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
67-107: Two.item()host-device syncs remain (lines 79 and 90).The "optimized" Triton path still calls
cu[-1].item()andp[-1].item(), each causing a host-device synchronization. For the multi-sequence case, this means two sync points before the kernel launch. This is acceptable whennum_seqsis large (since the kernel avoids a Python loop), but for smallnum_seqsit may not be faster than the CPU path.Consider documenting this trade-off, or adding a heuristic to fall back to the CPU path for small
num_seqs.cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp (1)
121-128: Local variable naming:high_precision_normed_outputuses snake_case.Per C++ coding guidelines, local variables should use camelCase (e.g.,
highPrecisionNormedOutput). However, this matches the struct field name inlayernorm_param.h, so I understand the consistency choice.tensorrt_llm/_torch/modules/mlp.py (1)
122-138: Consider adding a dtype check for theinput_scaletensor.The method correctly guards against non-fp16/bf16 input, but
self.down_proj.input_scaleis passed directly to the C++ kernel which expectsfloat32(seeCHECK_INPUT(sf_scale, torch::kFloat32)infusedActivationQuant.cppline ~38). Ifinput_scalehas an unexpected dtype, the error will surface deep in the C++ binding rather than at the Python level.This is a minor robustness concern — in practice,
input_scalewill likely always be float32 from the weight loading path.cpp/tensorrt_llm/kernels/fusedActivationQuant.cu (1)
164-175:sfVecSizeparameter is accepted but unused in the launcher.The function signature takes
sfVecSizebut the body usesconstexpr int kSfVecSize = 16unconditionally. Since the thop layer enforcessf_vec_size == 16, this is not a runtime bug, but it makes the API misleading.Option: remove the parameter or add an assertion
template <typename T> void invokeFusedRelu2Quantize(T const* input, float const* sfScale, std::uint8_t* outputFp4, std::uint8_t* outputSf, int m, int n, int sfVecSize, cudaStream_t stream) { constexpr int kSfVecSize = 16; + TLLM_CHECK_WITH_INFO(sfVecSize == kSfVecSize, "sfVecSize must be 16 for NVFP4."); int const numColThreadsPadded = ((n + 4 * kSfVecSize - 1) / (4 * kSfVecSize)) * (4 * kSfVecSize) / kEltsPerThread;tensorrt_llm/_torch/modules/mamba/layernorm_gated.py (1)
210-230: NVFP4 fused path is well-integrated; minor redundancy on line 219.The gating condition (
is_nvfp4 and z is not None and not norm_before_gate) correctly matches the fused kernel's semantics. Thefp4_out.view(torch.uint8)and reshape logic for non-2D inputs are correct.Line 219 re-assigns
weight = self.weight.contiguous()which was already done on line 205. This is harmless but unnecessary.Remove redundant weight assignment
if self.is_nvfp4 and z is not None and not self.norm_before_gate: if self.nvfp4_scale is None: raise ValueError( "RMSNormGated NVFP4 output requested but no `nvfp4_scale` is attached. " "Please set module.nvfp4_scale = input_scale from the next linear layer." ) - weight = self.weight.contiguous() sf_scale = self.nvfp4_scale.contiguous() fp4_out, sf_out = torch.ops.trtllm.fused_gated_rmsnorm_quant( x, z, weight, self.group_size, self.eps, sf_scale)cpp/tensorrt_llm/thop/fusedGatedRMSNormQuant.cpp (1)
114-116: Static variable should usesprefix per TRT-LLM naming convention.Locally scoped static variables in TRT-LLM use camelCase with an
sprefix (e.g.,sMultiProcessorCount).Rename to follow convention
- static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); + static int const sMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();Update the reference in the macro accordingly.
Based on learnings: "TensorRT-LLM C++ style: Locally scoped static variables (e.g., inside functions) use camelCase with an 's' prefix."
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
1-1: Copyright year needs update.The copyright header says 2022-2024 but this file has significant modifications in 2026. As per coding guidelines, the year should be updated.
Update copyright year
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "All source files must contain an NVIDIA copyright header with the year of latest meaningful modification."
tests/unittest/_torch/modules/test_fused_activation_quant.py (3)
21-21: Import style: prefer importing the module, not the function.Per the coding guideline, use
from package.subpackage import moduleinstead of importing individual symbols. Consider:-from tests.unittest.utils.util import getSMVersion +from tests.unittest.utils import utilThen reference as
util.getSMVersion()at the call sites (lines 35, 40).As per coding guidelines, "Always maintain the namespace when importing. Use
from package.subpackage import fooinstead offrom package.subpackage.foo import SomeClass."
149-149: Prefix unused unpacked variables with_to satisfy linter.
sf_fused(lines 149, 185, 216) andsf_separate(lines 178, 213) are never used. Prefix them with_to signal intentional discard.Example fix
- fp4_fused, sf_fused = torch.ops.trtllm.fused_relu2_quantize(input_tensor, sf_scale, 16) + fp4_fused, _sf_fused = torch.ops.trtllm.fused_relu2_quantize(input_tensor, sf_scale, 16)- fp4_separate, sf_separate = torch.ops.trtllm.fp4_quantize( + fp4_separate, _sf_separate = torch.ops.trtllm.fp4_quantize(Also applies to: 185-185, 213-216
155-192: Consider also asserting scale factor equality, not just FP4 packed values.The test compares
fp4_fusedvsfp4_separatebut ignoressf_fusedvssf_separate. Since scale factors are part of the quantization output and would affect dequantized accuracy, consider adding a match-rate assertion on the scale factors too.cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm_fp4_traits.cu (1)
122-148: Consider collapsing the 4-wayuse_rms_norm × output_hp_normbranch.The two nested
if/elseblocks can be reduced to a single dispatch with a helper, improving readability and reducing template-variant duplication. That said, the current structure matches existing patterns in this file and is correct.Optional simplification
- // Select kernel variant based on use_rms_norm and output_hp_norm - if (use_rms_norm) - { - if (output_hp_norm) - { - _invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float, - true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{}); - } - else - { - _invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float, - true, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{}); - } - } - else - { - if (output_hp_norm) - { - _invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float, - false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, true>{}); - } - else - { - _invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float, - false, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, false>{}); - } - } + auto _dispatch = [&](auto rmsNorm, auto hpNorm) + { + _invoke(FP4AddBiasResidualPreLayerNormTraits<GeneralFP4AddBiasResidualPreLayerNormParam<T>, T, T, float, + decltype(rmsNorm)::value, M_BLOCK, N_BLOCK, STAGES, PERSISTENT, LOW_LATENCY_MODE, + decltype(hpNorm)::value>{}); + }; + if (use_rms_norm) + { + output_hp_norm ? _dispatch(ConstBool<true>{}, ConstBool<true>{}) + : _dispatch(ConstBool<true>{}, ConstBool<false>{}); + } + else + { + output_hp_norm ? _dispatch(ConstBool<false>{}, ConstBool<true>{}) + : _dispatch(ConstBool<false>{}, ConstBool<false>{}); + }cpp/tensorrt_llm/kernels/fusedGatedRMSNormQuant/fusedGatedRMSNormQuant.cu (4)
518-608: Significant code duplication in phase 1 of the full-row kernel.The four blocks at lines 536–553, 554–571, 572–589, and 590–607 are nearly identical, differing only in the
xVec2/zVec2index and accumulator variable. This is a deliberate ILP optimization with separate accumulators, but the duplication can be reduced using a#pragma unrollloop with an accumulator array, which the compiler will still optimize for ILP.Optional refactor
- float localSqSum0 = 0.0f; - float localSqSum1 = 0.0f; - float localSqSum2 = 0.0f; - float localSqSum3 = 0.0f; + float localSqSums[4] = {0.0f, 0.0f, 0.0f, 0.0f}; ... for (int vec8Idx = tid; vec8Idx < numColVecs; vec8Idx += BLOCK_SIZE) { ... - // ~70 lines of manually unrolled blocks +#pragma unroll + for (int i = 0; i < 4; i++) + { + float2 xf2, zf2; + if constexpr (std::is_same_v<T, half>) + { + xf2 = __half22float2(xVec2[i]); + zf2 = __half22float2(zVec2[i]); + } + else + { + xf2 = __bfloat1622float2(xVec2[i]); + zf2 = __bfloat1622float2(zVec2[i]); + } + float sig0 = fast_sigmoid(zf2.x); + float sig1 = fast_sigmoid(zf2.y); + float gated0 = xf2.x * zf2.x * sig0; + float gated1 = xf2.y * zf2.y * sig1; + localSqSums[i] += gated0 * gated0 + gated1 * gated1; + } } - float localSqSum = localSqSum0 + localSqSum1 + localSqSum2 + localSqSum3; + float localSqSum = localSqSums[0] + localSqSums[1] + localSqSums[2] + localSqSums[3];
54-57: Consider using thekprefix for file-scope constants.Per C++ coding guidelines, static constants should use uppercase snake_case with a
kprefix (e.g.,kELTS_PER_THREAD,kSF_VEC_SIZE,kNUM_THREADS_PER_SF). The current naming without prefix is common in CUDA code, so this is a minor style nit.As per coding guidelines, "Use uppercase snakecase with prefix 'k' for... static constants at class-scope, and function-scope magic-number/literal constants."
314-480: Grouped kernel re-reads x and z from HBM in phase 2 — consider documenting the trade-off.Unlike the optimized kernel (which stores gated values in registers), the grouped kernel recomputes gated values in phase 2 by re-reading x and z from global memory. This doubles the HBM traffic for these tensors. A brief inline comment noting this is a deliberate trade-off (register pressure vs. memory traffic) for arbitrary group sizes would help future readers.
485-714: Full-row kernel also re-reads x and z in phase 2.Same trade-off as the grouped kernel — gated values are recomputed in phase 2 (line 652–689) rather than stored in registers. For large N, this means each row's x and z are loaded from HBM twice. This is acceptable given the register constraints for large N, but worth a comment.
accec5e to
f826195
Compare
f826195 to
0bd2bfd
Compare
|
/bot run --only-multi-gpu-test --disable-fail-fast |
|
PR_Github #37063 [ run ] triggered by Bot. Commit: |
|
PR_Github #37063 [ run ] completed with state
|
0a39d74 to
86157b2
Compare
|
/bot run --only-multi-gpu-test --disable-fail-fast |
|
PR_Github #37304 [ run ] triggered by Bot. Commit: |
|
PR_Github #37304 [ run ] completed with state
|
|
/bot run --only-multi-gpu-test --disable-fail-fast |
|
PR_Github #37424 [ run ] triggered by Bot. Commit: |
|
PR_Github #37424 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37473 [ run ] triggered by Bot. Commit: |
|
PR_Github #37473 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37504 [ run ] triggered by Bot. Commit: |
|
PR_Github #37504 [ run ] completed with state
|
…a2_mixer Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
86157b2 to
ad4502c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #37652 [ run ] triggered by Bot. Commit: |
|
PR_Github #37652 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37799 [ run ] triggered by Bot. Commit: |
|
PR_Github #37799 [ run ] completed with state |
…a2_mixer (NVIDIA#11473) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
…a2_mixer (NVIDIA#11473) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
…a2_mixer (NVIDIA#11473) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Feature
Summary by CodeRabbit
Release Notes
New Features
Performance Improvements
Testing
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.