|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import pytest |
| 17 | +import torch |
| 18 | +import torch.nn.functional as F |
| 19 | + |
| 20 | +from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID |
| 21 | + |
| 22 | + |
| 23 | +def mamba_conv1d_ref(x, past_conv_state, conv_weight, conv_bias, apply_silu): |
| 24 | + """ |
| 25 | + Reference implementation for causal conv1d. |
| 26 | +
|
| 27 | + Arguments: |
| 28 | + x: [batch_size, dim, seq_len] |
| 29 | + past_conv_state: [batch_size, dim, dconv-1] |
| 30 | + conv_weight: [dim, 1, dconv] |
| 31 | + conv_bias: [dim] |
| 32 | + Output: |
| 33 | + y: [batch_size, dim, seq_len] |
| 34 | + present_conv_state: [batch_size, dim, dconv-1] |
| 35 | + """ |
| 36 | + assert x.dim() == 3 |
| 37 | + assert past_conv_state.dim() == 3 |
| 38 | + assert conv_weight.dim() == 3 |
| 39 | + assert conv_bias.dim() == 1 |
| 40 | + batch_size, dim, seq_len = x.shape |
| 41 | + assert past_conv_state.shape[0] == batch_size |
| 42 | + assert past_conv_state.shape[1] == dim |
| 43 | + dconv = past_conv_state.shape[2] + 1 |
| 44 | + assert conv_weight.shape[0] == dim |
| 45 | + assert conv_weight.shape[1] == 1 |
| 46 | + assert conv_weight.shape[2] == dconv |
| 47 | + assert conv_weight.shape[0] == dim |
| 48 | + |
| 49 | + padded_x = torch.cat([past_conv_state, x], dim=2) |
| 50 | + present_conv_state = padded_x[:, :, -(dconv - 1) :] |
| 51 | + x_conv = F.conv1d(padded_x, conv_weight, bias=conv_bias, groups=dim) |
| 52 | + |
| 53 | + y = F.silu(x_conv) if apply_silu else x_conv |
| 54 | + return y, present_conv_state |
| 55 | + |
| 56 | + |
| 57 | +def trtllm_causal_conv1d_available(): |
| 58 | + """Check if trtllm.causal_conv1d_fwd is available.""" |
| 59 | + return hasattr(torch.ops, "trtllm") and hasattr(torch.ops.trtllm, "causal_conv1d_fwd") |
| 60 | + |
| 61 | + |
| 62 | +skip_unsupported = pytest.mark.skipif( |
| 63 | + not torch.cuda.is_available() or not trtllm_causal_conv1d_available(), |
| 64 | + reason="Requires CUDA and trtllm.causal_conv1d_fwd op", |
| 65 | +) |
| 66 | + |
| 67 | + |
| 68 | +@skip_unsupported |
| 69 | +class TestCausalConv1d: |
| 70 | + """Tests for the causal_conv1d CUDA kernel.""" |
| 71 | + |
| 72 | + @pytest.mark.parametrize("dtype", ["float16", "bfloat16", "float32"]) |
| 73 | + @pytest.mark.parametrize("apply_silu", [True, False]) |
| 74 | + @pytest.mark.parametrize("dim", [256, 512, 1024, 2048]) |
| 75 | + def test_basic_correctness(self, dtype, apply_silu, dim): |
| 76 | + """Test basic correctness against reference implementation.""" |
| 77 | + torch.manual_seed(42) |
| 78 | + device = "cuda" |
| 79 | + torch_dtype = getattr(torch, dtype) |
| 80 | + |
| 81 | + batch_size = 4 |
| 82 | + seq_len = 32 |
| 83 | + dconv = 4 |
| 84 | + std_dev = 0.5 |
| 85 | + x = torch.randn(batch_size, dim, seq_len, dtype=torch_dtype, device=device) |
| 86 | + x = x * std_dev |
| 87 | + conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=torch_dtype, device=device) |
| 88 | + conv_weight = torch.randn(dim, 1, dconv, dtype=torch_dtype, device=device) |
| 89 | + conv_bias = torch.randn(dim, dtype=torch_dtype, device=device) |
| 90 | + x_kernel = x.clone() |
| 91 | + conv_state_kernel = conv_state.clone() |
| 92 | + |
| 93 | + conv_weight_input = conv_weight.squeeze(1).contiguous() |
| 94 | + torch.ops.trtllm.causal_conv1d_fwd( |
| 95 | + x_kernel, |
| 96 | + conv_weight_input, |
| 97 | + conv_bias, |
| 98 | + conv_state_kernel, |
| 99 | + None, # query_start_loc |
| 100 | + None, # cache_indices |
| 101 | + None, # has_initial_state |
| 102 | + apply_silu, |
| 103 | + PAD_SLOT_ID, |
| 104 | + ) |
| 105 | + out_ref, conv_state_ref = mamba_conv1d_ref( |
| 106 | + x, conv_state, conv_weight, conv_bias, apply_silu |
| 107 | + ) |
| 108 | + |
| 109 | + torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-2) |
| 110 | + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-2) |
| 111 | + |
| 112 | + @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) |
| 113 | + def test_various_batch_sizes(self, batch_size): |
| 114 | + """Test with various batch sizes.""" |
| 115 | + torch.manual_seed(42) |
| 116 | + device = "cuda" |
| 117 | + dtype = torch.bfloat16 |
| 118 | + dim = 1024 |
| 119 | + seq_len = 64 |
| 120 | + dconv = 4 |
| 121 | + apply_silu = True |
| 122 | + |
| 123 | + x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5 |
| 124 | + conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device) |
| 125 | + conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device) |
| 126 | + conv_bias = torch.randn(dim, dtype=dtype, device=device) |
| 127 | + x_kernel = x.clone() |
| 128 | + conv_state_kernel = conv_state.clone() |
| 129 | + |
| 130 | + conv_weight_input = conv_weight.squeeze(1).contiguous() |
| 131 | + torch.ops.trtllm.causal_conv1d_fwd( |
| 132 | + x_kernel, |
| 133 | + conv_weight_input, |
| 134 | + conv_bias, |
| 135 | + conv_state_kernel, |
| 136 | + None, |
| 137 | + None, |
| 138 | + None, |
| 139 | + apply_silu, |
| 140 | + PAD_SLOT_ID, |
| 141 | + ) |
| 142 | + out_ref, conv_state_ref = mamba_conv1d_ref( |
| 143 | + x, conv_state, conv_weight, conv_bias, apply_silu |
| 144 | + ) |
| 145 | + |
| 146 | + torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1) |
| 147 | + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1) |
| 148 | + |
| 149 | + @pytest.mark.parametrize("dconv", [2, 3, 4]) |
| 150 | + def test_various_kernel_widths(self, dconv): |
| 151 | + """Test with different convolution kernel widths.""" |
| 152 | + torch.manual_seed(42) |
| 153 | + device = "cuda" |
| 154 | + dtype = torch.bfloat16 |
| 155 | + |
| 156 | + batch_size = 4 |
| 157 | + dim = 1024 |
| 158 | + seq_len = 64 |
| 159 | + apply_silu = True |
| 160 | + x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5 |
| 161 | + conv_state = torch.zeros(batch_size, dim, dconv - 1, dtype=dtype, device=device) |
| 162 | + conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device) |
| 163 | + conv_bias = torch.randn(dim, dtype=dtype, device=device) |
| 164 | + x_kernel = x.clone() |
| 165 | + conv_state_kernel = conv_state.clone() |
| 166 | + |
| 167 | + conv_weight_input = conv_weight.squeeze(1).contiguous() |
| 168 | + torch.ops.trtllm.causal_conv1d_fwd( |
| 169 | + x_kernel, |
| 170 | + conv_weight_input, |
| 171 | + conv_bias, |
| 172 | + conv_state_kernel, |
| 173 | + None, |
| 174 | + None, |
| 175 | + None, |
| 176 | + apply_silu, |
| 177 | + PAD_SLOT_ID, |
| 178 | + ) |
| 179 | + out_ref, conv_state_ref = mamba_conv1d_ref( |
| 180 | + x, conv_state, conv_weight, conv_bias, apply_silu |
| 181 | + ) |
| 182 | + |
| 183 | + torch.testing.assert_close(x_kernel, out_ref, rtol=1e-2, atol=1e-1) |
| 184 | + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1) |
| 185 | + |
| 186 | + def test_with_initial_state(self): |
| 187 | + """Test with non-zero initial conv state.""" |
| 188 | + torch.manual_seed(42) |
| 189 | + device = "cuda" |
| 190 | + dtype = torch.bfloat16 |
| 191 | + |
| 192 | + batch_size = 4 |
| 193 | + dim = 1024 |
| 194 | + seq_len = 32 |
| 195 | + dconv = 4 |
| 196 | + apply_silu = True |
| 197 | + |
| 198 | + x = torch.randn(batch_size, dim, seq_len, dtype=dtype, device=device) * 0.5 |
| 199 | + # Non-zero initial state |
| 200 | + conv_state = torch.randn(batch_size, dim, dconv - 1, dtype=dtype, device=device) |
| 201 | + conv_state = conv_state * 0.5 |
| 202 | + conv_weight = torch.randn(dim, 1, dconv, dtype=dtype, device=device) |
| 203 | + conv_bias = torch.randn(dim, dtype=dtype, device=device) |
| 204 | + conv_state_kernel = conv_state.clone() |
| 205 | + # Need to tell the kernel about initial state |
| 206 | + has_initial_state = torch.ones(batch_size, dtype=torch.bool, device=device) |
| 207 | + query_start_loc = torch.tensor( |
| 208 | + [0] + [seq_len * (i + 1) for i in range(batch_size)], |
| 209 | + dtype=torch.int32, |
| 210 | + device=device, |
| 211 | + ) |
| 212 | + # Reshape for varlen format |
| 213 | + x_varlen = x.transpose(1, 2).reshape(-1, dim).T.contiguous() |
| 214 | + |
| 215 | + conv_weight_input = conv_weight.squeeze(1).contiguous() |
| 216 | + torch.ops.trtllm.causal_conv1d_fwd( |
| 217 | + x_varlen, |
| 218 | + conv_weight_input, |
| 219 | + conv_bias, |
| 220 | + conv_state_kernel, |
| 221 | + query_start_loc, |
| 222 | + None, # cache_indices |
| 223 | + has_initial_state, |
| 224 | + apply_silu, |
| 225 | + PAD_SLOT_ID, |
| 226 | + ) |
| 227 | + |
| 228 | + out_ref_list = [] |
| 229 | + conv_state_ref_list = [] |
| 230 | + for b in range(batch_size): |
| 231 | + out_b, state_b = mamba_conv1d_ref( |
| 232 | + x[b : b + 1], |
| 233 | + conv_state[b : b + 1], |
| 234 | + conv_weight, |
| 235 | + conv_bias, |
| 236 | + apply_silu, |
| 237 | + ) |
| 238 | + out_ref_list.append(out_b) |
| 239 | + conv_state_ref_list.append(state_b) |
| 240 | + out_ref = torch.cat(out_ref_list, dim=0) |
| 241 | + conv_state_ref = torch.cat(conv_state_ref_list, dim=0) |
| 242 | + x_kernel_reshaped = ( |
| 243 | + x_varlen.T.reshape(batch_size, seq_len, dim).transpose(1, 2).contiguous() |
| 244 | + ) |
| 245 | + |
| 246 | + torch.testing.assert_close(x_kernel_reshaped, out_ref, rtol=1e-2, atol=1e-1) |
| 247 | + torch.testing.assert_close(conv_state_kernel, conv_state_ref, rtol=1e-2, atol=1e-1) |
0 commit comments