Skip to content

Commit 6cd2cae

Browse files
committed
Add unittests for new added features
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 5b364aa commit 6cd2cae

8 files changed

Lines changed: 1092 additions & 9 deletions

File tree

cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ namespace torch_ext
5252
//
5353
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
5454
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
55-
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
56-
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm,
57-
double eps, bool output_hp_norm)
55+
std::tuple<at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>> fused_add_rms_norm_quant(
56+
at::Tensor const& input, at::Tensor const& residual, at::Tensor const& gamma,
57+
std::optional<at::Tensor> const& sf_scale, bool use_rms_norm, double eps, bool output_hp_norm)
5858
{
5959
CHECK_TH_CUDA(input);
6060
CHECK_CONTIGUOUS(input);
@@ -118,7 +118,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_qu
118118
int64_t const sfSizePadded = tensorrt_llm::computeSwizzledLayoutSFSize(m_padded, n / sfVecSize);
119119
at::Tensor sf_out_padded = at::detail::empty_cuda({sfSizePadded}, SF_DTYPE, input.device(), std::nullopt);
120120
at::Tensor sf_out = (m_padded == m) ? sf_out_padded : sf_out_padded.narrow(0, 0, sfSize);
121-
at::Tensor high_precision_normed_output;
121+
std::optional<at::Tensor> high_precision_normed_output = std::nullopt;
122122
if (output_hp_norm)
123123
{
124124
at::Tensor hp_normed_output_padded
@@ -163,7 +163,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_qu
163163
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
164164
param.beta = nullptr; \
165165
param.high_precision_normed_output \
166-
= output_hp_norm ? reinterpret_cast<T*>(high_precision_normed_output.data_ptr()) : nullptr; \
166+
= output_hp_norm ? reinterpret_cast<T*>(high_precision_normed_output.value().data_ptr()) : nullptr; \
167167
param.m = static_cast<int>(m); \
168168
param.n = static_cast<int>(n); \
169169
param.layernorm_eps = static_cast<float>(eps); \
@@ -204,7 +204,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
204204
m.def(
205205
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
206206
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-6, bool output_hp_norm=False) -> (Tensor, Tensor, "
207-
"Tensor, Tensor)");
207+
"Tensor, Tensor?)");
208208
}
209209

210210
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,9 @@ def _(
10031003
sf_scale: Optional[torch.Tensor],
10041004
use_rms_norm: bool = True,
10051005
eps: float = 1e-5,
1006-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1006+
output_hp_norm: bool = False,
1007+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
1008+
Optional[torch.Tensor]]:
10071009
m, n = input.shape
10081010
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
10091011
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
@@ -1013,7 +1015,10 @@ def _(
10131015
sf_vec_size = 16
10141016
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
10151017
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
1016-
return normed_output_fp4, output, sf_out
1018+
# high_precision_normed_output: [M, N] optional, only when output_hp_norm=True
1019+
hp_output = input.new_empty(
1020+
(m, n), dtype=input.dtype) if output_hp_norm else None
1021+
return normed_output_fp4, output, sf_out, hp_output
10171022

10181023
@torch.library.register_fake("trtllm::fused_relu2_quantize")
10191024
def _(

tensorrt_llm/tools/layer_wise_benchmarks/runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,14 @@ def forward(position_ids, hidden_states, attn_metadata, residual, **kwargs):
451451
position_ids, hidden_states, attn_metadata, residual, **kwargs
452452
)
453453
else:
454-
hidden_states = layer(position_ids, hidden_states, attn_metadata, **kwargs)
454+
result = layer(
455+
position_ids, hidden_states, attn_metadata, residual=residual, **kwargs
456+
)
457+
# Some layers (e.g., NemotronH) return (hidden_states, residual) tuple
458+
if isinstance(result, tuple):
459+
hidden_states, residual = result
460+
else:
461+
hidden_states = result
455462
return hidden_states, residual
456463

457464
model.forward = forward
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
import torch
18+
import torch.nn.functional as F
19+
20+
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
21+
22+
23+
def mamba_conv1d_ref(x, past_conv_state, conv_weight, conv_bias, apply_silu):
24+
"""
25+
Reference implementation for causal conv1d.
26+
27+
Arguments:
28+
x: [batch_size, dim, seq_len]
29+
past_conv_state: [batch_size, dim, dconv-1]
30+
conv_weight: [dim, 1, dconv]
31+
conv_bias: [dim]
32+
Output:
33+
y: [batch_size, dim, seq_len]
34+
present_conv_state: [batch_size, dim, dconv-1]
35+
"""
36+
assert x.dim() == 3
37+
assert past_conv_state.dim() == 3
38+
assert conv_weight.dim() == 3
39+
assert conv_bias.dim() == 1
40+
batch_size, dim, seq_len = x.shape
41+
assert past_conv_state.shape[0] == batch_size
42+
assert past_conv_state.shape[1] == dim
43+
dconv = past_conv_state.shape[2] + 1
44+
assert conv_weight.shape[0] == dim
45+
assert conv_weight.shape[1] == 1
46+
assert conv_weight.shape[2] == dconv
47+
assert conv_weight.shape[0] == dim
48+
49+
padded_x = torch.cat([past_conv_state, x], dim=2)
50+
present_conv_state = padded_x[:, :, -(dconv - 1) :]
51+
x_conv = F.conv1d(padded_x, conv_weight, bias=conv_bias, groups=dim)
52+
53+
y = F.silu(x_conv) if apply_silu else x_conv
54+
return y, present_conv_state
55+
56+
57+
def trtllm_causal_conv1d_available():
58+
"""Check if trtllm.causal_conv1d_fwd is available."""
59+
return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "causal_conv1d_fwd")
60+
61+
62+
skip_unsupported = pytest.mark.skipif(
63+
not torch.cuda.is_available() or not trtllm_causal_conv1d_available(),
64+
reason="Requires CUDA and trtllm.causal_conv1d_fwd op",
65+
)
66+
67+
68+
@skip_unsupported
69+
class TestCausalConv1d:
70+
"""Tests for the causal_conv1d CUDA kernel."""
71+
72+
@pytest.mark.parametrize("dtype", ["float16", "bfloat16", "float32"])
73+
@pytest.mark.parametrize("apply_silu", [True, False])
74+
@pytest.mark.parametrize("dim", [256, 512, 1024, 2048])
75+
def test_basic_correctness(self, dtype, apply_silu, dim):
76+
"""Test basic correctness against reference implementation."""
77+
torch.manual_seed(42)
78+
device = "cuda"
79+
torch_dtype = getattr(torch, dtype)
80+
81+
batch_size = 4
82+
seq_len = 32
83+
dconv = 4
84+
std_dev = 0.5
85+
x = torch.randn(batch_size, dim, seq_len, dtype=torch_dtype, device=device)
86+
x = x * std_dev
87+
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=torch_dtype, device=device)
88+
conv_weight = torch.randn(dim, 1, dconv, dtype=torch_dtype, device=device)
89+
conv_bias = torch.randn(dim, dtype=torch_dtype, device=device)
90+
x_kernel = x.clone()
91+
conv_state_kernel = conv_state.clone()
92+
93+
conv_weight_input = conv_weight.squeeze(1).contiguous()
94+
torch.ops.trtllm.causal_conv1d_fwd(
95+
x_kernel,
96+
conv_weight_input,
97+
conv_bias,
98+
conv_state_kernel,
99+
None, # query_start_loc
100+
None, # cache_indices
101+
None, # has_initial_state
102+
apply_silu,
103+
PAD_SLOT_ID,
104+
)
105+
out_ref, conv_state_ref = mamba_conv1d_ref(
106+
x, conv_state, conv_weight, conv_bias, apply_silu
107+
)
108+
109+
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-2)
110+
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-2)
111+
112+
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
113+
def test_various_batch_sizes(self, batch_size):
114+
"""Test with various batch sizes."""
115+
torch.manual_seed(42)
116+
device = "cuda"
117+
dtype = torch.bfloat16
118+
dim = 1024
119+
seq_len = 64
120+
dconv = 4
121+
apply_silu = True
122+
123+
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
124+
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device)
125+
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
126+
conv_bias = torch.randn(dim, dtype=dtype, device=device)
127+
x_kernel = x.clone()
128+
conv_state_kernel = conv_state.clone()
129+
130+
conv_weight_input = conv_weight.squeeze(1).contiguous()
131+
torch.ops.trtllm.causal_conv1d_fwd(
132+
x_kernel,
133+
conv_weight_input,
134+
conv_bias,
135+
conv_state_kernel,
136+
None,
137+
None,
138+
None,
139+
apply_silu,
140+
PAD_SLOT_ID,
141+
)
142+
out_ref, conv_state_ref = mamba_conv1d_ref(
143+
x, conv_state, conv_weight, conv_bias, apply_silu
144+
)
145+
146+
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1)
147+
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
148+
149+
@pytest.mark.parametrize("dconv", [2, 3, 4])
150+
def test_various_kernel_widths(self, dconv):
151+
"""Test with different convolution kernel widths."""
152+
torch.manual_seed(42)
153+
device = "cuda"
154+
dtype = torch.bfloat16
155+
156+
batch_size = 4
157+
dim = 1024
158+
seq_len = 64
159+
apply_silu = True
160+
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
161+
conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device)
162+
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
163+
conv_bias = torch.randn(dim, dtype=dtype, device=device)
164+
x_kernel = x.clone()
165+
conv_state_kernel = conv_state.clone()
166+
167+
conv_weight_input = conv_weight.squeeze(1).contiguous()
168+
torch.ops.trtllm.causal_conv1d_fwd(
169+
x_kernel,
170+
conv_weight_input,
171+
conv_bias,
172+
conv_state_kernel,
173+
None,
174+
None,
175+
None,
176+
apply_silu,
177+
PAD_SLOT_ID,
178+
)
179+
out_ref, conv_state_ref = mamba_conv1d_ref(
180+
x, conv_state, conv_weight, conv_bias, apply_silu
181+
)
182+
183+
torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1)
184+
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)
185+
186+
def test_with_initial_state(self):
187+
"""Test with non-zero initial conv state."""
188+
torch.manual_seed(42)
189+
device = "cuda"
190+
dtype = torch.bfloat16
191+
192+
batch_size = 4
193+
dim = 1024
194+
seq_len = 32
195+
dconv = 4
196+
apply_silu = True
197+
198+
x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5
199+
# Non-zero initial state
200+
conv_state = torch.randn(batch_size, dim, dconv - 1, dtype=dtype, device=device)
201+
conv_state = conv_state * 0.5
202+
conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device)
203+
conv_bias = torch.randn(dim, dtype=dtype, device=device)
204+
conv_state_kernel = conv_state.clone()
205+
# Need to tell the kernel about initial state
206+
has_initial_state = torch.ones(batch_size, dtype=torch.bool, device=device)
207+
query_start_loc = torch.tensor(
208+
[0] + [seq_len * (i + 1) for i in range(batch_size)],
209+
dtype=torch.int32,
210+
device=device,
211+
)
212+
# Reshape for varlen format
213+
x_varlen = x.transpose(1, 2).reshape(-1, dim).T.contiguous()
214+
215+
conv_weight_input = conv_weight.squeeze(1).contiguous()
216+
torch.ops.trtllm.causal_conv1d_fwd(
217+
x_varlen,
218+
conv_weight_input,
219+
conv_bias,
220+
conv_state_kernel,
221+
query_start_loc,
222+
None, # cache_indices
223+
has_initial_state,
224+
apply_silu,
225+
PAD_SLOT_ID,
226+
)
227+
228+
out_ref_list = []
229+
conv_state_ref_list = []
230+
for b in range(batch_size):
231+
out_b, state_b = mamba_conv1d_ref(
232+
x[b : b + 1],
233+
conv_state[b : b + 1],
234+
conv_weight,
235+
conv_bias,
236+
apply_silu,
237+
)
238+
out_ref_list.append(out_b)
239+
conv_state_ref_list.append(state_b)
240+
out_ref = torch.cat(out_ref_list, dim=0)
241+
conv_state_ref = torch.cat(conv_state_ref_list, dim=0)
242+
x_kernel_reshaped = (
243+
x_varlen.T.reshape(batch_size, seq_len, dim).transpose(1, 2).contiguous()
244+
)
245+
246+
torch.testing.assert_close(x_kernel_reshaped, out_ref, rtol=1e-2, atol=1e-1)
247+
torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1)

0 commit comments

Comments
 (0)